# Import libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights 
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
from PIL import Image
import os
import numpy as np

# Database creations using pytorch Dataset 

In [None]:
class ImageAuthenticityDataset(Dataset):
    """Dataset for image quality assessment."""

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the CSV file with annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.dir_path = os.path.dirname(csv_file)  # Directory of the CSV file

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx,):
        """
        Retrieves an image and its labels by index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple (image, labels) where:
                image (PIL.Image): The image.
                labels (torch.Tensor): Tensor containing quality and authenticity scores.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # TODO: to be fixed, right now is folder dependent
        img_name = self.data.iloc[idx, 3].replace("./", "../../")
        image = Image.open(img_name).convert('RGB')
        authenticity = self.data.iloc[idx, 1]  # Authenticity column
        labels = torch.tensor([authenticity], dtype=torch.float)


        if self.transform:
            image = self.transform(image)

        return image, labels


# Definitions of the models

In [None]:
class AuthenticityPredictor(nn.Module):
    def __init__(self, freeze_backbone=True):
        super().__init__()
        efficent_net = efficientnet_b3(weights=EfficientNet_B3_Weights.DEFAULT)
        
        # Freeze backbone if requested
        if freeze_backbone:
            for param in efficent_net.features.parameters():
                param.requires_grad = False
                
        # Extract features up to fc2
        self.features = efficent_net.features
        self.avgpool = efficent_net.avgpool
        
        
        # New regression head for EfficientNet
        self.regression_head = nn.Sequential(
            nn.Linear(1536, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1)  # Predict authenticity
        )
        
    def forward(self, x):
        # Pass through the backbone features
        x = self.features(x)
        # Apply pooling
        x = self.avgpool(x)
        # Flatten the features
        features = torch.flatten(x, 1)
        # Pass through regression head
        predictions = self.regression_head(features)
        
        return predictions, features

# Setup 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transformations for the ImageNet dataset
data_transforms = transforms.Compose([
    transforms.Resize(320),
    transforms.CenterCrop(300),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

annotations_file = '../../Dataset/AIGCIQA2023/real_images_annotations.csv'

# Create the dataset
dataset = ImageAuthenticityDataset(csv_file=annotations_file, transform=data_transforms)

# Set random seeds for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)

# Split the dataset into training, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


# Create data loaders
BATCH_SIZE = 1 # Set to 1 for handling individual images


train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


# Create a dictionary containing the data loaders
dataloaders = {
    'train': train_dataloader,
    'val': val_dataloader,
    'test': test_dataloader
}


## Compute for each image the importace scores of lastConv layer's channels

In [None]:
import tqdm
from tqdm.notebook import tqdm  # Use tqdm.notebook for Jupyter notebooks

def compute_obj_x_obj_feature_map_importance(model, dataloader, device, layer_name):
    """Computes the importance of each feature map in a convolution layer by measuring the change in 
    predictions when the feature map is zeroed out, calculated per object.
    
    Args:
        model (nn.Module): The trained model
        dataloader (DataLoader): Dataloader containing the images (with batch_size=1)
        device (str): Device to run computation on ('cuda' or 'cpu')
        layer_name (str): Name of the layer to analyze
        
    Returns:
        numpy.ndarray: An array where each element is a numpy array of channel importance scores
                      for the corresponding image. Shape: [num_images]
    """
    # Check if importance scores are already computed
    if os.path.exists('Ranking_arrays/Dual_scores_obj_x_obj_authenticity_importance_scores.npy'):
        print("Per-object importance scores already computed, loading from file")
        return np.load('Ranking_arrays/Dual_scores_obj_x_obj_authenticity_importance_scores.npy', allow_pickle=True)
    
    model.eval()
    model.to(device)
    named_modules = list(model.named_modules())
    layer = None
    
    for name, module in named_modules:
        if name == layer_name:
            layer = module
            break
    
    if layer is None:
        raise ValueError(f"Layer {layer_name} not found in model")
    
    num_channels = layer.out_channels
    
    # Make sure batch_size=1 in the dataloader
    if dataloader.batch_size != 1:
        print("Warning: Dataloader batch size should be 1 for per-object importance computation")
    
    # Create a directory for saving results if it doesn't exist
    os.makedirs('Ranking_arrays', exist_ok=True)
    
    # Initialize array to store importance scores arrays for each image
    num_images = len(dataloader)
    importance_array = np.empty(num_images, dtype=object)  # Changed to object type for storing arrays
    
    # Process each image individually with a progress bar
    for img_idx, (inputs, labels) in enumerate(tqdm(dataloader, desc="Processing images")):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Get baseline prediction
        with torch.no_grad():
            baseline_outputs, _ = model(inputs)
            y_true = labels
            y_pred_tensor = baseline_outputs
            y_pred = y_pred_tensor.item()  # Get scalar value
            residual_error = torch.abs(y_pred_tensor - y_true).item()  # Get scalar value
        
        # Initialize array for this image's channel importance scores
        channel_residual_scores = np.zeros(num_channels)
        pred_residual_scores = np.zeros(num_channels)
        
        # Compute importance for each feature map for this image with inner progress bar
        for channel_idx in tqdm(range(num_channels), desc=f"Channels for image {img_idx}", leave=False):
            # Create a backup of the weights and bias
            backup_weights = layer.weight[channel_idx, ...].clone()
            backup_bias = layer.bias[channel_idx].clone() if layer.bias is not None else None
            
            # Zero out the channel_idx-th output channel
            with torch.no_grad():
                layer.weight[channel_idx, ...] = 0
                if layer.bias is not None:
                    layer.bias[channel_idx] = 0
                
            # Get prediction with the pruned feature map
            with torch.no_grad():
                pruned_outputs, _ = model(inputs)
                y_pred_pruned_tensor = pruned_outputs
                y_pred_pruned = y_pred_pruned_tensor.item()
                residual_error_pruned = torch.abs(y_pred_pruned_tensor - y_true).item()
            
            # Compute scores
            delta_residual = residual_error_pruned - residual_error  # how prediction error changes when we remove a feature map
            delta_prediction = y_pred_pruned - y_pred  # how the predicted realism score changes when we remove the feature map
            
            channel_residual_scores[channel_idx] = delta_residual
            pred_residual_scores[channel_idx] = delta_prediction
            
            # Restore weights and bias
            with torch.no_grad():
                layer.weight[channel_idx, ...] = backup_weights
                if layer.bias is not None:
                    layer.bias[channel_idx] = backup_bias
        
        # Store the channel scores array for this image
        importance_array[img_idx] = np.array([channel_residual_scores, pred_residual_scores])
    
    
    # Save results as a numpy array
    np.save('Ranking_arrays/Dual_scores_obj_x_obj_authenticity_importance_scores.npy', importance_array)
    print(f"Saved importance scores for {num_images} images to 'Ranking_arrays/Dual_scores_obj_x_obj_authenticity_importance_scores.npy'")
    
    # Return the array of arrays
    return importance_array

In [None]:
# LAYER to prune 
LAYER = 'features.8.0' 
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Load the model
model = AuthenticityPredictor()
model.load_state_dict(torch.load('Weights/EfficentNetB3_real_authenticity_finetuned.pth'))
model.eval()
model.to(DEVICE)

# Compute per-object feature importance
obj_x_obj_importance = compute_obj_x_obj_feature_map_importance(model, dataloaders['test'], DEVICE, LAYER)



