In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torchvision.transforms.v2 as transforms
from torchvision.transforms.v2 import functional as v2F
from torch.utils.data import Dataset, DataLoader, Sampler

from ultralytics import YOLO
from transformers import CLIPVisionModel, CLIPImageProcessor

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pandas as pd
import random
import platform
import gc
import psutil
import time

# from models import *
from custom_logging import *
from mean_teacher import *

# Hardware/system stuff

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [3]:
def get_system_type():
        system = platform.system()
        
        if system == "Linux":
            if "microsoft" in platform.uname().release.lower() or \
            "wsl" in platform.uname().release.lower():
                return "wsl"
            return "linux"
        elif system == "Windows":
            return "windows"
        else:
            return "other"

def get_num_workers():
    
    system_type = get_system_type()
    if system_type == "linux":
        return 10
    elif system_type == "windows":
        return 0
    elif system_type == "wsl":
        return 4
    else:
        return 0

get_system_type(), get_num_workers()

('linux', 10)

In [4]:
def get_gpu_memory_usage():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024 / 1024  # MB
    return 0

def print_memory_usage():
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Main process RSS: {mem_info.rss / 1024**3:.2f} GB")
    
    # Check worker processes
    children = process.children()
    for i, child in enumerate(children):
        try:
            child_mem = child.memory_info()
            print(f"Worker {i} RSS: {child_mem.rss / 1024**3:.2f} GB")
        except:
            pass

def get_system_memory_usage():
    return psutil.virtual_memory().percent

### Config

In [5]:
run_name = f"resnet_clip_yolo_mean_teacher_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

logger = Logger(log_dir="./logs", experiment_name=run_name)

class Config:
    device = device
    use_amp = True
    batch_size = 40 # 20 # basically the max

    num_classes = 3

    initial_lr = 5e-5
    lr_backbone = 1e-5
    consistency_weight = 2.2 # 0.6
    ema_decay = 0.99

    warmup_steps = 40000 # scaling (in num batches) from 0 to consistency_weight for consistency loss

    cur_epoch = 0
    num_epochs = 4
    freeze_until_epoch = 0

    checkpoint_path = f"./checkpoints/{run_name}"
    log_interval = 5
    
    # Create directories
    os.makedirs(checkpoint_path, exist_ok=True)

config = Config()

logger.log_config(config)

Logging to: ./logs/resnet_clip_yolo_mean_teacher_20251029_112330


In [6]:
dry_run = None # set to None if not a dry run, set to desired num of rows if dry run
# dry_run = config.batch_size * 100

# MODELS

In [7]:
class YOLOv11(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = YOLO("yolov11l-face.pt").model
        # self.backbone = torch.nn.Sequential(*list(self.model.model.children())[:7])  # Stops after C3k2 (layer 6)
        self.feature_model = torch.nn.Sequential(*list(self.model.model.children())[:10])  # Stops after SPPF (layer 9)
        
    def forward(self, x):
        return self.feature_model(x)

# model = YOLOv11().to(device)
# features = model(images) 
# features.shape # torch.Size([B, 512, 20, 20])

In [8]:
CLIP_PROCESSOR = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

class CLIP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        # CLIP's final hidden state before projection (not the projection itself)
        self.clip_output_dim = self.clip_model.config.hidden_size 

    def forward(self, x):

        outputs = self.clip_model(**x)
        pooled_output = outputs.pooler_output  # shape: [batch_size, 512]
        return pooled_output

# img = CLIP_PROCESSOR(Image.open("image.jpg"), return_tensors="pt").to(device)
# img['pixel_values'].shape # torch.Size([1, 3, 224, 224])
# model = CLIP().to(device)
# outputs = model(img) 
# outputs.shape # torch.Size([B, 768])

In [9]:
class ResNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # self.resnet = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        
        self.feature_extractor = nn.Sequential(*list(self.resnet.children())[:-3])
        
    def forward(self, x):
        features = self.feature_extractor(x) # [B, 1024, H/16, W/16]
        return features
    
# model = ResNet152().to(device)
# outputs = model(torch.randn(16, 3, 224, 224).to(device)) 
# outputs.shape # torch.Size([B, 1024, 14, 14])

In [10]:
class BiggerClassifier(torch.nn.Module):
    def __init__(self, output_dim=3):
        super().__init__()
        self.clip = CLIP() # CLIP outputs: [B, 768]
        self.yolo = YOLOv11() # YOLO outputs: [B, 512, 20, 20]
        self.resnet = ResNet() # ResNet outputs: [B, 1024, H/16, W/16]

        # Global average pooling for feature maps
        self.yolo_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        # self.resnet_pool = torch.nn.AdaptiveAvgPool2d((1, 1))

        # self.fc1 = torch.nn.Linear(768 + 512 + 1024, 2048)
        self.fc1 = torch.nn.Linear(768 + 512, 2048)
        self.activation1 = torch.nn.GELU()
        self.dropout1 = torch.nn.Dropout(0.3)
        self.fc2 = torch.nn.Linear(2048, 1024)
        self.activation2 = torch.nn.GELU()
        self.dropout2 = torch.nn.Dropout(0.3)
        self.fc3 = torch.nn.Linear(1024, output_dim)
        
    def forward(self, clip_inputs, img_tensor):
        clip_features = self.clip(clip_inputs)  # [B, 768]
        yolo_features = self.yolo(img_tensor)  # [B, 512, 20, 20]
        # resnet_features = self.resnet(img_tensor) # [B, 1024, _, _]

        # Pool YOLO features to [B, 512, 1, 1] then to [B, 512]
        yolo_features = self.yolo_pool(yolo_features).flatten(1)
        # resnet_features = self.resnet_pool(resnet_features).flatten(1)

        # combined_features = torch.cat([clip_features, yolo_features, resnet_features], dim=1)  # [B, 2304]
        combined_features = torch.cat([clip_features, yolo_features], dim=1)
        
        x = self.fc1(combined_features)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        
        return x

# DATA

### transforms

In [11]:
to_tensor = transforms.Compose([
    transforms.ToImage(), 
    transforms.ToDtype(torch.float32, scale=True),
])

yolo_intermediate_input_size = 700
yolo_final_input_size = 640

base_transform = transforms.Compose([
    to_tensor,
    transforms.Resize(size=yolo_intermediate_input_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # Resize maintaining aspect ratio, then pad to square
    transforms.RandomRotation(degrees=(-15, 15), interpolation=transforms.InterpolationMode.BILINEAR, expand=True, fill=0),
    transforms.RandomCrop(yolo_final_input_size),
    transforms.RandomHorizontalFlip(p=0.5), # random flip
])


yolo_weak_transform = transforms.Compose([    
    transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.02),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.005), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])
yolo_strong_transform = transforms.Compose([    
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.01), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])

clip_base_transform = transforms.Compose([ 
    transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
])
clip_weak_transform = transforms.Compose([    
    transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.02),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.005), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])
clip_strong_transform = transforms.Compose([    
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.005), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])


In [12]:
yolo_val_transform = transforms.Compose([
    to_tensor,
    transforms.Resize(size=yolo_intermediate_input_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # Resize maintaining aspect ratio, then pad to square
    transforms.CenterCrop(yolo_final_input_size), 
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])

clip_val_transform = transforms.Compose([ 
    to_tensor,
    transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
    transforms.CenterCrop(224),
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])

### datasets/loaders

In [13]:
bad_images = []

class MeanTeacherDataset(Dataset):
    def __init__(self, csv_file, root_dir, val=False, supervised=False, supervised_ratio=0.5, upsample=None):
        self.root_dir = os.path.expanduser(root_dir) # root_dir
        self.annotations = pd.read_csv(os.path.join(self.root_dir, csv_file))
        if dry_run:
            self.annotations = self.annotations.sample(n=min(dry_run, len(self.annotations)), random_state=42)
            self.annotations = self.annotations.reset_index(drop=True)  # Reset index after sampling
        if supervised: # only provide supervised data
            # only labeled data:
            # self.annotations = self.annotations[self.annotations['label'] != 1].reset_index(drop=True)

            labeled = self.annotations[self.annotations['label'].isin([0, 2])]
            unlabeled = self.annotations[self.annotations['label'] == 1]
            n_labeled = len(labeled) 
            total_len = int(n_labeled / supervised_ratio)
            target_len = total_len - n_labeled
            # target_len = int(n_labeled * 0.5)
            if len(unlabeled) > target_len:
                unlabeled = unlabeled.sample(n=target_len) #, random_state=42)
            else: 
                print(f"not enough unlabeled. asked for {target_len}, only have {len(unlabeled)}")
            self.annotations = pd.concat([labeled, unlabeled], ignore_index=True)
        if upsample: # upsample supervised labels. Either an int, or default as None
            labeled = self.annotations[self.annotations['label'].isin([0, 2])].copy()
            unlabeled = self.annotations[self.annotations['label'] == 1].copy()
            n_labeled = len(labeled)
            n_unlabeled = len(unlabeled)
            print(f"Upsampling labeled data: {n_labeled} samples × {upsample} = {n_labeled * upsample}")
            print(f"Unlabeled data: {n_unlabeled} samples")
            # Create upsampled copies with unique IDs
            labeled_copies = []
            for i in range(upsample):
                labeled_copy = labeled.copy()
                # Append suffix to unique_id for each copy
                labeled_copy['unique_id'] = labeled_copy['unique_id'].astype(str) + f'_copy{i}'
                labeled_copies.append(labeled_copy)
            labeled_upsampled = pd.concat(labeled_copies, ignore_index=True)
            self.annotations = pd.concat([labeled_upsampled, unlabeled], ignore_index=True)
            self.annotations = self.annotations.sample(frac=1, random_state=42).reset_index(drop=True)
            print(f"Final dataset size: {len(self.annotations)} samples")
            print(f"Labeled ratio: {len(labeled_upsampled) / len(self.annotations):.2%}")
        # csv headers: relative_path,label,width,height,size_kb,source
        self.transform_times = []
        self.val = val

        self.error_log_path = f'dataset_errors_{"val" if val else "train"}.log'
        
    def __len__(self):
        return len(self.annotations)
    
    def get_id(self, idx):
        return self.annotations.at[idx, 'unique_id']
    
    def __getitem__(self, idx):

        try:
            # 1. process metadata

            original_label = self.annotations.iloc[idx, 1]
            label = torch.tensor(original_label, dtype=torch.long) # TODO: check this +0 offset should b right? 
            
            img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
            img_path = img_path.replace('\\', '/')
            image_pil = Image.open(img_path).convert('RGB')

            metadata = {
                'img_path': img_path,
                'label': original_label,
                'width': self.annotations.iloc[idx, 2],
                'height': self.annotations.iloc[idx, 3],
                'size_kb': self.annotations.iloc[idx, 4],
                'source': str(self.annotations.iloc[idx, 5]),
            }

            # 2. process images 

            if not self.val: # for training loop
                item_transform_start = time.time()
                
                base_image = base_transform(image_pil)
                
                # YOLO branch
                # yolo_base_image = base_transform(base_image)
                yolo_weak_image = yolo_weak_transform(base_image)
                yolo_strong_image = yolo_strong_transform(base_image)
                
                # CLIP branch
                clip_base_image = clip_base_transform(base_image)
                clip_weak_image = clip_weak_transform(clip_base_image)
                clip_strong_image = clip_strong_transform(clip_base_image)

                clip_weak_image = CLIP_PROCESSOR(images=clip_weak_image, return_tensors="pt", do_rescale=False)
                clip_weak_image['pixel_values'] = clip_weak_image['pixel_values'].squeeze(0)  # [1, 3, 224, 224] -> [3, 224, 224]
                clip_strong_image = CLIP_PROCESSOR(images=clip_strong_image, return_tensors="pt", do_rescale=False)
                clip_strong_image['pixel_values'] = clip_strong_image['pixel_values'].squeeze(0)  # [1, 3, 224, 224] -> [3, 224, 224]

                item_transform_time = time.time() - item_transform_start
                if len(self.transform_times) < 1000:
                    self.transform_times.append(item_transform_time)

                # clip_weak_image['pixel_values'] = clip_weak_image['pixel_values'].to(device, non_blocking=True)
                # clip_strong_image['pixel_values'] = clip_strong_image['pixel_values'].to(device, non_blocking=True)
                # yolo_weak_image = yolo_weak_image.to(device, non_blocking=True)
                # yolo_strong_image = yolo_strong_image.to(device, non_blocking=True)
                # label = label.to(device, non_blocking=True)

                return clip_weak_image, clip_strong_image, yolo_weak_image, yolo_strong_image, label, metadata

            else: # for validation
                yolo_image = yolo_val_transform(image_pil)

                clip_image = clip_val_transform(image_pil)
                clip_image = CLIP_PROCESSOR(images=clip_image, return_tensors="pt", do_rescale=False)
                clip_image['pixel_values'] = clip_image['pixel_values'].squeeze(0)
                return clip_image, clip_image, yolo_image, yolo_image, label, metadata
            
        except Exception as e:
            
            with open(os.path.join(logger.get_log_dir(), self.error_log_path), 'a') as f:
                f.write(f"Error loading image at index {idx}: {e}\n\n")

            # Create dummy CLIP inputs (matching CLIP_PROCESSOR output format)
            dummy_clip_input = {'pixel_values': torch.zeros(3, 224, 224)}
            dummy_yolo_tensor = torch.zeros(3, 640, 640)
            dummy_label = torch.tensor(1, dtype=torch.long)
            dummy_metadata = {
                'img_path': f'error_at_idx_{idx}',
                'label': 1,
                'width': 640,
                'height': 640,
                'size_kb': 0.0,
                'source': -1,
            }
            
            return dummy_clip_input, dummy_clip_input, dummy_yolo_tensor, dummy_yolo_tensor, dummy_label, dummy_metadata

In [14]:
class RandomVersionSampler(Sampler):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        
        self.grouped = {} # Group indices by ID
        for idx in range(len(base_dataset)):
            img_id = base_dataset.get_id(idx)
            if img_id not in self.grouped:
                self.grouped[img_id] = []
            self.grouped[img_id].append(idx)
        self.ids = list(self.grouped.keys())

    def __iter__(self):
        # For each ID, pick a random index
        chosen_indices = [random.choice(self.grouped[img_id]) for img_id in self.ids]
        # Shuffle the chosen indices for batching
        random.shuffle(chosen_indices)
        return iter(chosen_indices)

    def __len__(self):
        return len(self.ids)

In [15]:
from torch.utils.data._utils.collate import default_collate

def custom_collate_fn(batch):
    # Collate everything except metadata normally
    clip_weak = default_collate([item[0] for item in batch])
    clip_strong = default_collate([item[1] for item in batch])
    yolo_weak = default_collate([item[2] for item in batch])
    yolo_strong = default_collate([item[3] for item in batch])
    labels = default_collate([item[4] for item in batch])
    
    # Keep metadata as a list of dicts (no tensor conversion)
    metadata_list = [item[5] for item in batch]
    
    # Restructure to dict of lists for easier access
    metadata = {
        'img_path': [m['img_path'] for m in metadata_list],
        'label': [m['label'] for m in metadata_list],
        'width': [m['width'] for m in metadata_list],
        'height': [m['height'] for m in metadata_list],
        'size_kb': [m['size_kb'] for m in metadata_list],
        'source': [m['source'] for m in metadata_list],
    }
    
    return clip_weak, clip_strong, yolo_weak, yolo_strong, labels, metadata

##### The actual datasets and dataloaders

In [16]:
supervised_train_dataset = MeanTeacherDataset(
    csv_file = "train_2.csv", 
    root_dir = "~/Workspace/data-v2/train",
    supervised = True,
)
train_dataset = MeanTeacherDataset(
    csv_file = "train_2.csv", 
    root_dir = "~/Workspace/data-v2/train",
    val = False,
)
val_dataset = MeanTeacherDataset(
    csv_file = "val_2.csv", 
    root_dir = "~/Workspace/data-v2/val",
    val = True,
)

supervised_sampler = RandomVersionSampler(supervised_train_dataset)
supervised_train_dataloader = DataLoader(
    supervised_train_dataset, 
    batch_size=config.batch_size, 
    sampler=supervised_sampler,
    collate_fn=custom_collate_fn,
    num_workers=get_num_workers(),
    persistent_workers=False, # True if get_num_workers() > 0 else False,
    # pin_memory=False, # WSL does not support pin_memory well
    prefetch_factor=2 if get_num_workers() > 0 else None,
)

sampler = RandomVersionSampler(train_dataset)
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    sampler=sampler,
    collate_fn=custom_collate_fn,
    num_workers=get_num_workers(),
    persistent_workers=False, # True if get_num_workers() > 0 else False,
    # pin_memory=False, # WSL does not support pin_memory well
    prefetch_factor=2 if get_num_workers() > 0 else None,
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_size=config.batch_size, 
    collate_fn=custom_collate_fn,
    num_workers=get_num_workers(),
    persistent_workers=False, # True if get_num_workers() > 0 else False,
    # pin_memory=False, # WSL does not support pin_memory well
    prefetch_factor=3 if get_num_workers() > 0 else None,
)

len(train_dataset) - len(train_dataloader) * config.batch_size # > 0

487327

In [17]:
len(supervised_train_dataloader), len(train_dataloader)

(18472, 28399)

In [18]:
def benchmark_and_sanity_check(base_dataset, num_batches = 50):
    total_samples = 0
    print(f"Benchmarking dataloader for {num_batches} batches...")

    start_time = time.time()
    for i, (student_clip_inputs, teacher_clip_inputs, yolo_weak_tensors, yolo_strong_tensors, labels, _) in enumerate(base_dataset):
        clip_tensors = teacher_clip_inputs['pixel_values']
        yolo_tensors = yolo_strong_tensors
        if i >= num_batches:
            break
        total_samples += len(labels)
    total_time = time.time() - start_time

    avg_time_per_batch = total_time / num_batches
    avg_time_per_sample = total_time / total_samples
    samples_per_second = total_samples / total_time

    print(f"Total time: {total_time:.2f} seconds.", f"Total samples: {total_samples}")
    print(f"Average time per batch: {avg_time_per_batch:.4f} seconds.", f"Average time per sample: {avg_time_per_sample:.4f} seconds")
    print(f"Throughput: {samples_per_second:.2f} samples/second")

    # sanity check dataset
    clip, _, yolo, _, _, _= base_dataset[0]
    first_img = yolo.permute(1, 2, 0).cpu().numpy()  # [640, 640, 3]
    first_img = (first_img * 255).astype('uint8')
    pil_img = Image.fromarray(first_img)
    plt.imshow(pil_img)
    plt.show()

    # sanity check clip img
    first_img = clip_tensors[0]
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)
    first_img = first_img * std + mean
    first_img = first_img.permute(1, 2, 0).cpu().numpy()  # [224, 224, 3]
    first_img = (first_img * 255).astype('uint8')
    pil_img = Image.fromarray(first_img)
    plt.imshow(pil_img)
    plt.show()

    # sanity check yolo img
    first_img = yolo_tensors[0]
    first_img = first_img.permute(1, 2, 0).cpu().numpy()  # [640, 640, 3]
    first_img = (first_img * 255).astype('uint8')
    pil_img = Image.fromarray(first_img)
    plt.imshow(pil_img)
    plt.show()

    print("mean transform compute time:", np.mean(base_dataset.transform_times))

# benchmark_and_sanity_check(train_dataset, 10)
# benchmark_and_sanity_check(val_dataset, 10)


# TRAINING

In [19]:
model = BiggerClassifier().to(config.device)
teacher_model = copy.deepcopy(model).to(config.device)
# teacher_model.eval()
teacher_model.train()
for p in teacher_model.parameters():
    p.requires_grad = False

### learn rate, scheduler, optimiser

In [20]:
clip_params = []
yolo_params = []
resnet_params = []
classifier_params = []

for name, param in model.named_parameters():
    if 'clip' in name:
        clip_params.append(param)
    elif 'yolo' in name:
        yolo_params.append(param)
    elif 'resnet' in name:
        resnet_params.append(param)
        # i can separate out earlier layers if i want to
    else:
        classifier_params.append(param)

print("Freezing backbone parameters initially...")
for param in clip_params + yolo_params: # + resnet_params:
    param.requires_grad = False

optimizer = torch.optim.AdamW([
    {'params': clip_params, 'lr': config.lr_backbone*0.1, 'name': 'clip'},
    {'params': yolo_params, 'lr': config.lr_backbone*0.5, 'name': 'yolo'},
    {'params': resnet_params, 'lr': config.lr_backbone, 'name': 'resnet'},
    {'params': classifier_params, 'lr': config.initial_lr, 'name': 'classifier'}
])

# scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs, eta_min=1e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=1)
print('scheduler lr:', scheduler.get_last_lr())


Freezing backbone parameters initially...
scheduler lr: [1.0000000000000002e-06, 5e-06, 1e-05, 5e-05]


In [21]:
# import torch.nn as nn

# for name, module in model.named_modules():
#     if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
#         print(f"Found BatchNorm: {name} ({type(module).__name__})")

### loss and criterion

In [22]:
class AsymmetricFocalLoss(nn.Module):
    """
    Asymmetric Focal Loss variant that combines focal loss with asymmetric penalties.
    Useful when you also want to handle class imbalance.
    """
    def __init__(self, gamma=2.0, alpha=None, confusion_penalty_matrix=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        
        if confusion_penalty_matrix is None:
            confusion_penalty_matrix = torch.tensor([
                [1.0, 1.0, 1.0],   # True: BAD 
                [1.0, 1.0, 1.0],   # True: NEUTRAL
                [1.0, 1.0, 1.0]    # True: GOOD
            ])
        self.confusion_penalty_matrix = confusion_penalty_matrix
        
    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        probs = F.softmax(logits, dim=1)
        p_t = probs.gather(1, targets.view(-1, 1)).squeeze(1)
        
        # print(f"p_t min: {p_t.min():.6f}, max: {p_t.max():.6f}")
        # print(f"ce_loss min: {ce_loss.min():.6f}, max: {ce_loss.max():.6f}")
        
        # Focal term
        focal_weight = (1 - p_t) ** self.gamma

        # === Expected penalty ===
        # penalty_matrix: [C, C] where penalty_matrix[true, pred] gives penalty
        # penalty_for_true: [B, C] rows correspond to each sample's true class
        penalty_for_true = self.confusion_penalty_matrix[targets]  # shape [B, num_classes]

        # Expected penalty under predicted distribution
        expected_penalty = (probs * penalty_for_true).sum(dim=1)  # shape [B]

        # Combine focal weight with expected penalties
        loss = focal_weight * ce_loss * expected_penalty

        # Optional alpha weighting
        if self.alpha is not None:
            alpha_t = self.alpha.gather(0, targets)
            loss = alpha_t * loss
        
        # # Get predicted classes for confusion penalties
        # pred_classes = torch.argmax(logits, dim=1)
        
        # # Apply confusion-based penalties
        # batch_size = targets.size(0)
        # penalties = torch.zeros(batch_size, device=targets.device)
        
        # # for i in range(batch_size):
        # #     true_class = targets[i].item()
        # #     pred_class = pred_classes[i].item()
        # #     penalties[i] = self.confusion_penalty_matrix[true_class, pred_class]
        # penalties = self.confusion_penalty_matrix[targets, pred_classes]
        
        # # Combine focal weight with confusion penalties
        # loss = focal_weight * ce_loss * penalties
        
        # if self.alpha is not None:
        #     alpha_t = self.alpha.gather(0, targets)
        #     loss = alpha_t * loss
            
        return loss.mean()


In [23]:
ce_criterion = nn.CrossEntropyLoss().to(config.device)

af_criterion = AsymmetricFocalLoss(
    gamma=1.2,
    alpha=torch.tensor([1.1, 0.9, 1.2]).to(config.device),  
    confusion_penalty_matrix=torch.tensor([
        [1.0, 1.05, 1.15], 
        [0.88, 1.0, 0.90],
        [1.1, 1.05, 1.0]
    ]).to(config.device)
).to(config.device)


In [24]:
# i=0
# logits = torch.tensor([[10, 1, 1]]).to(device) * 1.5  # 5 samples, 3 classes
# # logits = torch.randn(5,3).to(device) * 1.5  # 5 samples, 3 classes
# targets = torch.tensor([0] * 10).to(device)

# # Forward pass
# probs = F.softmax(logits[i].unsqueeze(0), dim=1)
# loss = af_criterion(logits[i].unsqueeze(0), targets[i].unsqueeze(0))
# ce_loss = ce_criterion(logits[i].unsqueeze(0), targets[i].unsqueeze(0))

# print(probs, targets[i].unsqueeze(0).item())
# print(loss.item())
# print(ce_loss.item())

### util

In [25]:
# scaler = GradScaler() 
monitor = PerformanceMonitor()

In [26]:
def save_checkpoint(model, teacher_model, optimizer, scheduler, epoch, global_step, config, 
                    val_accuracy, monitor, loss_history, cls_loss_history):

    checkpoint_dir = config.checkpoint_path
    os.makedirs(checkpoint_dir, exist_ok=True)

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    filename = f"checkpoint_epoch{epoch}_step{global_step}_acc{val_accuracy:.4f}_{timestamp}.pth"
    
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    
    checkpoint = {
        'epoch': epoch,
        'global_step': global_step,
        'model_state_dict': model.state_dict(),
        'teacher_state_dict': teacher_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_accuracy': val_accuracy,
        'loss_history': loss_history,
        'cls_loss_history': cls_loss_history,
        'monitor_state': {
            'best_accuracy': monitor.best_accuracy,
            'epochs_without_improvement': monitor.epochs_without_improvement,
            'accuracy_history': monitor.accuracy_history,
        },
        'config': {k: v for k, v in vars(config).items() if not k.startswith('_')},
        
        'save_timestamp': datetime.now().isoformat(),
        'pytorch_version': torch.__version__,
    }
    
    try:
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved successfully: {checkpoint_path}")
        print(f"Epoch: {epoch} | Step: {global_step} | Val Acc: {val_accuracy:.4f}")
        return checkpoint_path
    except Exception as e:
        print(f"Error saving checkpoint: {e}")
        raise

def load_checkpoint(checkpoint_path, model, teacher_model, optimizer, scheduler, config, 
                    monitor, device='cuda', strict=True):

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    print(f"Loading checkpoint from: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
    teacher_model.load_state_dict(checkpoint['teacher_state_dict'], strict=strict)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Restore training state
    epoch = checkpoint['epoch']
    global_step = checkpoint['global_step']
    val_accuracy = checkpoint['val_accuracy']
    loss_history = checkpoint['loss_history']
    cls_loss_history = checkpoint['cls_loss_history']
    
    # Restore performance monitor
    monitor_state = checkpoint['monitor_state']
    monitor.best_accuracy = monitor_state['best_accuracy']
    monitor.accuracy_history = monitor_state['accuracy_history']
    
    # Update config with saved values (optional - be careful with paths)
    saved_config = checkpoint['config']
    for key, value in saved_config.items():
        if hasattr(config, key) and key not in ['checkpoint_path', 'log_dir']:
            setattr(config, key, value)
    
    config.cur_epoch = epoch + 1
    
    print(f"Checkpoint loaded successfully")
    print(f"Resuming from Epoch: {epoch + 1} | Step: {global_step}")
    print(f"Previous Val Acc: {val_accuracy:.4f} | Best Acc: {monitor.best_accuracy:.4f}")
    print(f"Loaded len(loss_history) = {len(loss_history)}")
    
    return epoch, global_step, loss_history, cls_loss_history, val_accuracy

# Example usage in your training loop:
"""
# At the end of each epoch, after validation:
if val_accuracy > monitor.best_accuracy:
    save_best_checkpoint(
        model, teacher_model, optimizer, scheduler,
        epoch, global_step, config, val_accuracy,
        monitor, loss_history, cls_loss_history
    )

# To resume training:
checkpoint_path = "./checkpoints/your_run_name/checkpoint_epoch10_step50000.pth"
epoch, global_step, loss_history, cls_loss_history, val_accuracy = load_checkpoint(
    checkpoint_path, model, teacher_model, optimizer, scheduler,
    config, monitor, device=config.device
)

# Then continue training loop from config.cur_epoch
"""

'\n# At the end of each epoch, after validation:\nif val_accuracy > monitor.best_accuracy:\n    save_best_checkpoint(\n        model, teacher_model, optimizer, scheduler,\n        epoch, global_step, config, val_accuracy,\n        monitor, loss_history, cls_loss_history\n    )\n\n# To resume training:\ncheckpoint_path = "./checkpoints/your_run_name/checkpoint_epoch10_step50000.pth"\nepoch, global_step, loss_history, cls_loss_history, val_accuracy = load_checkpoint(\n    checkpoint_path, model, teacher_model, optimizer, scheduler,\n    config, monitor, device=config.device\n)\n\n# Then continue training loop from config.cur_epoch\n'

In [27]:
def validate(model, val_loader, config, epoch):
    torch.cuda.empty_cache()
    model.eval()

    # cumulative lists for dataframes
    image_data = []
    batch_data = []
    total_loss = 0.0

    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f'Validation Epoch {epoch}', total=len(val_loader), leave=False)

        for batch_num, batch_input in enumerate(val_pbar):
            clip_inputs, _, yolo_tensors, _, labels, metadata = batch_input
            clip_inputs['pixel_values'] = clip_inputs['pixel_values'].to(config.device)
            yolo_tensors = yolo_tensors.to(config.device)
            labels = labels.to(config.device)

            # Forward pass
            outputs = model(clip_inputs, yolo_tensors)
            loss = af_criterion(outputs, labels).cpu().numpy()
            probs = F.softmax(outputs, dim=1).cpu().numpy()
            predictions = outputs.argmax(dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            # Add batch-level data
            batch_data.append({
                'batch_num': batch_num,
                'epoch': epoch,
                'loss': loss.item()
            })

            # Add image-level data
            for i in range(len(labels)):
                true_label = labels[i]
                pred_label = predictions[i]

                # Check for severe misclassifications
                # if true_label == 2 and pred_label == 0:  # True: GOOD, Pred: BAD
                    # good_predicted_as_bad += 1
                    # print(f"[WARNING] Predicted BAD on true GOOD: {metadata['img_path'][i]}")
                # elif true_label == 0 and pred_label == 2:  # True: BAD, Pred: GOOD
                    # bad_predicted_as_good += 1
                    # print(f"[WARNING] Predicted GOOD on true BAD: {metadata['img_path'][i]}")

                image_data.append({
                    'batch_num': batch_num,
                    'img_path': metadata['img_path'][i],
                    'label': int(true_label),
                    'width': int(metadata['width'][i]),
                    'height': int(metadata['height'][i]),
                    'size_kb': float(metadata['size_kb'][i]),
                    'source': metadata['source'][i],
                    'prediction': int(pred_label),
                    'bad': float(probs[i, 0]),
                    'neutral': float(probs[i, 1]),
                    'good': float(probs[i, 2])
                })

            val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})

            total_loss += loss.item()

    image_df = pd.DataFrame(image_data)
    batch_df = pd.DataFrame(batch_data)

    torch.cuda.empty_cache()
    gc.collect()

    return image_df, batch_df, total_loss

In [28]:
def analyse(image_df, batch_df, config, monitor, epoch):
    # Calculate metrics from dataframe
    labeled_df = image_df[image_df['label'] != 1]
    total = len(labeled_df)
    correct = (labeled_df['label'] == labeled_df['prediction']).sum()
    accuracy = correct / total

    # Calculate per-class metrics
    class_names = ['bad', 'neutral', 'good']
    labels=[0, 1, 2]
    all_labels = image_df['label'].values
    all_preds = image_df['prediction'].values

    # Confusion matrix and classification report
    cm = confusion_matrix(y_true=all_labels, y_pred=all_preds, labels=labels)
    report = classification_report(y_true=all_labels, y_pred=all_preds, labels=labels, target_names=class_names,
                                   output_dict=True, zero_division=0)

    # Calculate per-class recall
    class_recall = []
    for class_idx in range(config.num_classes):
        class_mask = image_df['label'] == class_idx
        if class_mask.sum() > 0:
            class_correct = ((image_df['label'] == class_idx) &
                           (image_df['prediction'] == class_idx)).sum()
            class_recall.append(class_correct / class_mask.sum())
        else:
            class_recall.append(0.0)

    # Compile metrics dictionary
    avg_val_loss = batch_df['loss'].mean()

    metrics = {
        "val/accuracy": accuracy,
        "val/loss": avg_val_loss,
        "val/recall_bad": class_recall[0],
        "val/recall_neutral": class_recall[1],
        "val/recall_good": class_recall[2],
    }

    # Add precision and F1 scores from classification report
    for class_name in class_names:
        if class_name in report:
            metrics[f"val/precision_{class_name}"] = report[class_name]['precision']
            metrics[f"val/f1_{class_name}"] = report[class_name]['f1-score']

    # Add severe misclassification counts
    good_as_bad = image_df[(image_df['label'] == 2) & (image_df['prediction'] == 0)]
    bad_as_good = image_df[(image_df['label'] == 0) & (image_df['prediction'] == 2)]
    metrics["val/bad_as_good_count"] = len(bad_as_good)
    metrics["val/good_as_bad_count"] = len(good_as_bad)

    # Update monitor
    monitor.accuracy_history.append(accuracy)
    if accuracy > monitor.best_accuracy:
        monitor.best_accuracy = accuracy
        monitor.epochs_without_improvement = 0
    else:
        monitor.epochs_without_improvement += 1

    print(f"Validation Summary - Epoch {epoch}")
    print(f"Accuracy: {accuracy:.4f} | Loss: {avg_val_loss:.4f}")

    return accuracy, metrics, cm

##### train_one_epoch()

In [29]:
scaler = GradScaler()

In [30]:
def analyse_epoch(image_df, batch_df, config, epoch):

    avg_loss = batch_df['loss'].mean()
    avg_cls_loss = batch_df['cls_loss'].mean()
    avg_consistency_loss = batch_df['consistency_loss'].mean()
    avg_consistency_weight = batch_df['consistency_weight'].mean()
    # total_time = batch_df['batch_time'].sum()
    # avg_throughput = batch_df['images_per_second'].mean()
    
    metrics = {
        "epoch/loss": avg_loss,
        "epoch/cls_loss": avg_cls_loss,
        "epoch/consistency_loss": avg_consistency_loss,
        "epoch/consistency_weight": avg_consistency_weight,
        # "epoch/time_minutes": total_time / 60,
        # "epoch/avg_throughput": avg_throughput,
    }
    
    if image_df is not None:
        total_samples = len(image_df)
        correct = (image_df['label'] == image_df['prediction']).sum()
        train_accuracy = correct / total_samples
        
        # Per-class accuracy
        for class_idx in range(config.num_classes):
            class_mask = image_df['label'] == class_idx
            if class_mask.sum() > 0:
                class_correct = ((image_df['label'] == class_idx) & 
                               (image_df['prediction'] == class_idx)).sum()
                class_acc = class_correct / class_mask.sum()
                metrics[f"epoch/accuracy_class_{class_idx}"] = class_acc
        
        # Average consistency loss per image
        metrics["epoch/avg_image_consistency_loss"] = image_df['consistency_loss'].mean()
        metrics["epoch/train_accuracy"] = train_accuracy
        
        # Severe misclassifications
        good_as_bad = image_df[(image_df['label'] == 2) & (image_df['prediction'] == 0)]
        bad_as_good = image_df[(image_df['label'] == 0) & (image_df['prediction'] == 2)]
        metrics["epoch/bad_as_good_count"] = len(bad_as_good)
        metrics["epoch/good_as_bad_count"] = len(good_as_bad)
    
    return metrics

In [31]:
def plot_running_loss(loss_history, save_path, window_size=10):
    """
    Plot running loss with moving average and save to file
    
    Args:
        loss_history: List of loss values
        save_path: Path to save the plot
        window_size: Window size for moving average
    """
    if len(loss_history) < window_size:
        return
    
    # Calculate moving average
    moving_avg = []
    for i in range(window_size - 1, len(loss_history)):
        window = loss_history[i - window_size + 1:i + 1]
        moving_avg.append(sum(window) / window_size)
    
    plt.figure(figsize=(12, 6))
    

    # Plot moving average in bold
    x_raw = [i * config.log_interval for i in range(len(loss_history))]
    x_moving = [i * config.log_interval for i in range(window_size - 1, len(loss_history))]

    # Plot raw loss in light color
    plt.plot(x_raw, loss_history, alpha=0.3, color='blue', label='Raw Loss')
    
    # Plot moving average in bold
    plt.plot(x_moving, moving_avg, color='red', linewidth=2, label=f'Moving Avg (window={window_size})')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Time')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Add statistics
    if moving_avg:
        current_avg = moving_avg[-1]
        min_avg = min(moving_avg)
        plt.axhline(y=current_avg, color='green', linestyle='--', alpha=0.5, label=f'Current: {current_avg:.4f}')
        plt.axhline(y=min_avg, color='orange', linestyle='--', alpha=0.5, label=f'Min: {min_avg:.4f}')
    
    # if 'train_loader' in globals() and len(train_loader) > 0:
    #     batches_per_epoch = len(train_loader)
    #     num_epochs = len(loss_history) // batches_per_epoch
    #     for epoch in range(1, num_epochs + 1):
    #         epoch_batch = epoch * batches_per_epoch
    #         if epoch_batch < len(loss_history):
    #             plt.axvline(x=epoch_batch, color='gray', linestyle=':', alpha=0.5)
    #             plt.text(epoch_batch, plt.ylim()[1] * 0.95, f'Epoch {epoch}', 
    #                     rotation=90, verticalalignment='top', fontsize=8, alpha=0.7)
                
    plt.tight_layout()
    plt.savefig(save_path, dpi=100)
    plt.close()
    
    print(f"  Loss graph saved to: {save_path}")

# TRAINING LOOP

In [32]:
# checkpoint_path = "checkpoints/resnet_clip_yolo_mean_teacher_20251021_012456/checkpoint_epoch0_step8000_acc0.4842_20251021_025438.pth"

# epoch, global_step, loss_history, cls_loss_history, val_accuracy = load_checkpoint(
#     checkpoint_path, model, teacher_model, optimizer, scheduler,
#     config, monitor, device=config.device
# )


In [33]:
criterion = ce_criterion
validation_frequency = len(train_dataloader) // 7 # validate every X batches
validation_frequency = validation_frequency - (validation_frequency % config.log_interval)
validation_frequency

4055

In [34]:
# ema_decay = 0.999
# try warmup_steps = 60000 or more
# low consistency weight = 0.2 NOT REALLY
# teacher receives student buffers for batchnorms NOPE
# try freezing backbone
# lower consistency weight on 0 and 2
# "semi supervised" start epoch

# next - try high consistency weight - the magnitudes should be roughly equal ... not like 1:50

In [35]:
global_step = 0
track_images = False
loss_history = []
cls_loss_history = []

In [None]:
# I HAVE NO WAY TO REDUCE CLS LOSS BASED ON TEACHER MODEL. 
# IF ITS CERTAIN THE LABEL IS WRONG< I HAVENT DONE IT
# I REALLY NEED TO DO THIS

# DO AN ACTUAL RUN WITHOUT TEACHER ! just see what happens. 



In [None]:
# monitor mode collapse
recent_predictions = [] 
max_recent_batches = 15

# if epoch == config.freeze_until_epoch:
if global_step >= 500:
    # print(f"Unfreezing backbone at epoch {epoch}")
    print(f"unfreezing backbone at step {global_step}")
    for param in clip_params + yolo_params:
        param.requires_grad = True

for epoch in range(config.cur_epoch, config.num_epochs):
    collapse_flag = False # model pred mode collapse
    epoch_start = time.time()

    # region Train one epoch ########################################
    model.train()
    teacher_model.train()
    # Cumulative lists for dataframes
    # image_data = [] if track_images else None
    batch_data = []

    if epoch <= 5:
        # train_loader = supervised_train_dataloader
        ratios = [0.66, 0.55, 0.55, 0.48, 0.40, 0.33, 0.01]
        supervised_train_dataset = MeanTeacherDataset(
            csv_file = "train_2.csv", 
            root_dir = "~/Workspace/data-v2/train",
            # supervised = True,
            # supervised_ratio = ratios[epoch],
            upsample = 3,
        )
        supervised_sampler = RandomVersionSampler(supervised_train_dataset)
        train_loader = DataLoader(
            supervised_train_dataset, 
            batch_size=config.batch_size, 
            sampler=supervised_sampler,
            collate_fn=custom_collate_fn,
            num_workers=get_num_workers(),
            persistent_workers=False, # True if get_num_workers() > 0 else False,
            # pin_memory=False, # WSL does not support pin_memory well
            prefetch_factor=2 if get_num_workers() > 0 else None,
        )
        print(ratios[epoch], len(train_loader))
    else: 
        train_loader = train_dataloader
    if epoch >= 1:
        criterion = af_criterion
    else:
        criterion = ce_criterion
    
    epoch_pbar = tqdm(train_loader, desc=f'Epoch {epoch}', total=len(train_loader))
    for batch_num, batch_input in enumerate(epoch_pbar):
        if global_step == 500 : #len(train_dataloader)//4:
            print(f"unfreezing backbone at step {global_step}")
            for param in clip_params + yolo_params:
                param.requires_grad = True
                
        batch_start = time.time()

        # 1. Grab data from dataloader
        clip_weak, clip_strong, yolo_weak, yolo_strong, labels, metadata = batch_input
        clip_weak['pixel_values'] = clip_weak['pixel_values'].to(device, non_blocking=True)
        clip_strong['pixel_values'] = clip_strong['pixel_values'].to(device, non_blocking=True)
        yolo_weak = yolo_weak.to(device, non_blocking=True)
        yolo_strong = yolo_strong.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # 2. Forward pass 
        with torch.no_grad():
            with autocast(device_type='cuda', enabled=False):
                teacher_outputs = teacher_model(clip_weak, yolo_weak)
        with autocast(device_type='cuda'):
            student_outputs = model(clip_strong, yolo_strong)
            cls_loss = criterion(student_outputs, labels)

            # Compute per-sample consistency losses
            consistency_losses = compute_consistency_loss(student_outputs, teacher_outputs, entropy_threshold=max(0.25, 0.7 - global_step / 90000))

            # Create per-image consistency weights with warmup # currently sigmoid scheduling
            if global_step > config.warmup_steps:
                warmup_factor = 1.0
            else:
                phase = 1.0 - global_step / config.warmup_steps # p(x) = 1 - x/30000
                warmup_factor = np.exp(-5.0 * phase * phase) # f(x) = e^(-5*p(x)*p(x))
            # warmup_factor = min(1.0, global_step / config.warmup_steps) # linear
            base_weight = config.consistency_weight * warmup_factor
            # higher consistency weights for unlabeled data. lower for labeled. 
            consistency_weights = torch.zeros_like(consistency_losses)
            unlabeled_mask = (labels == 1)
            labeled_mask = ~unlabeled_mask
            consistency_weights[unlabeled_mask] = base_weight * 2.0
            consistency_weights[labeled_mask] = base_weight / 2.0
            
            # Apply per-image weights and compute mean
            weighted_consistency = (consistency_losses * consistency_weights).mean()
            loss = cls_loss + weighted_consistency

            # if not torch.isfinite(loss):
            #     print(f"[WARN] Non-finite loss at step {global_step}, skipping batch.")
            #     print("student_outputs finite:", torch.isfinite(student_outputs).all().item())
            #     print("teacher_outputs finite:", torch.isfinite(teacher_outputs).all().item())
            #     print(metadata)
            #     continue
        
        # 3. Backward pass
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        scaler.step(optimizer)
        scaler.update()

        # # forward pass
        # student_outputs = model(clip_strong, yolo_strong)
        # with torch.no_grad():
        #     teacher_outputs = teacher_model(clip_weak, yolo_weak)
        # cls_loss = criterion(student_outputs, labels)
        # consistency_losses = compute_consistency_loss(student_outputs, teacher_outputs)

        # if global_step > config.warmup_steps:
        #     warmup_factor = 1.0
        # else:
        #     current = max(0.0, min(float(global_step), config.warmup_steps))
        #     phase = 1.0 - current / config.warmup_steps
        #     warmup_factor = np.exp(-5.0 * phase * phase)

        # base_weight = config.consistency_weight * warmup_factor
        # consistency_weights = torch.full_like(consistency_losses, base_weight)
        # weighted_consistency = (consistency_losses * consistency_weights).mean()
        # loss = cls_loss + weighted_consistency

        # # --- Optional: check for NaNs ---
        # if not torch.isfinite(loss):
        #     print(f"[WARN] Non-finite loss at step {global_step}, skipping batch.")
        #     print("student_outputs finite:", torch.isfinite(student_outputs).all().item())
        #     print("teacher_outputs finite:", torch.isfinite(teacher_outputs).all().item())
        #     print(metadata)
        #     continue

        # # ======================================
        # # 3. Backward pass
        # # ======================================
        # optimizer.zero_grad(set_to_none=True)
        # loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        # optimizer.step()


        
        # 4. Update teacher model with EMA
        # if batch_num % 3 == 0:
            # update_ema_variables(model, teacher_model, alpha=config.ema_decay**3, global_step=global_step)
        update_ema_variables(model, teacher_model, alpha=config.ema_decay, global_step=global_step)
        
        if global_step % config.log_interval != 0:
            global_step += 1
            continue
        
        # monitor mode collapse
        with torch.no_grad():
            current_predictions = F.softmax(student_outputs, dim=1).argmax(dim=1).cpu().numpy()
        recent_predictions.append(current_predictions)
        if len(recent_predictions) > max_recent_batches:
            recent_predictions.pop(0)
        all_recent_preds = np.concatenate(recent_predictions)
        unique, counts = np.unique(all_recent_preds, return_counts=True)
        dominant_class = unique[np.argmax(counts)]
        dominant_ratio = counts.max() / len(all_recent_preds)
        if collapse_flag==False and dominant_ratio > 0.95:  # If >95% of predictions are same class
            print(f"[WARN] Possible mode collapse at step {global_step}!")
            print(f"       Class {dominant_class} represents {dominant_ratio:.1%} of last {len(all_recent_preds)} predictions")
            print(f"       Distribution: {dict(zip(unique, counts))}")
            collapse_flag = True

        # 5. Calculate metrics #############################################
        batch_time = time.time() - batch_start
        images_per_second = config.batch_size / batch_time
        
        # Move to CPU for storage
        loss_cpu = loss.item()
        cls_loss_cpu = cls_loss.item()
        weighted_consistency_cpu = weighted_consistency.item()
        
        loss_history.append(loss_cpu)
        cls_loss_history.append(cls_loss_cpu)

        # Add batch-level data
        batch_data.append({
            'batch_num': batch_num,
            'epoch': epoch,
            'global_step': global_step,
            'loss': loss_cpu,
            'cls_loss': cls_loss_cpu,
            'consistency_loss': weighted_consistency_cpu,
            'consistency_weight': consistency_weights[0].item(), # TODO: remove after sanity check
            # 'batch_time': batch_time,
            # 'images_per_second': images_per_second,
        })
        
        # Optionally track per-image data
        # if track_images:
        #     probs = F.softmax(student_outputs, dim=1).cpu().numpy()
        #     predictions = student_outputs.argmax(dim=1).cpu().numpy()
        #     labels_cpu = labels.cpu().numpy()
        #     consistency_losses_cpu = consistency_losses.cpu().numpy()
            
        #     for i in range(len(labels)):
        #         image_data.append({
        #             'batch_num': batch_num,
        #             'global_step': global_step,
        #             'img_path': metadata['img_path'][i],
        #             'label': int(labels_cpu[i]),
        #             'width': int(metadata['width'][i]),
        #             'height': int(metadata['height'][i]),
        #             'size_kb': float(metadata['size_kb'][i]),
        #             'source': metadata['source'][i],
        #             'prediction': int(predictions[i]),
        #             'bad': float(probs[i, 0]),
        #             'neutral': float(probs[i, 1]),
        #             'good': float(probs[i, 2]),
        #             'consistency_loss': float(consistency_losses_cpu[i]),
        #         })
        
        # Update progress bar
        epoch_pbar.set_postfix({
            'loss': f'{loss_cpu:.4f}',
            'cls': f'{cls_loss_cpu:.4f}',
            'cons': f'{weighted_consistency_cpu:.4f}',
            'img/s': f'{images_per_second:.1f}'
        })
        
        # Log to tensorboard/CSV at intervals
        if global_step % config.log_interval == 0:
            current_lr = optimizer.param_groups[-1]['lr'] # should correspond to my FC classifier layers
            
            train_metrics = {
                "train/loss": loss_cpu,
                "train/cls_loss": cls_loss_cpu,
                "train/consistency": weighted_consistency_cpu,
                "train/consistency_weight": consistency_weights[0].item(),
                "train/learning_rate": current_lr,
                # "train/images_per_second": images_per_second,
                "system/gpu_memory_mb": get_gpu_memory_usage()
            }
            
            logger.log_metrics(train_metrics, global_step)
            logger.log_train_step(global_step, epoch, {
                'loss': loss_cpu,
                'cls_loss': cls_loss_cpu,
                'consistency_loss': weighted_consistency_cpu,
                'learning_rate': current_lr,
                'consistency_weight': consistency_weights[0].item()
            })

        # 6. Intermittent validation
        if global_step > 0 and global_step % validation_frequency == 0:
            intermittent_epoch = "step" + str(global_step)
            print("Running intermittent validation...")
            val_image_df, val_batch_df, val_loss = validate(teacher_model, val_dataloader, config, intermittent_epoch)
            val_accuracy, val_metrics, cm = analyse(val_image_df, val_batch_df, config, monitor, intermittent_epoch)
            scheduler.step(val_accuracy)
            print('scheduler lr:', scheduler.get_last_lr())

            # logger.log_metrics(val_metrics, epoch)
            logger.log_validation(intermittent_epoch, val_metrics)
            logger.log_confusion_matrix(cm, ['bad', 'neutral', 'good'], intermittent_epoch)
            
            print(f"  Validation Accuracy: {val_accuracy:.4f}")
            print(f"  Best Accuracy: {monitor.best_accuracy:.4f}")

            if val_accuracy >= monitor.best_accuracy:    
                save_checkpoint(
                    model, teacher_model, optimizer, scheduler, epoch, global_step, 
                    config, val_accuracy, monitor, loss_history, cls_loss_history
                )
            
            plot_running_loss(loss_history, os.path.join(logger.get_log_dir(), f'loss_graph.png'))
            plot_running_loss(cls_loss_history, os.path.join(logger.get_log_dir(), f'cls_loss_graph.png'))
        global_step += 1

    # Create dataframes
    # image_df = pd.DataFrame(image_data) if track_images else None
    batch_df = pd.DataFrame(batch_data)

    train_image_df = None # image_df
    train_batch_df = batch_df

    # endregion #####################################################
    
    # Analyze epoch
    epoch_metrics = analyse_epoch(train_image_df, train_batch_df, config, epoch)
    logger.log_metrics(epoch_metrics, epoch)
    
    print(f"Epoch {epoch} Summary:")
    print(f"  Time: {(time.time() - epoch_start):.1f}s")
    print(f"  Avg Loss: {epoch_metrics['epoch/loss']:.4f}")
    print(f"  Avg Classification Loss: {epoch_metrics['epoch/cls_loss']:.4f}")
    print(f"  Avg Consistency Loss: {epoch_metrics['epoch/consistency_loss']:.4f}")
    # print(f"  Throughput: {epoch_metrics['epoch/avg_throughput']:.1f} images/second")

    print("="*80)
    
    # Validation
    print("Running end of epoch validation...")
    val_image_df, val_batch_df, val_loss = validate(teacher_model, val_dataloader, config, epoch)
    val_accuracy, val_metrics, cm = analyse(val_image_df, val_batch_df, config, monitor, epoch)
    
    logger.log_metrics(val_metrics, epoch)
    logger.log_validation(epoch, val_metrics)
    logger.log_confusion_matrix(cm, ['bad', 'neutral', 'good'], epoch)
    
    print(f"  Validation Accuracy: {val_accuracy:.4f}")
    print(f"  Best Accuracy: {monitor.best_accuracy:.4f}")
    print("="*80)
    print()
    print("="*80)
    
    save_checkpoint(
        model, teacher_model, optimizer, scheduler, epoch, global_step, 
        config, val_accuracy, monitor, loss_history, cls_loss_history
    )
    
    scheduler.step(val_accuracy)
    print('scheduler lr:', scheduler.get_last_lr())

plot_running_loss(loss_history, os.path.join(logger.get_log_dir(), f'loss_graph.png'))
plot_running_loss(cls_loss_history, os.path.join(logger.get_log_dir(), f'cls_loss_graph.png'))
cons_loss_history = [loss - cons_loss for loss, cons_loss in zip(loss_history, cls_loss_history)]
plot_running_loss(cons_loss_history, os.path.join(logger.get_log_dir(), f'cons_loss_graph.png'))
logger.close()

Upsampling labeled data: 512053 samples × 3 = 1536159
Unlabeled data: 1111234 samples
Final dataset size: 2647393 samples
Labeled ratio: 58.03%
0.66 42483


Epoch 0:   1%|          | 500/42483 [04:59<6:02:07,  1.93it/s, loss=0.7852, cls=0.7849, cons=0.0003, img/s=91.3]  

unfreezing backbone at step 500


Epoch 0:  10%|▉         | 4055/42483 [44:42<7:14:13,  1.47it/s, loss=0.5731, cls=0.5723, cons=0.0008, img/s=67.3]

Running intermittent validation...




Validation Summary - Epoch step4055
Accuracy: 0.3875 | Loss: 0.6555
scheduler lr: [1.0000000000000002e-06, 5e-06, 1e-05, 5e-05]
  Validation Accuracy: 0.3875
  Best Accuracy: 0.3875


Epoch 0:  10%|▉         | 4056/42483 [48:45<785:07:56, 73.55s/it, loss=0.5731, cls=0.5723, cons=0.0008, img/s=67.3]

Checkpoint saved successfully: ./checkpoints/resnet_clip_yolo_mean_teacher_20251029_112330/checkpoint_epoch0_step4055_acc0.3875_20251029_121230.pth
Epoch: 0 | Step: 4055 | Val Acc: 0.3875
  Loss graph saved to: ./logs/resnet_clip_yolo_mean_teacher_20251029_112330/loss_graph.png
  Loss graph saved to: ./logs/resnet_clip_yolo_mean_teacher_20251029_112330/cls_loss_graph.png


Epoch 0:  19%|█▉        | 8110/42483 [1:33:53<6:22:24,  1.50it/s, loss=0.8665, cls=0.7257, cons=0.1407, img/s=68.3]

Running intermittent validation...




Validation Summary - Epoch step8110
Accuracy: 0.3863 | Loss: 3.0379
scheduler lr: [1.0000000000000002e-06, 5e-06, 1e-05, 5e-05]
  Validation Accuracy: 0.3863
  Best Accuracy: 0.3875
  Loss graph saved to: ./logs/resnet_clip_yolo_mean_teacher_20251029_112330/loss_graph.png


Epoch 0:  19%|█▉        | 8111/42483 [1:37:55<700:12:51, 73.34s/it, loss=0.8665, cls=0.7257, cons=0.1407, img/s=68.3]

  Loss graph saved to: ./logs/resnet_clip_yolo_mean_teacher_20251029_112330/cls_loss_graph.png


Epoch 0:  29%|██▊       | 12165/42483 [2:23:32<5:40:55,  1.48it/s, loss=0.8743, cls=0.7121, cons=0.1621, img/s=68.1] 

Running intermittent validation...




Validation Summary - Epoch step12165
Accuracy: 0.4257 | Loss: 4.1521
scheduler lr: [1.0000000000000002e-06, 5e-06, 1e-05, 5e-05]
  Validation Accuracy: 0.4257
  Best Accuracy: 0.4257


Epoch 0:  29%|██▊       | 12166/42483 [2:27:39<628:31:00, 74.63s/it, loss=0.8743, cls=0.7121, cons=0.1621, img/s=68.1]

Checkpoint saved successfully: ./checkpoints/resnet_clip_yolo_mean_teacher_20251029_112330/checkpoint_epoch0_step12165_acc0.4257_20251029_135124.pth
Epoch: 0 | Step: 12165 | Val Acc: 0.4257
  Loss graph saved to: ./logs/resnet_clip_yolo_mean_teacher_20251029_112330/loss_graph.png
  Loss graph saved to: ./logs/resnet_clip_yolo_mean_teacher_20251029_112330/cls_loss_graph.png


Epoch 0:  33%|███▎      | 14176/42483 [2:50:11<5:27:35,  1.44it/s, loss=1.2942, cls=0.7701, cons=0.5241, img/s=59.7]  

### the lands down under

In [None]:
val_image_df, val_batch_df = validate(teacher_model, val_dataloader, config, -1)
val_accuracy, val_metrics, cm = analyse(val_image_df, val_batch_df, config, monitor, -1)
cm

                                                                                     

In [None]:
good_as_bad = val_image_df[(val_image_df['label'] == 2) & (val_image_df['prediction'] == 0)]
bad_as_good = val_image_df[(val_image_df['label'] == 0) & (val_image_df['prediction'] == 2)]

Validation Summary - Epoch -1
Accuracy: 0.4842 | Loss: 1.1626


array([[5280, 3550, 2577],
       [7105, 9833, 6210],
       [ 982, 1276, 2591]])

In [None]:
for s in bad_as_good["img_path"]:
    print(s)

In [None]:
from collections import Counter
num_batches = 100
hist_over_batches = []
for batch_i, (_, _, _, _, labels, _) in enumerate(supervised_train_dataloader):
    labels = labels.numpy().tolist()
    c = Counter(labels)
    hist_over_batches.append([c.get(0,0), c.get(1,0), c.get(2,0)])
    if batch_i+1 >= num_batches: break

import numpy as np
hist = np.array(hist_over_batches)
print("Per-batch means (bad,neut,good):", hist.mean(axis=0))
print("Fraction of batches with a single class:", np.mean((hist==0).sum(axis=1)==2))


Per-batch means (bad,neut,good): [     9.0645      11.903      2.8387]
Fraction of batches with a single class: 0.0


In [None]:

psutil.virtual_memory()

svmem(total=46286409728, available=40099004416, percent=13.4, used=5640769536, free=37585137664, active=1640988672, inactive=6467284992, buffers=300068864, cached=2760433664, shared=89407488, slab=213061632)

In [None]:
ab

NameError: name 'ab' is not defined

validate more frequently, my epochs are obscenely long
can I like ... cut out the resnet for now lmao

i should start with 1 epoch on labeled data
and mostly freeze 

and then let the consistency loss start creeping up after ^ epoch

# STUFF

In [None]:
# ignore below

In [None]:
test_dataset = MeanTeacherDataset(
    csv_file = "test_2.csv", 
    root_dir = "~/Workspace/data-v2/test",
    val = True,
)
test_dataloader = DataLoader(
    test_dataset, 
    batch_size=config.batch_size, 
    collate_fn=custom_collate_fn,
    num_workers=get_num_workers(),
    persistent_workers=False, # True if get_num_workers() > 0 else False,
    # pin_memory=False, # WSL does not support pin_memory well
    prefetch_factor=3 if get_num_workers() > 0 else None,
)
test_image_df, test_batch_df, loss = validate(teacher_model, test_dataloader, config, -1)
test_accuracy, test_metrics, cm = analyse(test_image_df, test_batch_df, config, monitor, -1)
cm

                                                                                     

Validation Summary - Epoch -1
Accuracy: 0.3433 | Loss: 0.6820


array([[ 3924,  7204,   279],
       [ 3703, 18692,   753],
       [  388,  2805,  1656]])

In [None]:
save_checkpoint(
        model, teacher_model, optimizer, scheduler, epoch, global_step, 
        config, val_accuracy, monitor, loss_history, cls_loss_history
    )

Checkpoint saved successfully: ./checkpoints/resnet_clip_yolo_mean_teacher_20251027_112613/checkpoint_epoch2_step54258_acc0.3441_20251027_221157.pth
Epoch: 2 | Step: 54258 | Val Acc: 0.3441


'./checkpoints/resnet_clip_yolo_mean_teacher_20251027_112613/checkpoint_epoch2_step54258_acc0.3441_20251027_221157.pth'

In [None]:

def benchmark(num_workers, num_batches=100):
    global dry_run
    
    # Temporarily set dry_run for benchmarking
    original_dry_run = dry_run
    dry_run = num_batches * config.batch_size  # Ensure we have enough samples
    
    # Create benchmark dataset and dataloader
    bench_dataset = MeanTeacherDataset(
        csv_file="train_2.csv", 
        root_dir= "/mnt/d/data-v2/train",#"~/data-v2/train" if get_system_type() == "wsl" else "D:\\data-v2\\train",
        val=False,
    )
    
    bench_sampler = RandomVersionSampler(bench_dataset)
    bench_dataloader = DataLoader(
        bench_dataset, 
        batch_size=config.batch_size, 
        sampler=bench_sampler,
        num_workers=num_workers,
        persistent_workers=True if num_workers > 0 else False,
    )
    
    # Reset model and optimizer to ensure fair comparison
    temp_model = BiggerClassifier().to(config.device)
    temp_teacher = copy.deepcopy(temp_model).to(config.device)
    temp_teacher.eval()
    for p in temp_teacher.parameters():
        p.requires_grad = False
    
    temp_optimizer = torch.optim.AdamW([
        {'params': temp_model.parameters(), 'lr': config.initial_lr}
    ])

    print(f"Benchmarking with num_workers={num_workers}...")
    start_time = time.time()
    batch_count = 0
    
    for epoch in range(1):
        
        global_step, train_image_df, train_batch_df = train_one_epoch(
            temp_model, temp_teacher, bench_dataloader, temp_optimizer, criterion, 
            config, epoch, logger, 0, 
            track_images=False  # Set to True only if you need detailed per-image analysis
        )

        
        
        # Validation
        # val_image_df, val_batch_df = validate(teacher_model, val_dataloader, config, epoch)
        # val_accuracy, val_metrics, cm = analyse(val_image_df, val_batch_df, config, monitor, epoch)



        batch_count += len(bench_dataloader)
    
    total_time = time.time() - start_time

    # Restore original dry_run
    dry_run = original_dry_run
    
    # Clean up
    del temp_model, temp_teacher, temp_optimizer, bench_dataset, bench_dataloader
    torch.cuda.empty_cache()
    
    return total_time



In [None]:
# Run benchmark
print("\n" + "="*60)
print("BENCHMARKING NUM_WORKERS")
print("="*60)

results = []
for num_workers in range(0, 16):

    elapsed_time = benchmark(num_workers, num_batches=100)
    batches_per_second = 100 / elapsed_time
    images_per_second = (100 * config.batch_size) / elapsed_time
    
    results.append({
        'num_workers': num_workers,
        'total_time': elapsed_time,
        'batches_per_second': batches_per_second,
        'images_per_second': images_per_second
    })
    
    print(f"\nnum_workers={num_workers:2d} | "
            f"Time: {elapsed_time:6.2f}s | "
            f"Batches/s: {batches_per_second:5.2f} | "
            f"Images/s: {images_per_second:6.1f}")


# Print summary
print("\n" + "="*60)
print("BENCHMARK SUMMARY")
print("="*60)
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False))

# Find optimal
if len(results) > 0:
    best_result = max(results, key=lambda x: x['images_per_second'])
    print(f"\nOptimal num_workers: {best_result['num_workers']} "
          f"({best_result['images_per_second']:.1f} images/s)")


In [None]:
def profile_training_step(model, teacher_model, dataloader, optimizer, criterion, config):
    model.train()
    # Profile
    print("Starting profiling...")
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, 
                   torch.profiler.ProfilerActivity.CUDA],
        schedule=torch.profiler.schedule(wait=40, warmup=2, active=3, repeat=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/profiler'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        for i, batch in enumerate(dataloader):
            if i >= 50: break
            
            clip_weak, clip_strong, yolo_weak, yolo_strong, labels, metadata = batch
            clip_weak['pixel_values'] = clip_weak['pixel_values'].to(device, non_blocking=True)
            clip_strong['pixel_values'] = clip_strong['pixel_values'].to(device, non_blocking=True)
            yolo_weak = yolo_weak.to(device, non_blocking=True)
            yolo_strong = yolo_strong.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            
            # with autocast(device_type='cuda'):
            student_outputs = model(clip_strong, yolo_strong)
            cls_loss = criterion(student_outputs, labels)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(clip_weak, yolo_weak)
            
            consistency_losses = compute_consistency_loss(student_outputs, teacher_outputs)
            warmup_factor = min(1.0, i / config.warmup_steps)
            base_weight = config.consistency_weight * warmup_factor
            consistency_weights = torch.full_like(consistency_losses, base_weight)
            weighted_consistency = (consistency_losses * consistency_weights).mean()
            loss = cls_loss + weighted_consistency
            loss.backward()
            # optimizer.zero_grad()
            # scaler.scale(loss).backward()
            # scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
            # scaler.step(optimizer)
            # scaler.update()
            optimizer.step()
            optimizer.zero_grad()
            
            update_ema_variables(model, teacher_model, alpha=config.ema_decay, global_step=i)
            
            prof.step()
    
    print("Profiling complete! Check ./logs/profiler for results")
    print("View with: tensorboard --logdir=./logs/profiler")

# Usage:
profile_training_step(model, teacher_model, train_dataloader, optimizer, criterion, config)

Starting profiling...
Profiling complete! Check ./logs/profiler for results
View with: tensorboard --logdir=./logs/profiler


environment: wsl


4

# NOTES

remember to collect the original file name without suffix when making your csv

if my teacher model strongly misclassifies, toss the file name out into a log

compute_consistency_loss should evolve over epochs - later on, use higher thresholding, and maybe move to sigmoid/binary thresholding
