In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torchvision.transforms.v2 as transforms
from torchvision.transforms.v2 import functional as v2F
from torch.utils.data import Dataset, DataLoader, Sampler

from ultralytics import YOLO
from transformers import CLIPVisionModel, CLIPImageProcessor

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pandas as pd
import random
import platform
import gc
import psutil
import time

# REMOVED: from mean_teacher import *
from custom_logging import *

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

def get_system_type():
        system = platform.system()
        
        if system == "Linux":
            if "microsoft" in platform.uname().release.lower() or \
            "wsl" in platform.uname().release.lower():
                return "wsl"
            return "linux"
        elif system == "Windows":
            return "windows"
        else:
            return "other"

def get_num_workers():
    
    system_type = get_system_type()
    if system_type == "linux":
        return 8
    elif system_type == "windows":
        return 0
    elif system_type == "wsl":
        return 4
    else:
        return 0

get_system_type(), get_num_workers()

def get_gpu_memory_usage():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024 / 1024  # MB
    return 0

def print_memory_usage():
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Main process RSS: {mem_info.rss / 1024**3:.2f} GB")
    
    # Check worker processes
    children = process.children()
    for i, child in enumerate(children):
        try:
            child_mem = child.memory_info()
            print(f"Worker {i} RSS: {child_mem.rss / 1024**3:.2f} GB")
        except:
            pass

def get_system_memory_usage():
    return psutil.virtual_memory().percent

# CHANGED: Removed "mean_teacher" from run name
run_name = f"sanitycheck_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

logger = Logger(log_dir="./logs", experiment_name=run_name)

class Config:
    device = device
    use_amp = True
    batch_size = 40

    num_classes = 3

    initial_lr = 5e-5
    lr_backbone = 1e-5
    # REMOVED: consistency_weight (not needed without mean teacher)
    # REMOVED: ema_decay (not needed without mean teacher)
    # REMOVED: warmup_steps (not needed without mean teacher)

    cur_epoch = 0
    num_epochs = 4
    freeze_until_epoch = 0

    checkpoint_path = f"./checkpoints/{run_name}"
    log_interval = 5
    
    # Create directories
    os.makedirs(checkpoint_path, exist_ok=True)

config = Config()

logger.log_config(config)

dry_run = None
# dry_run = config.batch_size * 100

class YOLOv11(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = YOLO("yolov11l-face.pt").model
        self.feature_model = torch.nn.Sequential(*list(self.model.model.children())[:10])
        
    def forward(self, x):
        return self.feature_model(x)

CLIP_PROCESSOR = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

class CLIP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_output_dim = self.clip_model.config.hidden_size 

    def forward(self, x):
        outputs = self.clip_model(**x)
        pooled_output = outputs.pooler_output
        return pooled_output

class ResNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.feature_extractor = nn.Sequential(*list(self.resnet.children())[:-3])
        
    def forward(self, x):
        features = self.feature_extractor(x)
        return features

class BiggerClassifier(torch.nn.Module):
    def __init__(self, output_dim=3):
        super().__init__()
        self.clip = CLIP()
        self.yolo = YOLOv11()
        self.resnet = ResNet()

        self.yolo_pool = torch.nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = torch.nn.Linear(768 + 512, 2048)
        self.activation1 = torch.nn.GELU()
        self.dropout1 = torch.nn.Dropout(0.3)
        self.fc2 = torch.nn.Linear(2048, 1024)
        self.activation2 = torch.nn.GELU()
        self.dropout2 = torch.nn.Dropout(0.3)
        self.fc3 = torch.nn.Linear(1024, output_dim)
        
    def forward(self, clip_inputs, img_tensor):
        clip_features = self.clip(clip_inputs)
        yolo_features = self.yolo(img_tensor)

        yolo_features = self.yolo_pool(yolo_features).flatten(1)

        combined_features = torch.cat([clip_features, yolo_features], dim=1)
        
        x = self.fc1(combined_features)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        
        return x


to_tensor = transforms.Compose([
    transforms.ToImage(), 
    transforms.ToDtype(torch.float32, scale=True),
])

yolo_intermediate_input_size = 700
yolo_final_input_size = 640

base_transform = transforms.Compose([
    to_tensor,
    transforms.Resize(size=yolo_intermediate_input_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
    transforms.RandomRotation(degrees=(-15, 15), interpolation=transforms.InterpolationMode.BILINEAR, expand=True, fill=0),
    transforms.RandomCrop(yolo_final_input_size),
    transforms.RandomHorizontalFlip(p=0.5),
])

# CHANGED: Only keeping weak augmentation (removed strong augmentation)
yolo_weak_transform = transforms.Compose([    
    transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.02),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.005), 
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])
# REMOVED: yolo_strong_transform

clip_base_transform = transforms.Compose([ 
    transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
])

# CHANGED: Only keeping weak augmentation (removed strong augmentation)
clip_weak_transform = transforms.Compose([    
    transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.02),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.005), 
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])
# REMOVED: clip_strong_transform

yolo_val_transform = transforms.Compose([
    to_tensor,
    transforms.Resize(size=yolo_intermediate_input_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
    transforms.CenterCrop(yolo_final_input_size), 
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])

clip_val_transform = transforms.Compose([ 
    to_tensor,
    transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
    transforms.CenterCrop(224),
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])


bad_images = []

# CHANGED: Simplified dataset to only return single augmented version
class SingleModelDataset(Dataset):
    def __init__(self, csv_file, root_dir, val=False, supervised=False, supervised_ratio=0.5, upsample=None):
        self.root_dir = os.path.expanduser(root_dir)
        self.annotations = pd.read_csv(os.path.join(self.root_dir, csv_file))
        if dry_run:
            self.annotations = self.annotations.sample(n=min(dry_run, len(self.annotations)), random_state=42)
            self.annotations = self.annotations.reset_index(drop=True)
        if supervised:
            labeled = self.annotations[self.annotations['label'].isin([0, 2])]
            unlabeled = self.annotations[self.annotations['label'] == 1]
            n_labeled = len(labeled) 
            total_len = int(n_labeled / supervised_ratio)
            target_len = total_len - n_labeled
            if len(unlabeled) > target_len:
                unlabeled = unlabeled.sample(n=target_len)
            else: 
                print(f"not enough unlabeled. asked for {target_len}, only have {len(unlabeled)}")
            self.annotations = pd.concat([labeled, unlabeled], ignore_index=True)
        if upsample:
            labeled = self.annotations[self.annotations['label'].isin([0, 2])].copy()
            unlabeled = self.annotations[self.annotations['label'] == 1].copy()
            n_labeled = len(labeled)
            n_unlabeled = len(unlabeled)
            print(f"Upsampling labeled data: {n_labeled} samples × {upsample} = {n_labeled * upsample}")
            print(f"Unlabeled data: {n_unlabeled} samples")
            labeled_copies = []
            for i in range(upsample):
                labeled_copy = labeled.copy()
                labeled_copy['unique_id'] = labeled_copy['unique_id'].astype(str) + f'_copy{i}'
                labeled_copies.append(labeled_copy)
            labeled_upsampled = pd.concat(labeled_copies, ignore_index=True)
            self.annotations = pd.concat([labeled_upsampled, unlabeled], ignore_index=True)
            self.annotations = self.annotations.sample(frac=1, random_state=42).reset_index(drop=True)
            print(f"Final dataset size: {len(self.annotations)} samples")
            print(f"Labeled ratio: {len(labeled_upsampled) / len(self.annotations):.2%}")
        
        self.transform_times = []
        self.val = val

        self.error_log_path = f'dataset_errors_{"val" if val else "train"}.log'
        
    def __len__(self):
        return len(self.annotations)
    
    def get_id(self, idx):
        return self.annotations.at[idx, 'unique_id']
    
    def __getitem__(self, idx):
        try:
            # 1. process metadata
            original_label = self.annotations.iloc[idx, 1]
            label = torch.tensor(original_label, dtype=torch.long)
            
            img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
            img_path = img_path.replace('\\', '/')
            image_pil = Image.open(img_path).convert('RGB')

            metadata = {
                'img_path': img_path,
                'label': original_label,
                'width': self.annotations.iloc[idx, 2],
                'height': self.annotations.iloc[idx, 3],
                'size_kb': self.annotations.iloc[idx, 4],
                'source': str(self.annotations.iloc[idx, 5]),
            }

            # 2. process images 
            # CHANGED: Only return single augmented version (weak augmentation)
            if not self.val:
                item_transform_start = time.time()
                
                base_image = base_transform(image_pil)
                
                # YOLO branch - only weak augmentation
                yolo_image = yolo_weak_transform(base_image)
                
                # CLIP branch - only weak augmentation
                clip_base_image = clip_base_transform(base_image)
                clip_image = clip_weak_transform(clip_base_image)

                clip_image = CLIP_PROCESSOR(images=clip_image, return_tensors="pt", do_rescale=False)
                clip_image['pixel_values'] = clip_image['pixel_values'].squeeze(0)

                item_transform_time = time.time() - item_transform_start
                if len(self.transform_times) < 1000:
                    self.transform_times.append(item_transform_time)

                # CHANGED: Return single version instead of weak/strong pairs
                return clip_image, yolo_image, label, metadata

            else:  # for validation
                yolo_image = yolo_val_transform(image_pil)

                clip_image = clip_val_transform(image_pil)
                clip_image = CLIP_PROCESSOR(images=clip_image, return_tensors="pt", do_rescale=False)
                clip_image['pixel_values'] = clip_image['pixel_values'].squeeze(0)
                return clip_image, yolo_image, label, metadata
            
        except Exception as e:
            with open(os.path.join(logger.get_log_dir(), self.error_log_path), 'a') as f:
                f.write(f"Error loading image at index {idx}: {e}\n\n")

            dummy_clip_input = {'pixel_values': torch.zeros(3, 224, 224)}
            dummy_yolo_tensor = torch.zeros(3, 640, 640)
            dummy_label = torch.tensor(1, dtype=torch.long)
            dummy_metadata = {
                'img_path': f'error_at_idx_{idx}',
                'label': 1,
                'width': 640,
                'height': 640,
                'size_kb': 0.0,
                'source': -1,
            }
            
            return dummy_clip_input, dummy_yolo_tensor, dummy_label, dummy_metadata

class RandomVersionSampler(Sampler):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        
        self.grouped = {}
        for idx in range(len(base_dataset)):
            img_id = base_dataset.get_id(idx)
            if img_id not in self.grouped:
                self.grouped[img_id] = []
            self.grouped[img_id].append(idx)
        self.ids = list(self.grouped.keys())

    def __iter__(self):
        chosen_indices = [random.choice(self.grouped[img_id]) for img_id in self.ids]
        random.shuffle(chosen_indices)
        return iter(chosen_indices)

    def __len__(self):
        return len(self.ids)


from torch.utils.data._utils.collate import default_collate

# CHANGED: Simplified collate function for single model
def custom_collate_fn(batch):
    clip_inputs = default_collate([item[0] for item in batch])
    yolo_tensors = default_collate([item[1] for item in batch])
    labels = default_collate([item[2] for item in batch])
    
    metadata_list = [item[3] for item in batch]
    
    metadata = {
        'img_path': [m['img_path'] for m in metadata_list],
        'label': [m['label'] for m in metadata_list],
        'width': [m['width'] for m in metadata_list],
        'height': [m['height'] for m in metadata_list],
        'size_kb': [m['size_kb'] for m in metadata_list],
        'source': [m['source'] for m in metadata_list],
    }
    
    return clip_inputs, yolo_tensors, labels, metadata

# CHANGED: Use SingleModelDataset instead of MeanTeacherDataset
supervised_train_dataset = SingleModelDataset(
    csv_file = "train_2.csv", 
    root_dir = "~/Workspace/data-v2/train",
    supervised = True,
)
train_dataset = SingleModelDataset(
    csv_file = "train_2.csv", 
    root_dir = "~/Workspace/data-v2/train",
    val = False,
)
val_dataset = SingleModelDataset(
    csv_file = "val_2.csv", 
    root_dir = "~/Workspace/data-v2/val",
    val = True,
)

supervised_sampler = RandomVersionSampler(supervised_train_dataset)
supervised_train_dataloader = DataLoader(
    supervised_train_dataset, 
    batch_size=config.batch_size, 
    sampler=supervised_sampler,
    collate_fn=custom_collate_fn,
    num_workers=get_num_workers(),
    persistent_workers=False,
    prefetch_factor=2 if get_num_workers() > 0 else None,
)

sampler = RandomVersionSampler(train_dataset)
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    sampler=sampler,
    collate_fn=custom_collate_fn,
    num_workers=get_num_workers(),
    persistent_workers=False,
    prefetch_factor=2 if get_num_workers() > 0 else None,
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_size=config.batch_size, 
    collate_fn=custom_collate_fn,
    num_workers=get_num_workers(),
    persistent_workers=False,
    prefetch_factor=3 if get_num_workers() > 0 else None,
)

len(train_dataset) - len(train_dataloader) * config.batch_size

len(supervised_train_dataloader), len(train_dataloader)



# CHANGED: Only create single model (no teacher model)
model = BiggerClassifier().to(config.device)
# REMOVED: teacher_model creation and EMA setup

clip_params = []
yolo_params = []
resnet_params = []
classifier_params = []

for name, param in model.named_parameters():
    if 'clip' in name:
        clip_params.append(param)
    elif 'yolo' in name:
        yolo_params.append(param)
    elif 'resnet' in name:
        resnet_params.append(param)
    else:
        classifier_params.append(param)

print("Freezing backbone parameters initially...")
for param in clip_params + yolo_params:
    param.requires_grad = False

optimizer = torch.optim.AdamW([
    {'params': clip_params, 'lr': config.lr_backbone*0.1, 'name': 'clip'},
    {'params': yolo_params, 'lr': config.lr_backbone*0.5, 'name': 'yolo'},
    {'params': resnet_params, 'lr': config.lr_backbone, 'name': 'resnet'},
    {'params': classifier_params, 'lr': config.initial_lr, 'name': 'classifier'}
])

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=1)
print('scheduler lr:', scheduler.get_last_lr())


class AsymmetricFocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, confusion_penalty_matrix=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        
        if confusion_penalty_matrix is None:
            confusion_penalty_matrix = torch.tensor([
                [1.0, 1.0, 1.0],
                [1.0, 1.0, 1.0],
                [1.0, 1.0, 1.0]
            ])
        self.confusion_penalty_matrix = confusion_penalty_matrix
        
    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        probs = F.softmax(logits, dim=1)
        p_t = probs.gather(1, targets.view(-1, 1)).squeeze(1)
        
        focal_weight = (1 - p_t) ** self.gamma

        penalty_for_true = self.confusion_penalty_matrix[targets]
        expected_penalty = (probs * penalty_for_true).sum(dim=1)

        loss = focal_weight * ce_loss * expected_penalty

        if self.alpha is not None:
            alpha_t = self.alpha.gather(0, targets)
            loss = alpha_t * loss
            
        return loss.mean()


ce_criterion = nn.CrossEntropyLoss().to(config.device)

af_criterion = AsymmetricFocalLoss(
    gamma=1.2,
    alpha=torch.tensor([1.1, 0.9, 1.2]).to(config.device),  
    confusion_penalty_matrix=torch.tensor([
        [1.0, 1.1, 1.2], 
        [0.85, 1.0, 0.85],
        [1.15, 1.1, 1.0]
    ]).to(config.device)
).to(config.device)

# CHANGED: Simplified PerformanceMonitor (removed consistency tracking)
class PerformanceMonitor:
    def __init__(self):
        self.best_accuracy = 0.0
        self.epochs_without_improvement = 0
        self.accuracy_history = []

monitor = PerformanceMonitor()

# CHANGED: Simplified checkpoint saving (removed teacher model)
def save_checkpoint(model, optimizer, scheduler, epoch, global_step, config, 
                    val_accuracy, monitor, loss_history):

    checkpoint_dir = config.checkpoint_path
    os.makedirs(checkpoint_dir, exist_ok=True)

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    filename = f"checkpoint_epoch{epoch}_step{global_step}_acc{val_accuracy:.4f}_{timestamp}.pth"
    
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    
    checkpoint = {
        'epoch': epoch,
        'global_step': global_step,
        'model_state_dict': model.state_dict(),
        # REMOVED: teacher_state_dict
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_accuracy': val_accuracy,
        'loss_history': loss_history,
        'monitor_state': {
            'best_accuracy': monitor.best_accuracy,
            'epochs_without_improvement': monitor.epochs_without_improvement,
            'accuracy_history': monitor.accuracy_history,
        },
        'config': {k: v for k, v in vars(config).items() if not k.startswith('_')},
        'save_timestamp': datetime.now().isoformat(),
        'pytorch_version': torch.__version__,
    }
    
    try:
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved successfully: {checkpoint_path}")
        print(f"Epoch: {epoch} | Step: {global_step} | Val Acc: {val_accuracy:.4f}")
        return checkpoint_path
    except Exception as e:
        print(f"Error saving checkpoint: {e}")
        raise

# CHANGED: Simplified checkpoint loading (removed teacher model)
def load_checkpoint(checkpoint_path, model, optimizer, scheduler, config, 
                    monitor, device='cuda', strict=True):

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    print(f"Loading checkpoint from: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
    # REMOVED: teacher model loading
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    epoch = checkpoint['epoch']
    global_step = checkpoint['global_step']
    val_accuracy = checkpoint['val_accuracy']
    loss_history = checkpoint['loss_history']
    
    monitor_state = checkpoint['monitor_state']
    monitor.best_accuracy = monitor_state['best_accuracy']
    monitor.accuracy_history = monitor_state['accuracy_history']
    
    saved_config = checkpoint['config']
    for key, value in saved_config.items():
        if hasattr(config, key) and key not in ['checkpoint_path', 'log_dir']:
            setattr(config, key, value)
    
    config.cur_epoch = epoch + 1
    
    print(f"Checkpoint loaded successfully")
    print(f"Resuming from Epoch: {epoch + 1} | Step: {global_step}")
    print(f"Previous Val Acc: {val_accuracy:.4f} | Best Acc: {monitor.best_accuracy:.4f}")
    print(f"Loaded len(loss_history) = {len(loss_history)}")
    
    return epoch, global_step, loss_history, val_accuracy

def validate(model, val_loader, config, epoch):
    torch.cuda.empty_cache()
    model.eval()

    image_data = []
    batch_data = []
    total_loss = 0.0

    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f'Validation Epoch {epoch}', total=len(val_loader), leave=False)

        # CHANGED: Updated to match new data structure
        for batch_num, (clip_inputs, yolo_tensors, labels, metadata) in enumerate(val_pbar):
            clip_inputs['pixel_values'] = clip_inputs['pixel_values'].to(config.device)
            yolo_tensors = yolo_tensors.to(config.device)
            labels = labels.to(config.device)

            outputs = model(clip_inputs, yolo_tensors)
            loss = af_criterion(outputs, labels).cpu().numpy()
            probs = F.softmax(outputs, dim=1).cpu().numpy()
            predictions = outputs.argmax(dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            batch_data.append({
                'batch_num': batch_num,
                'epoch': epoch,
                'loss': loss.item()
            })

            for i in range(len(labels)):
                true_label = labels[i]
                pred_label = predictions[i]

                image_data.append({
                    'batch_num': batch_num,
                    'img_path': metadata['img_path'][i],
                    'label': int(true_label),
                    'width': int(metadata['width'][i]),
                    'height': int(metadata['height'][i]),
                    'size_kb': float(metadata['size_kb'][i]),
                    'source': metadata['source'][i],
                    'prediction': int(pred_label),
                    'bad': float(probs[i, 0]),
                    'neutral': float(probs[i, 1]),
                    'good': float(probs[i, 2])
                })

            val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            total_loss += loss.item()

    image_df = pd.DataFrame(image_data)
    batch_df = pd.DataFrame(batch_data)

    torch.cuda.empty_cache()
    gc.collect()

    return image_df, batch_df, total_loss

def analyse(image_df, batch_df, config, monitor, epoch):
    # Calculate metrics from dataframe
    labeled_df = image_df[image_df['label'] != 1]
    total = len(labeled_df)
    correct = (labeled_df['label'] == labeled_df['prediction']).sum()
    accuracy = correct / total

    # Calculate per-class metrics
    class_names = ['bad', 'neutral', 'good']
    labels=[0, 1, 2]
    all_labels = image_df['label'].values
    all_preds = image_df['prediction'].values

    # Confusion matrix and classification report
    cm = confusion_matrix(y_true=all_labels, y_pred=all_preds, labels=labels)
    report = classification_report(y_true=all_labels, y_pred=all_preds, labels=labels, target_names=class_names,
                                   output_dict=True, zero_division=0)

    # Calculate per-class recall
    class_recall = []
    for class_idx in range(config.num_classes):
        class_mask = image_df['label'] == class_idx
        if class_mask.sum() > 0:
            class_correct = ((image_df['label'] == class_idx) &
                           (image_df['prediction'] == class_idx)).sum()
            class_recall.append(class_correct / class_mask.sum())
        else:
            class_recall.append(0.0)

    # Compile metrics dictionary
    avg_val_loss = batch_df['loss'].mean()

    metrics = {
        "val/accuracy": accuracy,
        "val/loss": avg_val_loss,
        "val/recall_bad": class_recall[0],
        "val/recall_neutral": class_recall[1],
        "val/recall_good": class_recall[2],
    }

    # Add precision and F1 scores from classification report
    for class_name in class_names:
        if class_name in report:
            metrics[f"val/precision_{class_name}"] = report[class_name]['precision']
            metrics[f"val/f1_{class_name}"] = report[class_name]['f1-score']

    # Add severe misclassification counts
    good_as_bad = image_df[(image_df['label'] == 2) & (image_df['prediction'] == 0)]
    bad_as_good = image_df[(image_df['label'] == 0) & (image_df['prediction'] == 2)]
    metrics["val/bad_as_good_count"] = len(bad_as_good)
    metrics["val/good_as_bad_count"] = len(good_as_bad)

    # Update monitor
    monitor.accuracy_history.append(accuracy)
    if accuracy > monitor.best_accuracy:
        monitor.best_accuracy = accuracy
        monitor.epochs_without_improvement = 0
    else:
        monitor.epochs_without_improvement += 1

    print(f"Validation Summary - Epoch {epoch}")
    print(f"Accuracy: {accuracy:.4f} | Loss: {avg_val_loss:.4f}")

    return accuracy, metrics, cm


scaler = GradScaler()

# CHANGED: Simplified epoch analysis (removed consistency loss tracking)
def analyse_epoch(image_df, batch_df, config, epoch):

    avg_loss = batch_df['loss'].mean()
    # REMOVED: avg_cls_loss and avg_consistency_loss tracking
    
    metrics = {
        "epoch/loss": avg_loss,
        # REMOVED: consistency-related metrics
    }
    
    if image_df is not None:
        total_samples = len(image_df)
        correct = (image_df['label'] == image_df['prediction']).sum()
        train_accuracy = correct / total_samples
        
        # Per-class accuracy
        for class_idx in range(config.num_classes):
            class_mask = image_df['label'] == class_idx
            if class_mask.sum() > 0:
                class_correct = ((image_df['label'] == class_idx) & 
                               (image_df['prediction'] == class_idx)).sum()
                class_acc = class_correct / class_mask.sum()
                metrics[f"epoch/accuracy_class_{class_idx}"] = class_acc
        
        metrics["epoch/train_accuracy"] = train_accuracy
        
        # Severe misclassifications
        good_as_bad = image_df[(image_df['label'] == 2) & (image_df['prediction'] == 0)]
        bad_as_good = image_df[(image_df['label'] == 0) & (image_df['prediction'] == 2)]
        metrics["epoch/bad_as_good_count"] = len(bad_as_good)
        metrics["epoch/good_as_bad_count"] = len(good_as_bad)
    
    return metrics

def plot_running_loss(loss_history, save_path, window_size=10):
    """
    Plot running loss with moving average and save to file
    """
    if len(loss_history) < window_size:
        return
    
    # Calculate moving average
    moving_avg = []
    for i in range(window_size - 1, len(loss_history)):
        window = loss_history[i - window_size + 1:i + 1]
        moving_avg.append(sum(window) / window_size)
    
    plt.figure(figsize=(12, 6))
    
    x_raw = [i * config.log_interval for i in range(len(loss_history))]
    x_moving = [i * config.log_interval for i in range(window_size - 1, len(loss_history))]

    # Plot raw loss in light color
    plt.plot(x_raw, loss_history, alpha=0.3, color='blue', label='Raw Loss')
    
    # Plot moving average in bold
    plt.plot(x_moving, moving_avg, color='red', linewidth=2, label=f'Moving Avg (window={window_size})')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Time')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Add statistics
    if moving_avg:
        current_avg = moving_avg[-1]
        min_avg = min(moving_avg)
        plt.axhline(y=current_avg, color='green', linestyle='--', alpha=0.5, label=f'Current: {current_avg:.4f}')
        plt.axhline(y=min_avg, color='orange', linestyle='--', alpha=0.5, label=f'Min: {min_avg:.4f}')
                
    plt.tight_layout()
    plt.savefig(save_path, dpi=100)
    plt.close()
    
    print(f"  Loss graph saved to: {save_path}")


Logging to: ./logs/sanitycheck_20251030_013129
Freezing backbone parameters initially...
scheduler lr: [1.0000000000000002e-06, 5e-06, 1e-05, 5e-05]


In [15]:
config.cur_epoch = 4
config.num_epochs = 5

In [16]:


criterion = ce_criterion
validation_frequency = len(train_dataloader) // 7
validation_frequency = validation_frequency - (validation_frequency % config.log_interval)
validation_frequency

global_step = 0
track_images = False
loss_history = []
# REMOVED: cls_loss_history (no separate classification loss tracking)

# monitor mode collapse
recent_predictions = [] 
max_recent_batches = 15

if global_step >= 500:
    print(f"unfreezing backbone at step {global_step}")
    for param in clip_params + yolo_params:
        param.requires_grad = True

for epoch in range(config.cur_epoch, config.num_epochs):
    collapse_flag = False
    epoch_start = time.time()

    # region Train one epoch ########################################
    model.train()
    # REMOVED: teacher_model.train()
    
    batch_data = []

    if epoch <= 5:
        ratios = [0.66, 0.55, 0.44, 0.33, 0.66]
        # CHANGED: Use SingleModelDataset instead of MeanTeacherDataset
        supervised_train_dataset = SingleModelDataset(
            csv_file = "train_2.csv", 
            root_dir = "~/Workspace/data-v2/train",
            supervised = True,
            supervised_ratio = ratios[epoch],
            # upsample = 3,
        )
        supervised_sampler = RandomVersionSampler(supervised_train_dataset)
        train_loader = DataLoader(
            supervised_train_dataset, 
            batch_size=config.batch_size, 
            sampler=supervised_sampler,
            collate_fn=custom_collate_fn,
            num_workers=get_num_workers(),
            persistent_workers=False,
            prefetch_factor=3 if get_num_workers() > 0 else None,
        )
        print(ratios[epoch], len(train_loader))
    else: 
        train_loader = train_dataloader
    if epoch >= 1:
        criterion = af_criterion
    else:
        criterion = ce_criterion
    
    epoch_pbar = tqdm(train_loader, desc=f'Epoch {epoch}', total=len(train_loader))
    # CHANGED: Updated batch unpacking to match new data structure
    for batch_num, (clip_inputs, yolo_tensors, labels, metadata) in enumerate(epoch_pbar):
        if global_step == 500:
            print(f"unfreezing backbone at step {global_step}")
            for param in clip_params + yolo_params:
                param.requires_grad = True
                
        batch_start = time.time()

        # 1. Grab data from dataloader
        # CHANGED: Only single version of images (no weak/strong split)
        clip_inputs['pixel_values'] = clip_inputs['pixel_values'].to(device, non_blocking=True)
        yolo_tensors = yolo_tensors.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # 2. Forward pass 
        # CHANGED: Single forward pass only (no teacher model)
        with autocast(device_type='cuda'):
            outputs = model(clip_inputs, yolo_tensors)
            loss = criterion(outputs, labels)
            # REMOVED: consistency loss computation
        
        # 3. Backward pass
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        scaler.step(optimizer)
        scaler.update()

        # REMOVED: EMA teacher model update
        
        if global_step % config.log_interval != 0:
            global_step += 1
            continue
        
        # monitor mode collapse
        with torch.no_grad():
            current_predictions = F.softmax(outputs, dim=1).argmax(dim=1).cpu().numpy()
        recent_predictions.append(current_predictions)
        if len(recent_predictions) > max_recent_batches:
            recent_predictions.pop(0)
        all_recent_preds = np.concatenate(recent_predictions)
        unique, counts = np.unique(all_recent_preds, return_counts=True)
        dominant_class = unique[np.argmax(counts)]
        dominant_ratio = counts.max() / len(all_recent_preds)
        if collapse_flag==False and dominant_ratio > 0.95:
            print(f"[WARN] Possible mode collapse at step {global_step}!")
            print(f"       Class {dominant_class} represents {dominant_ratio:.1%} of last {len(all_recent_preds)} predictions")
            print(f"       Distribution: {dict(zip(unique, counts))}")
            collapse_flag = True

        # 5. Calculate metrics #############################################
        batch_time = time.time() - batch_start
        images_per_second = config.batch_size / batch_time
        
        # Move to CPU for storage
        loss_cpu = loss.item()
        # REMOVED: separate cls_loss and consistency_loss tracking
        
        loss_history.append(loss_cpu)

        # CHANGED: Simplified batch data (removed consistency tracking)
        batch_data.append({
            'batch_num': batch_num,
            'epoch': epoch,
            'global_step': global_step,
            'loss': loss_cpu,
        })
        
        # Update progress bar
        # CHANGED: Simplified progress bar (removed consistency info)
        epoch_pbar.set_postfix({
            'loss': f'{loss_cpu:.4f}',
            'img/s': f'{images_per_second:.1f}'
        })
        
        # Log to tensorboard/CSV at intervals
        if global_step % config.log_interval == 0:
            current_lr = optimizer.param_groups[-1]['lr']
            
            # CHANGED: Removed consistency metrics
            train_metrics = {
                "train/loss": loss_cpu,
                "train/learning_rate": current_lr,
                "system/gpu_memory_mb": get_gpu_memory_usage()
            }
            
            logger.log_metrics(train_metrics, global_step)
            logger.log_train_step(global_step, epoch, {
                'loss': loss_cpu,
                'learning_rate': current_lr,
            })

        # 6. Intermittent validation
        if global_step > 0 and global_step % validation_frequency == 0:
            intermittent_epoch = "step" + str(global_step)
            print("Running intermittent validation...")
            # CHANGED: Validate with single model (not teacher)
            val_image_df, val_batch_df, val_loss = validate(model, val_dataloader, config, intermittent_epoch)
            val_accuracy, val_metrics, cm = analyse(val_image_df, val_batch_df, config, monitor, intermittent_epoch)
            scheduler.step(val_accuracy)
            print('scheduler lr:', scheduler.get_last_lr())

            logger.log_validation(intermittent_epoch, val_metrics)
            logger.log_confusion_matrix(cm, ['bad', 'neutral', 'good'], intermittent_epoch)
            
            print(f"  Validation Accuracy: {val_accuracy:.4f}")
            print(f"  Best Accuracy: {monitor.best_accuracy:.4f}")

            if val_accuracy >= monitor.best_accuracy:    
                # CHANGED: Simplified checkpoint saving (removed teacher model and cls_loss_history)
                save_checkpoint(
                    model, optimizer, scheduler, epoch, global_step, 
                    config, val_accuracy, monitor, loss_history
                )
            
            plot_running_loss(loss_history, os.path.join(logger.get_log_dir(), f'loss_graph.png'))
            # REMOVED: cls_loss_history plotting
        global_step += 1

    # Create dataframes
    batch_df = pd.DataFrame(batch_data)

    train_image_df = None
    train_batch_df = batch_df

    # endregion #####################################################
    
    # Analyze epoch
    epoch_metrics = analyse_epoch(train_image_df, train_batch_df, config, epoch)
    logger.log_metrics(epoch_metrics, global_step)
    
    print(f"Epoch {epoch} Summary:")
    print(f"  Time: {(time.time() - epoch_start):.1f}s")
    print(f"  Avg Loss: {epoch_metrics['epoch/loss']:.4f}")
    # REMOVED: Consistency loss printing

    print("="*80)
    
    # Validation
    print("Running end of epoch validation...")
    # CHANGED: Validate with single model (not teacher)
    val_image_df, val_batch_df, val_loss = validate(model, val_dataloader, config, epoch)
    val_accuracy, val_metrics, cm = analyse(val_image_df, val_batch_df, config, monitor, epoch)
    
    logger.log_metrics(val_metrics, epoch)
    logger.log_validation(epoch, val_metrics)
    logger.log_confusion_matrix(cm, ['bad', 'neutral', 'good'], epoch)
    
    print(f"  Validation Accuracy: {val_accuracy:.4f}")
    print(f"  Best Accuracy: {monitor.best_accuracy:.4f}")
    print("="*80)
    print()
    print("="*80)
    
    # CHANGED: Simplified checkpoint saving
    save_checkpoint(
        model, optimizer, scheduler, epoch, global_step, 
        config, val_accuracy, monitor, loss_history
    )
    
    scheduler.step(val_accuracy)
    print('scheduler lr:', scheduler.get_last_lr())

plot_running_loss(loss_history, os.path.join(logger.get_log_dir(), f'loss_graph.png'))
# REMOVED: Separate cls_loss and cons_loss plotting
logger.close()



0.66 13271


Epoch 4:   4%|▍         | 500/13271 [03:32<1:24:51,  2.51it/s, loss=0.2256, img/s=110.6]

unfreezing backbone at step 500


Epoch 4:  31%|███       | 4055/13271 [28:01<1:04:53,  2.37it/s, loss=0.2070, img/s=109.4]

Running intermittent validation...


Epoch 4:  31%|███       | 4056/13271 [31:43<172:05:30, 67.23s/it, loss=0.2070, img/s=109.4]

Validation Summary - Epoch step4055
Accuracy: 0.4600 | Loss: 0.8823
scheduler lr: [2.8247524899999995e-08, 9.886633714999994e-08, 1.9773267429999988e-07, 9.886633714999995e-07]
  Validation Accuracy: 0.4600
  Best Accuracy: 0.5033
  Loss graph saved to: ./logs/sanitycheck_20251030_013129/loss_graph.png


Epoch 4:  61%|██████    | 8110/13271 [59:33<35:01,  2.46it/s, loss=0.5156, img/s=109.6]    

Running intermittent validation...


Epoch 4:  61%|██████    | 8111/13271 [1:03:32<103:22:52, 72.13s/it, loss=0.5156, img/s=109.6]

Validation Summary - Epoch step8110
Accuracy: 0.4554 | Loss: 0.8821
scheduler lr: [2.8247524899999995e-08, 6.920643600499995e-08, 1.384128720099999e-07, 6.920643600499996e-07]
  Validation Accuracy: 0.4554
  Best Accuracy: 0.5033
  Loss graph saved to: ./logs/sanitycheck_20251030_013129/loss_graph.png


Epoch 4:  92%|█████████▏| 12165/13271 [1:31:27<07:12,  2.56it/s, loss=0.1865, img/s=106.3]   

Running intermittent validation...


Epoch 4:  92%|█████████▏| 12166/13271 [1:35:16<21:12:38, 69.10s/it, loss=0.1865, img/s=106.3]

Validation Summary - Epoch step12165
Accuracy: 0.4608 | Loss: 0.8802
scheduler lr: [2.8247524899999995e-08, 6.920643600499995e-08, 1.384128720099999e-07, 6.920643600499996e-07]
  Validation Accuracy: 0.4608
  Best Accuracy: 0.5033
  Loss graph saved to: ./logs/sanitycheck_20251030_013129/loss_graph.png


Epoch 4: 100%|██████████| 13271/13271 [1:42:48<00:00,  2.15it/s, loss=0.1953, img/s=362.8]   


Epoch 4 Summary:
  Time: 6170.7s
  Avg Loss: 0.2208
Running end of epoch validation...


                                                                                   

Validation Summary - Epoch 4
Accuracy: 0.4527 | Loss: 0.8883
  Validation Accuracy: 0.4527
  Best Accuracy: 0.5033

Checkpoint saved successfully: ./checkpoints/sanitycheck_20251030_013129/checkpoint_epoch4_step13271_acc0.4527_20251030_160832.pth
Epoch: 4 | Step: 13271 | Val Acc: 0.4527
scheduler lr: [2.8247524899999995e-08, 4.844450520349996e-08, 9.688901040699992e-08, 4.844450520349997e-07]
  Loss graph saved to: ./logs/sanitycheck_20251030_013129/loss_graph.png


In [17]:
val_image_df, val_batch_df, val_loss = validate(model, val_dataloader, config, -1)
val_accuracy, val_metrics, cm = analyse(val_image_df, val_batch_df, config, monitor, -1)
cm

                                                                                    

Validation Summary - Epoch -1
Accuracy: 0.4527 | Loss: 0.8883


array([[ 5111,  5563,   733],
       [ 6436, 15358,  1354],
       [  511,  2090,  2248]])

In [18]:


# good_as_bad = val_image_df[(val_image_df['label'] == 2) & (val_image_df['prediction'] == 0)]
# bad_as_good = val_image_df[(val_image_df['label'] == 0) & (val_image_df['prediction'] == 2)]

# for s in bad_as_good["img_path"]:
#     print(s)



In [19]:
test_dataset = SingleModelDataset(
    csv_file = "test_2.csv", 
    root_dir = "~/Workspace/data-v2/test",
    val = True,
)
test_dataloader = DataLoader(
    test_dataset, 
    batch_size=config.batch_size, 
    collate_fn=custom_collate_fn,
    num_workers=get_num_workers(),
    persistent_workers=False, # True if get_num_workers() > 0 else False,
    # pin_memory=False, # WSL does not support pin_memory well
    prefetch_factor=3 if get_num_workers() > 0 else None,
)
test_image_df, test_batch_df, loss = validate(model, test_dataloader, config, -1)
test_accuracy, test_metrics, cm = analyse(test_image_df, test_batch_df, config, monitor, -1)
cm

                                                                                     

Validation Summary - Epoch -1
Accuracy: 0.3584 | Loss: 1.4201


array([[12886, 13842,   786],
       [16691, 38291,  3604],
       [ 2390, 12738,  3736]])

In [20]:
# good_as_bad = test_image_df[(test_image_df['label'] == 2) & (test_image_df['prediction'] == 0)]
# bad_as_good = test_image_df[(test_image_df['label'] == 0) & (test_image_df['prediction'] == 2)]

# for s in good_as_bad["img_path"]:
#     print(s)

In [21]:
save_checkpoint(
                    model, optimizer, scheduler, 5, global_step, 
                    config, val_accuracy, monitor, loss_history
                )

Checkpoint saved successfully: ./checkpoints/sanitycheck_20251030_013129/checkpoint_epoch5_step13271_acc0.4527_20251030_170239.pth
Epoch: 5 | Step: 13271 | Val Acc: 0.4527


'./checkpoints/sanitycheck_20251030_013129/checkpoint_epoch5_step13271_acc0.4527_20251030_170239.pth'