# Import libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.models import resnet152, ResNet152_Weights
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
from PIL import Image
import os
import numpy as np

# Database creations using pytorch Dataset 

In [2]:
class ImageAuthenticityDataset(Dataset):
    """Dataset for image quality assessment."""

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the CSV file with annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.dir_path = os.path.dirname(csv_file)  # Directory of the CSV file

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx,):
        """
        Retrieves an image and its labels by index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple (image, labels) where:
                image (PIL.Image): The image.
                labels (torch.Tensor): Tensor containing quality and authenticity scores.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # TODO: to be fixed, right now is folder dependent
        img_name = self.data.iloc[idx, 3].replace("./", "../../")
        image = Image.open(img_name).convert('RGB')
        authenticity = self.data.iloc[idx, 1]  # Authenticity column
        labels = torch.tensor([authenticity], dtype=torch.float)


        if self.transform:
            image = self.transform(image)

        return image, labels


# Definitions of the models

In [3]:
class AuthenticityPredictor(nn.Module):
    def __init__(self, freeze_backbone=True):
        super().__init__()
        # Load pre-trained ResNet-152 instead of VGG16
        resnet = resnet152(weights=ResNet152_Weights.DEFAULT)
        
        # Freeze backbone if requested
        if freeze_backbone:
            for param in resnet.parameters():
                param.requires_grad = False
                
        # Store the backbone (excluding the final fc layer)
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.avgpool = resnet.avgpool
        
        self.regression_head = nn.Sequential(
                nn.Linear(2048, 512),
                nn.ReLU(),
                nn.Dropout(0.5),  # Reduced dropout ratio
                nn.Linear(512, 128),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(128, 1)
            )    
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        predictions = self.regression_head(x)
        return predictions, x  # Return predictions and features

# Training and evaluation functions

In [4]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, device='cuda'):
    """
    Trains the model.

    Args:
        model (nn.Module): The model to train.
        dataloaders (dict): A dictionary containing the training and validation data loaders.
        criterion (nn.Module): The loss function.
        optimizer (optim.Optimizer): The optimizer.
        num_epochs (int): Number of epochs to train for. Defaults to 10.
        device (str): Device to use for training ('cuda' or 'cpu'). Defaults to 'cuda'.

    Returns:
        nn.Module: The trained model.
    """
    model.to(device)
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:  # Iterate over training and validation phases
            print(f'{phase} phase')
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0

            for inputs, labels in dataloaders[phase]:  # Iterate over data in the current phase
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):  # Enable gradients only during training
                    outputs, _ = model(inputs)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)

            print(f'{phase} Loss: {epoch_loss:.4f}') # Print loss for the current phase

    print("Finished Training")
    return model

def test_model(model, dataloader, criterion, device='cuda'):

    """
    Tests the model on the test dataset.

    Args:
        model (nn.Module): The trained model.
        dataloader (DataLoader): The test data loader.
        criterion (nn.Module): The loss function.
        device (str): Device to use for testing ('cuda' or 'cpu'). Defaults to 'cuda'.

    Returns:
        float: The average loss on the test dataset.
    """
    model.eval()  # Set the model to evaluation mode
    model.to(device)
    running_loss = 0.0

    with torch.no_grad():  # Disable gradient calculation
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs, _ = model(inputs)
            
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)

    test_loss = running_loss / len(dataloader.dataset)
    print(f'Test Loss: {test_loss:.4f}')
    return test_loss


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transformations for the ImageNet dataset
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

annotations_file = '../../Dataset/AIGCIQA2023/real_images_annotations.csv'

# Create the dataset
dataset = ImageAuthenticityDataset(csv_file=annotations_file, transform=data_transforms)

# Set random seeds for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)

# Split the dataset into training, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


# Create data loaders
BATCH_SIZE = 64 # Set to 1 for handling individual images
EPOCHS = 20

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


# Create a dictionary containing the data loaders
dataloaders = {
    'train': train_dataloader,
    'val': val_dataloader,
    'test': test_dataloader
}

model = AuthenticityPredictor()
criterion = nn.MSELoss()  # Mean Squared Error Loss (regression)
optimizer = optim.Adam(model.regression_head.parameters(), lr=0.001)

# Train the model
model = train_model(model, dataloaders, criterion, optimizer, num_epochs=EPOCHS, device=device)

# Save the model
torch.save(model.state_dict(), 'Weights/ResNet-152_real_authenticity_finetuned.pth')

# Test the model
test_loss = test_model(model, test_dataloader, criterion, device=device)


Epoch 0/19
----------
train phase
train Loss: 1742.0669
val phase
val Loss: 648.2523
Epoch 1/19
----------
train phase
train Loss: 751.5873
val phase
val Loss: 732.0970
Epoch 2/19
----------
train phase
train Loss: 521.9006
val phase
val Loss: 424.2814
Epoch 3/19
----------
train phase
train Loss: 415.0892
val phase
val Loss: 420.2628
Epoch 4/19
----------
train phase
train Loss: 332.9170
val phase
val Loss: 287.7996
Epoch 5/19
----------
train phase
train Loss: 250.1784
val phase
val Loss: 209.0623
Epoch 6/19
----------
train phase
train Loss: 200.5630
val phase
val Loss: 183.7928
Epoch 7/19
----------
train phase
train Loss: 162.3822
val phase
val Loss: 164.4508
Epoch 8/19
----------
train phase
train Loss: 156.5727
val phase
val Loss: 161.7968
Epoch 9/19
----------
train phase
train Loss: 127.5748
val phase
val Loss: 140.0128
Epoch 10/19
----------
train phase
train Loss: 124.6554
val phase
val Loss: 124.3749
Epoch 11/19
----------
train phase
train Loss: 122.6817
val phase
val Loss