In [1]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from torchvision import transforms
from PIL import Image
import os

In [2]:
training_dataset_dir = "/content/drive/MyDrive/richhf/train" if os.path.exists("/content/drive/MyDrive/richhf/train") else "./train"
train_data_csv = "/content/drive/MyDrive/richhf/RichHF_18K.csv" if os.path.exists("/content/drive/MyDrive/richhf") else "./RichHF_18K.csv"
model_output_name =  "/content/drive/MyDrive/meta_iqa_model.pth" if os.path.exists("/content/drive/MyDrive") else "./meta_iqa_model.pth"


In [3]:
class MetaIQAFeatureExtractor(nn.Module):
    def __init__(self, input_size=3*512*512, hidden_size=1024, output_size=512):
        super(MetaIQAFeatureExtractor, self).__init__()
        self.input_size = input_size

        # First layer
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)  # Batch Normalization
        self.relu1 = nn.SiLU()  # Swish Activation
        self.dropout1 = nn.Dropout(0.3)

        # Second layer with residual connection
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.residual_connection = nn.Linear(hidden_size, output_size)  # Adjust dimensions for residual connection
        self.bn2 = nn.BatchNorm1d(output_size)  # Batch Normalization
        self.relu2 = nn.SiLU()
        self.dropout2 = nn.Dropout(0.3)

        # Output layer
        self.fc3 = nn.Linear(output_size, 1)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten input if necessary

        # First layer
        out = self.fc1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.dropout1(out)

        # Second layer with residual connection
        residual = self.residual_connection(out)  # Align dimensions for residual connection
        out = self.fc2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.dropout2(out)
        out += residual  # Add residual connection

        # Output layer
        out = self.fc3(out)
        return out

In [4]:
class IQADataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        # Get image name and quality score from the dataframe
        img_name = self.data_frame.iloc[idx, 0]
        quality_score = self.data_frame.iloc[idx, 5]

        img_path = f"{self.root_dir}/{img_name}.png"

        # Check if the image exists only when loading it
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image {img_path} not found")

        # Load the image lazily
        image = Image.open(img_path).convert('RGB')

        # Apply transform if necessary
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(quality_score, dtype=torch.float)


In [5]:
def train_meta_iqa(model, train_loader, optimizer, criterion, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, quality_scores in train_loader:
            images, quality_scores = images.to(device), quality_scores.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, quality_scores.unsqueeze(1))  # Ensure dimensions match
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}")


In [None]:
if __name__ == "__main__":
    # Hyperparameters and configuration
    input_size = 3*512*512
    hidden_size = 1024
    output_size = 512
    epochs = 10
    batch_size = 32
    learning_rate = 1e-4

    # Prepare the dataset and dataloader
    transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])

    dataset = IQADataset(csv_file=train_data_csv, root_dir=training_dataset_dir, transform=transform)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize the model, optimizer, and loss function
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MetaIQAFeatureExtractor(input_size=input_size, hidden_size=hidden_size, output_size=output_size).to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()  # Mean Squared Error Loss

    # Train the model
    train_meta_iqa(model, train_loader, optimizer, criterion, device, epochs)

    # Save the trained model
    
    torch.save(model.state_dict(), model_output_name)
    print("Training complete and model saved!")

Epoch [1/10], Loss: 0.0685
Epoch [2/10], Loss: 0.0204
Epoch [3/10], Loss: 0.0114
