In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics import YOLO 
from transformers import CLIPVisionModel, CLIPImageProcessor
from torch.utils.data import Dataset, DataLoader
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from models import *
from custom_logging import *
from mean_teacher import *


In [17]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [18]:
def get_gpu_memory_usage():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024 / 1024  # MB
    return 0

In [None]:
class Config:
    device = device
    num_classes = 3
    initial_lr = 1e-4
    lr_backbone = 3e-5
    batch_size = 32
    num_epochs = 4
    freeze_until_epoch = 0
    checkpoint_path = "./checkpoints"
    log_interval = 10
    early_stopping_patience = 3
    use_amp = True
    consistency_weight = 0.5  
    ema_decay = 0.99
    warmup_steps = 10000 # linear scaling (batches) from 0 to consistency_weight for consistency loss
    
    # Create directories
    os.makedirs(checkpoint_path, exist_ok=True)

config = Config()

In [20]:
logger = Logger(log_dir="./logs", experiment_name=f"clip_yolo_mean_teacher_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
logger.log_config(config)

Logging to: ./logs/clip_yolo_mean_teacher_20250915_005919


# MODELS

In [None]:
class YOLOv11(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = YOLO("yolov11l-face.pt").model
        # self.backbone = torch.nn.Sequential(*list(self.model.model.children())[:7])  # Stops after C3k2 (layer 6)
        self.feature_model = torch.nn.Sequential(*list(self.model.model.children())[:10])  # Stops after SPPF (layer 9)
        
    def forward(self, x):
        return self.feature_model(x)

# model = YOLOv11().to(device)
# features = model(images) 
# features.shape # torch.Size([B, 512, 20, 20])

In [22]:
CLIP_PROCESSOR = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

class CLIP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        # self.clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
        # CLIP's final hidden state before projection (not the projection itself)
        self.clip_output_dim = self.clip_model.config.hidden_size 

    def forward(self, x):
        # inputs = self.clip_processor(images=x, return_tensors="pt").to(device)
        # outputs = self.clip_model(**inputs)
        outputs = self.clip_model(**x)
        pooled_output = outputs.pooler_output  # shape: [batch_size, 512]
        return pooled_output

# model = CLIP().to(device)
# images = [Image.open("image.jpg"), Image.open("image.jpg")]
# outputs = model(Image.open("image.jpg")) 
# outputs.shape # torch.Size([B, 768])

In [23]:
class BiggerClassifier(torch.nn.Module):
    def __init__(self, output_dim=3):
        super().__init__()
        self.clip = CLIP() # CLIP outputs: [B, 768]
        self.yolo = YOLOv11() # YOLO outputs: [B, 512, 20, 20]
        
        # Global average pooling for YOLO features
        self.yolo_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        
        self.yolo_gate = nn.Sequential(
            nn.Linear(512, 64),
            nn.GELU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        self.fc1 = torch.nn.Linear(768 + 512, 1024)  # Input dim: 768 (CLIP) + 512 (YOLO after pooling) = 1280
        self.activation1 = torch.nn.GELU()
        self.fc2 = torch.nn.Linear(1024, 1024)
        self.activation2 = torch.nn.GELU()
        self.fc3 = torch.nn.Linear(1024, output_dim)
        
    def forward(self, clip_inputs, img_tensor):
        clip_features = self.clip(clip_inputs)  # [B, 768]

        yolo_features = self.yolo(img_tensor)  # [B, 512, 20, 20]
        # Pool YOLO features to [B, 512, 1, 1] then squeeze to [B, 512]
        yolo_features = self.yolo_pool(yolo_features).squeeze(-1).squeeze(-1)
        # Now learnable self gate 
        gate_value = self.yolo_gate(yolo_features)
        gated_yolo_features = yolo_features * gate_value

        combined_features = torch.cat([clip_features, gated_yolo_features], dim=1)  # [B, 1280]
        
        x = self.fc1(combined_features)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        
        return x


# DATA

In [24]:
to_tensor = transforms.Compose([
    transforms.ToImage(), 
    transforms.ToDtype(torch.float32, scale=True),
])

In [25]:
yolo_intermediate_input_size = 700
yolo_final_input_size = 640

yolo_weak_transform = transforms.Compose([    
    transforms.Resize(size=yolo_intermediate_input_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # Resize maintaining aspect ratio, then pad to square
    # transforms.RandomRotation(degrees=(-5, 5), interpolation=transforms.InterpolationMode.BILINEAR, expand=True, fill=0),
    transforms.RandomCrop(yolo_final_input_size), 
    # transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.02),
    # transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.005), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])
yolo_strong_transform = transforms.Compose([    
    transforms.Resize(size=yolo_intermediate_input_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # Resize maintaining aspect ratio, then pad to square
    # transforms.RandomRotation(degrees=(-15, 15), interpolation=transforms.InterpolationMode.BILINEAR, expand=True, fill=0),
    transforms.RandomCrop(yolo_final_input_size), 
    # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    # transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.01), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
])
yolo_val_transform = transforms.Compose([
    transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True),
    transforms.Resize(size=yolo_intermediate_input_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # Resize maintaining aspect ratio, then pad to square
    transforms.CenterCrop(yolo_final_input_size), 
])

In [26]:
# since I can't (don't want to) figure out how to do CLIP's augmentation on GPU, we simply forego the augmentation
clip_val_transform = transforms.Compose([ 
    transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
    transforms.CenterCrop(224),
])

In [None]:
def custom_collate_fn(batch):
    """
    Custom collate function to handle PIL images and tensors
    For CLIP, expects PIL images, batches them, then puts them in CLIP_PROCESSOR for batch processing. 
    """
    transform_start = time.time()
    
    # For training dataset (5 items per sample)
    if len(batch[0]) == 5:
        clip_images = [item[0] for item in batch]
        clip_inputs = CLIP_PROCESSOR(images=clip_images, return_tensors="pt")

        yolo_weak_tensors = torch.stack([item[2] for item in batch])
        yolo_strong_tensors = torch.stack([item[3] for item in batch])

        labels = torch.stack([item[4] for item in batch])
        
        ### CHANGE: Calculate and store transform time
        transform_time = time.time() - transform_start
        
        # Add transform_time as an additional element in the return tuple
        return clip_inputs, clip_inputs, yolo_weak_tensors, yolo_strong_tensors, labels, transform_time

    
    # For validation dataset (3 items per sample)
    elif len(batch[0]) == 3:
        clip_images = [item[0] for item in batch]
        clip_inputs = CLIP_PROCESSOR(images=clip_images, return_tensors="pt")

        yolo_tensors = torch.stack([item[1] for item in batch])

        labels = torch.stack([item[2] for item in batch])
        
        transform_time = time.time() - transform_start
        
        return clip_inputs, yolo_tensors, labels, transform_time
    
    else:
        raise ValueError(f"Unexpected batch format with {len(batch[0])} items per sample")


In [28]:
class TrainImageDataset(Dataset):
    def __init__(self, csv_file, root_dir, 
                 yolo_weak_transform=None, yolo_strong_transform=None,
                 clip_transform=None,
                 device=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(os.path.join(self.root_dir, csv_file))
        self.yolo_weak_transform = yolo_weak_transform
        self.yolo_strong_transform = yolo_strong_transform
        self.clip_transform = clip_transform
        self.device = device # tensor transforms will occur on device, currently used for YOLO transforms
        self.clip_processor = CLIP_PROCESSOR
        self.transform_times = []
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        try:
            img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
            img_path = img_path.replace('\\', '/')
            image_pil = Image.open(img_path).convert('RGB')

            item_transform_start = time.time()
            
            clip_image = self.clip_transform(image_pil) if self.clip_transform else image_pil

            # with torch.no_grad():
            #     image_tensor = to_tensor(image_pil).to(self.device)
            #     yolo_weak_image = self.yolo_weak_transform(image_tensor) if self.yolo_weak_transform else image_tensor
            #     yolo_strong_image = self.yolo_strong_transform(image_tensor) if self.yolo_strong_transform else image_tensor
            image_tensor = to_tensor(image_pil)
            yolo_weak_image = self.yolo_weak_transform(image_tensor) if self.yolo_weak_transform else image_tensor
            yolo_strong_image = self.yolo_strong_transform(image_tensor) if self.yolo_strong_transform else image_tensor
            

            item_transform_time = time.time() - item_transform_start
            if len(self.transform_times) < 1000:  # Keep last 1000 for memory efficiency
                self.transform_times.append(item_transform_time)
                
            original_label = self.annotations.iloc[idx, 1]
            label = torch.tensor(original_label + 1, dtype=torch.long)
            
            return clip_image, clip_image, yolo_weak_image, yolo_strong_image, label

            
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            dummy_clip_tensor = torch.zeros(3, 224, 224)
            dummy_yolo_tensor = torch.zeros(3, 640, 640).to(self.device)
            return dummy_clip_tensor, dummy_clip_tensor, dummy_yolo_tensor, dummy_yolo_tensor, torch.tensor(1, dtype=torch.long)


train_dataset = TrainImageDataset(
    csv_file="train.csv", 
    # root_dir=".\\data\\train",
    root_dir="./data/train",
    yolo_weak_transform=yolo_weak_transform,
    yolo_strong_transform=yolo_strong_transform,
    clip_transform=clip_val_transform,
    device=device
)

train_loader = DataLoader(
    dataset=train_dataset, 
    batch_size=config.batch_size, 
    shuffle=True, 
    collate_fn=custom_collate_fn,
    num_workers=12,
)

len(train_loader)

7820

In [29]:
class ValidationImageDataset(Dataset):
    def __init__(self, csv_file, root_dir, 
                 clip_transform=None, yolo_transform=None,
                 device=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(os.path.join(self.root_dir, csv_file))
        self.clip_transform = clip_transform
        self.yolo_transform = yolo_transform
        self.clip_processor = CLIP_PROCESSOR
        self.device = device
        self.transform_times = []

        # Pre-compute file paths for faster access
        self.file_paths = [os.path.join(self.root_dir, self.annotations.iloc[i, 0])
                          for i in range(len(self.annotations))]
        self.labels = [self.annotations.iloc[i, 1] for i in range(len(self.annotations))]
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        try:
            img_path = self.file_paths[idx]
            img_path = img_path.replace('\\', '/')
            image_pil = Image.open(img_path).convert('RGB')
            
            item_transform_start = time.time()
            
            clip_image = self.clip_transform(image_pil) if self.clip_transform else image_pil

            # with torch.no_grad():
            #     image_tensor = to_tensor(image_pil).to(self.device)
            #     yolo_image = self.yolo_transform(image_tensor) if self.yolo_transform else image_tensor
            image_tensor = to_tensor(image_pil)
            yolo_image = self.yolo_transform(image_tensor) if self.yolo_transform else image_tensor
            
            item_transform_time = time.time() - item_transform_start
            if len(self.transform_times) < 1000:
                self.transform_times.append(item_transform_time)
                
            # Get label (adding 1 to correspond to indices, 0=BAD, 1=UNLABELED, 2=GOOD)
            original_label = self.labels[idx]
            label = torch.tensor(original_label + 1, dtype=torch.long)
            
            return clip_image, yolo_image, label

        except Exception as e:
            print(f"Error in val dataset loading image {self.file_paths[idx]}: {e}")
            # Return dummy data in case of error
            dummy_clip_tensor = torch.zeros(3, 224, 224)
            dummy_yolo_tensor = torch.zeros(3, 640, 640)
            return dummy_clip_tensor, dummy_yolo_tensor, torch.tensor(1, dtype=torch.long)
        
val_dataset = ValidationImageDataset(
    csv_file="val.csv", 
    root_dir="./data/val", 
    # root_dir=".\\data\\val", 
    clip_transform=clip_val_transform,
    yolo_transform=yolo_val_transform
)

val_loader = DataLoader(
    dataset=val_dataset, 
    batch_size=32, 
    shuffle=True, 
    collate_fn=custom_collate_fn,
    num_workers=12,
)

len(val_loader)

978

In [30]:
for batch_data in train_loader:
    if len(batch_data) == 6:  # Training has 6 elements now
        a, b, c, d, e, transform_time = batch_data
        print(a['pixel_values'].shape, b['pixel_values'].shape)
        print(c.shape, d.shape)
        print(e.shape)
        print(f"single transform time in collate: {transform_time:.4f}s")
    break

torch.Size([32, 3, 224, 224]) torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 640, 640]) torch.Size([32, 3, 640, 640])
torch.Size([32])
single transform time in collate: 0.4663s


# TRAINING

In [31]:
model = BiggerClassifier().to(config.device)
teacher_model = copy.deepcopy(model).to(config.device)
teacher_model.eval()
for p in teacher_model.parameters():
    p.requires_grad = False

In [32]:
class AsymmetricFocalLoss(nn.Module):
    """
    Asymmetric Focal Loss variant that combines focal loss with asymmetric penalties.
    Useful when you also want to handle class imbalance.
    currently unused, since suspected to be unstable. 
    """
    def __init__(self, gamma=2.0, alpha=None, confusion_penalty_matrix=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        
        if confusion_penalty_matrix is None:
            confusion_penalty_matrix = torch.tensor([
                [1.0, 1.0, 1.0],   # True: BAD 
                [1.0, 1.0, 1.0],   # True: NEUTRAL
                [1.0, 1.0, 1.0]    # True: GOOD
            ])
        self.confusion_penalty_matrix = confusion_penalty_matrix
        
    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        probs = F.softmax(logits, dim=1)
        p_t = probs.gather(1, targets.view(-1, 1)).squeeze(1)
        
        # print(f"p_t min: {p_t.min():.6f}, max: {p_t.max():.6f}")
        # print(f"ce_loss min: {ce_loss.min():.6f}, max: {ce_loss.max():.6f}")
        
        # Focal term
        focal_weight = (1 - p_t) ** self.gamma
        
        # Get predicted classes for confusion penalties
        pred_classes = torch.argmax(logits, dim=1)
        
        # Apply confusion-based penalties
        batch_size = targets.size(0)
        penalties = torch.zeros(batch_size, device=targets.device)
        
        # for i in range(batch_size):
        #     true_class = targets[i].item()
        #     pred_class = pred_classes[i].item()
        #     penalties[i] = self.confusion_penalty_matrix[true_class, pred_class]
        penalties = self.confusion_penalty_matrix[targets, pred_classes]
        
        # Combine focal weight with confusion penalties
        loss = focal_weight * ce_loss * penalties
        
        if self.alpha is not None:
            alpha_t = self.alpha.gather(0, targets)
            loss = alpha_t * loss
            
        return loss.mean()


In [33]:
clip_params = []
yolo_params = []
classifier_params = []

for name, param in model.named_parameters():
    if 'clip' in name:
        clip_params.append(param)
    elif 'yolo' in name:
        yolo_params.append(param)
    else:
        classifier_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': clip_params, 'lr': config.lr_backbone*0.1, 'weight_decay': 1e-4}, # CLIP is no longer getting data augmentations, so lowering LR
    {'params': yolo_params, 'lr': config.lr_backbone, 'weight_decay': 1e-4},
    {'params': classifier_params, 'lr': config.initial_lr, 'weight_decay': 1e-4},
])

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=1)
criterion = nn.CrossEntropyLoss().to(config.device)
scaler = GradScaler() if config.use_amp else None
monitor = PerformanceMonitor()

criterion = AsymmetricFocalLoss(
    gamma=1.5,  # Reduced from 2.0 for less aggressive focusing
    alpha=torch.tensor([1.2, 0.8, 1.2]).to(config.device),  
    confusion_penalty_matrix=torch.tensor([
        [1.0, 1.05, 1.15], 
        [0.95, 1.0, 0.95],
        [1.1, 1.05, 1.0]  
    ]).to(config.device)
).to(config.device)


In [34]:
def load_checkpoint(checkpoint_path, config, device='cuda'):
    # check this is consistent with everything else later
    print(f"Loading checkpoint from: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    if 'config' in checkpoint:
        checkpoint_config = checkpoint['config']
        print("Checkpoint configuration:")
        for key, value in checkpoint_config.items():
            print(f"  {key}: {value}")
    
    model = BiggerClassifier().to(device)
    teacher_model = BiggerClassifier().to(device)
    
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Loaded student model weights")
    
    if 'teacher_state_dict' in checkpoint:
        teacher_model.load_state_dict(checkpoint['teacher_state_dict'])
        print("Loaded teacher model weights")
    
    for p in teacher_model.parameters():
        p.requires_grad = False
    
    clip_params = []
    yolo_params = []
    classifier_params = []
    
    for name, param in model.named_parameters():
        if 'clip' in name:
            clip_params.append(param)
        elif 'yolo' in name:
            yolo_params.append(param)
        else:
            classifier_params.append(param)
    
    optimizer = torch.optim.AdamW([
        {'params': clip_params, 'lr': config.lr_backbone, 'weight_decay': 1e-4},
        {'params': yolo_params, 'lr': config.lr_backbone, 'weight_decay': 1e-4},
        {'params': classifier_params, 'lr': config.initial_lr, 'weight_decay': 1e-4},
    ])
    
    if 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("Loaded optimizer state")
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=1)
    
    if 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        print("Loaded scheduler state")
    
    start_epoch = checkpoint.get('epoch', -1) + 1
    best_accuracy = checkpoint.get('val_accuracy', 0)
    
    print(f"\nCheckpoint info:")
    print(f"  Saved at epoch: {checkpoint.get('epoch', 'unknown')}")
    print(f"  Best validation accuracy: {best_accuracy:.4f}")
    print(f"  Will resume from epoch: {start_epoch}")
    
    return {
        'model': model,
        'teacher_model': teacher_model,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'start_epoch': start_epoch,
        'best_accuracy': best_accuracy,
        'checkpoint': checkpoint
    }

In [35]:
def validate(model, val_loader, config, monitor, epoch):
    model.eval()
    total, correct = 0, 0
    class_correct = [0] * config.num_classes
    class_total = [0] * config.num_classes
    all_preds, all_labels = [], []
    val_loss = 0
    probs_bad = []
    probs_neutral = []
    probs_good = []

    total_batches = len(val_loader)
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f'Validation Epoch {epoch}', total=total_batches, leave=False)
        for batch_data in val_pbar:
            clip_inputs, yolo_tensors, labels, transform_time = batch_data
            clip_inputs['pixel_values'] = clip_inputs['pixel_values'].to(config.device)
            yolo_tensors = yolo_tensors.to(config.device)
            labels = labels.to(config.device)
            
            outputs = model(clip_inputs, yolo_tensors)
            loss = F.cross_entropy(outputs, labels)
            val_loss += loss.item()

            softmax_probs = F.softmax(outputs, dim=1)
            probs_bad.extend(softmax_probs[:, 0].cpu().numpy())
            probs_neutral.extend(softmax_probs[:, 1].cpu().numpy())
            probs_good.extend(softmax_probs[:, 2].cpu().numpy())
            
            preds = outputs.argmax(dim=1)
            preds_cpu = preds.cpu().numpy()
            labels_cpu = labels.cpu().numpy()
            all_preds.extend(preds_cpu)
            all_labels.extend(labels_cpu)
            
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            
            for i, label in enumerate(labels_cpu):
                class_total[label] += 1
                class_correct[label] += (preds_cpu[i] == label).item()
            
            val_pbar.set_postfix({'loss': f'{loss.item():.4f}', 
                                  'acc': f'{correct/total:.4f}'})
    
    ### HISTOGRAMS FOR EACH CLASS 
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Plot histogram for BAD class probabilities
    axes[0].hist(probs_bad, bins=50, alpha=0.7, color='red', edgecolor='black')
    axes[0].set_title('BAD Class Probabilities')
    axes[0].set_xlabel('Softmax Probability')
    axes[0].set_ylabel('Frequency')
    axes[0].set_xlim([0, 1])
    axes[0].grid(True, alpha=0.3)
    
    # Plot histogram for NEUTRAL class probabilities
    axes[1].hist(probs_neutral, bins=50, alpha=0.7, color='gray', edgecolor='black')
    axes[1].set_title('NEUTRAL Class Probabilities')
    axes[1].set_xlabel('Softmax Probability')
    axes[1].set_ylabel('Frequency')
    axes[1].set_xlim([0, 1])
    axes[1].grid(True, alpha=0.3)
    
    # Plot histogram for GOOD class probabilities
    axes[2].hist(probs_good, bins=50, alpha=0.7, color='green', edgecolor='black')
    axes[2].set_title('GOOD Class Probabilities')
    axes[2].set_xlabel('Softmax Probability')
    axes[2].set_ylabel('Frequency')
    axes[2].set_xlim([0, 1])
    axes[2].grid(True, alpha=0.3)
    
    plt.suptitle(f'Softmax Probability Distributions - Epoch {epoch}')
    plt.tight_layout()
    histogram_path = os.path.join(logger.get_log_dir(), f'softmax_histograms_epoch_{epoch}.png')
    plt.savefig(histogram_path, dpi=100)
    plt.close()
    ### END OF HISTOGRAMS FOR EACH CLASS 

    accuracy = correct / total 
    avg_val_loss = val_loss / total_batches
    class_names = ['bad', 'neutral', 'good']
    
    class_recall = [class_correct[i] / class_total[i] if class_total[i] else 0 for i in range(config.num_classes)]
    
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    
    metrics = {
        "val/accuracy": accuracy,
        "val/loss": avg_val_loss,
        "val/recall_bad": class_recall[0],
        "val/recall_neutral": class_recall[1],
        "val/recall_good": class_recall[2],
    }
    
    for class_name in class_names:
        if class_name in report:
            metrics[f"val/precision_{class_name}"] = report[class_name]['precision']
            metrics[f"val/f1_{class_name}"] = report[class_name]['f1-score']
    
    monitor.accuracy_history.append(accuracy)
    
    if accuracy > monitor.best_accuracy:
        monitor.best_accuracy = accuracy
        monitor.epochs_without_improvement = 0
    else:
        monitor.epochs_without_improvement += 1
    
    model.train()
    return accuracy, metrics, cm

In [36]:
def plot_running_loss(loss_history, save_path, window_size=50):
    """
    Plot running loss with moving average and save to file
    
    Args:
        loss_history: List of loss values
        save_path: Path to save the plot
        window_size: Window size for moving average
    """
    if len(loss_history) < window_size:
        return
    
    # Calculate moving average
    moving_avg = []
    for i in range(window_size - 1, len(loss_history)):
        window = loss_history[i - window_size + 1:i + 1]
        moving_avg.append(sum(window) / window_size)
    
    plt.figure(figsize=(12, 6))
    
    # Plot raw loss in light color
    plt.plot(loss_history, alpha=0.3, color='blue', label='Raw Loss')
    
    # Plot moving average in bold
    x_moving = list(range(window_size - 1, len(loss_history)))
    plt.plot(x_moving, moving_avg, color='red', linewidth=2, label=f'Moving Avg (window={window_size})')
    
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Time')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Add statistics
    if moving_avg:
        current_avg = moving_avg[-1]
        min_avg = min(moving_avg)
        plt.axhline(y=current_avg, color='green', linestyle='--', alpha=0.5, label=f'Current: {current_avg:.4f}')
        plt.axhline(y=min_avg, color='orange', linestyle='--', alpha=0.5, label=f'Min: {min_avg:.4f}')
    
    if 'train_loader' in globals() and len(train_loader) > 0:
        batches_per_epoch = len(train_loader)
        num_epochs = len(loss_history) // batches_per_epoch
        for epoch in range(1, num_epochs + 1):
            epoch_batch = epoch * batches_per_epoch
            if epoch_batch < len(loss_history):
                plt.axvline(x=epoch_batch, color='gray', linestyle=':', alpha=0.5)
                plt.text(epoch_batch, plt.ylim()[1] * 0.95, f'Epoch {epoch}', 
                        rotation=90, verticalalignment='top', fontsize=8, alpha=0.7)
                
    plt.tight_layout()
    plt.savefig(save_path, dpi=100)
    plt.close()
    
    # print(f"  Loss graph saved to: {save_path}")

In [37]:
print("\nStarting training...")
print(f"Tensorboard logs: tensorboard --logdir {logger.get_log_dir()}")
global_step = 0
best_accuracy = 0 
data_loading_times = []
gpu_compute_times = []
transform_times = []
start_epoch = 0

loss_history = []

batches_per_half_epoch = len(train_loader) // 2


# Uncomment to load from checkpoint
# loaded = load_checkpoint("checkpoints/best_model_acc0.7807.pth", config)
# model = loaded['model']
# teacher_model = loaded['teacher_model']
# optimizer = loaded['optimizer']
# scheduler = loaded['scheduler']
# start_epoch = loaded['start_epoch']
# best_accuracy = loaded['best_accuracy']
# monitor.best_accuracy = best_accuracy
# global_step = start_epoch * len(train_loader)



Starting training...
Tensorboard logs: tensorboard --logdir ./logs/clip_yolo_mean_teacher_20250915_005919


In [38]:
for epoch in range(start_epoch, config.num_epochs):
    if epoch == config.freeze_until_epoch:
        print(f"Unfreezing backbone at epoch {epoch}")
        for param in clip_params + yolo_params:
            param.requires_grad = True
    
    model.train()
    epoch_start = time.time()
    epoch_loss = 0
    epoch_cls_loss = 0
    epoch_consistency_loss = 0

    epoch_pbar = tqdm(train_loader, desc=f'Epoch {epoch}', total=len(train_loader))
    
    for batch_idx, batch_data in enumerate(epoch_pbar):
        batch_start = time.time()
        
        clip_weak, clip_strong, yolo_weak, yolo_strong, labels, collate_transform_time = batch_data
        clip_weak['pixel_values'] = clip_weak['pixel_values'].to(config.device)
        clip_strong['pixel_values'] = clip_strong['pixel_values'].to(config.device) 
        yolo_weak = yolo_weak.to(config.device)
        yolo_strong = yolo_strong.to(config.device)
        labels = labels.to(config.device)
        
        data_load_time = time.time() - batch_start

        dataset_transform_time = 0
        if hasattr(train_dataset, 'transform_times') and train_dataset.transform_times:
            # Get recent average from dataset
            recent_transforms = train_dataset.transform_times[-100:]
            dataset_transform_time = np.mean(recent_transforms) * config.batch_size  # Scale by batch size
        
        total_transform_time = dataset_transform_time + collate_transform_time

        compute_start = time.time()
        
        if config.use_amp:
            with autocast(device_type='cuda'):

                student_outputs = model(clip_strong, yolo_strong)
                cls_loss = criterion(student_outputs, labels)
                
                with torch.no_grad():
                    teacher_outputs = teacher_model(clip_weak, yolo_weak)
                
                # Compute per-sample consistency losses
                consistency_losses = compute_consistency_loss(student_outputs, teacher_outputs)
                
                # Create per-image consistency weights (can be customized per image if needed)
                warmup_factor = min(1.0, global_step / config.warmup_steps)
                base_weight = config.consistency_weight * warmup_factor
                consistency_weights = torch.zeros_like(consistency_losses)
                unlabeled_mask = (labels == 1)
                labeled_mask = ~unlabeled_mask
                consistency_weights[unlabeled_mask] = base_weight * 2.0
                consistency_weights[labeled_mask] = base_weight / 2.0
                # consistency_weights = torch.full_like(consistency_losses, config.consistency_weight * warmup_factor)
                
                # Apply per-image weights and compute mean
                weighted_consistency = (consistency_losses * consistency_weights).mean()
                loss = cls_loss + weighted_consistency
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            student_outputs = model(clip_strong, yolo_strong)
            cls_loss = criterion(student_outputs, labels)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(clip_weak, yolo_weak)
            
            # Compute per-sample consistency losses
            consistency_losses = compute_consistency_loss(student_outputs, teacher_outputs)
            
            # Create per-image consistency weights (can be customized per image if needed)
            warmup_factor = min(1.0, global_step / config.warmup_steps)
            base_weight = config.consistency_weight * warmup_factor
            consistency_weights = torch.zeros_like(consistency_losses)
            unlabeled_mask = (labels == 1)
            labeled_mask = ~unlabeled_mask
            consistency_weights[unlabeled_mask] = base_weight * 2.0
            consistency_weights[labeled_mask] = base_weight / 2.0
            # consistency_weights = torch.full_like(consistency_losses, config.consistency_weight * warmup_factor)
            
            # Apply per-image weights and compute mean
            weighted_consistency = (consistency_losses * consistency_weights).mean()
            loss = cls_loss + weighted_consistency
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        update_ema_variables(model, teacher_model, alpha=config.ema_decay, global_step=global_step)
        
        gpu_compute_time = time.time() - compute_start
        total_time = time.time() - batch_start
        
        if len(data_loading_times) < 1000:
            data_loading_times.append(data_load_time)
            gpu_compute_times.append(gpu_compute_time)
            transform_times.append(total_transform_time)
        
        loss_history.append(loss.item())

        epoch_loss += loss.item()
        epoch_cls_loss += cls_loss.item()
        epoch_consistency_loss += weighted_consistency.item()
        
        images_per_second = config.batch_size / total_time
        
        epoch_pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'img/s': f'{images_per_second:.1f}',
            'data_ms': f'{data_load_time*1000:.1f}',
            'trans_ms': f'{total_transform_time*1000:.1f}',
            'gpu_ms': f'{gpu_compute_time*1000:.1f}'
        })

        if global_step > 0 and global_step % 1000 == 0:
            loss_graph_path = os.path.join(logger.get_log_dir(), f'loss_graph.png')
            plot_running_loss(loss_history, loss_graph_path)
        
        if global_step % config.log_interval == 0:
            avg_data_time = np.mean(data_loading_times[-100:]) if data_loading_times else 0
            avg_gpu_time = np.mean(gpu_compute_times[-100:]) if gpu_compute_times else 0
            avg_transform_time = np.mean(transform_times[-100:]) if transform_times else 0
            current_lr = optimizer.param_groups[0]['lr']
            
            train_metrics = {
                "train/loss": loss.item(),
                "train/cls_loss": cls_loss.item(),
                "train/consistency": weighted_consistency.item(),
                "train/consistency_weight": consistency_weights[0].item(),  # Log first sample's weight as example
                "train/learning_rate": current_lr,
                "train/images_per_second": images_per_second,
                "timing/data_load_ms": avg_data_time * 1000,
                "timing/transform_ms": avg_transform_time * 1000,
                "timing/gpu_compute_ms": avg_gpu_time * 1000,
                "timing/data_load_percentage": (avg_data_time / (avg_data_time + avg_gpu_time + avg_transform_time)) * 100 if (avg_data_time + avg_gpu_time + avg_transform_time) > 0 else 0,
                "timing/transform_percentage": (avg_transform_time / (avg_data_time + avg_gpu_time + avg_transform_time)) * 100 if (avg_data_time + avg_gpu_time + avg_transform_time) > 0 else 0, 
                "timing/gpu_percentage": (avg_gpu_time / (avg_data_time + avg_gpu_time + avg_transform_time)) * 100 if (avg_data_time + avg_gpu_time + avg_transform_time) > 0 else 0,
                                "system/gpu_memory_mb": get_gpu_memory_usage()
            }
            
            logger.log_metrics(train_metrics, global_step)
            
            logger.log_train_step(global_step, epoch, {
                'loss': loss.item(),
                'cls_loss': cls_loss.item(),
                'consistency_loss': weighted_consistency.item(),
                'learning_rate': current_lr,
                'consistency_weight': consistency_weights[0].item()
            })
            
            logger.log_system_metrics(global_step, {
                'images_per_second': images_per_second,
                'data_load_ms': avg_data_time * 1000,
                'transform_ms': avg_transform_time * 1000,
                'gpu_compute_ms': avg_gpu_time * 1000,
                'queue_size': 0,
                'gpu_memory_mb': get_gpu_memory_usage()
            })

        global_step += 1
            
    epoch_time = time.time() - epoch_start
    avg_epoch_loss = epoch_loss / len(train_loader) if len(train_loader) > 0 else 0
    avg_epoch_cls_loss = epoch_cls_loss / len(train_loader) if len(train_loader) > 0 else 0
    avg_epoch_consistency_loss = epoch_consistency_loss / len(train_loader) if len(train_loader) > 0 else 0
    avg_epoch_transform_time = np.mean(transform_times) if transform_times else 0
    
    print(f"\nEpoch {epoch} Summary:")
    print(f"  Time: {epoch_time/60:.1f} minutes")
    print(f"  Avg Loss: {avg_epoch_loss:.4f}")
    print(f"  Avg Classification Loss: {avg_epoch_cls_loss:.4f}")
    print(f"  Avg Consistency Loss: {avg_epoch_consistency_loss:.4f}")
    print(f"  Avg Transform Time: {avg_epoch_transform_time*1000:.1f}ms")
    print(f"  Throughput: {len(train_dataset) / epoch_time:.1f} images/second")
    
    print("Running end of epoch validation...")
    val_acc, val_metrics, cm = validate(teacher_model, val_loader, config, monitor, epoch)
    
    logger.log_metrics(val_metrics, epoch)
    logger.log_validation(epoch, val_metrics)
    logger.log_confusion_matrix(cm, ['bad', 'neutral', 'good'], epoch)
    
    print(f"  Validation Accuracy: {val_acc:.4f}")
    print(f"  Best Accuracy: {monitor.best_accuracy:.4f}")
    
    if val_acc > best_accuracy:
        best_accuracy = val_acc
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'teacher_state_dict': teacher_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_accuracy': val_acc,
            'config': vars(config)
        }
        checkpoint_path = os.path.join(config.checkpoint_path, f"best_model_acc{val_acc:.4f}.pth")
        torch.save(checkpoint, checkpoint_path)
        print(f"Saved best model: {checkpoint_path}")
    
    scheduler.step()
    
    if monitor.epochs_without_improvement >= config.early_stopping_patience:
        print(f"Early stopping at epoch {epoch}")
        break


Unfreezing backbone at epoch 0


Epoch 0:   4%|â–Ž         | 274/7820 [02:34<1:10:47,  1.78it/s, loss=0.3101, img/s=71.7, data_ms=66.0, trans_ms=225.1, gpu_ms=380.1]    


KeyboardInterrupt: 

In [None]:
final_loss_graph_path = os.path.join(logger.get_log_dir(), 'final_loss_graph.png')
plot_running_loss(loss_history, final_loss_graph_path, window_size=100)
logger.close()

In [None]:
# 1.9 iteration is about as good as it gets (basically gpu bottlenecked)
# 1.6 is without random rotation


# INFERENCE 

In [None]:
test_dataset = ValidationImageDataset(
    csv_file="test.csv", 
    root_dir=".\\data\\test", 
    clip_transform=clip_val_transform,
    yolo_transform=yolo_val_transform
)
test_loader = DataLoader(test_dataset, config.batch_size, False, collate_fn=custom_collate_fn)
print(f"Test batches: {len(test_loader)}")

test_acc, test_metrics, cm = validate(teacher_model, test_loader, config, monitor, -1)
print(f"Test accuracy: {test_acc:.4f}")

# Inference functions
def load_for_inference(checkpoint_path, device='cuda'):
    """Load model for inference"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model = BiggerClassifier().to(device)
    
    # Try loading teacher model first (usually better), fallback to student
    if 'teacher_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['teacher_state_dict'])
        print("Loaded teacher model for inference")
    elif 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Loaded student model for inference")
    else:
        raise KeyError("No model found in checkpoint")
    
    model.eval()
    return model

def infer(pil_imgs, model=None, checkpoint_path=None, device='cuda'):
    """
    Inference function that takes a list of PIL images and returns softmaxed logits
    
    Args:
        pil_imgs: List of PIL images or single PIL image
        model: Loaded model (if None, will load from checkpoint_path)
        checkpoint_path: Path to checkpoint (if model is None)
        device: Device to run inference on
    
    Returns:
        Softmaxed logits of shape [batch_size, num_classes]
    """
    if model is None:
        if checkpoint_path is None:
            raise ValueError("Either model or checkpoint_path must be provided")
        model = load_for_inference(checkpoint_path, device)
    
    # Handle single image
    if not isinstance(pil_imgs, list):
        pil_imgs = [pil_imgs]
    
    # Prepare images
    clip_images = []
    yolo_tensors = []
    
    for img in pil_imgs:
        # Ensure RGB
        if img.mode != 'RGB':
            img = img.convert('RGB')
        
        # Apply validation transforms
        clip_img = clip_val_transform(img)
        yolo_tensor = yolo_val_transform(img)
        
        clip_images.append(clip_img)
        yolo_tensors.append(yolo_tensor)
    
    # Stack YOLO tensors
    yolo_batch = torch.stack(yolo_tensors).to(device)
    
    # Run inference
    with torch.no_grad():
        logits = model(clip_images, yolo_batch)
        probs = F.softmax(logits, dim=1)
    
    return probs


In [None]:

# Example usage:
# model = load_for_inference("./checkpoints/best_model_acc0.8328.pth")
# probs = infer([Image.open("test.jpg")], model=model)
# print(f"Predictions: {probs}")
# print(f"Predicted class: {probs.argmax(dim=1)}")  # 0=bad, 1=neutral, 2=good


# COMMENTS


SHIT I DID
1. the consistency loss has been corrected - now is on a per image basis
2. now using asymmetric focal loss
3. teacher model now in eval() always 
4. clip processor should be called not in forward pass, but in dataset
5. back to CE loss, for sanity checking
6. CLIP no longer gets data augmentation, to save on the disgusting CPU transform time
7. training loss graph is now shown, from beginning of training to current time, every 1000 batches 
8. also tracking the distribution of the preds in validate() as 3 histograms
9. the cosine scheduler now makes sense, using T_0=2, T_mult=1 (cycles of a constant 2 epochs)
10. consistency loss is halved on labeled data, and doubled on unlabeled data. 
11. Asymm focal loss has been softened, but doesn't seem to do better than CE? I need to validate more frequently... 
12. I think I'll keep running with the asymm focal loss for now? Currently: penalise predicting BAD/GOOD on UNLABELED data less. GOOD b/c we don't want to miss any goods. BAD b/c... . **the intended result of this is more true neutrals predicted bad/good in the confusion matrix**
13. transform pipeline is correct now. sanity checked. clip and yolo now see the same rotation
14. slapped a resnet on. can change between resnet50, 152, etc.


FUTURE STEPS
1. the consistency loss can be weighted by teacher confidence (already implemented in mean_teacher.py)
2. the start epochs can be fully supervised data..  disable teacher consistency? disable unlabeled data? 
    [meh i'm basically already doing that, long warmup period and stuff]
4. try recovering from chkpt, and also try a CE loss -> asymm focal for epoch 0-3, 4-5 respectively for example
7. would like finer grain control on learnrates, and also how they thaw... which should I freeze more btw? the front or back of the pretraineds? I should completely freeze the first X layers, then do some sort of linearly increasing LR for the back ones i guess
8. I need to move away from cosine anneal. **I need to get my LR shit right**. ***top priority***

11. MUST: validate more frequently on the new incoming dataset


# SANDBOX


<function torchvision.transforms.functional.to_tensor(pic: Union[PIL.Image.Image, numpy.ndarray]) -> torch.Tensor>

In [None]:
from PIL import Image