In [None]:
!pip install gspread oauth2client
!pip install hilbertcurve
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from tqdm import tqdm
from google.colab import drive
import matplotlib.pyplot as plt
import datetime
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
from IPython.display import Javascript, display

drive.mount('/content/drive')

import random
import numpy as np

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

##########################
# Configuration
##########################
DATA_DIR = "/content/drive/MyDrive/BRACS/MyTransformer"

TRAIN_CSV = os.path.join(DATA_DIR, "train.csv")
VAL_CSV = os.path.join(DATA_DIR, "val.csv")
TEST_CSV = os.path.join(DATA_DIR, "test.csv")

TRAIN_EMBED_DIR = os.path.join(DATA_DIR, "train")
VAL_EMBED_DIR = os.path.join(DATA_DIR, "val")
TEST_EMBED_DIR = os.path.join(DATA_DIR, "test")

NUM_CLASSES = 3
EMBED_DIM = 1536
MAX_TILES = 600
BATCH_SIZE = 16
LR = 1e-4
EPOCHS = 70

# Parameter For Testing Ordering Methods
# spiral or 2dsinusodial or rasterscan
# or rasterscanwencoding or spiralwencoding
ORDER_METHOD = "spiralwencoding";
# Scaling factor (alpha) that controls the influence/weight of positional encoding
ALPHA_ENCOD = 1;

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Create an experiment folder with timestamp
timestamp = datetime.datetime.now().strftime("%d.%m-%H:%M")
EXP_DIR = f"/content/drive/MyDrive/BRACS/experiments/{ORDER_METHOD}/{ALPHA_ENCOD}-{timestamp}-2d-6head4layer-70epoch-drop20-focalLoss"
print("Experiment directory set to: ", EXP_DIR)
CHECKPOINT_DIR = os.path.join(EXP_DIR, "checkpoints")
PLOTS_DIR = os.path.join(EXP_DIR, "plots")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

# Initialize lists to store metrics
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

# New Metrics
train_precisions = []
train_recalls = []
train_f1s = []

val_precisions = []
val_recalls = []
val_f1s = []

##########################
# 2D Sinusoidal Position Embedding Functions
##########################
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int (same for height and width)
    return: pos_embed of shape
        [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim]
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # w first
    grid = np.stack(grid, axis=0)

    # example
    # grid_h = [0, 1, 2]  # y-coordinates (rows)
    # grid_w = [0, 1, 2]  # x-coordinates (columns)

    # grid = np.meshgrid(grid_w, grid_h)

    # grid =
    # [[[0 1 2]   # X-coordinates (width)
    #   [0 1 2]
    #   [0 1 2]]

    # [[0 0 0]   # Y-coordinates (height)
    #   [1 1 1]
    #   [2 2 2]]]

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
    emb = np.concatenate([emb_h, emb_w], axis=1)
    return emb

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.0
    omega = 1.0 / (10000 ** omega)

    pos = pos.reshape(-1)
    out = np.einsum("m,d->md", pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

##########################
# Dataset
##########################
def get_1d_sincos_pos_embed(embed_dim, seq_len):
        """
        Generate 1D sinusoidal positional embeddings
        embed_dim: Embedding dimension (e.g., 1536)
        seq_len: The sequence length (e.g., number of tiles)
        """
        positions = torch.arange(seq_len).unsqueeze(1).float()  # Shape: [seq_len, 1]
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))  # Shape: [embed_dim//2]

        # Apply sine and cosine functions
        pos_encoding = torch.zeros(seq_len, embed_dim)
        pos_encoding[:, 0::2] = torch.sin(positions * div_term)  # Even dimensions: Apply sin
        pos_encoding[:, 1::2] = torch.cos(positions * div_term)  # Odd dimensions: Apply cos

        return pos_encoding

class WSIDataset(Dataset):
    def __init__(self, df, embed_dir, max_tiles=600, order_method="rasterscan", transform=None):
        self.df = df
        self.embed_dir = embed_dir
        self.max_tiles = max_tiles
        self.transform = transform
        self.order_method = order_method

        # We pre-generate a 1000 x 1000 grid of position embeddings.
        pos_embed_2d = get_2d_sincos_pos_embed(
            embed_dim=EMBED_DIM,
            grid_size=800,     # as per your instruction
            cls_token=False
        )
        # Convert to torch.Tensor
        self.pos_embed_2d = torch.from_numpy(pos_embed_2d).float()  # shape [1000*1000, EMBED_DIM]


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        slide_id = row['slide_id']
        label = int(row['label'])

        # Load the embedding
        embed_path = os.path.join(self.embed_dir, f"{slide_id}.pt")
        data = torch.load(embed_path)
        tile_embeds = data["tile_embeds"]  # [N, EMBED_DIM]

        # If coords are available, reorder tiles
        if "coords" in data:
            coords = data["coords"]  # coords is [N, 2] with (x, y)

            if(self.order_method == "spiral"):
                tile_embeds = self.get_spiral_order_embeddings(coords, tile_embeds)
                seq_len = tile_embeds.shape[0]

            elif(self.order_method == "spiralwencoding"):
                tile_embeds = self.get_spiral_order_embeddings(coords, tile_embeds)
                seq_len = tile_embeds.shape[0]
                pos_encoding = get_1d_sincos_pos_embed(EMBED_DIM, seq_len)
                tile_embeds = tile_embeds + (ALPHA_ENCOD * pos_encoding)

            elif(self.order_method == "rasterscan"):
                #  sort the tiles by their coordinates to ensure row major ordering
                sorted_indices = torch.argsort(coords[:,0] + coords[:,1] * 100000)
                tile_embeds = tile_embeds[sorted_indices]
                seq_len = tile_embeds.shape[0]

            elif(self.order_method == "rasterscanwencoding"):
                #  sort the tiles by their coordinates to ensure row major ordering
                sorted_indices = torch.argsort(coords[:,0] + coords[:,1] * 100000)
                tile_embeds = tile_embeds[sorted_indices]
                seq_len = tile_embeds.shape[0]
                pos_encoding = get_1d_sincos_pos_embed(EMBED_DIM, seq_len)
                tile_embeds = tile_embeds + (ALPHA_ENCOD * pos_encoding)

            elif(self.order_method == "2dsinusodial"):
                original_embeds_before_pos = tile_embeds.clone() # Keep a copy before adding PE
                tile_embeds = self.apply_2d_position_embed(tile_embeds, coords)
                coords_for_pe = torch.floor(coords / 256.0)
                x_coords = coords_for_pe[:, 0].numpy()
                y_coords = coords_for_pe[:, 1].numpy()
                pos_1d_list = [int(y_c * 800 + x_c) for x_c, y_c in zip(x_coords, y_coords)]
                pos_1d_tensor = torch.tensor(pos_1d_list, dtype=torch.long)
                positional_embeddings_2d = self.pos_embed_2d[pos_1d_tensor]
                # print("2D Positional Embeddings Min:", positional_embeddings_2d.min().item())
                # print("2D Positional Embeddings Max:", positional_embeddings_2d.max().item())
                # print("-" * 30)

            elif self.order_method == "hilbertwencoding":
                tile_embeds = self.get_hilbert_order_embeddings(coords, tile_embeds, visualize=True)
                seq_len = tile_embeds.shape[0]
                pos_encoding = get_1d_sincos_pos_embed(EMBED_DIM, seq_len)
                tile_embeds = tile_embeds + (ALPHA_ENCOD * pos_encoding)

        # Pad/truncate to MAX_TILES
        num_tiles = tile_embeds.shape[0]
        if num_tiles > self.max_tiles:
            tile_embeds = tile_embeds[:self.max_tiles]
        elif num_tiles < self.max_tiles:
            pad_len = self.max_tiles - num_tiles
            pad_embeds = torch.zeros(pad_len, EMBED_DIM)
            tile_embeds = torch.cat([tile_embeds, pad_embeds], dim=0)

        if self.transform:
            tile_embeds = self.transform(tile_embeds)

        return tile_embeds, label

    def apply_2d_position_embed(self, tile_embeds, coords_):
        """
        Apply 2D sincos positional embeddings based on (x,y) coords
        (divided by 256 => up to a 1000x1000 grid).
        """
        coords = torch.floor(coords_ / 256.0)  # dividing by tile_size=256
        x_coords = coords[:, 0].numpy()
        y_coords = coords[:, 1].numpy()

        # Build 1D positions
        pos_1d_list = []
        for i in range(len(x_coords)):
            x_c = int(x_coords[i])
            y_c = int(y_coords[i])
            pos_1d = y_c * 800 + x_c
            pos_1d_list.append(pos_1d)

        pos_1d_tensor = torch.tensor(pos_1d_list, dtype=torch.long)
        # Add the 2D positional embeddings
        tile_embeds = tile_embeds + self.pos_embed_2d[pos_1d_tensor]

        return tile_embeds

####################################################
#         # hilbert order method
####################################################
    def get_hilbert_order_embeddings(self, coords_, tile_embeds, visualize=True):
        """
        Arrange tile embeddings in Hilbert curve order and visualize traversal if visualize=True.

        Args:
            coords_ (torch.Tensor): Tensor of shape [N, 2] with (x, y) coordinates.
            tile_embeds (torch.Tensor): Tensor of shape [N, EMBED_DIM].
            visualize (bool): Whether to plot the Hilbert traversal or not.

        Returns:
            torch.Tensor: Tensor of shape [num_tiles, EMBED_DIM] arranged in Hilbert order.
        """
        import hilbertcurve.hilbertcurve as hilbert_lib
        import numpy as np
        import matplotlib.pyplot as plt

        coords = torch.floor(coords_ / 256.0)  # Tile size normalization
        x_coords = coords[:, 0].numpy()
        y_coords = coords[:, 1].numpy()

        # Normalize
        min_x, max_x = int(np.min(x_coords)), int(np.max(x_coords))
        min_y, max_y = int(np.min(y_coords)), int(np.max(y_coords))
        norm_x = x_coords - min_x
        norm_y = y_coords - min_y

        width = int(max_x - min_x + 1)
        height = int(max_y - min_y + 1)

        max_dim = max(width, height)
        grid_size = 1
        while grid_size < max_dim:
            grid_size *= 2

        p = int(np.log2(grid_size))
        hilbert_curve = hilbert_lib.HilbertCurve(p, 2)

        hilbert_indices = []
        for x, y in zip(norm_x, norm_y):
            if x < grid_size and y < grid_size:
                hilbert_distance = hilbert_curve.distance_from_point([int(x), int(y)])
                hilbert_indices.append(hilbert_distance)
            else:
                hilbert_indices.append(float('inf'))

        hilbert_indices = np.array(hilbert_indices)
        sorted_indices = np.argsort(hilbert_indices)

        hilbert_order_embeds = tile_embeds[sorted_indices]

        if visualize:
            sorted_coords = np.vstack((norm_x[sorted_indices], norm_y[sorted_indices])).T

            plt.figure(figsize=(8, 8))
            plt.scatter(norm_x, norm_y, c="blue", alpha=0.5, label="Original Tiles")
            plt.plot(sorted_coords[:, 0], sorted_coords[:, 1], color="red", linewidth=1.5, linestyle="dashed", label="Hilbert Path")
            plt.scatter(sorted_coords[:, 0], sorted_coords[:, 1], c="red", label="Hilbert Tiles")

            for i, (x, y) in enumerate(sorted_coords[:10]):
                plt.text(x, y, str(i), fontsize=10, color="black")

            plt.xlabel("X Coordinate (Normalized)")
            plt.ylabel("Y Coordinate (Normalized)")
            plt.title("Hilbert Curve Traversal of Tiles")
            plt.gca().invert_yaxis()
            plt.legend()
            plt.grid(True)
            plt.show()

        return hilbert_order_embeds

####################################################
#         # ORIGINAL METHOD FOR SPIRAL ORDER
####################################################
    def get_spiral_order_embeddings(self, coords_, tile_embeds):
        """
        Arrange tile embeddings in a spiral order without padding for missing tiles.

        Args:
            coords_ (torch.Tensor): Tensor of shape [N, 2] with (x, y) coordinates.
            tile_embeds (torch.Tensor): Tensor of shape [N, EMBED_DIM].

        Returns:
            torch.Tensor: Tensor of shape [num_tiles, EMBED_DIM] arranged in spiral order.
        """
        # Convert coordinates to integer grid indices

        coords = torch.floor(coords_ / 256.0) # Tile size

        x_coords = coords[:, 0].numpy()
        y_coords = coords[:, 1].numpy()

        # Normalize coordinates to start from (0,0)
        min_x, max_x = int(np.min(x_coords)), int(np.max(x_coords))
        min_y, max_y = int(np.min(y_coords)), int(np.max(y_coords))
        norm_x = x_coords - min_x
        norm_y = y_coords - min_y

        width = int(max_x - min_x + 1)
        height = int(max_y - min_y + 1)

        # Create a grid mapping from (x, y) to embedding index
        grid = {}
        for idx, (x, y) in enumerate(zip(norm_x, norm_y)):
            grid[(x, y)] = idx

        # Generate spiral order coordinates
        spiral_coords = self.generate_spiral_coords(width, height)

        # Collect tile embeddings in spiral order
        spiral_order = []
        for coord in spiral_coords:
            idx = grid.get(coord)
            if idx is not None:
                spiral_order.append(tile_embeds[idx])

        # Convert list to tensor
        if len(spiral_order) == 0:
            # Handle case with no tiles
            print("!----- Spiral list is empty")
            return torch.zeros(0, self.embed_dim)
        spiral_embeds = torch.stack(spiral_order)  # [num_tiles, EMBED_DIM]

        return spiral_embeds
####################################################
##    END ORIGINAL METHOD FOR SPIRAL ORDER
####################################################

    def generate_spiral_coords(self, width, height):
            """
            Generate coordinates in spiral order starting from the top-left corner.

            Args:
                width (int): Number of columns in the grid.
                height (int): Number of rows in the grid.

            Returns:
                list of tuples: List of (x, y) coordinates in spiral order.
            """
            spiral_order = []
            top, bottom = 0, height - 1
            left, right = 0, width - 1

            while top <= bottom and left <= right:
                # Traverse from Left to Right
                for x in range(left, right + 1):
                    spiral_order.append((x, top))
                top += 1

                # Traverse from Top to Bottom
                for y in range(top, bottom + 1):
                    spiral_order.append((right, y))
                right -= 1

                if top <= bottom:
                    # Traverse from Right to Left
                    for x in range(right, left - 1, -1):
                        spiral_order.append((x, bottom))
                    bottom -= 1

                if left <= right:
                    # Traverse from Bottom to Top
                    for y in range(bottom, top - 1, -1):
                        spiral_order.append((left, y))
                    left += 1

            return spiral_order

##########################
# Sinusoidal Positional Encoding
##########################
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=601):  # +1 for CLS
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        x = x + self.pe[:seq_len, :]
        return x

##########################
# Transformer Model with Dropout and MLP Head
##########################
class WSITransformer(nn.Module):
    def __init__(self, embed_dim=1536, num_heads=8, num_layers=4, ff_dim=4096, num_classes=3, max_tiles=600, dropout=0.1):
        super().__init__()

        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.normal_(self.cls_token, std=0.02)

        self.pos_encoder = SinusoidalPositionalEncoding(d_model=embed_dim, max_len=max_tiles+1)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim,
                                                   nhead=num_heads,
                                                   dim_feedforward=ff_dim,
                                                   dropout=dropout,
                                                   batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Classification head: MLP with dropout
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(embed_dim, embed_dim//2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(embed_dim//2, num_classes)
        )

    def forward(self, x):
        b, n, d = x.shape

        # CLS token
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Pos encoding/ Don't apply if 2D sinusodial
        #x = self.pos_encoder(x)  --------->??????

        # Transformer
        x = self.transformer_encoder(x)

        # CLS output
        cls_out = x[:, 0, :]
        logits = self.classifier(cls_out)
        return logits

########################
### Loss
########################

class FocalLoss(nn.Module):
    """
    Focal Loss for multi-class classification
    Reference:
    https://arxiv.org/pdf/1708.02002.pdf
    """
    def __init__(self, alpha=None, gamma=2, reduction='mean'):
        """
        Args:
            alpha (float or list, optional): Weighting factor for each class. Default is None.
            gamma (float): Focusing parameter for modulating factor (1-p).
            reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        if isinstance(alpha, (list, np.ndarray)):
            self.alpha = torch.tensor(alpha, dtype=torch.float32)
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Args:
            inputs: Predictions of the model (logits) with shape [batch_size, num_classes].
            targets: Ground truth labels with shape [batch_size].
        """
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha.to(inputs.device) if self.alpha is not None else None)
        pt = torch.exp(-ce_loss)  # prevents nans when probability 0
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

##########################
# Data Preparation
##########################
def print_label_distribution(df, split_name):
    label_names = {0: 'Benign', 1: 'Atypical', 2: 'Malignant'}

    counts = df['label'].value_counts().sort_index()
    print(f"=== {split_name} Set Label Distribution ===")
    for label in sorted(counts.index):
        label_name = label_names.get(label, f"Unknown ({label})")
        count = counts[label]
        print(f"{label_name} (Label {label}): {count}")
    print("\n")

train_df = pd.read_csv(TRAIN_CSV)
val_df = pd.read_csv(VAL_CSV)
test_df = pd.read_csv(TEST_CSV)

# Initialize datasets with their respective embedding directories
train_dataset = WSIDataset(train_df, TRAIN_EMBED_DIR, max_tiles=MAX_TILES, order_method=ORDER_METHOD)
test_dataset = WSIDataset(val_df, VAL_EMBED_DIR, max_tiles=MAX_TILES, order_method=ORDER_METHOD)
val_dataset = WSIDataset(test_df, TEST_EMBED_DIR, max_tiles=MAX_TILES, order_method=ORDER_METHOD)

print("ORDER METHOD", ORDER_METHOD )

print_label_distribution(train_df, "Train")
print_label_distribution(val_df, "Validation")
print_label_distribution(test_df, "Test")

print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

##########################
# Training Setup
##########################
model = WSITransformer(embed_dim=EMBED_DIM,
                       num_heads=6,
                       num_layers=4,
                       ff_dim=4096,
                       num_classes=NUM_CLASSES,
                       max_tiles=MAX_TILES,
                       dropout=0.20).to(device)


class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(train_df['label']), y=train_df['label'])
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

# Initialize FocalLoss with class weights
print(class_weights)
criterion = FocalLoss(gamma=1, alpha=class_weights, reduction='mean')
#criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)

# ADDING A LEARNING RATE SCHEDULER
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

best_val_acc = 0.0

##########################
# Evaluation Function
##########################
def evaluate(model, loader, criterion, compute_metrics=True):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for tile_embeds, labels in loader:
            tile_embeds = tile_embeds.to(device)
            labels = labels.to(device)

            logits = model(tile_embeds)
            loss = criterion(logits, labels)

            total_loss += loss.item() * tile_embeds.size(0)
            _, predicted = torch.max(logits, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            if compute_metrics:
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
    avg_loss = total_loss / total if total > 0 else 0
    accuracy = correct / total if total > 0 else 0

    if compute_metrics:
        precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        return avg_loss, accuracy, precision, recall, f1, all_preds, all_labels
    else:
        return avg_loss, accuracy


##########################
# Training Loop
##########################
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0
    correct = 0
    total = 0
    all_train_preds = []
    all_train_labels = []

    for tile_embeds, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        tile_embeds = tile_embeds.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(tile_embeds)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * tile_embeds.size(0)
        _, predicted = torch.max(logits, dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        all_train_preds.extend(predicted.cpu().numpy())
        all_train_labels.extend(labels.cpu().numpy())

    train_loss = epoch_loss / total if total > 0 else 0
    train_acc = correct / total if total > 0 else 0
    train_precision = precision_score(all_train_labels, all_train_preds, average='macro', zero_division=0)
    train_recall = recall_score(all_train_labels, all_train_preds, average='macro', zero_division=0)
    train_f1 = f1_score(all_train_labels, all_train_preds, average='macro', zero_division=0)

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    train_precisions.append(train_precision)
    train_recalls.append(train_recall)
    train_f1s.append(train_f1)

    val_loss, val_acc, val_precision, val_recall, val_f1, temp1 , temp2 = evaluate(model, val_loader, criterion)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    val_precisions.append(val_precision)
    val_recalls.append(val_recall)
    val_f1s.append(val_f1)

    # Step the scheduler with validation loss
    scheduler.step(val_loss)

    print(f"Epoch [{epoch+1}/{EPOCHS}]")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}, Train F1: {train_f1:.4f}")
    print(f"Val Loss:   {val_loss:.4f}, Val Acc:   {val_acc:.4f}, Val Precision: {val_precision:.4f}, Val Recall: {val_recall:.4f}, Val F1: {val_f1:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, "best_model.pth"))
    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f"model_epoch_{epoch+1}.pth"))

##########################
# Evaluate on Test Set
##########################
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "best_model.pth"), map_location=device))
print(f"Best Val: {best_val_acc}")
test_loss, test_acc, test_precision, test_recall, test_f1, test_labels, test_preds = evaluate(model, test_loader, criterion)
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test Precision: {test_precision:.4f}, Test Recall: {test_recall:.4f}, Test F1: {test_f1:.4f}")

# Define Label Names
label_names = ['Benign', 'Atypical', 'Malignant']

# Compute Confusion Matrix
cm = confusion_matrix(test_labels, test_preds)

# Plot Confusion Matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_names, yticklabels=label_names)
plt.xlabel('True Label')
plt.ylabel('Predicted Label')
plt.title('Confusion Matrix')
plt.savefig(os.path.join(PLOTS_DIR, 'confusion_matrix.png'))
plt.close()


# Plotting Loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, EPOCHS+1), train_losses, label='Train Loss')
plt.plot(range(1, EPOCHS+1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(PLOTS_DIR, 'loss_curve.png'))
plt.close()

# Plotting Accuracy
plt.figure(figsize=(10, 5))
plt.plot(range(1, EPOCHS+1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, EPOCHS+1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Curve')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(PLOTS_DIR, 'accuracy_curve.png'))
plt.close()

# Plotting Precision
plt.figure(figsize=(10, 5))
plt.plot(range(1, EPOCHS+1), train_precisions, label='Train Precision')
plt.plot(range(1, EPOCHS+1), val_precisions, label='Validation Precision')
plt.xlabel('Epoch')
plt.ylabel('Precision')
plt.title('Precision Curve')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(PLOTS_DIR, 'precision_curve.png'))
plt.close()

# Plotting Recall
plt.figure(figsize=(10, 5))
plt.plot(range(1, EPOCHS+1), train_recalls, label='Train Recall')
plt.plot(range(1, EPOCHS+1), val_recalls, label='Validation Recall')
plt.xlabel('Epoch')
plt.ylabel('Recall')
plt.title('Recall Curve')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(PLOTS_DIR, 'recall_curve.png'))
plt.close()

# Plotting F1-Score
plt.figure(figsize=(10, 5))
plt.plot(range(1, EPOCHS+1), train_f1s, label='Train F1-Score')
plt.plot(range(1, EPOCHS+1), val_f1s, label='Validation F1-Score')
plt.xlabel('Epoch')
plt.ylabel('F1-Score')
plt.title('F1-Score Curve')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(PLOTS_DIR, 'f1_score_curve.png'))
plt.close()


# Save metrics and parameters to a text file
with open(os.path.join(EXP_DIR, "experiment_details.txt"), "w") as f:
    f.write("Experiment Configuration:\n")
    f.write(f"Timestamp: {timestamp}\n")
    f.write(f"NUM_CLASSES: {NUM_CLASSES}\n")
    f.write(f"EMBED_DIM: {EMBED_DIM}\n")
    f.write(f"MAX_TILES: {MAX_TILES}\n")
    f.write(f"BATCH_SIZE: {BATCH_SIZE}\n")
    f.write(f"LR: {LR}\n")
    f.write(f"EPOCHS: {EPOCHS}\n")
    f.write(f"Model Architecture:\n{model}\n")
    f.write("\nTraining Metrics:\n")
    f.write("Epoch\tTrain Loss\tTrain Acc\tTrain Precision\tTrain Recall\tTrain F1\tVal Loss\tVal Acc\tVal Precision\tVal Recall\tVal F1\n")
    for epoch in range(EPOCHS):
        f.write(f"{epoch+1}\t")
        f.write(f"{train_losses[epoch]:.4f}\t")
        f.write(f"{train_accuracies[epoch]:.4f}\t")
        f.write(f"{train_precisions[epoch]:.4f}\t")
        f.write(f"{train_recalls[epoch]:.4f}\t")
        f.write(f"{train_f1s[epoch]:.4f}\t")
        f.write(f"{val_losses[epoch]:.4f}\t")
        f.write(f"{val_accuracies[epoch]:.4f}\t")
        f.write(f"{val_precisions[epoch]:.4f}\t")
        f.write(f"{val_recalls[epoch]:.4f}\t")
        f.write(f"{val_f1s[epoch]:.4f}\n")


    f.write(f"\nORDER_METHOD: {ORDER_METHOD}\n")
    f.write(f"ALPHA_ENCOD: {ALPHA_ENCOD}\n")
    f.write(f"\nBest Validation Accuracy: {best_val_acc:.4f}\n")
    f.write(f"Test Metrics:\n")
    f.write(f"Test Loss: {test_loss:.4f}\n")
    f.write(f"Test Acc: {test_acc:.4f}\n")
    f.write(f"Test Precision: {test_precision:.4f}\n")
    f.write(f"Test Recall: {test_recall:.4f}\n")
    f.write(f"Test F1-Score: {test_f1:.4f}\n")


##########################
# Save Final Test Metrics to Excel
##########################
import pandas as pd
import os
from openpyxl import load_workbook
from openpyxl.styles import PatternFill, Font

# Create a dictionary for the final 9 metrics
final_metrics = {
    "Order": ORDER_METHOD,
    "Alpha": ALPHA_ENCOD,
    "Timestamp": timestamp,
    "Best Val Accuracy": round(float(best_val_acc), 4),
    "Test Loss": round(float(test_loss), 4),
    "Test Accuracy": round(float(test_acc), 4),
    "Test Precision": round(float(test_precision), 4),
    "Test Recall": round(float(test_recall), 4),
    "Test F1-Score": round(float(test_f1), 4)
}

# Convert to single-row DataFrame
df_final_metrics = pd.DataFrame([final_metrics])

# Path to Excel file
metrics_path = os.path.join("/content/drive/MyDrive/BRACS/experiments/", "final_test_metrics.xlsx")

# If file exists, load and append
if os.path.exists(metrics_path):
    df_existing = pd.read_excel(metrics_path)
    df_final = pd.concat([df_existing, df_final_metrics], ignore_index=True)
else:
    df_final = df_final_metrics

# First: save the plain DataFrame without styles
df_final.to_excel(metrics_path, index=False)

# Now reopen and apply styles safely
wb = load_workbook(metrics_path)
ws = wb.active

# Define fills for each Order Method
fill_colors = {
    "spiral":           "ADD8E6",  # Light Blue
    "spiralwencoding":  "87CEEB",  # Sky Blue
    "2dsinusodial":     "98FB98",  # Pale Green
    "rasterscan":       "FFB6C1",  # Light Pink
    "rasterscanwencoding": "FFD700",  # Gold
    "hilbertwencoding": "9370DB",  # Medium Purple
}

# Apply font and coloring
for row_idx in range(2, ws.max_row + 1):  # Skip header row (start from 2)
    order_method = ws[f"A{row_idx}"].value
    fill_color = fill_colors.get(order_method, None)

    for col_idx in range(1, ws.max_column + 1):
        cell = ws.cell(row=row_idx, column=col_idx)

        # Apply Arial font
        cell.font = Font(name='Arial', size=11)

        # Apply background color if matching
        if fill_color:
            cell.fill = PatternFill(start_color=fill_color, end_color=fill_color, fill_type="solid")

# Also set header font to Arial
for col_idx in range(1, ws.max_column + 1):
    cell = ws.cell(row=1, column=col_idx)
    cell.font = Font(name='Arial', bold=True, size=12)

# Save workbook
wb.save(metrics_path)

print("Final test metrics saved with Arial font and colored by Order Method.")

##########################
# End Save Final Test Metrics to Excel
##########################

# Load best model before testing
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "best_model.pth"), map_location=device))
print(f"Best Val: {best_val_acc}")

# Evaluate on test set
test_loss, test_acc, test_precision, test_recall, test_f1, test_labels, test_preds = evaluate(model, test_loader, criterion)
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test Precision: {test_precision:.4f}, Test Recall: {test_recall:.4f}, Test F1: {test_f1:.4f}")


beep()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
cuda
Experiment directory set to:  /content/drive/MyDrive/BRACS/experiments/spiral/1-30.04-18:40-2d-6head4layer-70epoch-drop20-focalLoss
ORDER METHOD spiral
=== Train Set Label Distribution ===
Benign (Label 0): 202
Atypical (Label 1): 52
Malignant (Label 2): 140


=== Validation Set Label Distribution ===
Benign (Label 0): 30
Atypical (Label 1): 14
Malignant (Label 2): 21


=== Test Set Label Distribution ===
Benign (Label 0): 32
Atypical (Label 1): 22
Malignant (Label 2): 32


394
86
65




tensor([0.6502, 2.5256, 0.9381], device='cuda:0')


Epoch 1/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [1/70]
Train Loss: 0.9220, Train Acc: 0.3426, Train Precision: 0.3566, Train Recall: 0.3691, Train F1: 0.3297
Val Loss:   0.9146, Val Acc:   0.2558, Val Precision: 0.0853, Val Recall: 0.3333, Val F1: 0.1358


Epoch 2/70: 100%|██████████| 25/25 [00:19<00:00,  1.27it/s]


Epoch [2/70]
Train Loss: 0.8317, Train Acc: 0.2766, Train Precision: 0.3733, Train Recall: 0.3618, Train F1: 0.2768
Val Loss:   0.9463, Val Acc:   0.2558, Val Precision: 0.0853, Val Recall: 0.3333, Val F1: 0.1358


Epoch 3/70: 100%|██████████| 25/25 [00:19<00:00,  1.30it/s]


Epoch [3/70]
Train Loss: 0.7763, Train Acc: 0.3198, Train Precision: 0.3651, Train Recall: 0.3469, Train F1: 0.2971
Val Loss:   0.8818, Val Acc:   0.2558, Val Precision: 0.0853, Val Recall: 0.3333, Val F1: 0.1358


Epoch 4/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [4/70]
Train Loss: 0.6692, Train Acc: 0.4239, Train Precision: 0.5681, Train Recall: 0.4908, Train F1: 0.4354
Val Loss:   0.8945, Val Acc:   0.4767, Val Precision: 0.3532, Val Recall: 0.5028, Val F1: 0.3986


Epoch 5/70: 100%|██████████| 25/25 [00:21<00:00,  1.19it/s]


Epoch [5/70]
Train Loss: 0.6809, Train Acc: 0.5533, Train Precision: 0.5136, Train Recall: 0.5090, Train F1: 0.4980
Val Loss:   0.7483, Val Acc:   0.4651, Val Precision: 0.3748, Val Recall: 0.5066, Val F1: 0.3959


Epoch 6/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [6/70]
Train Loss: 0.5940, Train Acc: 0.4518, Train Precision: 0.6307, Train Recall: 0.5389, Train F1: 0.4628
Val Loss:   0.8473, Val Acc:   0.3837, Val Precision: 0.5923, Val Recall: 0.4290, Val F1: 0.3709


Epoch 7/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [7/70]
Train Loss: 0.4887, Train Acc: 0.6168, Train Precision: 0.6704, Train Recall: 0.6703, Train F1: 0.6018
Val Loss:   0.7565, Val Acc:   0.4884, Val Precision: 0.4303, Val Recall: 0.5417, Val F1: 0.4202


Epoch 8/70: 100%|██████████| 25/25 [00:21<00:00,  1.19it/s]


Epoch [8/70]
Train Loss: 0.4482, Train Acc: 0.6523, Train Precision: 0.7049, Train Recall: 0.7026, Train F1: 0.6358
Val Loss:   0.7892, Val Acc:   0.5233, Val Precision: 0.6200, Val Recall: 0.5540, Val F1: 0.4722


Epoch 9/70: 100%|██████████| 25/25 [00:20<00:00,  1.21it/s]


Epoch [9/70]
Train Loss: 0.4535, Train Acc: 0.6447, Train Precision: 0.6899, Train Recall: 0.6950, Train F1: 0.6243
Val Loss:   1.0861, Val Acc:   0.6628, Val Precision: 0.6536, Val Recall: 0.6553, Val F1: 0.6534


Epoch 10/70: 100%|██████████| 25/25 [00:21<00:00,  1.16it/s]


Epoch [10/70]
Train Loss: 0.4168, Train Acc: 0.6904, Train Precision: 0.7143, Train Recall: 0.7372, Train F1: 0.6669
Val Loss:   0.8789, Val Acc:   0.6860, Val Precision: 0.6773, Val Recall: 0.6761, Val F1: 0.6763


Epoch 11/70: 100%|██████████| 25/25 [00:21<00:00,  1.18it/s]


Epoch [11/70]
Train Loss: 0.4641, Train Acc: 0.7005, Train Precision: 0.7086, Train Recall: 0.7314, Train F1: 0.6734
Val Loss:   0.9310, Val Acc:   0.5233, Val Precision: 0.6501, Val Recall: 0.5634, Val F1: 0.5222


Epoch 12/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [12/70]
Train Loss: 0.5006, Train Acc: 0.5888, Train Precision: 0.6645, Train Recall: 0.6426, Train F1: 0.5774
Val Loss:   0.7453, Val Acc:   0.4767, Val Precision: 0.5070, Val Recall: 0.5028, Val F1: 0.4486


Epoch 13/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [13/70]
Train Loss: 0.4056, Train Acc: 0.7284, Train Precision: 0.7202, Train Recall: 0.7422, Train F1: 0.6923
Val Loss:   1.0743, Val Acc:   0.4767, Val Precision: 0.5406, Val Recall: 0.4744, Val F1: 0.4446


Epoch 14/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [14/70]
Train Loss: 0.3346, Train Acc: 0.7766, Train Precision: 0.7531, Train Recall: 0.8046, Train F1: 0.7415
Val Loss:   1.4667, Val Acc:   0.5116, Val Precision: 0.5420, Val Recall: 0.4962, Val F1: 0.4747


Epoch 15/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [15/70]
Train Loss: 0.2889, Train Acc: 0.7792, Train Precision: 0.7620, Train Recall: 0.8085, Train F1: 0.7452
Val Loss:   1.0114, Val Acc:   0.5233, Val Precision: 0.6941, Val Recall: 0.5492, Val F1: 0.4693


Epoch 16/70: 100%|██████████| 25/25 [00:19<00:00,  1.30it/s]


Epoch [16/70]
Train Loss: 0.2545, Train Acc: 0.7893, Train Precision: 0.7650, Train Recall: 0.8165, Train F1: 0.7535
Val Loss:   0.8486, Val Acc:   0.5116, Val Precision: 0.5969, Val Recall: 0.5246, Val F1: 0.4557


Epoch 17/70: 100%|██████████| 25/25 [00:19<00:00,  1.28it/s]


Epoch [17/70]
Train Loss: 0.3563, Train Acc: 0.6802, Train Precision: 0.7064, Train Recall: 0.7104, Train F1: 0.6536
Val Loss:   0.6797, Val Acc:   0.7326, Val Precision: 0.7581, Val Recall: 0.7415, Val F1: 0.7314


Epoch 18/70: 100%|██████████| 25/25 [00:21<00:00,  1.18it/s]


Epoch [18/70]
Train Loss: 0.2575, Train Acc: 0.8299, Train Precision: 0.7899, Train Recall: 0.8419, Train F1: 0.7902
Val Loss:   1.2720, Val Acc:   0.6744, Val Precision: 0.6703, Val Recall: 0.6562, Val F1: 0.6582


Epoch 19/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [19/70]
Train Loss: 0.3305, Train Acc: 0.8096, Train Precision: 0.7753, Train Recall: 0.8298, Train F1: 0.7737
Val Loss:   0.6248, Val Acc:   0.6744, Val Precision: 0.7349, Val Recall: 0.7036, Val F1: 0.6759


Epoch 20/70: 100%|██████████| 25/25 [00:19<00:00,  1.28it/s]


Epoch [20/70]
Train Loss: 0.3160, Train Acc: 0.7563, Train Precision: 0.7543, Train Recall: 0.7988, Train F1: 0.7280
Val Loss:   0.8861, Val Acc:   0.6163, Val Precision: 0.6418, Val Recall: 0.6278, Val F1: 0.6116


Epoch 21/70: 100%|██████████| 25/25 [00:19<00:00,  1.27it/s]


Epoch [21/70]
Train Loss: 0.2635, Train Acc: 0.8173, Train Precision: 0.7729, Train Recall: 0.8190, Train F1: 0.7738
Val Loss:   1.0427, Val Acc:   0.6744, Val Precision: 0.6689, Val Recall: 0.6562, Val F1: 0.6569


Epoch 22/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [22/70]
Train Loss: 0.1975, Train Acc: 0.8756, Train Precision: 0.8257, Train Recall: 0.8825, Train F1: 0.8406
Val Loss:   1.0816, Val Acc:   0.6744, Val Precision: 0.6625, Val Recall: 0.6515, Val F1: 0.6504


Epoch 23/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [23/70]
Train Loss: 0.2274, Train Acc: 0.8706, Train Precision: 0.8209, Train Recall: 0.8807, Train F1: 0.8337
Val Loss:   1.0787, Val Acc:   0.7093, Val Precision: 0.6947, Val Recall: 0.6686, Val F1: 0.6581


Epoch 24/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [24/70]
Train Loss: 0.1961, Train Acc: 0.8782, Train Precision: 0.8272, Train Recall: 0.8878, Train F1: 0.8415
Val Loss:   1.2938, Val Acc:   0.6628, Val Precision: 0.6834, Val Recall: 0.6174, Val F1: 0.5993


Epoch 25/70: 100%|██████████| 25/25 [00:19<00:00,  1.30it/s]


Epoch [25/70]
Train Loss: 0.1344, Train Acc: 0.9213, Train Precision: 0.8753, Train Recall: 0.9324, Train F1: 0.8942
Val Loss:   1.6110, Val Acc:   0.6512, Val Precision: 0.6415, Val Recall: 0.6259, Val F1: 0.6259


Epoch 26/70: 100%|██████████| 25/25 [00:19<00:00,  1.27it/s]


Epoch [26/70]
Train Loss: 0.1103, Train Acc: 0.9492, Train Precision: 0.9101, Train Recall: 0.9593, Train F1: 0.9291
Val Loss:   1.9329, Val Acc:   0.6279, Val Precision: 0.5708, Val Recall: 0.5767, Val F1: 0.5449


Epoch 27/70: 100%|██████████| 25/25 [00:19<00:00,  1.25it/s]


Epoch [27/70]
Train Loss: 0.2242, Train Acc: 0.8909, Train Precision: 0.8395, Train Recall: 0.8957, Train F1: 0.8540
Val Loss:   1.7705, Val Acc:   0.5930, Val Precision: 0.5623, Val Recall: 0.5549, Val F1: 0.5424


Epoch 28/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [28/70]
Train Loss: 0.1509, Train Acc: 0.9188, Train Precision: 0.8695, Train Recall: 0.9157, Train F1: 0.8864
Val Loss:   0.9532, Val Acc:   0.6163, Val Precision: 0.6165, Val Recall: 0.6042, Val F1: 0.6067


Epoch 29/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [29/70]
Train Loss: 0.1925, Train Acc: 0.8756, Train Precision: 0.8219, Train Recall: 0.8774, Train F1: 0.8373
Val Loss:   0.7398, Val Acc:   0.6279, Val Precision: 0.6748, Val Recall: 0.6572, Val F1: 0.6251


Epoch 30/70: 100%|██████████| 25/25 [00:19<00:00,  1.30it/s]


Epoch [30/70]
Train Loss: 0.1341, Train Acc: 0.9239, Train Precision: 0.8776, Train Recall: 0.9300, Train F1: 0.8966
Val Loss:   1.0415, Val Acc:   0.6512, Val Precision: 0.6396, Val Recall: 0.6402, Val F1: 0.6397


Epoch 31/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [31/70]
Train Loss: 0.0816, Train Acc: 0.9492, Train Precision: 0.9122, Train Recall: 0.9553, Train F1: 0.9294
Val Loss:   1.3895, Val Acc:   0.7093, Val Precision: 0.7109, Val Recall: 0.6828, Val F1: 0.6854


Epoch 32/70: 100%|██████████| 25/25 [00:19<00:00,  1.28it/s]


Epoch [32/70]
Train Loss: 0.0400, Train Acc: 0.9721, Train Precision: 0.9451, Train Recall: 0.9723, Train F1: 0.9574
Val Loss:   1.3971, Val Acc:   0.6860, Val Precision: 0.7029, Val Recall: 0.6619, Val F1: 0.6648


Epoch 33/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [33/70]
Train Loss: 0.0546, Train Acc: 0.9670, Train Precision: 0.9442, Train Recall: 0.9668, Train F1: 0.9545
Val Loss:   1.1173, Val Acc:   0.7209, Val Precision: 0.7282, Val Recall: 0.7074, Val F1: 0.7115


Epoch 34/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [34/70]
Train Loss: 0.0312, Train Acc: 0.9797, Train Precision: 0.9556, Train Recall: 0.9846, Train F1: 0.9684
Val Loss:   1.7508, Val Acc:   0.6744, Val Precision: 0.6489, Val Recall: 0.6326, Val F1: 0.6201


Epoch 35/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [35/70]
Train Loss: 0.0343, Train Acc: 0.9772, Train Precision: 0.9614, Train Recall: 0.9782, Train F1: 0.9693
Val Loss:   1.3838, Val Acc:   0.6860, Val Precision: 0.6697, Val Recall: 0.6619, Val F1: 0.6623


Epoch 36/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [36/70]
Train Loss: 0.0349, Train Acc: 0.9848, Train Precision: 0.9751, Train Recall: 0.9853, Train F1: 0.9801
Val Loss:   1.5772, Val Acc:   0.6744, Val Precision: 0.6740, Val Recall: 0.6420, Val F1: 0.6406


Epoch 37/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [37/70]
Train Loss: 0.0064, Train Acc: 0.9949, Train Precision: 0.9877, Train Recall: 0.9960, Train F1: 0.9917
Val Loss:   1.9411, Val Acc:   0.6744, Val Precision: 0.6835, Val Recall: 0.6326, Val F1: 0.6211


Epoch 38/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [38/70]
Train Loss: 0.0043, Train Acc: 0.9975, Train Precision: 0.9937, Train Recall: 0.9983, Train F1: 0.9960
Val Loss:   2.0731, Val Acc:   0.6860, Val Precision: 0.7139, Val Recall: 0.6430, Val F1: 0.6306


Epoch 39/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [39/70]
Train Loss: 0.0203, Train Acc: 0.9949, Train Precision: 0.9921, Train Recall: 0.9952, Train F1: 0.9936
Val Loss:   1.5867, Val Acc:   0.6977, Val Precision: 0.6829, Val Recall: 0.6818, Val F1: 0.6820


Epoch 40/70: 100%|██████████| 25/25 [00:19<00:00,  1.27it/s]


Epoch [40/70]
Train Loss: 0.0101, Train Acc: 0.9949, Train Precision: 0.9877, Train Recall: 0.9967, Train F1: 0.9921
Val Loss:   1.8777, Val Acc:   0.6860, Val Precision: 0.6686, Val Recall: 0.6667, Val F1: 0.6670


Epoch 41/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [41/70]
Train Loss: 0.0330, Train Acc: 0.9848, Train Precision: 0.9686, Train Recall: 0.9846, Train F1: 0.9761
Val Loss:   1.7314, Val Acc:   0.6512, Val Precision: 0.6464, Val Recall: 0.6449, Val F1: 0.6421


Epoch 42/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [42/70]
Train Loss: 0.0629, Train Acc: 0.9619, Train Precision: 0.9276, Train Recall: 0.9698, Train F1: 0.9451
Val Loss:   2.0423, Val Acc:   0.6395, Val Precision: 0.6602, Val Recall: 0.5966, Val F1: 0.5825


Epoch 43/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [43/70]
Train Loss: 0.0368, Train Acc: 0.9822, Train Precision: 0.9693, Train Recall: 0.9808, Train F1: 0.9747
Val Loss:   1.9379, Val Acc:   0.6279, Val Precision: 0.6259, Val Recall: 0.6004, Val F1: 0.6000


Epoch 44/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [44/70]
Train Loss: 0.0050, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   1.6655, Val Acc:   0.6744, Val Precision: 0.6825, Val Recall: 0.6515, Val F1: 0.6558


Epoch 45/70: 100%|██████████| 25/25 [00:19<00:00,  1.25it/s]


Epoch [45/70]
Train Loss: 0.0064, Train Acc: 0.9949, Train Precision: 0.9877, Train Recall: 0.9960, Train F1: 0.9917
Val Loss:   1.7892, Val Acc:   0.6628, Val Precision: 0.6588, Val Recall: 0.6269, Val F1: 0.6213


Epoch 46/70: 100%|██████████| 25/25 [00:19<00:00,  1.26it/s]


Epoch [46/70]
Train Loss: 0.0020, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   1.9088, Val Acc:   0.6628, Val Precision: 0.6733, Val Recall: 0.6269, Val F1: 0.6221


Epoch 47/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [47/70]
Train Loss: 0.0020, Train Acc: 0.9975, Train Precision: 0.9937, Train Recall: 0.9983, Train F1: 0.9960
Val Loss:   1.9533, Val Acc:   0.6628, Val Precision: 0.6733, Val Recall: 0.6269, Val F1: 0.6221


Epoch 48/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [48/70]
Train Loss: 0.0014, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.0439, Val Acc:   0.6512, Val Precision: 0.6652, Val Recall: 0.6165, Val F1: 0.6124


Epoch 49/70: 100%|██████████| 25/25 [00:19<00:00,  1.25it/s]


Epoch [49/70]
Train Loss: 0.0011, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.0365, Val Acc:   0.6512, Val Precision: 0.6652, Val Recall: 0.6165, Val F1: 0.6124


Epoch 50/70: 100%|██████████| 25/25 [00:19<00:00,  1.26it/s]


Epoch [50/70]
Train Loss: 0.0027, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   1.8734, Val Acc:   0.6512, Val Precision: 0.6551, Val Recall: 0.6212, Val F1: 0.6195


Epoch 51/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [51/70]
Train Loss: 0.0025, Train Acc: 0.9975, Train Precision: 0.9937, Train Recall: 0.9983, Train F1: 0.9960
Val Loss:   1.9982, Val Acc:   0.6628, Val Precision: 0.6751, Val Recall: 0.6316, Val F1: 0.6311


Epoch 52/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [52/70]
Train Loss: 0.0009, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.0446, Val Acc:   0.6628, Val Precision: 0.6751, Val Recall: 0.6316, Val F1: 0.6311


Epoch 53/70: 100%|██████████| 25/25 [00:19<00:00,  1.30it/s]


Epoch [53/70]
Train Loss: 0.0008, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.0331, Val Acc:   0.6860, Val Precision: 0.6896, Val Recall: 0.6525, Val F1: 0.6508


Epoch 54/70: 100%|██████████| 25/25 [00:19<00:00,  1.25it/s]


Epoch [54/70]
Train Loss: 0.0009, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.0797, Val Acc:   0.6628, Val Precision: 0.6751, Val Recall: 0.6316, Val F1: 0.6311


Epoch 55/70: 100%|██████████| 25/25 [00:19<00:00,  1.25it/s]


Epoch [55/70]
Train Loss: 0.0023, Train Acc: 0.9975, Train Precision: 0.9984, Train Recall: 0.9976, Train F1: 0.9980
Val Loss:   2.0124, Val Acc:   0.6744, Val Precision: 0.6688, Val Recall: 0.6373, Val F1: 0.6308


Epoch 56/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [56/70]
Train Loss: 0.0005, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   1.9999, Val Acc:   0.6744, Val Precision: 0.6877, Val Recall: 0.6420, Val F1: 0.6412


Epoch 57/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [57/70]
Train Loss: 0.0005, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   1.9989, Val Acc:   0.6744, Val Precision: 0.6877, Val Recall: 0.6420, Val F1: 0.6412


Epoch 58/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [58/70]
Train Loss: 0.0006, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.0083, Val Acc:   0.6744, Val Precision: 0.6742, Val Recall: 0.6420, Val F1: 0.6400


Epoch 59/70: 100%|██████████| 25/25 [00:19<00:00,  1.26it/s]


Epoch [59/70]
Train Loss: 0.0008, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.0854, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 60/70: 100%|██████████| 25/25 [00:19<00:00,  1.30it/s]


Epoch [60/70]
Train Loss: 0.0004, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.0992, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 61/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [61/70]
Train Loss: 0.0004, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1086, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 62/70: 100%|██████████| 25/25 [00:19<00:00,  1.30it/s]


Epoch [62/70]
Train Loss: 0.0006, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1400, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 63/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [63/70]
Train Loss: 0.0004, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1514, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 64/70: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]


Epoch [64/70]
Train Loss: 0.0005, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1513, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 65/70: 100%|██████████| 25/25 [00:19<00:00,  1.27it/s]


Epoch [65/70]
Train Loss: 0.0004, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1507, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 66/70: 100%|██████████| 25/25 [00:20<00:00,  1.23it/s]


Epoch [66/70]
Train Loss: 0.0004, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1652, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 67/70: 100%|██████████| 25/25 [00:19<00:00,  1.27it/s]


Epoch [67/70]
Train Loss: 0.0004, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1678, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 68/70: 100%|██████████| 25/25 [00:19<00:00,  1.26it/s]


Epoch [68/70]
Train Loss: 0.0005, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1593, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 69/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [69/70]
Train Loss: 0.0005, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1600, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506


Epoch 70/70: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [70/70]
Train Loss: 0.0005, Train Acc: 1.0000, Train Precision: 1.0000, Train Recall: 1.0000, Train F1: 1.0000
Val Loss:   2.1594, Val Acc:   0.6860, Val Precision: 0.6972, Val Recall: 0.6525, Val F1: 0.6506
Best Val: 0.7325581395348837
Test Loss: 0.9613, Test Acc: 0.6462, Test Precision: 0.6643, Test Recall: 0.5984, Test F1: 0.5975
Final test metrics saved with Arial font and colored by Order Method.
Best Val: 0.7325581395348837
Test Loss: 0.9613, Test Acc: 0.6462, Test Precision: 0.6643, Test Recall: 0.5984, Test F1: 0.5975


<IPython.core.display.Javascript object>

In [None]:

# #####################################################
#           # METHOD WITH VISUALIZER FOR SPIRAL ORDER
# #####################################################
    # def get_spiral_order_embeddings(self, coords_, tile_embeds):
    #     """
    #     Arrange tile embeddings in a spiral order and visualize traversal.

    #     Args:
    #         coords_ (torch.Tensor): Tensor of shape [N, 2] with (x, y) coordinates.
    #         tile_embeds (torch.Tensor): Tensor of shape [N, EMBED_DIM].

    #     Returns:
    #         torch.Tensor: Tensor of shape [num_tiles, EMBED_DIM] arranged in spiral order.
    #     """
    #     # Convert coordinates to integer grid indices
    #     coords = torch.floor(coords_ / 256.0)  # Tile size

    #     x_coords = coords[:, 0].numpy()
    #     y_coords = coords[:, 1].numpy()

    #     # Normalize coordinates to start from (0,0)
    #     min_x, max_x = int(np.min(x_coords)), int(np.max(x_coords))
    #     min_y, max_y = int(np.min(y_coords)), int(np.max(y_coords))
    #     norm_x = x_coords - min_x
    #     norm_y = y_coords - min_y

    #     width = int(max_x - min_x + 1)
    #     height = int(max_y - min_y + 1)

    #     # Create a grid mapping from (x, y) to embedding index
    #     grid = {}
    #     for idx, (x, y) in enumerate(zip(norm_x, norm_y)):
    #         grid[(x, y)] = idx

    #     # Generate spiral order coordinates
    #     spiral_coords = self.generate_spiral_coords(width, height)

    #     # Collect tile embeddings in spiral order
    #     spiral_order = []
    #     sorted_coords = []
    #     for coord in spiral_coords:
    #         idx = grid.get(coord)
    #         if idx is not None:
    #             spiral_order.append(tile_embeds[idx])
    #             sorted_coords.append((coord[0] + min_x, coord[1] + min_y))  # Convert back to original coordinates

    #     # Convert list to tensor
    #     if len(spiral_order) == 0:
    #         print("!----- Spiral list is empty")
    #         return torch.zeros(0, self.embed_dim)

    #     spiral_embeds = torch.stack(spiral_order)  # [num_tiles, EMBED_DIM]

    #     # VISUALIZATION: Plot Spiral Traversal Inline
    #     plt.figure(figsize=(8, 8))
    #     sorted_coords = np.array(sorted_coords)  # Convert list to NumPy array

    #     plt.scatter(coords[:, 0], coords[:, 1], c="blue", alpha=0.5, label="Original Tiles")
    #     plt.plot(sorted_coords[:, 0], sorted_coords[:, 1], color="red", linewidth=1.5, linestyle="dashed", label="Spiral Path")
    #     plt.scatter(sorted_coords[:, 0], sorted_coords[:, 1], c="red", label="Spiral Tiles")

    #     for i, (x, y) in enumerate(sorted_coords[:10]):  # Label first 10 tiles
    #         plt.text(x, y, str(i), fontsize=10, color="black")

    #     plt.xlabel("X Coordinate")
    #     plt.ylabel("Y Coordinate")
    #     plt.title("Spiral Traversal of Tiles")
    #     plt.gca().invert_yaxis()  # Flip y-axis to match image coordinates
    #     plt.legend()
    #     plt.grid(True)
    #     plt.show()

    #     return spiral_embeds  # No need to return coordinates since they are visualized
# #####################################################
#        # END METHOD WITH VISUALIZER FOR SPIRAL ORDER
# #####################################################

In [None]:



####################################
####  EMBEDDING VISUALIZER  ########
####################################
def visualize_embeddings(dataloader, model, num_samples=500):
    model.eval()
    embeddings = []
    labels = []

    with torch.no_grad():
        for i, (tile_embeds, label) in enumerate(dataloader):
            if i * tile_embeds.shape[0] > num_samples:
                break  # Limit to `num_samples`

            tile_embeds = tile_embeds.to(device)
            logits = model(tile_embeds)  # Get model outputs

            embeddings.append(logits.cpu().numpy())  # Store embeddings
            labels.append(label.cpu().numpy())

    embeddings = np.vstack(embeddings)
    labels = np.hstack(labels)

    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)

    plt.figure(figsize=(10, 8))
    plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap="viridis", alpha=0.7)
    plt.colorbar()
    plt.title("t-SNE Visualization of Embeddings")
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")
    plt.savefig(os.path.join(PLOTS_DIR, 'embedding_tsne.png'))
    plt.show()

# Call the function
visualize_embeddings(test_loader, model)

#######################################
####  EMBEDDING VISUALIZER END ########
#######################################

In [None]:
%reset