In [1]:
# Vit Transformer

In [2]:
import os, random
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 0 = all messages, 1 = INFO, 2 = WARNING, 3 = ERROR
os.environ["PYTHONHASHSEED"] = "42"
# For CUDA GEMMs determinism (must be set before torch import)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"   # best choice for A100

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import umap
import numpy as np
import cv2
import random
import plotly.express as px
import pandas as pd
from clearml import Task, OutputModel
from datetime import datetime
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, davies_bouldin_score
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, silhouette_score, davies_bouldin_score, f1_score, classification_report
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network   import MLPClassifier
from collections import Counter

# Modular Files
from data_loader import get_dataloaders, load_dataset_files
from transformer_model import ViTForSimCLR, MAEModule
from utils import CFG    #, zvm_thresholds_scaled

E0000 00:00:1755893284.705332   69578 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755893284.712007   69578 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755893284.728585   69578 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755893284.728604   69578 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755893284.728606   69578 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755893284.728608   69578 computation_placer.cc:177] computation placer already registered. Please check linka

In [4]:
seed = 42

In [5]:
# For deterministic purposes
def set_seed(seed=seed):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True, warn_only=False)

    # Deterministic attention path
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)

    # Ampere precision determinism
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    try:
        torch.set_float32_matmul_precision("highest")
    except AttributeError:
        pass

set_seed(seed)

Tier 1 = strict + reproducible: 

I enable deterministic algorithms, disable TF32 and fast attention, and set the cuBLAS workspace so results are bit-stable on the same setup; if I need invariance across num_workers, I also use index-keyed per-sample seeding. I use Tier 1 for ablations I care about, bug repros (NaNs/collapse), and final numbers/demos. 

I won't implement a Tier 0 (less strict) because training time is not negatively impacted going from 0 to 1. 

In [6]:
'''
Tier 1 should show: 
    det-algos: True, 
    TF32 both False, 
    SDPA flash=False, 
    mem=False, 
    math=True, 
    CUBLAS_WORKSPACE_CONFIG=":4096:8"
'''

print("det-algos:", torch.are_deterministic_algorithms_enabled())
print("TF32 (matmul/cudnn):", torch.backends.cuda.matmul.allow_tf32,
                              torch.backends.cudnn.allow_tf32)
print("SDPA flash/mem/math:",
      torch.backends.cuda.flash_sdp_enabled(),
      torch.backends.cuda.mem_efficient_sdp_enabled(),
      torch.backends.cuda.math_sdp_enabled())
print("CUBLAS_WORKSPACE_CONFIG:", os.environ.get("CUBLAS_WORKSPACE_CONFIG"))


det-algos: True
TF32 (matmul/cudnn): False False
SDPA flash/mem/math: False False True
CUBLAS_WORKSPACE_CONFIG: :4096:8


In [7]:
# For dataloader / to be deterministic   
def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [8]:
# training and validation set need deterministic generators
g_train = torch.Generator().manual_seed(seed)
g_val   = torch.Generator().manual_seed(seed + 1)

In [9]:
# Check that CUDA is available
assert torch.cuda.is_available(), "CUDA not available!"
DEVICE = torch.device("cuda:0")
print("Using GPU:", torch.cuda.get_device_name(DEVICE))

Using GPU: NVIDIA A100-SXM4-40GB


In [10]:
print(f"[Seed Check] Python random: {random.randint(0, 100)}")
print(f"[Seed Check] NumPy: {np.random.randint(0, 100)}")
print(f"[Seed Check] PyTorch: {torch.randint(0, 100, (1,)).item()}")
if torch.cuda.is_available():
    print(f"[Seed Check] PyTorch CUDA: {torch.randint(0, 100, (1,), device='cuda').item()}")


[Seed Check] Python random: 81
[Seed Check] NumPy: 51
[Seed Check] PyTorch: 42
[Seed Check] PyTorch CUDA: 43


In [11]:
# from clearml import Dataset
# import os

# # Sanity Check data
# ds = Dataset.get(dataset_id="d00ed3a421684bbfa96c03b77bc8ae98")
# local_copy = ds.get_local_copy()

# print("Dataset downloaded to:", local_copy)
# print("Files in dataset:")
# print(os.listdir(local_copy))

In [12]:
# base = "/home/jupyter/.clearml/cache/storage_manager/datasets/ds_d00ed3a421684bbfa96c03b77bc8ae98"

# spectrograms = np.load(os.path.join(base, "spectrograms.npy"))
# persistence = np.load(os.path.join(base, "persistence_spectra.npy"))
# labels = np.load(os.path.join(base, "labels.npy"))
# label_encoder = np.load(os.path.join(base, "label_encoder.npz"))['classes']

# print("Spectrograms shape:", spectrograms.shape)
# print("Persistence shape:", persistence.shape)
# print("Labels shape:", labels.shape)
# print("Classes:", label_encoder)


In [13]:
# =============================
# === Load Data and Model ====
# =============================
data = load_dataset_files("d00ed3a421684bbfa96c03b77bc8ae98")

#====================================================
# Pretrain loaders (with SimCLR augmentations)
train_loader, val_loader, class_names = get_dataloaders(
    data,
    # batch_size=BATCH_SIZE,
    batch_size=CFG.BATCH_SIZE,
    # resize_to=IMAGE_SIZE,
    resize_to=CFG.IMAGE_SIZE,
    augment=True,
    val_split=0.1,
    # num_workers=NUM_WORKERS,
    num_workers=CFG.NUM_WORKERS,
    worker_init_fn=worker_init_fn,
    generator=g_train,
    #generator=torch.Generator().manual_seed(seed)
)

# Eval loaders (no augmentations) for linear‐probe & clustering metrics
train_eval_loader, val_eval_loader, _ = get_dataloaders(
    data,
    # batch_size=BATCH_SIZE,
    batch_size=CFG.BATCH_SIZE,
    # resize_to=IMAGE_SIZE,
    resize_to=CFG.IMAGE_SIZE,
    augment=False,
    val_split=0.3,
    # num_workers=NUM_WORKERS,
    num_workers=CFG.NUM_WORKERS,
    worker_init_fn=worker_init_fn,
    generator=g_val,
   # generator=torch.Generator().manual_seed(seed)
)
#====================================================
# Debug print to verify
print("Final class names:", class_names)

Loading spectrograms.npy
Loading persistence_spectra.npy
Loading labels.npy
Loading metadata.npz
Loading label_encoder.npz
Final class names: ['AR_drone', 'Bebop_drone', 'Phantom', 'ambient', 'mavic_pro_2']


In [14]:
# current 
H, W = (CFG.IMAGE_SIZE, CFG.IMAGE_SIZE) if isinstance(CFG.IMAGE_SIZE, int) else CFG.IMAGE_SIZE
P = int(CFG.PATCH_SIZE)
n_h, n_w = H // P, W // P
N = n_h * n_w

# baseline at 256 image/16 patch
BASE_H = int(getattr(CFG, "BASELINE_IMAGE_SIZE", 256))
BASE_W = int(getattr(CFG, "BASELINE_IMAGE_WIDTH", BASE_H))
BASE_P = int(getattr(CFG, "BASELINE_PATCH_SIZE", 16))
N0 = (BASE_H // BASE_P) * (BASE_W // BASE_P)

# + 1 for CLS token
attn_mult = float(((N + 1) / (N0 + 1)) ** 2)

In [15]:
# =========================
# === ClearML Setup ===
# =========================
task = Task.init(
    project_name="Signal Fingerprinting/ViT Fingerprinting",
    task_name="ViT Training_94",
    auto_connect_frameworks={"pytorch": False}
)
task.add_tags(['Full Dataset',f'Lambda={CFG.LAMBDA_MAE}',f'Temp={CFG.TEMPERATURE}'])

logger = task.get_logger()

task_params = {
    "Epochs":               CFG.EPOCHS,
    "Base learning_rate":   CFG.BASE_LR,           
    "Encoder learning_rate": CFG.ENCODER_LR,       
    "Projector learning_rate": CFG.PROJECTOR_LR,   
    "Eta min":              CFG.ETA_MIN,
    "Weight decay":         CFG.WEIGHT_DECAY,
    "Warmup fraction":      CFG.WARMUP_FRAC,
    "Warmup epochs":        CFG.WARMUP_EPOCHS,
    "Starting Lambda MAE":  CFG.LAMBDA_MAE,
    "Temperature":          CFG.TEMPERATURE,
    "Num_workers":          CFG.NUM_WORKERS,
    "Device":               CFG.DEVICE,
    "ViT Embedding Size":   CFG.EMBED_DIM,
    "Attention Heads":      CFG.NUM_HEADS,
    "ViT Layers":           CFG.DEPTH,
    "Batch Size":           CFG.BATCH_SIZE,
    "Patch Size":           CFG.PATCH_SIZE,
    "Image Size":           CFG.IMAGE_SIZE,
    "Projection Head Hidden Size |Hidden Dim|": CFG.HIDDEN_DIM,
    "Projection Head Output Size |Projection Dim|": CFG.PROJECTION_DIM,
    "Dataset_id":           "d00ed3a421684bbfa96c03b77bc8ae98",
    "Class Names":          list(class_names),
}
task.connect(task_params)

# =========== token budget =============
task.connect({
    "Patches (h×w)": f"{n_h}×{n_w}",
    "Tokens (N)": int(N),
    "Baseline": f"{BASE_H}×{BASE_W}/{BASE_P}  (N0={N0})",
    "Attn cost vs baseline": attn_mult,
}, name="token_budget")
# =======================================

with open("transformer_model.py", "r") as f:
    model_def = f.read()

model_config = {
    "serial": f"{datetime.now().strftime('%H:%M:%S')}",
    "framework": "pytorch",
    "args": {
        "embedding_dim": CFG.EMBED_DIM,
        "projection_dim": CFG.PROJECTION_DIM
    },
    "input_signature": f"[B, 6, {CFG.IMAGE_SIZE}, {CFG.IMAGE_SIZE}]",
    "output_signature": f"[B, {CFG.EMBED_DIM}] + [B, {CFG.PROJECTION_DIM}]",
    "def": model_def
}

output_model = OutputModel(task=task, name="ViT Transformer", config_dict=model_config)
output_model.set_upload_destination("gs://ewa-clearml")

ClearML Task: created new task id=de0f4b207bee4a38a173d87cd403a981
ClearML results page: https://app.clear.ml/projects/8d94fdfcc93849c699ecfc70878d8dc3/experiments/de0f4b207bee4a38a173d87cd403a981/output/log
ClearML results page: https://app.clear.ml/projects/8d94fdfcc93849c699ecfc70878d8dc3/experiments/de0f4b207bee4a38a173d87cd403a981/output/log
ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start


In [16]:
# Get underlying dataset from Subset
subset: torch.utils.data.Subset = train_loader.dataset
full_dataset = subset.dataset  # This is SimCLRDualViewDataset
base_dataset = full_dataset.base_dataset  # This is RFSpectrogramDataset

# Get labels
labels = base_dataset.labels

# Count 
class_counts = Counter(labels)
print("Training class distribution:")
for class_id, count in sorted(class_counts.items()):
    class_name = class_names[class_id] if class_names else str(class_id)
    print(f"{class_name:15}: {count} samples")

Training class distribution:
AR_drone       : 162 samples
Bebop_drone    : 168 samples
Phantom        : 390 samples
ambient        : 10104 samples
mavic_pro_2    : 597 samples


In [17]:
# ViT model and MAE 

# needed to pass from utils.py
model = ViTForSimCLR(
    in_channels=6,
    img_size=CFG.IMAGE_SIZE,
    patch_size=CFG.PATCH_SIZE,
    emb_dim=CFG.EMBED_DIM,
    depth=CFG.DEPTH,
    num_heads=CFG.NUM_HEADS,
    proj_hidden=CFG.HIDDEN_DIM,     
    proj_out=CFG.PROJECTION_DIM,    
).to(CFG.DEVICE)

mae_module = MAEModule().to(DEVICE)

# Adjust encoder and projector LR
# This gives the SimCLR head much larger steps (PROJECTOR_LR = BASE_LR) 
# while keeping the ViT backbone more conservative (ENCODER_LR = BASE_LR × 0.1)  / check util.py for setting
optimizer = torch.optim.AdamW([
    {'params': model.encoder.parameters(),   'lr': CFG.ENCODER_LR,   'weight_decay': CFG.WEIGHT_DECAY},
    {'params': model.projector.parameters(), 'lr': CFG.PROJECTOR_LR, 'weight_decay': 0.0},
], betas=(0.9, 0.999))

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.EPOCHS, eta_min=CFG.ETA_MIN)


enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True



In [18]:
from torchinfo import summary

# ViTForSimCLR expects 6-channel 
arch = summary(
    model,
    input_size=(1, 6, CFG.IMAGE_SIZE, CFG.IMAGE_SIZE),
    col_names=("input_size", "output_size", "num_params"),
    depth=30   
)

In [19]:
print(arch)

Layer (type:depth-idx)                             Input Shape               Output Shape              Param #
ViTForSimCLR                                       [1, 6, 252, 252]          [1, 128]                  --
├─PatchEmbed: 1-1                                  [1, 6, 252, 252]          [1, 324, 128]             --
│    └─Conv2d: 2-1                                 [1, 6, 252, 252]          [1, 128, 18, 18]          150,656
├─ViTEncoder: 1-2                                  [1, 324, 128]             [1, 325, 128]             41,728
│    └─Dropout: 2-2                                [1, 325, 128]             [1, 325, 128]             --
│    └─TransformerEncoder: 2-3                     [1, 325, 128]             [1, 325, 128]             --
│    │    └─ModuleList: 3-1                        --                        --                        --
│    │    │    └─TransformerEncoderLayer: 4-1      [1, 325, 128]             [1, 325, 128]             --
│    │    │    │    └─MultiheadA

In [20]:
# print the MAE module’s sub-module structure
print(mae_module)

# total number of parameters
total = sum(p.numel() for p in mae_module.parameters())
print(f"MAE module parameters: {total:,}")

MAEModule(
  (enc_to_dec): Linear(in_features=128, out_features=256, bias=True)
  (decoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
    )
  )
)
MAE module parameters: 1,612,800


In [21]:
# =============================
# === SimCLR Contrastive Loss ===
# =============================
def contrastive_loss(z1, z2, temperature=0.5):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    N = z1.shape[0]
    z = torch.cat([z1, z2], dim=0)  # [2N, D]

    # Cosine similarity matrix
    sim_matrix = torch.matmul(z, z.T) / temperature  # [2N, 2N]

    # Create mask to exclude self-similarities
    mask = torch.eye(2 * N, device=z.device).bool()
    sim_matrix = sim_matrix.masked_fill(mask, -1e9)

    # Positive pairs are [i, i+N] and [i+N, i]
    positives = torch.cat([torch.arange(N, 2 * N), torch.arange(0, N)]).to(z.device)

    # Labels: each sample's positive is at index `positives[i]`
    loss = F.cross_entropy(sim_matrix, positives)
    

    return loss

In [22]:
# For troubleshooting simclr loss (sanity check, looks good)
fixed = next(iter(train_loader))
for step in range(100):
    v1, v2, _ = fixed
    diff = (v1 - v2).abs().mean().item()
    print("Mean abs diff between views:", diff)
    
    
    v1, v2 = v1.to(DEVICE), v2.to(DEVICE)
    z1, z2 = model(v1)[1], model(v2)[1]
    
    loss = contrastive_loss(z1, z2, CFG.TEMPERATURE)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if step % 10 == 0:
        print(f"Step {step:3d} | SimCLR loss: {loss.item():.4f}")
        
    if step % 20 == 0:
        for i, g in enumerate(optimizer.param_groups):
            print(f"group {i} lr = {g['lr']:.2e}")

Mean abs diff between views: 0.5014529824256897
Step   0 | SimCLR loss: 30.5340
group 0 lr = 1.00e-03
group 1 lr = 1.00e-02
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Step  10 | SimCLR loss: 33.0656
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views: 0.5014529824256897
Mean abs diff between views:

In [23]:
# For troubleshooting simclr loss (Sanity check too, looks good)
v1, v2, _ = next(iter(train_loader))
diff = (v1 - v2).abs().mean().item()
print("Mean abs diff between views:", diff)

Mean abs diff between views: 0.5103262662887573


In [24]:
def extract_embeddings(loader, model, device):
    model.eval()
    feats, labs = [], []
    with torch.no_grad():
        for view1, view2, labels in loader:
            x = view1.to(device)                      
            cls_tok, _ = model(x)
            feats.append(cls_tok.cpu().numpy())
            labs.append(labels.numpy())
    return np.vstack(feats), np.concatenate(labs)

In [25]:
# This is for computing var_z_mean for checking for collapse
def emb_var_mean_from_numpy(X: np.ndarray, l2_normalize: bool = True):

    if l2_normalize:
        norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-12
        X = X / norms
    # 
    var_per_dim = X.var(axis=0, ddof=0)
    mean_var = float(var_per_dim.mean())
    return mean_var, var_per_dim

In [26]:
def evaluate_linear_probe(model, train_loader, val_loader, device):
    # Extract embeddings + labels
    X_train, y_train = extract_embeddings(train_loader, model, device)
    X_val,   y_val   = extract_embeddings(val_loader,   model, device)
    
#=========== This is for computing z_var_mean inside of linear probe instead of on its own ============
    # ---- z_var_mean on CLS (L2-normalized) using the SAME X_val ----
    zvm, var_per_dim = emb_var_mean_from_numpy(X_val, l2_normalize=True)
    
#======================================================================================================

    # Train a simple logistic regression on train set    
    scaler = StandardScaler().fit(X_train)
    X_train_scaled = scaler.transform(X_train)
    X_val_scaled   = scaler.transform(X_val)
    
    # To handle class imbalance, F1-score was 20.31%
    probe = LogisticRegression(
        max_iter=2_000, 
        random_state=seed,  # for repeatability
    ).fit(X_train_scaled, y_train)
    
    # Now using scaled for better performance on evaling 
    y_pred = probe.predict(X_val_scaled)
    # Accuracy & F1 score / using macro for individual class performance
    acc    = accuracy_score(y_val, y_pred)
    f1  = f1_score(y_val, y_pred, average='macro')

    # Clustering metrics on the *validation* embeddings
    sil = silhouette_score(X_val_scaled, y_val)  
    db  = davies_bouldin_score(X_val_scaled,   y_val)

    # return acc,f1, sil, db, y_val, y_pred
    return acc,f1, sil, db, y_val, y_pred, zvm, var_per_dim  # this is to return z_var_mean
    # return acc,f1, sil, db, y_val, y_pred, zvm_cls, var_cls, zvm_proj, var_proj 

In [27]:
# MLP Probe 
def evaluate_mlp_probe(model, train_loader, val_loader, device):
    # This is basically the same as linear probe, but swapping out for MLP
    X_train, y_train = extract_embeddings(train_loader, model, device)
    X_val,   y_val   = extract_embeddings(val_loader,   model, device)
    # X_train, _, y_train = extract_embeddings_both(train_loader, model, device)
    # X_val, _,  y_val   = extract_embeddings_both(val_loader,   model, device)

    scaler     = StandardScaler().fit(X_train)
    X_train_sc = scaler.transform(X_train)
    X_val_sc   = scaler.transform(X_val)

    # small 1-hidden-layer MLP / keep it simple
    probe = MLPClassifier(
        hidden_layer_sizes=(256,),    # one hidden layer of 256 units
        activation='relu',            # ReLU non-linearity
        alpha=1e-4,                   # L2 penalty on weights
        max_iter=2_000,                 # bumping up to match log reg
        random_state=seed,  # for repeatability
    )
    probe.fit(X_train_sc, y_train)
    y_pred = probe.predict(X_val_sc)

    acc = accuracy_score(y_val, y_pred)
    f1  = f1_score(y_val, y_pred, average='macro')
    return acc, f1

In [28]:
def log_umap_plotly(emb_2d, y_int, epoch_idx, logger, class_names, title_prefix="UMAP of ViT [CLS] Embeddings"):
    
    labels_text = [class_names[int(i)] for i in y_int]

    df = pd.DataFrame({
        "UMAP-1": emb_2d[:, 0],
        "UMAP-2": emb_2d[:, 1],
        "label":  labels_text,
    })
    
    palette = {
        'AR_drone':   '#0072B2',  #blue
        'Bebop_drone':'#E69F00',  # orange
        'Phantom':    '#009E73',  # green
        'mavic_pro_2':'#D55E00',  # light orange
        'ambient':    '#7A7A7A',  # gray
    }

    fig = px.scatter(
        df, x="UMAP-1", 
        y="UMAP-2",
        color="label",                     
        color_discrete_map=palette, # get rid of this for different classes / or not use color scheme
        opacity=0.8,
        title=f"{title_prefix} — Epoch {epoch_idx}"
    )
    fig.update_traces(marker=dict(size=5))
    fig.update_layout(legend_title_text="Class")

    logger.report_plotly(
        title=title_prefix,
        series=f"Epoch {epoch_idx}",
        iteration=epoch_idx,
        figure=fig
    )

In [29]:
UMAP_DIR = "umap_images"
os.makedirs(UMAP_DIR, exist_ok=True)

In [30]:
#TEMPERATURE = CFG.TEMPERATURE
EPOCHS = CFG.EPOCHS
LAMBDA_MAE = CFG.LAMBDA_MAE

# temperature state 
tau_current = float(CFG.TEMPERATURE)    
TAU_MIN     = 1e-4                  

In [None]:
# =============================
# === Training Loop ===
# =============================
best_combined_loss = float('inf')
patience, patience_counter = 20, 0  # Early stopping was 7, now 15
embeddings_list, labels_list = [], []
best_f1 = 0.0
lambda_mae = LAMBDA_MAE
#===============================
for epoch in range(EPOCHS):
    model.train()
    total_loss, total_simclr, total_mae = 0, 0, 0
    

    print(f"Epoch: {epoch+1}/{EPOCHS}")

    for view1, view2, labels in train_loader:
        view1, view2 = view1.to(DEVICE), view2.to(DEVICE)

        # Get both CLS token and projection head outputs
        z1_cls, z1_proj = model(view1)  # z1_cls: [B, 128], z1_proj: [B, 64]
        z2_cls, z2_proj = model(view2)

        # SimCLR contrastive loss (use projected features)
        simclr_loss = contrastive_loss(z1_proj, z2_proj, temperature=tau_current)

        # Masked Autoencoder (MAE) loss using view1
        recon_loss = mae_module(view1, model)

        # Total loss
        # combined_loss = simclr_loss + LAMBDA_MAE * recon_loss
        combined_loss = simclr_loss + lambda_mae * recon_loss

        optimizer.zero_grad()
        combined_loss.backward()
        
        optimizer.step()

        total_loss += combined_loss.item()
        total_simclr += simclr_loss.item()
        total_mae += recon_loss.item()

    avg_loss = total_loss / len(train_loader)
    avg_simclr = total_simclr / len(train_loader)
    avg_mae = total_mae / len(train_loader)
    scheduler.step()
    
    # report to ClearML
    logger.report_scalar("Combined Loss","Combined Total", avg_loss, epoch+1)
    logger.report_scalar("SimCLR Loss","SimCLR", avg_simclr, epoch+1)
    logger.report_scalar("MAE Loss","MAE", avg_mae, epoch+1)

    print(f"Epoch {epoch+1}/{EPOCHS} | Total: {avg_loss:.4f} | SimCLR: {avg_simclr:.4f} | MAE: {avg_mae:.4f}")
    
    # ===== every epoch do a linear-probe eval and UMAP & MLP-probe eval =====
    if (epoch + 1) % 1 == 0:   
        
        optimizer.zero_grad(set_to_none=True) # added for oom
        torch.cuda.empty_cache() # added to clean up oom
        torch.cuda.reset_peak_memory_stats()
        model.eval()    # added for probes
        
        # Added z_var_mean ability
        acc, f1, sil, db, y_true, y_pred, zvm, var_per_dim = evaluate_linear_probe(
            model, train_eval_loader, val_eval_loader, DEVICE
        )
        
        print(f"[Linear Probe @ epoch {epoch+1}]  Acc={acc:.4f}, F1={f1:.4f}, Sil={sil:.4f}, DB={db:.4f}")

        logger.report_scalar("Probe Accuracy","Linear Accuracy", acc, epoch+1)
        logger.report_scalar("Probe F1-Score","Linear F1-Score", f1,  epoch+1)
        logger.report_scalar("Linear Probe Silhouette Score","Silhouette Score",sil,epoch+1)
        logger.report_scalar("Linear Probe DB Score","Davies-Bouldin", db, epoch+1)
        logger.report_scalar("MAE Weight", "lambda_mae", lambda_mae, epoch+1)
        
        
        # --- MLP‐probe eval ---
        acc_mlp, f1_mlp = evaluate_mlp_probe(
            model, train_eval_loader, val_eval_loader, DEVICE
        )
        print(f"[MLP @ epoch {epoch+1}]    Acc={acc_mlp:.4f}, F1={f1_mlp:.4f}")
        
        logger.report_scalar("Probe Accuracy","MLP Accuracy", acc_mlp, epoch+1)
        logger.report_scalar("Probe F1-Score","MLP F1-Score",  f1_mlp,  epoch+1)
        
        #--- Z-var Mean for Collapse ---
        print(f"[Z-var Mean @ epoch {epoch+1} |zvm < 1 ⚠️|] z_var_mean={zvm:.4e} | n={len(y_true)} | "f"min={var_per_dim.min():.4e}, max={var_per_dim.max():.4e}")
        
        # ClearML logging 
        logger.report_scalar("Collapse Check |zvm < 1 ⚠️|", "Z Var Mean", zvm * 1e4, epoch+1)
             
    
        #=============== Save Model =====================
        if f1 >= best_f1 + 1e-4:
            
            # ensure patience counter resets to prevent early stopping
            patience_counter = 0
            
            cr_str = classification_report(
                y_true, y_pred,
                target_names=class_names,
                digits=4,
                zero_division=0
            )
            print(cr_str)

            best_f1 = f1
            # eval_counter = 0
            torch.save(model.state_dict(), "best_f1_model.pth")
        
            # Save model
            best_path   = f"best_model_epoch_{epoch+1}.pth"
            torch.save(model.state_dict(), best_path)

            
            output_model.update_weights(
                weights_filename=best_path,  
                target_filename=f"model_best_f1.pth"
            )
            output_model.wait_for_uploads() 

        #================= UMAP =================================
            # --- UMAP visualization on validation embeddings ---
            X_val, y_val = extract_embeddings(val_eval_loader, model, DEVICE)
            # X_val, _, y_val = extract_embeddings_both(val_eval_loader, model, DEVICE)
            reducer      = umap.UMAP(n_neighbors=50, min_dist=0.2, metric='cosine')    # changed from 30 to 50 , 0.1 to 0.2
            umap_emb     = reducer.fit_transform(X_val)  # shape [N,2]
            
            
            #----------- Plotly approach for UMAP w/ labels -------------
            # log interactive UMAP with class labels / made function to be cleaner
            log_umap_plotly(umap_emb, y_val, epoch+1, logger, class_names)
            
    # Early stopping
    if avg_loss < best_combined_loss:
        best_combined_loss = avg_loss
        patience_counter  = 0
        
    else:
        patience_counter += 1
        print(f"Patience Counter: {patience_counter}")
        
        # when the early-stop count is at 10 OR GREATER, lower by 0.001 each epoch
        if patience_counter >= 10:
            
            # stop using MAE / focus solely on SimCLR loss
            lambda_mae = 0
            
            new_tau = max(tau_current - 0.001, TAU_MIN)
            if new_tau != tau_current:
                print(f"Lowering temperature: {tau_current:.6f} → {new_tau:.6f} (patience={patience_counter})")
            tau_current = new_tau
            # allow new temp to impact performance / then, if no improvement for 5 epochs it will decrease
            patience_counter = 5
        
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break
            
    logger.report_scalar("SimCLR Temperature", "tau_current", tau_current, epoch+1)

Epoch: 1/150
Epoch 1/150 | Total: 6.5248 | SimCLR: 6.5166 | MAE: 0.8234
[Linear Probe @ epoch 1]  Acc=0.9904, F1=0.8653, Sil=0.8046, DB=5.5595
[MLP @ epoch 1]    Acc=0.9863, F1=0.7380
[Z-var Mean @ epoch 1 |zvm < 1 ⚠️|] z_var_mean=2.4918e-05 | n=3426 | min=4.0419e-07, max=2.1776e-04
              precision    recall  f1-score   support

    AR_drone     0.6111    0.7333    0.6667        45
 Bebop_drone     0.7273    0.6038    0.6598        53
     Phantom     1.0000    1.0000    1.0000       109
     ambient     1.0000    1.0000    1.0000      3031
 mavic_pro_2     1.0000    1.0000    1.0000       188

    accuracy                         0.9904      3426
   macro avg     0.8677    0.8674    0.8653      3426
weighted avg     0.9907    0.9904    0.9904      3426

Epoch: 2/150


In [None]:
# # =============================
# # === Save Model Snapshot ====
# # =============================
# torch.save(model.state_dict(), "vit_simclr.pth")
# output_model.update_weights("vit_simclr.pth")

In [None]:
task.close()

In [None]:
# fix mae to not go to zero but be zero by end of epoch
# undersample classes
# scale ViT
# 