# Import libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
import torch.hub
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
from PIL import Image
import os
import numpy as np

# Database creations using pytorch Dataset 

In [None]:
class ImageAuthenticityDataset(Dataset):
    """
    Dataset for image quality assessment, modified to use new headers.
    """

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the CSV file with annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.dir_path = os.path.dirname(csv_file)  # Directory of the CSV file.  Useful for relative paths.

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieves an image and its labels by index, adapted for new headers.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple (image, labels) where:
                image (PIL.Image): The image.
                labels (torch.Tensor): Tensor containing authenticity scores and standard deviation.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Use the 'name' column to get the image file path.
        img_name = self.data.loc[idx, 'name']  # 'name' column
        # Construct the full image path.  Handle both relative and absolute paths.
        if not os.path.isabs(img_name):
            img_path = os.path.join(self.dir_path,"Image/allimg", img_name)
        else:
            img_path = img_name

        image = Image.open(img_path).convert('RGB')

        authenticity_mos = self.data.loc[idx, 'authenticity_mos']  # 'authenticity_mos' column
        authenticity_std = self.data.loc[idx, 'authenticity_std']  # 'authenticity_std' column

        # Create a tensor containing  authenticity MOS and std
        label = torch.tensor([authenticity_mos], dtype=torch.float)
        std = torch.tensor([authenticity_std], dtype=torch.float)

        if self.transform:
            image = self.transform(image)

        return image, label, std

    def head(self, n=5):
        """
        Returns the first n rows of the dataset's DataFrame as a string.

        Args:
            n (int, optional): Number of rows to return. Defaults to 5.

        Returns:
            str: String representation of the first n rows of the DataFrame.
        """
        return self.data.head(n).to_string()


In [None]:
annotations_file = '../../Dataset/AIGCIQA2023/AIGCIQA_2023.csv'


# Definitions of the models

In [None]:
class AuthenticityPredictor(nn.Module):
    def __init__(self, freeze_backbone=True):
        super().__init__()
        # Load pre-trained BarlowTwins ResNet50 instead of ResNet-152
        barlow_twins_resnet = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
        
        # Freeze backbone if requested
        if freeze_backbone:
            for param in barlow_twins_resnet.parameters():
                param.requires_grad = False
                
        self.features = nn.Sequential(*list(barlow_twins_resnet.children())[:-2])
        self.avgpool = barlow_twins_resnet.avgpool
        
        
        self.regression_head = nn.Sequential(
                nn.Linear(2048, 512),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(512, 128),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(128, 1)
            )    
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        predictions = self.regression_head(x)
        return predictions, x 
    

# Testing and evaluation functions

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, device='cuda'):
    """
    Trains the model.

    Args:
        model (nn.Module): The model to train.
        dataloaders (dict): A dictionary containing the training and validation data loaders.
        criterion (nn.Module): The loss function.
        optimizer (optim.Optimizer): The optimizer.
        num_epochs (int): Number of epochs to train for. Defaults to 10.
        device (str): Device to use for training ('cuda' or 'cpu'). Defaults to 'cuda'.

    Returns:
        nn.Module: The trained model.
    """
    model.to(device)
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:  # Iterate over training and validation phases
            print(f'{phase} phase')
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0

            for inputs, label, std in dataloaders[phase]:  # Iterate over data in the current phase
                inputs = inputs.to(device)
                label = label.to(device)
                std = std.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):  # Enable gradients only during training
                    output, _ = model(inputs)
                    loss = criterion(output, label, std)  # Compute loss

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)

            print(f'{phase} Loss: {epoch_loss:.4f}') # Print loss for the current phase

    print("Finished Training")
    return model

def test_model(model, dataloader, criterion, device='cuda'): 
    model.eval()
    model.to(device)
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        # Ensure the dataloader yields three items: inputs, label, std
        for inputs, label, std_dev in dataloader: # Unpack std_dev
            inputs = inputs.to(device)
            label = label.to(device)
            std_dev = std_dev.to(device) # Move std_dev to device

            output, _ = model(inputs)

            # Pass std_dev to the criterion if it's WeightedMSELoss
            # If criterion can be other types, you might need to check its type
            # or ensure all used criteria have a consistent interface or signature.
            # For this specific notebook context, criterion IS WeightedMSELoss.
            loss = criterion(output, label, std_dev)

            running_loss += loss.item() * inputs.size(0)

    test_loss = running_loss / len(dataloader.dataset)
    print(f'Test Loss: {test_loss:.4f}')
    return test_loss

# Handling annotation discordance

In [None]:
# --- Weight Creation Functions ---
def scale_stsd(sds, factor=20):
    """
    Scales standard deviation values to a range suitable for weight calculation.
    This is useful when the standard deviation values are too small or too large.

    Args:
        sds (torch.Tensor): Tensor of standard deviation values for each sample.
        factor (float): Scaling factor. Adjust this based on the distribution of your SDs.

    Returns:
        torch.Tensor: Scaled standard deviation values.
    """
    if not isinstance(sds, torch.Tensor):
        sds = torch.tensor(sds, dtype=torch.float32)
    
    # Scale the standard deviations
    scaled_sds = sds * factor
    return scaled_sds

def inverse_sd_weights(sds, epsilon=1e-6):
    """
    Calculates weights as the inverse of standard deviation.
    Weights are normalized to sum to N (batch size) to maintain loss magnitude.

    Args:
        sds (torch.Tensor): Tensor of standard deviation values for each sample.
        epsilon (float): Small constant to prevent division by zero and stabilize weights.

    Returns:
        torch.Tensor: Calculated weights for each sample.
    """
    # Ensure sds is a torch tensor
    if not isinstance(sds, torch.Tensor):
        sds = torch.tensor(sds, dtype=torch.float32)

    
    # Add epsilon to prevent division by zero for SD=0
    # and to avoid excessively large weights for very small SDs.
    weights = 1.0 / (sds + epsilon)

    # Normalize weights so their sum equals the number of samples (batch size).
    # This helps keep the overall magnitude of the loss comparable to unweighted loss.
    # Alternatively, could normalize to sum to 1, or mean to 1.
    # Normalizing to sum to N:
    if weights.numel() > 0: # Check if tensor is not empty
        weights = (weights / torch.sum(weights)) * weights.numel()
    return weights

def exponential_decay_weights(sds, alpha=1.0):
    """
    Calculates weights using an exponential decay function based on standard deviation.
    Weights are normalized to sum to N (batch size).

    Args:
        sds (torch.Tensor): Tensor of standard deviation values for each sample.
        alpha (float): Decay rate. Higher alpha means faster decay for higher SDs.
                       This might be tuned based on the typical range/mean of your SDs.
                       For example, if mean_sd is around 1.1, alpha=1.0 means weight is exp(-1.1*SD).

    Returns:
        torch.Tensor: Calculated weights for each sample.
    """
    if not isinstance(sds, torch.Tensor):
        sds = torch.tensor(sds, dtype=torch.float32)

    weights = torch.exp(-alpha * sds)

    # Normalize weights
    if weights.numel() > 0:
        weights = (weights / torch.sum(weights)) * weights.numel()
    return weights

def gaussian_weights(sds, sigma_g=1.0):
    """
    Calculates weights using a Gaussian function centered at SD=0.
    Weights are normalized to sum to N (batch size).

    Args:
        sds (torch.Tensor): Tensor of standard deviation values for each sample.
        sigma_g (float): Standard deviation of the Gaussian curve. Controls the "width".
                         A smaller sigma_g leads to a sharper decline in weights as SD moves from 0.
                         This might be tuned based on the typical range/mean of your SDs.
                         For instance, setting sigma_g to your mean_sd (1.103) could be a starting point.

    Returns:
        torch.Tensor: Calculated weights for each sample.
    """
    if not isinstance(sds, torch.Tensor):
        sds = torch.tensor(sds, dtype=torch.float32)

    weights = torch.exp(-(sds**2) / (2 * sigma_g**2))

    # Normalize weights
    if weights.numel() > 0:
        weights = (weights / torch.sum(weights)) * weights.numel()
    return weights

# --- Custom Weighted Loss Function ---

class WeightedMSELoss(nn.Module):
    """
    Calculates Mean Squared Error with per-sample weighting.
    Loss = mean(weights * (predictions - targets)^2)
    """
    def __init__(self, reduction='mean'):
        """
        Args:
            reduction (str): Specifies the reduction to apply to the output:
                             'none' | 'mean' | 'sum'.
                             'none': no reduction will be applied.
                             'mean': the sum of the output will be divided by the number of elements in the output.
                             'sum': the output will be summed.
                             Default: 'mean'
        """
        super().__init__()
        if reduction not in ['mean', 'sum', 'none']:
            raise ValueError(f"reduction must be one of 'mean', 'sum', or 'none', but got {reduction}")
        self.reduction = reduction
    
    def compute_weights(self, sds, method='inverse_sd', **kwargs):
        """
        Computes weights based on the specified method.

        Args:
            sds (torch.Tensor): Standard deviation values for each sample.
            method (str): Method to compute weights. Options: 'inverse_sd', 'exponential_decay', 'gaussian'.
            **kwargs: Additional parameters for the weight computation methods.

        Returns:
            torch.Tensor: Computed weights.
        """
        # Scale the standard deviations 
        sds = scale_stsd(sds, **kwargs)
        if method == 'inverse_sd':
            return inverse_sd_weights(sds, **kwargs)
        elif method == 'exponential_decay':
            return exponential_decay_weights(sds, **kwargs)
        elif method == 'gaussian':
            return gaussian_weights(sds, **kwargs)
        else:
            raise ValueError(f"Unknown weight computation method: {method}")
        

    def forward(self, predictions, targets, sample_std, method='inverse_sd', **kwargs):
        """
        Args:
            predictions (torch.Tensor): Model predictions.
            targets (torch.Tensor): Ground truth values.
            sample_std (torch.Tensor): Standard deviation values for each sample.
            method (str): Method to compute weights. Options: 'inverse_sd', 'exponential_decay', 'gaussian'.
            **kwargs: Additional parameters for the weight computation methods.

        Returns:
            torch.Tensor: Computed weighted loss.
        """
        # Compute weights based on the specified method
        weights = self.compute_weights(sample_std, method=method, **kwargs)

        # Calculate the squared differences
        squared_diff = (predictions - targets) ** 2

        # Apply weights to the squared differences
        weighted_squared_diff = weights * squared_diff

        # Compute the mean or sum of the weighted squared differences
        if self.reduction == 'mean':
            return torch.mean(weighted_squared_diff)
        elif self.reduction == 'sum':
            return torch.sum(weighted_squared_diff)
        else:
            return weighted_squared_diff

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transformations for the ImageNet dataset
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

annotations_file = '../../Dataset/AIGCIQA2023/AIGCIQA_2023.csv'

# Create the dataset
dataset = ImageAuthenticityDataset(csv_file=annotations_file, transform=data_transforms)

# Set random seeds for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)

# Split the dataset into training, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


# Create data loaders
BATCH_SIZE = 64 # Set to 1 for handling individual images
EPOCHS = 20

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


# Create a dictionary containing the data loaders
dataloaders = {
    'train': train_dataloader,
    'val': val_dataloader,
    'test': test_dataloader
}

model = AuthenticityPredictor()
criterion = WeightedMSELoss(reduction='mean')
test_criterion = nn.L1Loss()
optimizer = optim.Adam(model.regression_head.parameters(), lr=0.001)

# Train the model
# model = train_model(model, dataloaders, criterion, optimizer, num_epochs=EPOCHS, device=device)

# Save the model
# torch.save(model.state_dict(), 'Weights/BarlowTwins_SD_weight_real_authenticity_finetuned.pth')

