# SHOT vs Coral Pipeline

This notebook pre-trains a CNN on a primary dataset split, then adapts to a target domain with Information Maximization followed by Pseudo-Label fine-tuning across multiple seeds. Evaluations on several validation sets are recorded and summarized.

## 1. Imports and Model Definition

Load libraries, define the CNN architecture, and import helpers.

In [None]:
import os, sys
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

from pkldataset import PKLDataset
from helpers import train_model, eval_model, set_seed

class CNN(nn.Module):
    def __init__(self, input_length: int = 2800, num_classes: int = 10, input_channels: int = 1):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(input_channels, 16, kernel_size=31, padding=15),
            nn.BatchNorm1d(16), nn.ReLU(inplace=True), nn.MaxPool1d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(16, 32, kernel_size=31, padding=15),
            nn.BatchNorm1d(32), nn.ReLU(inplace=True), nn.MaxPool1d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=31, padding=15),
            nn.BatchNorm1d(64), nn.ReLU(inplace=True), nn.MaxPool1d(2)
        )
        conv_output_length = input_length // 8
        self.feature_extractor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * conv_output_length, 128),
            nn.ReLU(inplace=True)
        )
        self.classifier_head = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 2:
            x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.feature_extractor(x)
        return self.classifier_head(x)

    def extract_features(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 2:
            x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return self.feature_extractor(x)


## 2. Configuration

- **Primary dataset**: path to pretraining folder  
- **Transfer set(s)**: domain(s) for adaptation  
- **Validation sets**: held-out datasets for evaluation  
- **Seeds**: random seeds for reproducibility

In [None]:
# Primary source path (will be split)
train_path_1 = r"C:\Users\gus07\Desktop\data hiwi\preprocessing\HC\T197\RP"

# Target adaptation sets
transfer_sets = ["../datasets/RPDC185/train_500"]

# Validation sets
val_paths = [
    "../datasets/RPDC185/val_1000",
    "../datasets/RPDC188/val_1000",
    "../datasets/RPDC191/val_1000",
    "../datasets/RPDC194/val_1000",
    "../datasets/RPDC197/val_1000",
]

# Random seeds
seeds = [101, 202, 303, 404, 505, 606, 707, 808, 909, 1001]

# Device and loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

# Results container
results = {t: {vp: [] for vp in val_paths} for t in transfer_sets}


## 3. Phase 1: Source Pretraining

Train on the primary dataset split once with a fixed seed, save pretrained weights, and evaluate.

In [None]:
# One-time source training
seed0 = 42
set_seed(seed0)

# Split source dataset
train_ds, val_ds = PKLDataset.split_dataset(train_path_1)
loader_tr = DataLoader(train_ds, batch_size=64, shuffle=True)
loader_va = DataLoader(val_ds, batch_size=64, shuffle=False)

# Initialize and train
source_model = CNN().to(device)
opt0 = optim.Adam(source_model.parameters(), lr=1e-3, weight_decay=1e-5)
sch0 = optim.lr_scheduler.StepLR(opt0, step_size=50, gamma=0.1)
source_model = train_model(
    source_model, loader_tr, criterion, opt0, sch0,
    num_epochs=10, device=device
)
print(f"Source (phase1) val-acc: {eval_model(source_model, loader_va, device):.2f}%")

# Save pretrained state
pretrained_state = source_model.state_dict()


## 4. Phase 2: Self-Supervised Adaptation

For each seed and transfer set:
1. Load pretrained model
2. Run Information Maximization for `im_only` epochs
3. Generate pseudo-labels and unfreeze head
4. Fine-tune with pseudo-label loss for remaining epochs
5. Evaluate on validation sets

In [None]:
for seed in seeds:
    print(f"\n>>> Adaptation seed {seed}")
    set_seed(seed)

    for t in transfer_sets:
        # Reload from pretrained
        student = CNN().to(device)
        student.load_state_dict(pretrained_state)
        student.train()

        # Freeze head
        for p in student.classifier_head.parameters():
            p.requires_grad = False
        feat_params = list(student.conv1.parameters()) + list(student.conv2.parameters()) + \
                      list(student.conv3.parameters()) + list(student.feature_extractor.parameters())

        shot_optimizer = optim.Adam(feat_params, lr=1e-4, weight_decay=1e-5)
        total_epochs = 100
        shot_scheduler = optim.lr_scheduler.CosineAnnealingLR(shot_optimizer, T_max=total_epochs, eta_min=1e-6)

        # Information Maximization settings
        lambda_shot = 1.0
        im_only = 50
        pseudo_labels_all = None
        pseudo_confidence = None

        dataset_t = PKLDataset(t)
        unlab_loader = DataLoader(dataset_t, batch_size=64, shuffle=True)
        full_loader_t = DataLoader(dataset_t, batch_size=64, shuffle=False)

        for epoch in range(total_epochs):
            student.train()
            # Information Maximization phase
            if epoch < im_only:
                running_Hc = running_Hm = 0.0
                for x, _ in unlab_loader:
                    x = x.to(device)
                    logits = student(x)
                    probs = torch.softmax(logits, dim=1)
                    logp = torch.log(probs + 1e-12)
                    Hc = -(probs * logp).sum(1).mean()
                    p_bar = probs.mean(0)
                    Hm = -(p_bar * torch.log(p_bar + 1e-12)).sum()
                    loss = Hc - lambda_shot * Hm

                    shot_optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(feat_params, 1.0)
                    shot_optimizer.step()

                    running_Hc += Hc.item() * x.size(0)
                    running_Hm += Hm.item() * x.size(0)

                if (epoch+1) % 20 == 0 or epoch == 0:
                    N = len(unlab_loader.dataset)
                    print(f"  [IM] Epoch {epoch+1}/{total_epochs}  H_cond={(running_Hc/N):.4f}  H_marg={(running_Hm/N):.4f}")

                shot_scheduler.step()

                # Generate pseudo-labels at end of IM
                if epoch == im_only - 1:
                    student.eval()
                    feats, probs_all = [], []
                    with torch.no_grad():
                        for x, _ in full_loader_t:
                            x = x.to(device)
                            f = student.extract_features(x)
                            p = torch.softmax(student.classifier_head(f), dim=1)
                            feats.append(f.cpu())
                            probs_all.append(p.cpu())
                    feats = torch.cat(feats)
                    probs_all = torch.cat(probs_all)

                    C = probs_all.size(1)
                    centroids = torch.zeros(C, feats.size(1))
                    for k in range(C):
                        w = probs_all[:, k].unsqueeze(1)
                        denom = w.sum()
                        if denom > 0:
                            centroids[k] = (w * feats).sum(0) / denom
                    dists = torch.cdist(feats, centroids)
                    pseudo_labels_all = torch.argmin(dists, dim=1)
                    pseudo_confidence = probs_all.max(1).values

                    # Unfreeze head and add to optimizer
                    for p in student.classifier_head.parameters():
                        p.requires_grad = True
                    shot_optimizer.add_param_group({
                        'params': student.classifier_head.parameters(),
                        'lr': shot_optimizer.param_groups[0]['lr'] * 0.2,
                        'weight_decay': 1e-5
                    })
                    student.train()
            # Pseudo-Label fine-tuning phase
            else:
                if pseudo_labels_all is None:
                    raise RuntimeError("No pseudo-labels!")
                running_pl = 0.0
                idxs = torch.randperm(len(dataset_t)).tolist()
                pl_subset = Subset(dataset_t, idxs)
                pl_loader = DataLoader(pl_subset, batch_size=64, shuffle=False)

                for b_idx, (x, _) in enumerate(pl_loader):
                    start = b_idx * pl_loader.batch_size
                    end = start + x.size(0)
                    batch_idx = idxs[start:end]
                    conf = pseudo_confidence[batch_idx].to(device)
                    labels = pseudo_labels_all[batch_idx].to(device)

                    x = x.to(device)
                    f = student.extract_features(x)
                    logits = student.classifier_head(f)

                    per_sample = nn.CrossEntropyLoss(reduction='none')(logits, labels)
                    mask = conf >= 0.3
                    loss = (per_sample[mask] * conf[mask]).mean() if mask.sum() > 0 else per_sample.mean()

                    shot_optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(shot_optimizer.param_groups[0]['params'], 1.0)
                    shot_optimizer.step()

                    running_pl += loss.item() * x.size(0)

                if (epoch+1) % 20 == 0 or epoch == im_only:
                    N = len(dataset_t)
                    print(f"  [PL] Epoch {epoch+1}/{total_epochs}  Pseudo-CE Loss: {(running_pl/N):.4f}")

                shot_scheduler.step()

        # Evaluation on validation sets
        student.eval()
        for vp in val_paths:
            vl = DataLoader(PKLDataset(vp), batch_size=64, shuffle=False)
            acc = eval_model(student, vl, device)
            results[t][vp].append(acc)
            print(f"  → {vp}: {acc:.2f}%")


## 5. Final Summary

Compute mean and standard deviation of accuracy across seeds for each transfer → validation pair.

In [None]:
print("\n=== Mean ± Std Dev over seeds ===")
for t in transfer_sets:
    for vp in val_paths:
        arr = np.array(results[t][vp])
        print(f"{t} → {vp}: mean={arr.mean():.2f}%, std={arr.std(ddof=1):.2f}%")
