In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import copy

In [None]:
config = {
    # General
    "training_session": 1,

    # Model
    "pre_trained": True,
    "learning_rate": 3e-4,
    "alpha": 0.99,
    "lambda_u": 1.0,
    "epochs": 10,

    # Dataset
    "input_type": "image",
    "dataset_path": "../datasets/rvl-cdip-kyc",
    "num_labels_per_class": 400,
    "batch_size": 64,

    # Image input
    "image_classes": ["form", "invoice", "memo", "letter"],
    "image_size": 224,

    # Text input

    # Tabular input
    "categorical_columns": ["Cbal", "Chist", "Cpur", "Sbal", "MSG", "Oparties", "Prop", "inPlans", "Htype", "JobType", "telephone", "foreign"],
    "numeric_columns": ["Cdur", "Camt", "Edur", "InRate", "Rdur", "age", "NumCred", "Ndepend"],
    "target_column": "creditScore",
}

In [3]:
def update_ema(student_model, teacher_model):
    for student_param, teacher_param in zip(student_model.parameters(), teacher_model.parameters()):
        teacher_param.data = config["alpha"] * teacher_param.data + (1 - config["alpha"]) * student_param.data

In [4]:
def train_one_epoch(student_model, teacher_model, lb_loader, ulb_loader, optimizer, device, epoch):
    student_model.train()
    teacher_model.train()

    total_loss = 0
    for (x_lb, y_lb), (x_ulb_w, x_ulb_s) in zip(lb_loader, ulb_loader):
        # Move to device
        x_lb, y_lb = x_lb.to(device), y_lb.to(device)
        x_ulb_w, x_ulb_s = x_ulb_w.to(device), x_ulb_s.to(device)

        # Supervised loss
        logits_lb = student_model(x_lb)
        loss_sup = F.cross_entropy(logits_lb, y_lb)

        # Unsupervised loss (consistency)
        with torch.no_grad():
            logits_ulb_w = teacher_model(x_ulb_w)
            pseudo_labels = torch.softmax(logits_ulb_w, dim=1)

        logits_ulb_s = student_model(x_ulb_s)
        loss_unsup = F.mse_loss(torch.softmax(logits_ulb_s, dim=1), pseudo_labels)

        # Total loss
        loss = loss_sup + config["lambda_u"] * loss_unsup
        total_loss += loss.item()

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # EMA update
        update_ema(student_model, teacher_model)

    print(f'Epoch {epoch} | Total Loss: {total_loss:.4f}')

In [5]:

def evaluate(model, val_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0.0

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = F.cross_entropy(logits, y)

            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            total_loss += loss.item()

    acc = np.mean(np.array(all_preds) == np.array(all_labels))
    print(f"Validation Accuracy: {acc:.4f} | Loss: {total_loss:.4f}")
    return acc, total_loss

In [6]:
def train_mean_teacher(student_model, lb_loader, ulb_loader, val_loader, device="cuda", epochs=10):
    teacher_model = copy.deepcopy(student_model)
    for param in teacher_model.parameters():
        param.requires_grad = False

    optimizer = optim.Adam(student_model.parameters(), lr=config["learning_rate"])

    best_val_accuracy = 0.0
    best_model_path = f"../models/mean_teacher/best_model_{config["training_session"]}.pt"
    for epoch in range(1, epochs + 1):
        train_one_epoch(student_model, teacher_model, lb_loader, ulb_loader, optimizer, device, epoch)

        val_accuracy, _ = evaluate(student_model, val_loader, device)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(student_model.state_dict(), best_model_path)
            print(f"✅ Best model saved to {best_model_path} | Accuracy: {val_accuracy:.4f}")

    return student_model, teacher_model

In [7]:
import sys
import os
sys.path.append(os.path.abspath(".."))

import importlib
import utilities.mt_token_factory as tf
import utilities.mt_dataloader_factory as dl
import utilities.mt_model_factory as md

importlib.reload(tf)
importlib.reload(dl)
importlib.reload(md)

  from .autonotebook import tqdm as notebook_tqdm


<module 'utilities.mt_model_factory' from '/Users/dundale/Downloads/bpi-ssl/utilities/mt_model_factory.py'>

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if config["input_type"] == "image":
    from utilities.mt_token_factory import token_factory
    from utilities.mt_model_factory import model_factory
    from utilities.mt_dataloader_factory import dataloader_factory

    base_transform = token_factory("image", image_size=(config["image_size"], config["image_size"]))
    
    # Create model and dataloaders for image input
    model = model_factory(
        "image", 
        num_classes=len(config["image_classes"]), 
        pretrained=config["pre_trained"]
    ).to(device)
    lb_loader, ulb_loader, val_loader = dataloader_factory(config, base_transform)

    # Train Mean Teacher
    trained_student, trained_teacher = train_mean_teacher(
        model, lb_loader, ulb_loader, val_loader, device, config["epochs"]
    )

elif config["input_type"] == "text":
    ...
elif config["input_type"] == "tabular":
    ...
else:
    raise ValueError(f"Unsupported input type: {config["input_type"]}")



Epoch 1 | Total Loss: 20.9808
Validation Accuracy: 0.4025 | Loss: 17.0293
✅ Best model saved to ../models/mean_teacher/best_model_1.pt | Accuracy: 0.4025
Epoch 2 | Total Loss: 9.5580
Validation Accuracy: 0.4138 | Loss: 16.0269
✅ Best model saved to ../models/mean_teacher/best_model_1.pt | Accuracy: 0.4138
Epoch 3 | Total Loss: 5.7303
