In [None]:
! cp -r /kaggle/input/test-time-registers-code/* /kaggle/working/
! pip -qqq install ftfy

import gzip
import shutil

src = "/kaggle/working/clip/clip/vocab/bpe_simple_vocab_16e6.txt"
dst = "/kaggle/working/clip/clip/vocab/bpe_simple_vocab_16e6.txt.gz"
with open(src, "rb") as f_in:
    with gzip.open(dst, "wb") as f_out:
        shutil.copyfileobj(f_in, f_out)

import sys
import os
from shared.utils import (
    load_images,
    plot_images_with_max_per_row,
    plot_attn_maps,
    filter_layers,
)
from shared.algorithms import find_register_neurons
import numpy as np
import matplotlib.pyplot as plt
import torch
import yaml
from tqdm import tqdm
import pandas as pd
import random
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
import json

sys.path.append("clip/")
sys.path.append("dinov2/")

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [None]:
def load_images(orthonet_path, count=10, images_only=True):
    # Load the CSV file
    csv_path = os.path.join(orthonet_path, 'train.csv')
    df = pd.read_csv(csv_path)
    
    # Get the image filenames (not mask files)
    image_filenames = df['filenames'].tolist()
    
    # Sample random images
    if count > len(image_filenames):
        sampled_filenames = image_filenames
    else:
        sampled_filenames = random.sample(image_filenames, count)
    
    # Construct full paths and load images
    # Images are in orthonet_path/orthonet data/orthonet data/
    images_dir = os.path.join(orthonet_path, 'orthonet data', 'orthonet data')
    
    image_files = []
    sampled_paths = []
    for filename in sampled_filenames:
        full_path = os.path.join(images_dir, filename)
        image_files.append(Image.open(full_path))
        sampled_paths.append(full_path)
    
    print("Loaded {} images".format(len(image_files)))
    
    if images_only:
        return image_files
    else:
        return image_files, sampled_paths

In [None]:
from shared import utils
utils.load_images = load_images

In [None]:
MODEL = "dinov2"

In [None]:
if MODEL == "dinov2":
    from dinov2_state import load_dinov2_state

    config = {
        "backbone_size": "vitl14",
        "device": "cuda:0",
        "detect_outliers_layer": -2,
        "register_norm_threshold": 150,
        "highest_layer": 19,
        "top_k": 50,
    }

    state = load_dinov2_state(config)
    
elif MODEL == "clip":
    from clip_state import load_clip_state

    config = {
        "model_name": "ViT-B-16",
        "pretrained": "laion2b_s34b_b88k",
        "device": "cuda",
        "highest_layer": 5,
        "detect_outliers_layer": -1,
        "register_norm_threshold": 30,
        "top_k": 20,
    }

    state = load_clip_state(config)
else:
    raise ValueError(f"Model {MODEL} not supported")

In [None]:
IMAGE_PATH = "/kaggle/input/orthonet-data"
IMAGE_SIZE = 224

run_model = state["run_model"]
model = state["model"]
preprocess = state["preprocess"]
hook_manager = state["hook_manager"]
num_layers = state["num_layers"]
num_heads = state["num_heads"]
patch_size = state["patch_size"]
config = state["config"]
patch_height = IMAGE_SIZE // patch_size
patch_width = IMAGE_SIZE // patch_size
device = "cuda"

In [None]:
image = load_images(IMAGE_PATH, count=2)[0]

In [None]:
processed_image = preprocess(image).unsqueeze(0).to(device)
hook_manager.reinit()
hook_manager.finalize()
representation = run_model(model, processed_image)

attention_maps = hook_manager.get_attention_maps()  # shape (L, H, N, N)
layer_outputs = hook_manager.get_layer_outputs()  # shape (L, N, D)

patch_norms = np.linalg.norm(layer_outputs[:, 1:], axis=2).reshape(
    num_layers, patch_height, patch_width
)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("off")
plt.title("Input")
plt.show()

In [None]:


# Plot patch norms across all layers. Notice that outliers appear in the later layers.

plt = plot_images_with_max_per_row(patch_norms, max_per_row=4, image_title="Layer")
plt.tight_layout()
plt.show()

In [None]:
# Plot attention maps

cls_attn_maps = attention_maps[:, :, 0, 1:].reshape(
    num_layers, num_heads, patch_height, patch_width
)
plt = plot_attn_maps(cls_attn_maps)
plt.show()

In [None]:
rand_images = load_images(IMAGE_PATH, count=10)
max_patch_norms = [0] * num_layers
max_attn_norms = [0] * num_layers

for image in tqdm(rand_images, desc="Processing images"):
    processed_image = preprocess(image).unsqueeze(0).to(device)
    hook_manager.reinit()
    hook_manager.finalize()
    representation = run_model(model, processed_image)
    attention_maps = hook_manager.get_attention_maps()  # shape (L, H, N, N)
    layer_outputs = hook_manager.get_layer_outputs()  # shape (L, N, D)

    patch_norms = np.max(
        np.linalg.norm(layer_outputs[:, 1 : patch_height * patch_width + 1], axis=2),
        axis=1,
    )
    attn_norms = np.max(
        np.mean(attention_maps[:, :, 0, 1 : patch_height * patch_width + 1], axis=1),
        axis=1,
    )

    for j in range(num_layers):
        max_attn_norms[j] += attn_norms[j]
        max_patch_norms[j] += patch_norms[j]

max_attn_norms = [x / len(rand_images) for x in max_attn_norms]
max_patch_norms = [x / len(rand_images) for x in max_patch_norms]

In [None]:
# Create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), dpi=100)

# === Subplot 1: Patch Norms ===
ax1.plot(
    range(num_layers),
    max_patch_norms,
    marker="o",
    markersize=8,
    color="steelblue",
    linestyle="-",
    linewidth=2,
)
ax1.set_title("Average Max Patch Norms (Layer output)")
ax1.set_xlabel("Layer")
ax1.set_ylabel("Norm")
ax1.set_xticks(range(0, num_layers, 2))  # Show every second tick
ax1.tick_params(axis="both")
ax1.tick_params(axis="x", which="major", pad=15)
ax1.margins(x=0.1)
ax1.grid(True, linestyle="-")

# === Subplot 2: Attention Norms ===
ax2.plot(
    range(num_layers),
    max_attn_norms,
    marker="^",
    markersize=8,
    color="orange",
    linestyle="-",
    linewidth=2,
)
ax2.set_title("Average Max Attention (CLS)")
ax2.set_xlabel("Layer")
ax2.set_ylabel("Attention Value")
ax2.set_xticks(range(0, num_layers, 2))  # Show every second tick
ax2.tick_params(axis="both")
ax2.tick_params(axis="x", which="major", pad=15)
ax2.margins(x=0.1)
ax2.grid(True, linestyle="-")

In [None]:
#########################################
#              PARAMETERS               #
#########################################

# layer used to detect outliers based on the patch norms. Set to the last layer (-1) for CLIP and second-to-last layer (-2) for DINOv2 large
detect_outliers_layer = config["detect_outliers_layer"]
register_norm_threshold = config[
    "register_norm_threshold"
]  # threshold for detecting register neurons

print(detect_outliers_layer)

In [None]:
register_neurons = find_register_neurons(
    model_state=state,
    image_path=IMAGE_PATH,
    detect_outliers_layer=detect_outliers_layer,
    processed_image_cnt=1000,
    device="cuda",
    register_norm_threshold=register_norm_threshold,
    apply_sparsity_filter=True,
)

In [None]:
torch.save(register_neurons, "register_neurons.pt")

In [None]:
register_neurons = torch.load("register_neurons.pt")

In [None]:
#########################################
#              PARAMETERS               #
#########################################
top_k = config["top_k"]
highest_layer = config["highest_layer"]
num_registers = 1

print(highest_layer)

In [None]:
filtered_register_neurons = filter_layers(register_neurons, highest_layer=highest_layer)

neurons_to_ablate = dict()
for layer, neuron, score in filtered_register_neurons[:top_k]:
    if layer not in neurons_to_ablate:
        neurons_to_ablate[layer] = []
    neurons_to_ablate[layer].append(neuron)
print(neurons_to_ablate)

In [None]:
print(preprocess)

In [None]:
random_images = load_images(IMAGE_PATH, count=50)
image_norms = []
register_norms = []
image_attentions = []
register_attentions = []
for i in tqdm(range(len(random_images)), desc="Processing random images"):
    image = preprocess(random_images[i]).unsqueeze(0).to(device)

    hook_manager.reinit()
    hook_manager.intervene_register_neurons(
        neurons_to_ablate=neurons_to_ablate,
        num_registers=num_registers,
        normal_values="zero",
        scale=1,
    )
    hook_manager.finalize()
    representation = run_model(model, image, num_registers=num_registers)
    attention_maps = hook_manager.get_attention_maps()
    layer_outputs = hook_manager.get_layer_outputs()

    layer_norms = np.linalg.norm(layer_outputs[detect_outliers_layer], axis=1)

    image_patch_norms = layer_norms[1 : patch_height * patch_width + 1]
    register_patch_norms = layer_norms[patch_height * patch_width + 1 :]

    image_norms.extend(image_patch_norms.tolist())
    register_norms.extend(register_patch_norms.tolist())

    image_attentions.append(
        np.max(attention_maps[-1, :, 0, 1 : patch_height * patch_width + 1])
    )
    register_attentions.append(
        np.max(attention_maps[-1, :, 0, patch_height * patch_width + 1 :])
    )

In [None]:
# Process images with register intervention
random_images = load_images(IMAGE_PATH, count=50)
image_norms = []
register_norms = []
image_attentions = []
register_attentions = []

for i in tqdm(range(len(random_images)), desc="Processing random images"):
    image = preprocess(random_images[i]).unsqueeze(0).to(device)
    hook_manager.reinit()
    hook_manager.intervene_register_neurons(
        neurons_to_ablate=neurons_to_ablate,
        num_registers=num_registers,
        normal_values="zero",
        scale=1,
    )
    hook_manager.finalize()
    representation = run_model(model, image, num_registers=num_registers)
    attention_maps = hook_manager.get_attention_maps()
    layer_outputs = hook_manager.get_layer_outputs()

    layer_norms = np.linalg.norm(layer_outputs[detect_outliers_layer], axis=1)
    image_patch_norms = layer_norms[1 : patch_height * patch_width + 1]
    register_patch_norms = layer_norms[patch_height * patch_width + 1 :]

    image_norms.extend(image_patch_norms.tolist())
    register_norms.extend(register_patch_norms.tolist())
    image_attentions.append(
        np.max(attention_maps[-1, :, 0, 1 : patch_height * patch_width + 1])
    )
    register_attentions.append(
        np.max(attention_maps[-1, :, 0, patch_height * patch_width + 1 :])
    )

# Now compute layer-wise statistics across all images
max_patch_norms = [0] * num_layers
max_attn_norms = [0] * num_layers

for image in tqdm(random_images, desc="Computing layer-wise statistics"):
    processed_image = preprocess(image).unsqueeze(0).to(device)
    hook_manager.reinit()
    hook_manager.intervene_register_neurons(
        neurons_to_ablate=neurons_to_ablate,
        num_registers=num_registers,
        normal_values="zero",
        scale=1,
    )
    hook_manager.finalize()
    representation = run_model(model, processed_image, num_registers=num_registers)
    attention_maps = hook_manager.get_attention_maps()  # shape (L, H, N, N)
    layer_outputs = hook_manager.get_layer_outputs()  # shape (L, N, D)

    # Compute max patch norms per layer
    patch_norms = np.max(
        np.linalg.norm(layer_outputs[:, 1 : patch_height * patch_width + 1], axis=2),
        axis=1,
    )
    # Compute max attention to image patches from CLS token per layer
    attn_norms = np.max(
        np.mean(attention_maps[:, :, 0, 1 : patch_height * patch_width + 1], axis=1),
        axis=1,
    )

    for j in range(num_layers):
        max_attn_norms[j] += attn_norms[j]
        max_patch_norms[j] += patch_norms[j]

# Average over all images
max_attn_norms = [x / len(random_images) for x in max_attn_norms]
max_patch_norms = [x / len(random_images) for x in max_patch_norms]

# Create the visualization

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), dpi=100)

# === Subplot 1: Patch Norms ===
ax1.plot(
    range(num_layers),
    max_patch_norms,
    marker="o",
    markersize=8,
    color="steelblue",
    linestyle="-",
    linewidth=2,
)
ax1.set_title("Average Max Patch Norms (Layer output) - With Register Ablation")
ax1.set_xlabel("Layer")
ax1.set_ylabel("Norm")
ax1.set_xticks(range(0, num_layers, 2))
ax1.tick_params(axis="both")
ax1.tick_params(axis="x", which="major", pad=15)
ax1.margins(x=0.1)
ax1.grid(True, linestyle="-")

# === Subplot 2: Attention Norms ===
ax2.plot(
    range(num_layers),
    max_attn_norms,
    marker="^",
    markersize=8,
    color="orange",
    linestyle="-",
    linewidth=2,
)
ax2.set_title("Average Max Attention (CLS) - With Register Ablation")
ax2.set_xlabel("Layer")
ax2.set_ylabel("Attention Value")
ax2.set_xticks(range(0, num_layers, 2))
ax2.tick_params(axis="both")
ax2.tick_params(axis="x", which="major", pad=15)
ax2.margins(x=0.1)
ax2.grid(True, linestyle="-")

plt.tight_layout()
plt.show()

In [None]:
# Plot the image norms and register norms as histograms side by side
plt.figure(figsize=(10, 12))

# Plot image norms
plt.subplot(2, 2, 1)
plt.hist(image_norms, bins=50, alpha=0.7, color="blue")
plt.xlabel("Norm Value")
plt.ylabel("Frequency")
plt.title("Image Patch Norms")
plt.grid(True, alpha=0.3)
plt.axvline(
    x=np.median(image_norms),
    color="r",
    linestyle="--",
    label=f"Median: {np.median(image_norms):.2f}",
)
plt.legend()

# Plot register norms
plt.subplot(2, 2, 2)
plt.hist(register_norms, bins=50, alpha=0.7, color="green")
plt.xlabel("Norm Value")
plt.ylabel("Frequency")
plt.title("Register Patch Norms")
plt.grid(True, alpha=0.3)
plt.axvline(
    x=np.median(register_norms),
    color="r",
    linestyle="--",
    label=f"Median: {np.median(register_norms):.2f}",
)
plt.legend()

# Plot image attentions
plt.subplot(2, 2, 3)
plt.hist(image_attentions, bins=50, alpha=0.7, color="blue")
plt.xlabel("Attention Value")
plt.ylabel("Frequency")
plt.title("Image Patch Attentions")
plt.grid(True, alpha=0.3)
plt.axvline(
    x=np.median(image_attentions),
    color="r",
    linestyle="--",
    label=f"Median: {np.median(image_attentions):.2f}",
)
plt.legend()

# Plot register attentions
plt.subplot(2, 2, 4)
plt.hist(register_attentions, bins=50, alpha=0.7, color="green")
plt.xlabel("Attention Value")
plt.ylabel("Frequency")
plt.title("Register Patch Attentions")
plt.grid(True, alpha=0.3)
plt.axvline(
    x=np.median(register_attentions),
    color="r",
    linestyle="--",
    label=f"Median: {np.median(register_attentions):.2f}",
)
plt.legend()

plt.tight_layout()
plt.show()

# Print some statistics for comparison
print(
    f"Image norms - Min: {min(image_norms):.2f}, Max: {max(image_norms):.2f}, Mean: {np.mean(image_norms):.2f}"
)
print(
    f"Register norms - Min: {min(register_norms):.2f}, Max: {max(register_norms):.2f}, Mean: {np.mean(register_norms):.2f}"
)
print(
    f"Image attentions - Min: {min(image_attentions):.2f}, Max: {max(image_attentions):.2f}, Mean: {np.mean(image_attentions):.2f}"
)
print(
    f"Register attentions - Min: {min(register_attentions):.2f}, Max: {max(register_attentions):.2f}, Mean: {np.mean(register_attentions):.2f}"
)

In [None]:
image = load_images(IMAGE_PATH, count=1)[0]

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("off")
plt.title("Input")
plt.show()

In [None]:
processed_image = preprocess(image).unsqueeze(0).to(device)

hook_manager.reinit()
hook_manager.finalize()
representation = run_model(model, processed_image)
original_attention_maps = hook_manager.get_attention_maps()
original_layer_outputs = hook_manager.get_layer_outputs()

hook_manager.reinit()
hook_manager.intervene_register_neurons(
    neurons_to_ablate=neurons_to_ablate, num_registers=num_registers
)
hook_manager.finalize()
representation = run_model(model, processed_image, num_registers=num_registers)
ablated_attention_maps = hook_manager.get_attention_maps()
ablated_layer_outputs = hook_manager.get_layer_outputs()
ablated_neuron_activations = hook_manager.get_neuron_activations()

In [None]:
# Norm map of output patch embeddings - baseline and ablated comparison for all layers

# Import the necessary module for make_axes_locatable
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Create a figure with subplots for each layer (1 row per layer)
fig, axs = plt.subplots(num_layers, 2, figsize=(16, 4 * num_layers))
fig.subplots_adjust(hspace=0.5, wspace=0.3)

# Plot norm maps for each layer
for layer in range(num_layers):
    # Calculate norms for baseline and ablated outputs

    # Calculate norms for outputs
    baseline_output_norms_flat = np.linalg.norm(
        original_layer_outputs[layer, 1:], axis=1
    )
    ablated_output_norms_flat = np.linalg.norm(ablated_layer_outputs[layer, 1:], axis=1)

    # Handle non-square reshaping
    def reshape_with_extras(flat_array, patch_height, patch_width):
        total_patches = len(flat_array)
        if total_patches == patch_height * patch_width:
            # Perfect square case
            return flat_array.reshape((patch_height, patch_width)), None
        else:
            # Non-square case
            square_part = flat_array[: patch_height * patch_width].reshape(
                (patch_height, patch_width)
            )
            extra_part = flat_array[patch_height * patch_width :]
            return square_part, extra_part

    # Reshape with handling for extra values
    baseline_output_norms, baseline_output_extras = reshape_with_extras(
        baseline_output_norms_flat, patch_height, patch_width
    )
    ablated_output_norms, ablated_output_extras = reshape_with_extras(
        ablated_output_norms_flat, patch_height, patch_width
    )

    # Plot baseline output
    im3 = axs[layer, 0].imshow(baseline_output_norms, cmap="viridis")
    extra_info = ""
    axs[layer, 0].set_title(
        f"Layer {layer} - Original (output){extra_info}", fontsize=14
    )
    axs[layer, 0].set_xlabel("Patch X", fontsize=12)
    axs[layer, 0].set_ylabel("Patch Y", fontsize=12)

    # Add colorbar for output original
    divider = make_axes_locatable(axs[layer, 0])
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cbar = fig.colorbar(im3, cax=cax)
    cbar.set_label(f"Layer {layer} Output Norm", fontsize=12)

    # Plot ablated output
    im4 = axs[layer, 1].imshow(ablated_output_norms, cmap="viridis")
    extra_info = ""
    axs[layer, 1].set_title(
        f"Layer {layer} - Ablated (output){extra_info}", fontsize=14
    )
    axs[layer, 1].set_xlabel("Patch X", fontsize=12)
    axs[layer, 1].set_ylabel("Patch Y", fontsize=12)

    # Add colorbar for output ablated
    divider = make_axes_locatable(axs[layer, 1])
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cbar = fig.colorbar(im4, cax=cax)
    cbar.set_label(f"Layer {layer} Output Norm", fontsize=12)

fig.suptitle("Norm Maps of Image Patches Across All Layers", fontsize=20, y=1.00)
plt.tight_layout()
plt.show()

In [None]:
# Original attention maps
plt = plot_attn_maps(
    original_attention_maps[:, :, 0, 1 : patch_height * patch_width + 1].reshape(
        (num_layers, num_heads, patch_height, patch_width)
    )
)
plt.show()

In [None]:
# Ablated attention maps
plt = plot_attn_maps(
    ablated_attention_maps[:, :, 0, 1 : patch_height * patch_width + 1].reshape(
        (num_layers, num_heads, patch_height, patch_width)
    )
)
plt.show()

In [None]:
train_df = pd.read_csv("/kaggle/input/orthonet-data/train.csv")
train_df, val_df = train_test_split(
    train_df, test_size=0.2, stratify=train_df["labels"], random_state=42
)
test_df = pd.read_csv("/kaggle/input/orthonet-data/test.csv")
img_dir = "/kaggle/input/orthonet-data/orthonet data/orthonet data"

classes = sorted(train_df["labels"].unique())
class_to_idx = {c: i for i, c in enumerate(classes)}
num_classes = len(classes)

In [None]:
print(neurons_to_ablate)

In [None]:
class OrthoNetDataset(Dataset):
    def __init__(self, dataframe, img_dir, class_to_idx, model, ttr=False):
        self.dataframe = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.class_to_idx = class_to_idx
        # Remove this line - transform is not defined
        # self.transform = transform
        self.model = model
        self.device = device
        self.ttr = ttr
        
        # Set model to eval mode
        self.model.eval()
        self.model.to(device)
        
        self.representations = []
        self.labels = []
        
        with torch.no_grad():  # Add this for efficiency and to prevent gradient tracking
            for _, row in tqdm(self.dataframe.iterrows(), 
                              total=len(self.dataframe), 
                              desc="Processing images"):
                img_path = os.path.join(self.img_dir, row["filenames"])
                img = Image.open(img_path).convert("RGB")
                processed_image = preprocess(img).unsqueeze(0).to(device)
                
                hook_manager.reinit()
                if ttr:
                    hook_manager.intervene_register_neurons(
                        neurons_to_ablate=neurons_to_ablate,
                        num_registers=num_registers,
                        normal_values="zero",
                        scale=1,
                    )
                hook_manager.finalize()
                
                # Fixed: run_model should take (model, image, num_registers)
                representation = run_model(model, processed_image, num_registers=num_registers)
                representation = representation.squeeze(0).cpu()
                
                self.representations.append(representation)
                self.labels.append(self.class_to_idx[row["labels"]])
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        representation = self.representations[idx]
        label = self.labels[idx]
        
        return representation, label

In [None]:
train_dataset_no_ttr = OrthoNetDataset(
    train_df, img_dir, class_to_idx, model, ttr = False
)
val_dataset_no_ttr = OrthoNetDataset(
    val_df, img_dir, class_to_idx, model, ttr = False
)
test_dataset_no_ttr = OrthoNetDataset(
    test_df, img_dir, class_to_idx, model, ttr = False
)

train_dataset_ttr = OrthoNetDataset(
    train_df, img_dir, class_to_idx, model, ttr = True
)
val_dataset_ttr = OrthoNetDataset(
    val_df, img_dir, class_to_idx, model, ttr = True
)
test_dataset_ttr = OrthoNetDataset(
    test_df, img_dir, class_to_idx, model, ttr = True
)

In [None]:
train_loader_no_ttr = DataLoader(train_dataset_no_ttr, batch_size=32, shuffle=True, num_workers=4)
train_loader_ttr = DataLoader(train_dataset_ttr, batch_size=32, shuffle=True, num_workers=4)
val_loader_no_ttr = DataLoader(val_dataset_no_ttr, batch_size=32, shuffle=True, num_workers=4)
val_loader_ttr = DataLoader(val_dataset_ttr, batch_size=32, shuffle=True, num_workers=4)
test_loader_no_ttr = DataLoader(test_dataset_no_ttr, batch_size=32, shuffle=True, num_workers=4)
test_loader_ttr = DataLoader(test_dataset_ttr, batch_size=32, shuffle=True, num_workers=4)

In [None]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim=512, num_classes=10):
        super(LinearClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

In [None]:
def compute_metrics(y_true, y_pred, y_prob):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    metrics = {}
    metrics["accuracy"] = (y_true == y_pred).mean()

    top3 = np.argsort(-y_prob, axis=1)[:, :3]
    top3_acc = np.mean([y_true[i] in top3[i] for i in range(len(y_true))])
    metrics["top3_accuracy"] = top3_acc

    metrics["f1"] = f1_score(y_true, y_pred, average="macro", zero_division=0)
    metrics["precision"] = precision_score(
        y_true, y_pred, average="macro", zero_division=0
    )
    metrics["recall"] = recall_score(y_true, y_pred, average="macro", zero_division=0)

    try:
        metrics["auc_roc"] = roc_auc_score(y_true, y_prob, multi_class="ovr")
    except:
        metrics["auc_roc"] = None

    return metrics

In [None]:
def train(
    model, train_loader, val_loader, criterion, optimizer, scheduler, method, epochs
):
    train_history, val_history = [], []

    best_val_acc = 0.0
    best_model_path = "best_model_" + method + ".pth"

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        all_labels, all_preds, all_probs = [], [], []

        for images, labels in tqdm(
            train_loader, desc=f"Train Epoch {epoch+1}/{epochs}"
        ):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()
            preds = outputs.argmax(dim=1).detach().cpu().numpy()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds)
            all_probs.extend(probs)

        scheduler.step()
        avg_train_loss = total_loss / len(train_loader)
        train_metrics = compute_metrics(all_labels, all_preds, np.array(all_probs))
        train_metrics["loss"] = avg_train_loss
        train_history.append(train_metrics)

        model.eval()
        total_loss = 0
        all_labels, all_preds, all_probs = [], [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                total_loss += loss.item()

                probs = torch.softmax(outputs, dim=1).cpu().numpy()
                preds = outputs.argmax(dim=1).cpu().numpy()
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds)
                all_probs.extend(probs)

        avg_val_loss = total_loss / len(val_loader)
        val_metrics = compute_metrics(all_labels, all_preds, np.array(all_probs))
        val_metrics["loss"] = avg_val_loss
        val_history.append(val_metrics)

        print(
            f"Epoch {epoch+1}: "
            f"Train Loss={avg_train_loss:.4f}, Train Acc={train_metrics['accuracy']:.4f} | "
            f"Val Loss={avg_val_loss:.4f}, Val Acc={val_metrics['accuracy']:.4f}"
        )

        if val_metrics["accuracy"] > best_val_acc:
            best_val_acc = val_metrics["accuracy"]
            torch.save(model.state_dict(), best_model_path)
            print(
                f"Best model saved at epoch {epoch+1} with Val Loss={avg_val_loss:.4f}"
            )

    with open("train_metrics_" + method + ".json", "w") as f:
        json.dump(train_history, f, indent=4)
    with open("val_metrics_" + method + ".json", "w") as f:
        json.dump(val_history, f, indent=4)

    print(f"\nBest Validation Accuracy: {best_val_acc:.4f}")
    return train_history, val_history

In [None]:
def test(model, test_loader):
    model.eval()
    all_labels, all_preds, all_probs = [], [], []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)

            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds)
            all_probs.extend(probs)

    metrics = compute_metrics(all_labels, all_preds, np.array(all_probs))

    print(f"Test Acc ={metrics['accuracy']:.4f}")

    return metrics

In [None]:
def plot_metrics(train_history, test_history, metric="accuracy"):
    train_vals = [m[metric] for m in train_history]
    test_vals = [m[metric] for m in test_history]

    plt.figure(figsize=(7, 5))
    plt.plot(train_vals, label=f"Train {metric}")
    plt.plot(test_vals, label=f"Validation {metric}")
    plt.xlabel("Epoch")
    plt.ylabel(metric.capitalize())
    plt.title(f"{metric.capitalize()} over epochs")
    plt.legend()
    plt.show()

In [None]:
# Usage example:
classifier = LinearClassifier(input_dim=1024, num_classes=num_classes)
classifier = classifier.to(device)

print(classifier)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(classifier.parameters(), lr=1e-3, weight_decay=0.05)
epochs = 50

T_max = epochs
eta_min = 1e-5
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=T_max, eta_min=eta_min
)

method = "no_ttr"
train_history, val_history = train(
    classifier,
    train_loader_no_ttr,
    val_loader_no_ttr,
    criterion,
    optimizer,
    scheduler,
    method,
    epochs,
)

plot_metrics(train_history, val_history, metric="loss")
plot_metrics(train_history, val_history, metric="accuracy")
plot_metrics(train_history, val_history, metric="top3_accuracy")
plot_metrics(train_history, val_history, metric="f1")
plot_metrics(train_history, val_history, metric="precision")
plot_metrics(train_history, val_history, metric="recall")
plot_metrics(train_history, val_history, metric="auc_roc")

In [None]:
method = "no_ttr"
classifier.load_state_dict(torch.load("best_model_" + method + ".pth"))

test_metrics = test(classifier, test_loader_no_ttr)
with open("test_metrics_" + method + ".json", "w") as f:
    json.dump(test_metrics, f, indent=4)

In [None]:
classifier = LinearClassifier(input_dim=1024, num_classes=num_classes)
classifier = classifier.to(device)

print(classifier)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(classifier.parameters(), lr=1e-3, weight_decay=0.05)
epochs = 50

T_max = epochs
eta_min = 1e-5
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=T_max, eta_min=eta_min
)

method = "ttr"
train_history, val_history = train(
    classifier,
    train_loader_ttr,
    val_loader_ttr,
    criterion,
    optimizer,
    scheduler,
    method,
    epochs,
)

plot_metrics(train_history, val_history, metric="loss")
plot_metrics(train_history, val_history, metric="accuracy")
plot_metrics(train_history, val_history, metric="top3_accuracy")
plot_metrics(train_history, val_history, metric="f1")
plot_metrics(train_history, val_history, metric="precision")
plot_metrics(train_history, val_history, metric="recall")
plot_metrics(train_history, val_history, metric="auc_roc")

In [None]:
method = "ttr"
classifier.load_state_dict(torch.load("best_model_" + method + ".pth"))

test_metrics = test(classifier, test_loader_ttr)
with open("test_metrics_" + method + ".json", "w") as f:
    json.dump(test_metrics, f, indent=4)