### Preliminaries

In [177]:
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torchvision.datasets import ImageFolder

In [178]:
import importlib
import token_factory as tf
import dataloader_factory as dl
import model_factory as md

importlib.reload(tf)
importlib.reload(dl)
importlib.reload(md)

<module 'model_factory' from '/Users/dundale/Downloads/bpi-ssl/algorithms/mean_teacher/model_factory.py'>

### Configuration

In [None]:
config = {
    # General
    "training_session": 1,
    "seed": 27,

    # Mean-Teacher Model
    "pre_trained": False,
    "learning_rate": 3e-4,
    "alpha": 0.99,
    "lambda_u": 1.0,
    "epochs": 20,

    # Dataset
    "input_type": "",                       
    "labeled_dataset_path": "",
    "unlabeled_dataset_path": "",
    "validation_set_percentage": 0,
    "batch_size": 64,

    # Image input
    "image_size": (0, 0),

    # Text input
    "text_column": "",
    "text_target_column": "",

    # Tabular input
    "categorical_columns": [],
    "numeric_columns": [],
    "tabular_target_column": "",
    "is_tabular_target_categorical": False, 
}

### Training Logic

In [180]:
def update_ema(student_model, teacher_model):
    alpha = config["alpha"]
    for student_param, teacher_param in zip(student_model.parameters(), teacher_model.parameters()):
        teacher_param.data = alpha * teacher_param.data + (1 - alpha) * student_param.data

In [181]:
def train_one_epoch(student_model, teacher_model, labeled_loader, unlabeled_loader, optimizer, device, epoch, is_regression):
    student_model.train()
    teacher_model.train()

    total_loss = 0
    for (x_labeled, y_labeled), (x_unlabeled_weak, x_unlabeled_strong) in zip(labeled_loader, unlabeled_loader):
        if config["input_type"] == "text" and config["pre_trained"]:
            x_labeled = {k: v.to(device) for k, v in x_labeled.items()}
            x_unlabeled_weak = {k: v.to(device) for k, v in x_unlabeled_weak.items()}
            x_unlabeled_strong = {k: v.to(device) for k, v in x_unlabeled_strong.items()}
        else:
            x_labeled = x_labeled.to(device)
            x_unlabeled_weak = x_unlabeled_weak.to(device)
            x_unlabeled_strong = x_unlabeled_strong.to(device)

        y_labeled = y_labeled.to(device)
        if is_regression:
            y_labeled = y_labeled.float().unsqueeze(1).to(device)
        else:
            y_labeled = y_labeled.to(device)

        # Supervised loss
        logits_labeled = student_model(x_labeled)
        supervised_loss = F.mse_loss(logits_labeled, y_labeled) if is_regression else F.cross_entropy(logits_labeled, y_labeled)

        # Unsupervised loss (consistency)
        if is_regression:
            with torch.no_grad():
                pseudo_labels = teacher_model(x_unlabeled_weak)
            logits_unlabeled_strong = student_model(x_unlabeled_strong)
            unsupervised_loss = F.mse_loss(logits_unlabeled_strong, pseudo_labels)
        else:
            with torch.no_grad():
                logits_ulb_w = teacher_model(x_unlabeled_weak)
                pseudo_labels = torch.softmax(logits_ulb_w, dim=1)
            logits_unlabeled_strong = student_model(x_unlabeled_strong)
            unsupervised_loss = F.mse_loss(torch.softmax(logits_unlabeled_strong, dim=1), pseudo_labels)

        # Total loss
        loss = supervised_loss + config["lambda_u"] * unsupervised_loss
        total_loss += loss.item()

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # EMA update
        update_ema(student_model, teacher_model)

    print(f"Total Loss: {total_loss:.4f}")

In [182]:
def evaluate(model, validation_loader, device, is_regression):
    model.eval()

    all_predictions, all_labels = [], []
    total_loss = 0.00

    with torch.no_grad():
        for x, y in validation_loader:
            if config["input_type"] == "text" and config["pre_trained"]:
                x = {k: v.to(device) for k, v in x.items()}
            else:
                x = x.to(device)

            y = y.to(device)

            logits = model(x)
            if is_regression:
                loss = F.mse_loss(logits.squeeze(), y.float())
                predictions = logits.squeeze()
            else:
                loss = F.cross_entropy(logits, y.long())
                predictions = torch.argmax(logits, dim=1)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            total_loss += loss.item()

    if is_regression:
        mae = np.mean(np.abs(np.array(all_predictions) - np.array(all_labels)))
        print(f"Validation MAE: {mae:.4f} | Loss: {total_loss:.4f}")
        return mae, total_loss
    else:
        accuracy = np.mean(np.array(all_predictions) == np.array(all_labels))
        print(f"Validation Accuracy: {accuracy:.4f} | Loss: {total_loss:.4f}")
        return accuracy, total_loss

In [183]:
def train_mean_teacher(student_model, labeled_loader, unlabeled_loader, validation_loader, device):
    teacher_model = copy.deepcopy(student_model)
    for param in teacher_model.parameters():
        param.requires_grad = False

    optimizer = optim.Adam(student_model.parameters(), lr=config["learning_rate"])
    is_regression = True if (config["input_type"] == "tabular" and not config["is_tabular_target_categorical"]) else False

    best_val_accuracy, best_mae = 0, float("inf")
    best_model_path = f"../../models/mean_teacher/best_model_{config["input_type"]}_{config["training_session"]}.pt"
    for epoch in range(1, config["epochs"] + 1):
        print(f"--- Start of Epoch {epoch} ---")
        
        train_one_epoch(
            student_model, teacher_model, labeled_loader, unlabeled_loader, optimizer, device, epoch, is_regression
        )

        if is_regression:
            mae, _ = evaluate(student_model, validation_loader, device, is_regression)
            if mae < best_mae:
                best_mae = mae
                torch.save(student_model.state_dict(), best_model_path)
                print(f"✅ Best model saved to {best_model_path} | MAE: {mae:.4f}")
        else:
            validation_accuracy, _ = evaluate(student_model, validation_loader, device, is_regression)
            if validation_accuracy > best_val_accuracy:
                best_val_accuracy = validation_accuracy
                torch.save(student_model.state_dict(), best_model_path)
                print(f"✅ Best model saved to {best_model_path} | Accuracy: {validation_accuracy:.4f}")

    print("--- End of Training ---")

### Training Main

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from token_factory import token_factory
from dataloader_factory import dataloader_factory
from model_factory import model_factory

if config["input_type"] == "image":
    # Load labeled and unlabeled dataset 
    labeled_dataset = ImageFolder(root=config["labeled_dataset_path"])
    unlabeled_dataset = ImageFolder(root=config["unlabeled_dataset_path"])

    # Split labeled dataset into train and validation sets
    indices = list(range(len(labeled_dataset)))
    labels = [sample[1] for sample in labeled_dataset.samples]
    train_indices, validation_indices = train_test_split(
        indices,
        test_size=config["validation_set_percentage"],
        stratify=labels,
        random_state=config["seed"]
    )

    # Obtain the base transform for image inputs
    tokenizer = token_factory(
        "image", 
        image_size=config["image_size"]
    )

    train = tokenizer(Subset(labeled_dataset, train_indices))
    validation = tokenizer(Subset(labeled_dataset, validation_indices))
    unlabeled = tokenizer(unlabeled_dataset)

    # Create dataloaders
    labeled_loader, unlabeled_loader, validation_loader = dataloader_factory(
        "image",
        train=train, validation=validation, 
        unlabeled=unlabeled, batch_size=config["batch_size"]
    )

    # Create ResNet or CNN model
    model = model_factory(
        "image", 
        num_classes=len(labeled_dataset.classes),
        pretrained=config["pre_trained"]
    ).to(device)

    train_mean_teacher(
        model, labeled_loader, unlabeled_loader, validation_loader, device
    )

elif config["input_type"] == "text":
    # Load labeled and unlabeled dataset
    labeled_dataframe = pd.read_csv(config["labeled_dataset_path"])
    unlabeled_dataframe = pd.read_csv(config["unlabeled_dataset_path"])

    # Split labeled dataset into train and validation sets
    train_dataframe, validation_dataframe = train_test_split(
        labeled_dataframe,
        test_size=config["validation_set_percentage"],
        stratify=labeled_dataframe[config["text_target_column"]],
        random_state=config["seed"]
    )

    # Obtain the tokenizer for text inputs
    tokenizer = token_factory(
        "text",
        text_column=config["text_column"],
        target_column=config["text_target_column"],
        pretrained=config["pre_trained"],
    )

    # Fit only on training dataframe (if not pre-trained)
    tokenizer.fit(train_dataframe)  

    # Transform remaining dataframes
    X_train = tokenizer.transform(train_dataframe)
    y_train = tokenizer.transform_target(train_dataframe)

    X_validation = tokenizer.transform(validation_dataframe)
    y_validation = tokenizer.transform_target(validation_dataframe)

    # Unlabeled text will be transformed later in dataloader_factory
    X_unlabeled = unlabeled_dataframe[config["text_column"]].tolist()

    # Create dataloaders
    labeled_loader, unlabeled_loader, validation_loader = dataloader_factory(
        "text",
        X_train=X_train, y_train=y_train,
        X_validation=X_validation, y_validation=y_validation,
        X_unlabeled=X_unlabeled, tokenizer=tokenizer, 
        batch_size=config["batch_size"]
    )

    # Create BERT model
    num_classes = len(np.unique(y_train.numpy())) 
    input_dim = X_train.shape[1] if not config["pre_trained"] else None 
    model = model_factory(
        "text",
        num_classes=num_classes,
        pretrained=config["pre_trained"],
        tfidf_dim=input_dim
    ).to(device)

    train_mean_teacher(
        model, labeled_loader, unlabeled_loader, validation_loader, device
    )

elif config["input_type"] == "tabular":
    is_regression = not config["is_tabular_target_categorical"]

    # Load labeled and unlabeled dataset
    labeled_dataframe = pd.read_csv(config["labeled_dataset_path"])
    unlabeled_dataframe = pd.read_csv(config["unlabeled_dataset_path"])

    # Split labeled dataset into train and validation sets
    train_dataframe, validation_dataframe = train_test_split(
        labeled_dataframe,
        test_size=config["validation_set_percentage"],
        stratify=labeled_dataframe[config["tabular_target_column"]],
        random_state=config["seed"]
    )

    # Obtain the tokenizer for tabular inputs
    tokenizer = token_factory(
        "tabular", 
        categorical_columns=config["categorical_columns"],
        numeric_columns=config["numeric_columns"],
        target_column=config["tabular_target_column"],
        is_target_categorical=config["is_tabular_target_categorical"]
    )

    # Fit only on training dataframe
    tokenizer.fit(train_dataframe)

    # Transform remaining dataframes
    X_train = tokenizer.transform(train_dataframe)
    y_train = tokenizer.transform_target(train_dataframe)

    X_validation = tokenizer.transform(validation_dataframe)
    y_validation = tokenizer.transform_target(validation_dataframe)

    X_unlabeled = tokenizer.transform(unlabeled_dataframe)
    
    # Convert to tensors
    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_validation = torch.tensor(X_validation, dtype=torch.float32)

    if is_regression:
        y_train = torch.tensor(y_train, dtype=torch.float32)
        y_validation = torch.tensor(y_validation.to_numpy(), dtype=torch.float32 if not config["is_tabular_target_categorical"] else torch.long)
    else:
        y_train = torch.tensor(y_train, dtype=torch.long)
        y_validation = torch.tensor(y_validation, dtype=torch.float32 if not config["is_tabular_target_categorical"] else torch.long)

    X_unlabeled = torch.tensor(X_unlabeled, dtype=torch.float32)

    # Create dataloaders
    labeled_loader, unlabeled_loader, validation_loader = dataloader_factory(
        "tabular", 
        X_train=X_train, y_train=y_train, 
        X_validation=X_validation, y_validation=y_validation, 
        X_unlabeled=X_unlabeled, batch_size=config["batch_size"]
    )
    
    # Create MLP model
    input_dim = labeled_dataframe.drop(columns=[config["tabular_target_column"]]).shape[1]
    num_classes = labeled_dataframe[config["tabular_target_column"]].nunique()
    model = model_factory(
        "tabular",
        input_dim=input_dim,
        num_classes=num_classes,
        regression=is_regression
    ).to(device)

    train_mean_teacher(
        model, labeled_loader, unlabeled_loader, validation_loader, device
    )

else:
    raise ValueError(f"Unsupported input type: {config["input_type"]}")