# Imports

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchvision.models import VGG16_Weights
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

# Costants

In [None]:

LAST_CONV_IDX = 29  # Index of last convolutional layer in VGG16
NUM_CHANNELS = 512  # Number of channels in VGG16 last conv layer
NUM_CLASSES = 2  # Number of output classes for our task

# Dataset class definition

In [None]:
class ImageQualityDataset(Dataset):
    """Dataset for image quality assessment."""

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the CSV file with annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieves an image and its labels by index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple (image, labels) where:
                image (PIL.Image): The image.
                labels (torch.Tensor): Tensor containing quality and authenticity scores.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(os.getcwd(), self.data.iloc[idx, 3])  # image_path column
        image = Image.open(img_name).convert('RGB')
        quality = self.data.iloc[idx, 0]  # Quality column
        authenticity = self.data.iloc[idx, 1]  # Authenticity column
        labels = torch.tensor([quality, authenticity], dtype=torch.float)

        if self.transform:
            image = self.transform(image)

        return image, labels

# Model definition

In [None]:
class VGG16_Feature_Extractor(nn.Module):
    """VGG16 model for image quality assessment with pruning capability."""

    def __init__(self, num_classes=NUM_CLASSES):
        """
        Initializes the VGG16 model for image quality assessment.
        
        Args:
            num_classes (int): Number of output classes.
        """
        super().__init__()
        # Load pre-trained VGG16 model
        self.features = models.vgg16(weights=VGG16_Weights.DEFAULT).features

        # Freeze all feature extraction layers
        for param in self.features.parameters():
            param.requires_grad = False

        # Store the last convolutional layer output index
        self.last_conv_idx = LAST_CONV_IDX

        # Classifier layers
        self.fc1 = nn.Sequential(
            nn.Linear(NUM_CHANNELS * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout()
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout()
        )
        
        self.fc3 = nn.Sequential(
            nn.Linear(4096, 1000),
            nn.ReLU(True),
            nn.Dropout()
        )
        
        self.final = nn.Linear(1000, num_classes)
            
    def forward(self, x, conv_masks=None):
        """
        Forward pass of the model with optional channel pruning.
        
        Args:
            x (torch.Tensor): Input tensor.
            conv_masks (torch.Tensor, optional): Binary mask to apply to last conv layer channels.
                Should be shape (512,) for VGG16 where 1 keeps channel and 0 prunes it.
        
        Returns:
            tuple: Tuple containing feature outputs at different stages
        """
        # Process through the feature layers up to the last convolutional layer
        for i in range(self.last_conv_idx + 1):
            x = self.features[i](x)
            
            # Apply channel pruning to the last convolutional layer output
            if i == self.last_conv_idx and conv_masks is not None:
                # Apply mask to prune channels - expanding to match dimensions
                mask = conv_masks.view(1, -1, 1, 1).to(x.device)
                x = x * mask
        
        # Continue through the remaining feature layers
        for i in range(self.last_conv_idx + 1, len(self.features)):
            x = self.features[i](x)
            
        # Store last_conv output
        last_conv = x
        
        # Continue through fully connected layers
        flattened = last_conv.view(last_conv.size(0), -1)
        fc1_out = self.fc1(flattened)
        fc2_out = self.fc2(fc1_out)
        fc3_out = self.fc3(fc2_out)
        output = self.final(fc3_out)
        
        return last_conv, fc1_out, fc2_out, fc3_out, output


# Feature extraction functions

In [None]:
def extract_features(model, dataloader, channels_to_prune=None):
    """
    Extract features from the model with optional pruning of the last conv layer.
    
    Args:
        model: PyTorch model to extract features from
        dataloader: DataLoader containing the data
        channels_to_prune: List of channel indices to prune (set to zero) in the last conv layer (optional)
        
    Returns:
        Dictionary containing extracted features from different layers and labels
    """
    # Create channel mask (1 = keep, 0 = prune)
    if channels_to_prune is None:
        conv_masks = torch.ones(NUM_CHANNELS)
    else:
        conv_masks = torch.ones(NUM_CHANNELS)
        conv_masks[channels_to_prune] = 0
    
    # Pre-allocate lists for features and labels
    last_conv_features = []
    fc1_features = []
    fc2_features = []
    fc3_features = []
    output_features = []
    labels_list = []
    
    model.eval()
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Extracting features"):
            # Move data to device
            inputs = inputs.to(model.features[0].weight.device)
            
            # Forward pass with pruning mask
            last_conv, fc1, fc2, fc3, output = model(inputs, conv_masks)
            
            # Collect features from different layers
            last_conv_features.append(last_conv.cpu().numpy())
            fc1_features.append(fc1.cpu().numpy())
            fc2_features.append(fc2.cpu().numpy())
            fc3_features.append(fc3.cpu().numpy())
            output_features.append(output.cpu().numpy())
            labels_list.append(targets.numpy())
    
    # Concatenate batched features
    features_dict = {
        'last_conv': np.concatenate(last_conv_features),
        'fc1': np.concatenate(fc1_features),
        'fc2': np.concatenate(fc2_features),
        'fc3': np.concatenate(fc3_features),
        'output': np.concatenate(output_features),
        'labels': np.concatenate(labels_list)
    }
    
    return features_dict

# Cell 6: Similarity and Correlation Analysis Functions
def compute_similarity_matrix(features):
    """
    Compute a similarity matrix from feature embeddings.
    Works with both convolutional features (4D) and FC features (2D).
    
    Args:
        features: numpy array - either shape (n_samples, n_channels, height, width)
                 or shape (n_samples, n_features)
        
    Returns:
        similarity_matrix: numpy array of shape (n_samples, n_samples)
    """
    # Check the dimensionality of features
    n_samples = features.shape[0]
    
    # If features are from convolutional layer (4D), reshape to 2D
    if len(features.shape) == 4:
        features_flat = features.reshape(n_samples, -1)
    else:
        # Features are already 2D (from FC layer)
        features_flat = features
    
    # Compute cosine similarity between all pairs
    similarity_matrix = cosine_similarity(features_flat)
    
    return similarity_matrix

def compute_quality_difference_matrix(quality_scores):
    """
    Compute a matrix of quality differences between all pairs of samples.
    
    Args:
        quality_scores: numpy array of shape (n_samples,) containing quality scores
        
    Returns:
        difference_matrix: numpy array of shape (n_samples, n_samples)
    """
    n_samples = quality_scores.shape[0]
    difference_matrix = np.zeros((n_samples, n_samples))
    
    # Compute absolute differences between all pairs
    for i in range(n_samples):
        for j in range(n_samples):
            difference_matrix[i, j] = abs(quality_scores[i] - quality_scores[j])
            
    return difference_matrix

def get_upper_triangle(matrix):
    """
    Extract the upper triangle of a matrix (excluding diagonal).
    
    Args:
        matrix: numpy array of shape (n, n)
        
    Returns:
        upper_triangle: flattened upper triangle values
    """
    indices = np.triu_indices_from(matrix, k=1)
    return matrix[indices]

def calculate_fit(similarity_matrix, quality_diff_matrix):
    """
    Calculate the fit between similarity and quality difference matrices.
    
    Args:
        similarity_matrix: numpy array of shape (n_samples, n_samples)
        quality_diff_matrix: numpy array of shape (n_samples, n_samples)
        
    Returns:
        correlation: Spearman correlation coefficient between the matrices
        p_value: p-value of the correlation
    """
    # Extract upper triangles (excluding diagonal)
    sim_upper = get_upper_triangle(similarity_matrix)
    qual_upper = get_upper_triangle(quality_diff_matrix)
    
    # Compute correlation (negative since higher similarity should correspond to lower difference)
    correlation, p_value = spearmanr(sim_upper, qual_upper)
    
    # We're expecting a negative correlation (higher similarity → lower quality difference)
    # so we return the negative correlation value for easier interpretation
    return -correlation, p_value

def extract_quality_scores(labels):
    """
    Extract quality scores from the label tensor.
    
    Args:
        labels: numpy array of shape (n_samples, 2) where the first column is quality
        
    Returns:
        quality_scores: numpy array of shape (n_samples,)
    """
    return labels[:, 0]  # Assuming first column contains quality scores

# Channel pruning functions

In [None]:
def compute_channel_impact(model, dataloader, num_channels=NUM_CHANNELS):
    """
    Analyze how pruning each channel in the last convolutional layer impacts 
    the feature quality in the fc2 layer.
    
    Args:
        model: PyTorch model to analyze
        dataloader: DataLoader containing the data
        num_channels: Number of channels in the last convolutional layer
        
    Returns:
        channel_impacts: numpy array of shape (num_channels) containing channel index and impact scores
    """
    print(f"Analyzing impact of pruning last_conv channels on fc2 features...")
    # Check cached list
    base_path = 'Ranking_arrays'
    if os.path.exists(f"{base_path}/sim_matrix_channel_importance.npy"):
        print("Loading cached channel impacts...")
        return np.load(f"{base_path}/sim_matrix_channel_importance.npy", allow_pickle=True)
    
    # First, get baseline features with no pruning
    baseline_features = extract_features(model, dataloader)
    
    # Extract quality scores and compute baseline fit for fc2
    quality_scores = extract_quality_scores(baseline_features['labels'])
    quality_diff = compute_quality_difference_matrix(quality_scores)
    fc2_similarity = compute_similarity_matrix(baseline_features['fc2'])
    baseline_fit, _ = calculate_fit(fc2_similarity, quality_diff)
    
    print(f"Baseline fit (no pruning): {baseline_fit:.4f}")
    
    # Analyze each channel's impact by pruning it and measuring fc2 fit
    channel_impacts = []
    
    for channel_idx in tqdm(range(num_channels), desc="Pruning channels"):
        # Prune one channel at a time
        pruned_features = extract_features(
            model, 
            dataloader, 
            channels_to_prune=[channel_idx] # Prune one channel at a time
        )
        
        # Compute fit for pruned features
        fc2_pruned_similarity = compute_similarity_matrix(pruned_features['fc2'])
        pruned_fit, _ = calculate_fit(fc2_pruned_similarity, quality_diff)
        print(f"Pruned channel {channel_idx}: {pruned_fit:.4f}")
        
        # Calculate impact: positive means removing the channel improves the fc2 fit
        channel_impact = pruned_fit - baseline_fit
        channel_impacts.append([channel_idx, channel_impact])
        print(f"Channel {channel_idx} impact: {channel_impact:.4f}")
        print("--------------------------------------------------")
    
    # Summarize results
    print("\nChannel Pruning Impact Analysis Results:")
    print(f"Baseline fit (no pruning): {baseline_fit:.4f}")
    
    # Sort the channel impacts by the impact score
    channel_impacts = np.array(channel_impacts)
    channel_impacts = channel_impacts[np.argsort(channel_impacts[:, 1])]
    
    return channel_impacts

def evaluate_multi_channel_pruning(model, dataloader, channels_to_prune):
    """
    Evaluate the impact of pruning multiple channels together on fc2 features.
    
    Args:
        model: PyTorch model to analyze
        dataloader: DataLoader containing the data
        channels_to_prune: List of channel indices to prune
        
    Returns:
        fc2_fit: Fit of fc2 features after pruning
    """
    # Extract features with specified pruning
    pruned_features = extract_features(
        model, 
        dataloader, 
        channels_to_prune=channels_to_prune
    )
    
    # Compute fit for pruned features
    quality_scores = extract_quality_scores(pruned_features['labels'])
    quality_diff = compute_quality_difference_matrix(quality_scores)
    fc2_pruned_similarity = compute_similarity_matrix(pruned_features['fc2'])
    fc2_fit, p_value = calculate_fit(fc2_pruned_similarity, quality_diff)
    
    return fc2_fit

def find_optimal_pruning(model, dataloader, channel_impacts, method='greedy', original_fit=None):
    """
    Find the optimal set of last_conv channels to prune to maximize fc2 fit.
    
    Args:
        model: PyTorch model to analyze
        dataloader: DataLoader containing the data
        channel_impacts: numpy array of impacts from pruning each channel
        original_fit: Baseline fit with no pruning
        method: Pruning strategy ('greedy' or 'threshold')
        
    Returns:
        optimal_channels: List of channels to prune
        optimal_fit: Fit achieved after pruning
        fit_history: History of fit improvements (for greedy method)
    """
    if original_fit is None:
        # First, get baseline features with no pruning
        baseline_features = extract_features(model, dataloader)
        
        # Extract quality scores and compute baseline fit for fc2
        quality_scores = extract_quality_scores(baseline_features['labels'])
        quality_diff = compute_quality_difference_matrix(quality_scores)
        fc2_similarity = compute_similarity_matrix(baseline_features['fc2'])
        baseline_fit, _ = calculate_fit(fc2_similarity, quality_diff)
        
    print(f"Baseline fit (no pruning): {baseline_fit:.4f}")
    fit_history = []
    
    if method == 'threshold':
        # Prune all channels with positive impact
        channels_to_prune = np.where(channel_impacts > 0)[0].tolist()
        print(f"Threshold method selected {len(channels_to_prune)} channels to prune")
        
        # Evaluate pruning
        prune_fit = evaluate_multi_channel_pruning(model, dataloader, channels_to_prune)
        
    elif method == 'greedy':
        # Start with no pruning
        channels_to_prune = []
        
        best_fit = baseline_fit
        
        # Sort channels by descending impact
        sorted_channels = np.argsort(channel_impacts)[::-1]
        
        # Greedily add channels to prune until fit no longer improves
        for i in tqdm(range(len(sorted_channels)), desc="Greedy pruning"):
            test_channels = channels_to_prune + [sorted_channels[i]]
            test_fit = evaluate_multi_channel_pruning(model, dataloader, test_channels)
            
            if test_fit > best_fit:
                print(f"Adding channel {sorted_channels[i]} to prune list improved fit to {test_fit:.4f}")
                channels_to_prune = test_channels
                best_fit = test_fit
                fit_history.append(test_fit)
            
        prune_fit = best_fit
        print(f"Greedy search found {len(channels_to_prune)} channels to prune")
        
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return channels_to_prune, prune_fit, fit_history


# Setup and initialization

In [None]:
def setup_environment():
    """Initialize device, transformations, and datasets"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Data transformations
    data_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Dataset setup
    ANNOTATIONS_FILE = 'Dataset/AIGCIQA2023/mos_data.csv'
    dataset = ImageQualityDataset(csv_file=ANNOTATIONS_FILE, transform=data_transforms)
    
    # Split the dataset
    TRAIN_RATIO = 0.8
    TEST_RATIO = 0.2
    train_size = int(TRAIN_RATIO * len(dataset))
    test_size = int(TEST_RATIO * len(dataset))
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    # Create dataloaders
    BATCH_SIZE = 64
    NUM_WORKERS = 10
    analysis_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    
    return device, analysis_dataloader, test_dataloader

# MAIN 

In [None]:
def main():
    # Setup environment
    device, analysis_dataloader, test_dataloader = setup_environment()
    
    # Create model
    model = VGG16_Feature_Extractor(num_classes=NUM_CLASSES).to(device)
    
    # Analyze channel importance for fc2 fit
    channel_impacts = compute_channel_impact(
        model, 
        analysis_dataloader,
        num_channels=NUM_CHANNELS
    )
    
    # Save channel impacts if needed
    # np.save('Ranking_arrays/sim_matrix_channel_importance.npy', channel_impacts)
    
    return model, analysis_dataloader, test_dataloader, channel_impacts

# Execute main function to initialize variables
model, analysis_dataloader, test_dataloader, channel_impacts = main()

# Cell 10: Threshold-based Pruning Evaluation
# Find channels to prune using threshold method
negative_impact_pruning_set, neg_impact_pruned_fit, _ = find_optimal_pruning(
    model, 
    analysis_dataloader, 
    channel_impacts[:, 1], 
    method='threshold'
)

print(f"Negative impact pruning set: {negative_impact_pruning_set}")
print(f"Negative impact pruning fit: {neg_impact_pruned_fit:.4f}")

# Save the negative impact pruning set
np.save('Pruning_sets/negative_impact_pruning_set.npy', negative_impact_pruning_set)

# Cell 11: Greedy Pruning Evaluation
# Find channels to prune using greedy method
greedy_search_pruning_set, greedy_pruned_fit, fit_history = find_optimal_pruning(
    model, 
    analysis_dataloader, 
    channel_impacts[:, 1], 
    method='greedy'
)

print(f"Greedy search pruning set: {greedy_search_pruning_set}")
print(f"Greedy search pruning fit: {greedy_pruned_fit:.4f}")
print(f"Greedy search fit history: {fit_history}")

# Cell 12: Model Evaluation with Different Pruning Strategies
# Evaluate base model (no pruning)
base_model = VGG16_Feature_Extractor(num_classes=NUM_CLASSES).to(device)
base_model_eval = evaluate_multi_channel_pruning(
    base_model,
    test_dataloader,
    channels_to_prune=None
)
del base_model

# Evaluate model with negative impact pruning
base_model_2 = VGG16_Feature_Extractor(num_classes=NUM_CLASSES).to(device)
channels_to_prune = np.load('Ranking_arrays/sim_matrix_channel_importance.npy', allow_pickle=True)
# Pick channels_to_prune[:,0] where channels_to_prune[:,1] is greater than 0
negative_impact_channels_to_prune = channels_to_prune[channels_to_prune[:,1] > 0][:,0]
# Transform the float64 to int
negative_impact_channels_to_prune = negative_impact_channels_to_prune.astype(int)

neg_impact_pruning_eval = evaluate_multi_channel_pruning(
    base_model_2, 
    test_dataloader, 
    negative_impact_channels_to_prune
)
del base_model_2

# Evaluate model with greedy pruning
base_model_3 = VGG16_Feature_Extractor(num_classes=NUM_CLASSES).to(device)
greedy_impact_channels_to_prune = np.load('Pruning_sets/RSA_greedy_search_pruning_set.npy', allow_pickle=True)

greedy_impact_pruning_eval = evaluate_multi_channel_pruning(
    base_model_3,
    test_dataloader,
    greedy_impact_channels_to_prune
)

print(f"Base model fit: {base_model_eval:.4f}")
print(f"Negative impact pruning fit: {neg_impact_pruning_eval:.4f}")
print(f"Greedy search pruning fit: {greedy_impact_pruning_eval:.4f}")

# Cell 13: Visualization of Pruning Performance
# Load pruning search history
pruning_search_history = np.load('Fit_histories/RSA_greedy_search_fit_history.npy')

plt.figure(figsize=(10, 6))
plt.plot(pruning_search_history, label='Greedy Search')
plt.xlabel('Number of Channels Pruned')
plt.ylabel('Fit Score')
plt.title('Greedy Search Pruning Performance')
plt.legend()
plt.grid(True)
plt.show()