In [None]:
import os
from copy import deepcopy
from scipy.io import loadmat
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms import Normalize, Compose, Resize, CenterCrop, ToTensor
from torchvision import utils as torch_utils
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torchvision
from tqdm.notebook import tqdm
from sklearn.decomposition import PCA

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    # $pip install --quiet pytorch-lightning>=1.4
    import subprocess
    subprocess.check_call(["pip", "install", "--quiet", "pytorch-lightning>=1.4"])
    import pytorch_lightning as pl

In [None]:
### Load and check the image data

PATH_TO_DATA = '../../data/selection1866'

file_1 = loadmat(os.path.join(PATH_TO_DATA, 'img1.mat'))
raw_img_1 = file_1['img']

plt.imshow(raw_img_1)
plt.title("Image 1 (trio) | Dims: {}".format(raw_img_1.shape))
plt.show()

img_1_tile_1 = raw_img_1[:, :500]

plt.imshow(img_1_tile_1)
plt.title("Leftmost tile of image 1 | Dims: {}".format(img_1_tile_1.shape))
plt.show()

# Visualise the transformations we will apply
transform = Compose([
    Resize(96), # Resize shortest edge to 96
    CenterCrop((96, 96)), # Crop to (224, 224)
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # Normalize
])

rgb_img = np.stack([img_1_tile_1] * 3, axis=-1) # Convert to RGB
tensor = torch.tensor(rgb_img, dtype=torch.float32).permute(2, 0, 1) # Shape (C, H, W)
tensor = (tensor + 2) / 4.0  # Scale to [0, 1]
tensor = torch.clamp(tensor, 0.0, 1.0)  # Clamp to ensure [0, 1] range
processed_img = transform(tensor) # Resize, crop, normalize

plt.imshow((processed_img * 0.5 + 0.5).permute(1, 2, 0).clamp(0, 1).numpy())
plt.title("Processed Image | Dims: {}".format(processed_img.shape))
plt.axis('off')
plt.show()

In [None]:
### Preprocess images for SimCLR

file_list = sorted(f for f in os.listdir(PATH_TO_DATA) if f.endswith('.mat'))

# Prepare images for SimCLR; todo: STL10 is 96x96
transform = Compose([
    Resize(224), # Resize shortest edge to 224 (cut off the rightmost part of the image)
    CenterCrop((224, 224)), # Crop to (224, 224)
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # !! Normalize expects input is already in the range [0, 1]
])

img_tensors,labels = [], []
for idx, filename in enumerate(file_list):
    data = loadmat(os.path.join(PATH_TO_DATA, filename))
    
    img = data['img'][:, :500] # Take leftmost part of the image
    rgb_img = np.stack([img] * 3, axis=-1) # Convert grayscale to RGB for SimCLR
    tensor = torch.tensor(rgb_img, dtype=torch.float32).permute(2, 0, 1) # Shape (C, H, W)
    
    # Min-max scale the tensor to [0, 1]
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    tensor = (tensor - tensor_min) / (tensor_max - tensor_min)

    # Clamp to [0, 1] to ensure no outliers due to numerical precision
    tensor = torch.clamp(tensor, 0.0, 1.0)

    transformed_tensor = transform(tensor) # Normalize and resize for SimCLR
    img_tensors.append(transformed_tensor)
    labels.append(idx)

image_dataset = TensorDataset(torch.stack(img_tensors), torch.tensor(labels))

dataset = TensorDataset(torch.stack(img_tensors), torch.tensor(labels))

images, labels = dataset.tensors
print("Labels:", labels[:10])
print("Processed dataset shape:", images.shape) # (N, C, 96, 96)
print(f"Min pixel value (processed): {torch.min(images)}")
print(f"Max pixel value (processed): {torch.max(images)}")

# Show a sample of processed images
img_grid = torch_utils.make_grid(images[:12], nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0).numpy()
plt.figure(figsize=(10, 5))
plt.title('Processed images: sample')
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

In [None]:
### Extract feature representations of our images from a pretrained SimCLR model

MODEL_CHECKPOINT_PATH = "../../models/tutorial17/SimCLR.ckpt"

NUM_WORKERS = os.cpu_count()
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
print("Number of workers:", NUM_WORKERS)

class SimCLR(pl.LightningModule):
    def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=500):
        super().__init__()
        
        self.save_hyperparameters()
        assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
        
        # Base model f(.)
        self.convnet = torchvision.models.resnet18(num_classes=4*hidden_dim)  # Output of last linear layer
        
        # The MLP for g(.) consists of Linear->ReLU->Linear
        self.convnet.fc = nn.Sequential(
            self.convnet.fc, # Linear(ResNet output, 4*hidden_dim)
            nn.ReLU(inplace=True),
            nn.Linear(4*hidden_dim, hidden_dim)
        )

# Function to register hooks and capture outputs from intermediate layers
def register_hooks(model, layers):
    features = {}

    def hook(module, input, output, layer_name):
        features[layer_name] = output.detach()

    for layer_name in layers:
        layer = dict([*model.named_modules()])[layer_name]
        layer.register_forward_hook(lambda module, input, output, layer_name=layer_name: hook(module, input, output, layer_name))
    
    return features

# Run the pretrained SimCLR model on the experiment images, and capture features from final layer and intermediate layers
@torch.no_grad()
def extract_simclr_features(model, dataset, layers_to_capture):
    # Prepare model and register hooks
    network = deepcopy(model.convnet)
    network.fc = nn.Identity() # Removing projection head g(.)
    network.eval()
    network.to(device)
    
    # Register hooks to capture specific intermediate layers
    features = register_hooks(network, layers_to_capture)
    
    # Encode all images
    data_loader = DataLoader(dataset, batch_size=64, num_workers=NUM_WORKERS, shuffle=False, drop_last=False)
    feats, labels, intermediate_features = [], [], {layer: [] for layer in layers_to_capture}
    
    for batch_imgs, batch_labels in tqdm(data_loader):
        batch_imgs = batch_imgs.to(device)
        batch_feats = network(batch_imgs)
        
        feats.append(batch_feats.detach().cpu())
        labels.append(batch_labels)
        
        # Collect intermediate layer outputs
        for layer in layers_to_capture:
            # Final linear layer outputs a 2d tensor; but intermediate layers don't, so we flatten them ready for PCA 
            layer_output_flattened = features[layer].view(features[layer].size(0), -1) 
            intermediate_features[layer].append(layer_output_flattened.cpu())
    
    # Concatenate results for each layer
    feats = torch.cat(feats, dim=0)
    labels = torch.cat(labels, dim=0)
    intermediate_features = {layer: torch.cat(intermediate_features[layer], dim=0) for layer in layers_to_capture}
    
    return TensorDataset(feats, labels), intermediate_features

# Load the pretrained SimCLR model
model = SimCLR.load_from_checkpoint(MODEL_CHECKPOINT_PATH)
model.eval()

# Extract SimCLR representations and intermediate features
layers_to_capture = ['layer1', 'layer2', 'layer3', 'layer4', 'fc']
final_layer, intermediate_features = extract_simclr_features(model, dataset, layers_to_capture)
final_layer_feats, labels = final_layer.tensors
layer1_feats = intermediate_features['layer1']
layer2_feats = intermediate_features['layer2']
layer4_feats = intermediate_features['layer4']

In [None]:
### Visualise SimCLR feature representations in pixel space

# Single image tensor (already extracted)
img_1 = images[0]

# Define the layers you want to capture
layers_to_capture = ['layer1', 'layer2', 'layer3', 'layer4']

# Function to register hooks for specified layers
def register_hooks(model, layers):
    features = {}

    def hook(module, input, output, layer_name):
        features[layer_name] = output.detach()

    for layer_name in layers:
        layer = dict([*model.named_modules()])[layer_name]
        layer.register_forward_hook(lambda module, input, output, layer_name=layer_name: hook(module, input, output, layer_name))
    
    return features

# Ensure img_1 is a tensor in the correct format
# Assuming img_1 is already a PyTorch tensor with shape (C, H, W) or (1, C, H, W)
if img_1.ndim == 3:  # If shape is (C, H, W), add batch dimension
    img_tensor = img_1.unsqueeze(0)
elif img_1.ndim == 4:  # If shape is already (1, C, H, W), no changes are needed
    img_tensor = img_1
else:
    raise ValueError("img_1 must have shape (C, H, W) or (1, C, H, W)")

# Ensure the tensor is normalized for SimCLR
# Normalize from [0, 1] to [-1, 1] if necessary
if img_tensor.max() > 1 or img_tensor.min() < -1:
    img_tensor = (img_tensor * 2) - 1  # Scale [0, 1] to [-1, 1]

# Load the SimCLR model, remove the projection head, and copy its encoder
model.eval()
network = deepcopy(model.convnet)
network.fc = torch.nn.Identity()  # Remove projection head
network.eval()
network.to(device)

# Register hooks and capture features
features = register_hooks(network, layers_to_capture)

# Pass the image through the network
img_tensor = img_tensor.to(device)
feats = network(img_tensor)  # Forward pass

# Visualize all feature maps layer-by-layer, row-by-row
def plot_feature_maps(features, max_maps_per_layer=16):
    """
    Plot feature maps layer-by-layer.
    Each layer's feature maps appear in a separate figure.
    """
    for layer_name, layer_output in features.items():
        layer_output = layer_output.squeeze(0).cpu()  # Remove batch dimension
        num_maps = layer_output.shape[0]  # Number of feature maps for the current layer
        
        # Limit the number of maps displayed per layer
        maps_to_plot = min(num_maps, max_maps_per_layer)

        # Create a new figure for each layer
        plt.figure(figsize=(maps_to_plot * 2, 4))
        for i in range(maps_to_plot):
            plt.subplot(1, maps_to_plot, i + 1)  # 1 row, multiple columns
            plt.imshow(layer_output[i].numpy(), cmap='viridis')
            plt.axis("off")
            plt.title(f"Map {i}")
        plt.suptitle(f"Feature Maps for {layer_name}")
        plt.tight_layout()
        plt.show()

# Call the function with the captured features
plot_feature_maps(features)

# Show the original image
plt.figure(figsize=(4, 4))
plt.imshow(img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy(), cmap='gray' if img_tensor.shape[1] == 1 else None)
plt.axis("off")
plt.title("Original Image")
plt.show()

In [None]:
### Visualise feature representations from SimCLR layers

import torch
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
from copy import deepcopy
import random

def total_variation_loss(image):
    """
    Total variation loss to smooth the optimized image.
    """
    tv_loss = torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :])) + \
              torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:]))
    return tv_loss


def visualize_random_filters(model, layers, num_filters=10, input_size=(96, 96), iterations=200, lr=0.005):
    """
    Visualize patterns that maximize activation for a random selection of filters in specified layers.

    Args:
        model: The SimCLR model (PyTorch).
        layers: List of layer names to visualize (e.g., ['layer1', 'layer2', 'layer3', 'layer4']).
        num_filters: Number of random filters to visualize per layer.
        input_size: Size of the input image (default: 96x96).
        iterations: Number of optimization steps (default: 200).
        lr: Learning rate for optimization (default: 0.005).
    """
    # Prepare the model
    network = deepcopy(model.convnet)
    network.fc = torch.nn.Identity()  # Remove the projection head
    network.eval().to(device)

    # Iterate over each layer
    for layer_name in layers:
        # Hook to capture the output of the target layer
        activations = {}

        def hook(module, input, output):
            activations["layer_output"] = output

        target_layer = dict([*network.named_modules()])[layer_name]
        target_layer.register_forward_hook(hook)

        # Determine the number of filters in the layer
        _ = network(torch.randn(1, 3, *input_size, device=device))  # Forward pass to populate hook
        total_filters = activations["layer_output"].shape[1]

        # Select random filters
        # random_filters = random.sample(range(total_filters), num_filters)
        random_filters = [106, 380, 26, 498, 373, 438, 65, 208, 333, 142]
        print(f"Visualizing {num_filters} filters from {layer_name} (Randomly chosen: {random_filters})")

        # Create a figure to display all filter visualizations
        rows = 5 
        cols = 10
        fig, axes = plt.subplots(rows, cols, figsize=(15, 7))
        fig.suptitle(f"{layer_name} Filter Visualizations", fontsize=16)

        # Optimize and visualize each filter
        for idx, filter_index in enumerate(random_filters):
            # Start with a random noise image
            input_img = torch.randn(1, 3, *input_size, requires_grad=True, device=device)

            # Optimizer to modify the input image
            optimizer = torch.optim.Adam([input_img], lr=lr)

            for _ in range(iterations):
                optimizer.zero_grad()
                _ = network(input_img)  # Forward pass
                layer_output = activations["layer_output"]

                # Maximize activation and add regularization
                loss = -layer_output[0, filter_index].mean() + 0.01 * total_variation_loss(input_img)
                loss.backward() # backprop, computes how much each tensor contributed to final lsos
                optimizer.step()

            # Normalize the resulting image for visualization
            input_img = input_img.detach().cpu().squeeze()
            input_img = (input_img - input_img.min()) / (input_img.max() - input_img.min())
            
            # Display the optimized image
            # axes[idx].imshow(ToPILImage()(input_img),  cmap="gray")
            # axes[idx].axis("off")
            # axes[idx].set_title(f"Filter {filter_index}")
            row, col = divmod(idx, cols)
            axes[row, col].imshow(ToPILImage()(input_img), cmap="gray")
            axes[row, col].axis("off")
            axes[row, col].set_title(f"Filter {filter_index}", fontsize=8)

        plt.tight_layout()
        plt.show()

# Example usage
layers_to_visualize = ["fc"]  # Add 'fc' for the final layer
visualize_random_filters(model, layers=layers_to_visualize, num_filters=10, input_size=(96, 96), iterations=500, lr=0.01)

# Layer3
# Filters sorted by activation variability (top 10):
# Filter 150: Variance = 28.695124
# Filter 116: Variance = 22.668385
# Filter 236: Variance = 12.087730
# Filter 134: Variance = 12.006798
# Filter 244: Variance = 10.250396
# Filter 239: Variance = 9.221184
# Filter 110: Variance = 7.636278
# Filter 145: Variance = 7.632918
# Filter 151: Variance = 4.537350
# Filter 195: Variance = 4.303740

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage, Compose, Resize, CenterCrop, Normalize
from copy import deepcopy

# Preprocessing transformation as used during training
transform_visualization = Compose([
    Resize(224),
    CenterCrop((224, 224)),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def preprocess_for_display(tensor):
    """
    Undo normalization for display, ensuring values are in [0,1].
    """
    device = tensor.device  
    mean = torch.tensor([0.5, 0.5, 0.5], device=device).view(3, 1, 1)
    std = torch.tensor([0.5, 0.5, 0.5], device=device).view(3, 1, 1)
    tensor = tensor * std + mean
    tensor = torch.clamp(tensor, 0.0, 1.0)
    return tensor

def total_variation_loss(image):
    """Total variation loss to smooth the optimized image."""
    tv_loss = torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :])) + \
              torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:]))
    return tv_loss

def visualize_filter_with_real_images(model, layer_name, filter_index, dataset, input_size=(224, 224), iterations=300, lr=0.01):
    """
    For a given filter in a specified layer, this function:
      - Computes activations over all images in the dataset.
      - Computes and prints filter variability.
      - Retrieves the top 50 and bottom 50 activating images.
      - Synthesizes a feature via gradient ascent.
      - Displays a grid: synthetic feature (first row), top 50 images (next 5 rows, 10 per row), and bottom 50 images (final 5 rows).
      
    Args:
        model: Trained SimCLR model.
        layer_name: The target layer (e.g., "layer3").
        filter_index: Index of the filter (or final dimension) to visualize.
        dataset: A PyTorch tensor of shape [N, C, H, W] with preprocessed images.
        input_size: Input image size for synthetic feature generation.
        iterations: Number of gradient ascent iterations.
        lr: Learning rate.
    """
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    network = deepcopy(model.convnet)
    network.fc = torch.nn.Identity()  # Remove projection head
    network.eval().to(device)
    
    # Hook to capture activations
    activations = {}
    def hook(module, input, output):
        activations["layer_output"] = output.clone().detach()
    
    target_layer = dict([*network.named_modules()])[layer_name]
    handle = target_layer.register_forward_hook(hook)
    
    # Move dataset to device and compute activations
    dataset = dataset.to(device)
    with torch.no_grad():
        _ = network(dataset)
        layer_output = activations["layer_output"]  # Shape: [N, num_filters, H, W]
    
    handle.remove()  # Remove hook
    
    # Compute variability of each filter
    def get_filter_variability(layer_output):
        N, num_filters, H, W = layer_output.shape
        variability_dict = {}
        for f in range(num_filters):
            pooled = layer_output[:, f].max(dim=1)[0].max(dim=1)[0]  # [N]
            variability_dict[f] = pooled.var().item()
        return variability_dict
    
    variability_dict = get_filter_variability(layer_output)
    sorted_filters = sorted(variability_dict.items(), key=lambda x: x[1], reverse=True)
    print("Top 10 filters sorted by activation variability:")
    for f, var in sorted_filters[:10]:
        print(f"Filter {f}: Variance = {var:.6f}")
    
    # For the chosen filter, pool activations over spatial dimensions
    pooled_activations = layer_output[:, filter_index].max(dim=1)[0].max(dim=1)[0]  # [N]
    pooled_activations = (pooled_activations - pooled_activations.min()) / (
        pooled_activations.max() - pooled_activations.min() + 1e-8
    )
    print(f"Pooled Activations Stats for Filter {filter_index}:")
    print(f"  Min: {pooled_activations.min().item():.6f}, Max: {pooled_activations.max().item():.6f}, Mean: {pooled_activations.mean().item():.6f}")
    
    # Rank images: now taking 50 instead of 10
    top_indices = torch.argsort(pooled_activations, descending=True)[:50]
    bottom_indices = torch.argsort(pooled_activations)[:50]
    
    print("Top image indices:", top_indices.cpu().numpy())
    print("Bottom image indices:", bottom_indices.cpu().numpy())
    
    # Retrieve top and bottom images (undo normalization for display)
    top_images = torch.stack([preprocess_for_display(dataset[i].to(device)) for i in top_indices]).cpu()
    bottom_images = torch.stack([preprocess_for_display(dataset[i].to(device)) for i in bottom_indices]).cpu()
    
    # Synthetic feature visualization via gradient ascent
    input_img = torch.randn(1, 3, *input_size, requires_grad=True, device=device)
    optimizer = torch.optim.Adam([input_img], lr=lr)
    
    for _ in range(iterations):
        optimizer.zero_grad()
        _ = network(input_img)
        layer_out = activations["layer_output"]
        loss = -layer_out[0, filter_index].mean() + 0.01 * total_variation_loss(input_img)
        loss.backward()
        optimizer.step()
    
    synthetic_img = preprocess_for_display(input_img.detach().cpu().squeeze())
    
    # Visualization grid: We want 1 row for synthetic, 5 rows for top 50, 5 rows for bottom 50 -> total 11 rows, 10 columns.
    fig, axes = plt.subplots(11, 10, figsize=(20, 22))
    fig.suptitle(f"Filter {filter_index} in {layer_name}: Synthetic + Top/Bottom Activating Images", fontsize=16)
    
    # First row: display synthetic feature in the first cell, rest blank
    axes[0, 0].imshow(ToPILImage()(synthetic_img), cmap="gray")
    axes[0, 0].set_title("Synthetic Feature", fontsize=8)
    axes[0, 0].axis("off")
    for j in range(1, 10):
        axes[0, j].axis("off")
    
    # Next 5 rows: display top 50 images (10 per row)
    for i in range(50):
        row = 1 + i // 10
        col = i % 10
        axes[row, col].imshow(ToPILImage()(top_images[i]), cmap="gray")
        axes[row, col].set_title(f"Top {i+1}", fontsize=8)
        axes[row, col].axis("off")
    
    # Last 5 rows: display bottom 50 images (10 per row)
    for i in range(50):
        row = 6 + i // 10  # rows 6 to 10 (0-indexed)
        col = i % 10
        axes[row, col].imshow(ToPILImage()(bottom_images[i]), cmap="gray")
        axes[row, col].set_title(f"Low {i+1}", fontsize=8)
        axes[row, col].axis("off")
    
    plt.tight_layout()
    plt.show()

# Example usage:
layer_to_visualize = "layer3"
filter_to_visualize = 150  # Choose a filter/dimension with high variability
visualize_filter_with_real_images(model, layer_to_visualize, filter_to_visualize, dataset=images)


# Layer3
# Filters sorted by activation variability (top 10):
# Filter 150: Variance = 28.695124
# Filter 116: Variance = 22.668385
# Filter 236: Variance = 12.087730
# Filter 134: Variance = 12.006798
# Filter 244: Variance = 10.250396
# Filter 239: Variance = 9.221184
# Filter 110: Variance = 7.636278
# Filter 145: Variance = 7.632918
# Filter 151: Variance = 4.537350
# Filter 195: Variance = 4.303740

# Layer4
# Filters sorted by activation variability (top 10):
# Filter 65: Variance = 3.064611
# Filter 228: Variance = 3.017383
# Filter 117: Variance = 2.004943
# Filter 503: Variance = 1.593511
# Filter 146: Variance = 1.398939
# Filter 37: Variance = 1.290163
# Filter 272: Variance = 1.241190
# Filter 6: Variance = 1.238872
# Filter 210: Variance = 1.196525
# Filter 504: Variance = 1.120707

In [None]:
### Visualise features from final layer

# Top final layer dimensions by average activation:
# Dimension 106: Average Activation = 1.319144
# Dimension 498: Average Activation = 1.194227
# Dimension 333: Average Activation = 1.170060
# Dimension 380: Average Activation = 1.162012
# Dimension 438: Average Activation = 1.102969
# Dimension 130: Average Activation = 1.085204
# Dimension 26: Average Activation = 1.064653
# Dimension 315: Average Activation = 1.018300
# Dimension 119: Average Activation = 1.000116
# Dimension 84: Average Activation = 0.959936

avg_activations = final_layer_feats.mean(dim=0)  # [D] average activation for each dimension

# Sort dimensions in descending order by average activation
top_k = 10
top_indices = torch.argsort(avg_activations, descending=True)[:top_k]

# Top final layer dimensions by activation variability:
# Dimension 106: Variance = 0.623146
# Dimension 380: Variance = 0.511653
# Dimension 26: Variance = 0.483480
# Dimension 498: Variance = 0.483273
# Dimension 373: Variance = 0.468444
# Dimension 438: Variance = 0.416003
# Dimension 65: Variance = 0.368167
# Dimension 208: Variance = 0.364436
# Dimension 333: Variance = 0.362497
# Dimension 142: Variance = 0.349626

# final_layer_feats: tensor of shape [N, D]
variability = final_layer_feats.var(dim=0)  # Variance of each dimension across images

top_k = 10  # Number of dimensions to display
top_variability_indices = torch.argsort(variability, descending=True)[:top_k]

print("Top final layer dimensions by activation variability:")
for idx in top_variability_indices:
    print(f"Dimension {idx.item()}: Variance = {variability[idx].item():.6f}")


print("Top final layer dimensions by average activation:")
for i in top_indices:
    print(f"Dimension {i.item()}: Average Activation = {avg_activations[i].item():.6f}")


def visualize_final_layer_dimension(
    model, layer_name, filter_index, dataset,
    input_size=(224, 224), iterations=300, lr=0.01
):
    """
    Visualize how a specific dimension (filter_index) in the final layer
    responds to your dataset and produce a synthetic feature.

    Args:
        model: Trained SimCLR model.
        layer_name: The name of the final layer (e.g., 'fc' or 'fc.0').
        filter_index: The dimension to visualize (0 <= filter_index < final_layer_dim).
        dataset: A PyTorch tensor of shape [N, C, H, W] (preprocessed images).
        input_size: Size for the synthetic image.
        iterations: Number of gradient ascent iterations.
        lr: Learning rate for optimization.
    """
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    # Copy the model and remove the projection head
    network = deepcopy(model.convnet)
    network.fc = torch.nn.Identity()
    network.eval().to(device)

    # Hook to capture final layer activations
    activations = {}
    def hook(module, input, output):
        # final layer shape: [N, D], no spatial dims
        activations["final_output"] = output.clone().detach()

    target_layer = dict([*network.named_modules()])[layer_name]
    handle = target_layer.register_forward_hook(hook)

    # Move dataset to device and run forward pass
    dataset = dataset.to(device)
    with torch.no_grad():
        _ = network(dataset)
        final_output = activations["final_output"]  # shape: [N, D]

    handle.remove()

    # For the chosen dimension, each image has a single scalar
    # So we can directly use final_output[:, filter_index]
    dim_activations = final_output[:, filter_index]

    # Normalize activations to [0,1] for sorting
    dim_activations = (dim_activations - dim_activations.min()) / (
        dim_activations.max() - dim_activations.min() + 1e-8
    )

    # Rank images
    top_indices = torch.argsort(dim_activations, descending=True)[:10]
    bottom_indices = torch.argsort(dim_activations)[:10]

    print(f"Final Layer Dimension {filter_index} Stats:")
    print(f"  Min: {dim_activations.min().item():.6f}, Max: {dim_activations.max().item():.6f}, Mean: {dim_activations.mean().item():.6f}")
    print("Top image indices:", top_indices.cpu().numpy())
    print("Bottom image indices:", bottom_indices.cpu().numpy())

    # Prepare images for display
    top_images = torch.stack([preprocess_for_display(dataset[i]) for i in top_indices]).cpu()
    bottom_images = torch.stack([preprocess_for_display(dataset[i]) for i in bottom_indices]).cpu()

    # -------------------------------
    # Synthetic feature visualization
    # -------------------------------
    # We'll do gradient ascent on the chosen dimension
    input_img = torch.randn(1, 3, *input_size, requires_grad=True, device=device)
    optimizer = torch.optim.Adam([input_img], lr=lr)

    # Re-register hook to capture final layer for the synthetic image
    handle = target_layer.register_forward_hook(hook)

    for _ in range(iterations):
        optimizer.zero_grad()
        _ = network(input_img)
        synthetic_output = activations["final_output"]  # shape: [1, D]
        # Maximize dimension filter_index
        loss = -synthetic_output[0, filter_index] + 0.01 * total_variation_loss(input_img)
        loss.backward()
        optimizer.step()

    handle.remove()
    synthetic_img = preprocess_for_display(input_img.detach().cpu().squeeze())

    # Visualization
    fig, axes = plt.subplots(3, 10, figsize=(20, 10))
    fig.suptitle(f"Final Layer Dimension {filter_index}: Synthetic + Top/Bottom Activating Images", fontsize=14)

    # Synthetic feature in row 0, col 0
    axes[0, 0].imshow(ToPILImage()(synthetic_img), cmap="gray")
    axes[0, 0].set_title("Synthetic Feature", fontsize=8)
    axes[0, 0].axis("off")
    # Optionally turn off the rest of row 0
    for j in range(1, 10):
        axes[0, j].axis("off")

    # Top 10 images
    for i in range(10):
        axes[1, i].imshow(ToPILImage()(top_images[i]), cmap="gray")
        axes[1, i].set_title(f"Top {i+1}", fontsize=8)
        axes[1, i].axis("off")

    # Bottom 10 images
    for i in range(10):
        axes[2, i].imshow(ToPILImage()(bottom_images[i]), cmap="gray")
        axes[2, i].set_title(f"Low {i+1}", fontsize=8)
        axes[2, i].axis("off")

    plt.tight_layout()
    plt.show()

# Example usage:
#  - layer_name could be "fc" (if your final layer is named fc).
#  - filter_index is any dimension in [0, D-1].
layer_to_visualize = "fc"  # or something like "fc.0" if it's a submodule
filter_to_visualize = 333    # Choose a dimension in the final layer
visualize_final_layer_dimension(model, layer_to_visualize, filter_to_visualize, dataset=images)