### Preliminaries

Essential libraries for data handling, model training, and dataset preparation.

In [177]:
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torchvision.datasets import ImageFolder

Imports the custom factory modules used in the pipeline:
- `token_factory`: Handles preprocessing/tokenization based on input type  
- `dataloader_factory`: Builds dataloaders for labeled, unlabeled, and validation sets  
- `model_factory`: Constructs models based on input type and task

The modules are reloaded using `importlib.reload()` to reflect any recent changes without restarting the environment (useful in interactive environments like notebooks).

In [None]:
import importlib
import token_factory as tf
import dataloader_factory as dl
import model_factory as md

importlib.reload(tf)
importlib.reload(dl)
importlib.reload(md)

### Configuration

This dictionary contains all the key settings for the training session. It is designed to support different input types — such as **images**, **text**, and **tabular data** — and allows you to toggle hyperparameters of the Mean Teacher algorithm.

Sections:
- **General**: Includes session ID and random seed for reproducibility.
- **Mean Teacher Settings**: Controls core hyperparameters like learning rate, consistency loss weighting (`lambda_u`), EMA decay (`alpha`), and number of epochs.
- **Dataset Paths & Structure**:
  - Dataset paths should point to appropriate folders or files for labeled, unlabeled, and validation data.
  - `input_type`: Specify "image", "text", or "tabular".
  - `validation_set_percentage`: Used to split part of the labeled data for validation.
- **Input-Specific Options**:
  - For **image data**: Define the expected image resolution.
  - For **text data**: Set the names of the column containing raw text and its corresponding label.
  - For **tabular data**: Specify the column names for categorical and numeric inputs, as well as the target column. You can also define whether the target is a classification or regression problem.
<br><br>
> 💡 Detailed explanations of configuration variables can be found in the README.md

In [None]:
config = {
    # General
    "training_session": 1,
    "seed": 27,

    # Mean-Teacher Model
    "pre_trained": False,
    "learning_rate": 3e-4,
    "alpha": 0.99,
    "lambda_u": 1.0,
    "epochs": 20,

    # Dataset
    "input_type": "",                       
    "labeled_dataset_path": "",
    "unlabeled_dataset_path": "",
    "validation_set_percentage": 0,
    "batch_size": 64,

    # Image input
    "image_size": (0, 0),

    # Text input
    "text_column": "",
    "text_target_column": "",

    # Tabular input
    "categorical_columns": [],
    "numeric_columns": [],
    "tabular_target_column": "",
    "is_tabular_target_categorical": False, 
}

### Training Utilities

***update_ema(...)***

This function performs an **Exponential Moving Average (EMA)** update on the teacher model's weights, which is a core part of the **Mean Teacher** algorithm.

How it works:
- The teacher model is not trained directly.
- Instead, it is updated slowly over time to be a smoothed version of the student model.
- This is done by blending each parameter of the teacher with the corresponding parameter from the student, based on the value of `alpha`

Parameters:
- `student_model`: The main model being trained with labeled and unlabeled data.
- `teacher_model`: A secondary model that receives EMA-updated weights from the student.
- `alpha`: The smoothing factor (usually close to 1, e.g., 0.99), defined in the config. 
<br><br>
> 💡 This update makes the teacher more stable and helps the student learn consistent predictions over time.

In [180]:
def update_ema(student_model, teacher_model):
    alpha = config["alpha"]
    for student_param, teacher_param in zip(student_model.parameters(), teacher_model.parameters()):
        teacher_param.data = alpha * teacher_param.data + (1 - alpha) * student_param.data

***train_one_epoch(...)***

This function performs **one training epoch** for both the student and teacher models using the **Mean Teacher** algorithm. It handles both **supervised learning** on labeled data and **consistency-based learning** on unlabeled data.

How it works:

1. **Load Data**:
   - Fetch one batch each from the labeled and unlabeled dataloaders.
   - For unlabeled data, two versions are used:
     - `weak`: slightly augmented input for the teacher.
     - `strong`: more heavily augmented input for the student.

2. **Device Handling**:
   - Inputs are moved to the GPU or CPU.
   - Handles special cases for tokenized text (like BERT) where inputs are dictionaries.

3. **Compute Supervised Loss**:
   - Use labeled data to compute a loss between predictions and ground-truth labels.
   - Chooses between MSE (for regression) or CrossEntropy (for classification).

4. **Compute Unsupervised Loss (Consistency Loss)**:
   - The teacher generates predictions (pseudo-labels) from the weakly augmented input.
   - The student is trained to match these pseudo-labels using the strongly augmented version.
   - Uses MSE between softmax probabilities if it's a classification task.

5. **Combine Losses**:
   - The total loss is a weighted sum of supervised and unsupervised losses.
   - `lambda_u` controls the influence of the unsupervised component.

6. **Backpropagation and Optimization**:
   - Gradients are computed from the total loss and used to update the student model.

7. **Teacher Update via EMA**:
   - The teacher model is updated to slowly follow the student model using Exponential Moving Average.

Parameters:
- `student_model`: The main model that is actively trained.
- `teacher_model`: A stable copy of the student model, updated with EMA.
- `labeled_loader`: Dataloader for labeled samples.
- `unlabeled_loader`: Dataloader that returns both weak and strong augmented unlabeled inputs.
- `optimizer`: Optimizer for student model.
- `device`: Target device (CPU/GPU).
- `epoch`: Current training epoch (for logging/debugging).
- `is_regression`: Boolean flag to switch loss functions for regression tasks.

<br>

> 💡 This function is designed to support **both regression and classification**, and works with all three input types: images, text, and tabular data.

In [181]:
def train_one_epoch(student_model, teacher_model, labeled_loader, unlabeled_loader, optimizer, device, epoch, is_regression):
    student_model.train()
    teacher_model.train()

    total_loss = 0
    for (x_labeled, y_labeled), (x_unlabeled_weak, x_unlabeled_strong) in zip(labeled_loader, unlabeled_loader):
        if config["input_type"] == "text" and config["pre_trained"]:
            x_labeled = {k: v.to(device) for k, v in x_labeled.items()}
            x_unlabeled_weak = {k: v.to(device) for k, v in x_unlabeled_weak.items()}
            x_unlabeled_strong = {k: v.to(device) for k, v in x_unlabeled_strong.items()}
        else:
            x_labeled = x_labeled.to(device)
            x_unlabeled_weak = x_unlabeled_weak.to(device)
            x_unlabeled_strong = x_unlabeled_strong.to(device)

        y_labeled = y_labeled.to(device)
        if is_regression:
            y_labeled = y_labeled.float().unsqueeze(1).to(device)
        else:
            y_labeled = y_labeled.to(device)

        # Supervised loss
        logits_labeled = student_model(x_labeled)
        supervised_loss = F.mse_loss(logits_labeled, y_labeled) if is_regression else F.cross_entropy(logits_labeled, y_labeled)

        # Unsupervised loss (consistency)
        if is_regression:
            with torch.no_grad():
                pseudo_labels = teacher_model(x_unlabeled_weak)
            logits_unlabeled_strong = student_model(x_unlabeled_strong)
            unsupervised_loss = F.mse_loss(logits_unlabeled_strong, pseudo_labels)
        else:
            with torch.no_grad():
                logits_ulb_w = teacher_model(x_unlabeled_weak)
                pseudo_labels = torch.softmax(logits_ulb_w, dim=1)
            logits_unlabeled_strong = student_model(x_unlabeled_strong)
            unsupervised_loss = F.mse_loss(torch.softmax(logits_unlabeled_strong, dim=1), pseudo_labels)

        # Total loss
        loss = supervised_loss + config["lambda_u"] * unsupervised_loss
        total_loss += loss.item()

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # EMA update
        update_ema(student_model, teacher_model)

    print(f"Total Loss: {total_loss:.4f}")

***evaluate(...)***

This function evaluates the performance of a trained model on a **validation dataset**. It works for both **regression** and **classification** tasks.

How it works:

1. **Switches to Evaluation Mode**:
   - Disables dropout and gradient tracking for faster and consistent evaluation.

2. **Iterates Through Validation Batches**:
   - Inputs are passed through the model.
   - Loss is computed using:
     - **MSE** (Mean Squared Error) for regression
     - **Cross-Entropy** for classification

3. **Generates Predictions**:
   - For regression: uses raw predicted values.
   - For classification: uses the class with the highest predicted score.

4. **Computes Metrics**:
   - For regression: Mean Absolute Error (MAE) and total loss
   - For classification: Accuracy and total loss

Parameters:
- `model`: The model to evaluate (student or teacher).
- `validation_loader`: Dataloader for the validation dataset.
- `device`: The target device (CPU or GPU).
- `is_regression`: Boolean flag that determines which evaluation metrics and loss function to use.
<br><br>
> 💡 Automatically handles different input types (text, tabular, image), and supports pretrained tokenizers for text inputs.

In [182]:
def evaluate(model, validation_loader, device, is_regression):
    model.eval()

    all_predictions, all_labels = [], []
    total_loss = 0.00

    with torch.no_grad():
        for x, y in validation_loader:
            if config["input_type"] == "text" and config["pre_trained"]:
                x = {k: v.to(device) for k, v in x.items()}
            else:
                x = x.to(device)

            y = y.to(device)

            logits = model(x)
            if is_regression:
                loss = F.mse_loss(logits.squeeze(), y.float())
                predictions = logits.squeeze()
            else:
                loss = F.cross_entropy(logits, y.long())
                predictions = torch.argmax(logits, dim=1)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            total_loss += loss.item()

    if is_regression:
        mae = np.mean(np.abs(np.array(all_predictions) - np.array(all_labels)))
        print(f"Validation MAE: {mae:.4f} | Loss: {total_loss:.4f}")
        return mae, total_loss
    else:
        accuracy = np.mean(np.array(all_predictions) == np.array(all_labels))
        print(f"Validation Accuracy: {accuracy:.4f} | Loss: {total_loss:.4f}")
        return accuracy, total_loss

***train_mean_teacher(...)***

This is the **main training function** for the **Mean Teacher** algorithm. It trains a student model using both labeled and unlabeled data, while maintaining a teacher model that is updated via Exponential Moving Average (EMA).

How it works:

1. **Initialize Teacher Model**:
   - A copy of the student model is made using `deepcopy`.
   - The teacher’s weights are frozen (`requires_grad = False`) so it's only updated via EMA.

2. **Set Up Training**:
   - Uses the Adam optimizer with a configurable learning rate.
   - Checks if the task is **regression** or **classification**, based on the config flags for tabular targets.

3. **Training Loop**:
   For each epoch:
   - Calls `train_one_epoch()` to train the student model and update the teacher.
   - Evaluates the student model on the validation set.

4. **Model Checkpointing**:
   - For **regression**, the best model is saved based on the **lowest MAE (Mean Absolute Error)**.
   - For **classification**, the best model is saved based on the **highest accuracy**.
   - Saves the model to a session-specific file path for reproducibility.

5. **Logging**:
   - Prints epoch-level progress and saves the model only when improvement is detected.

Parameters:
- `student_model`: The model actively trained.
- `labeled_loader`: Dataloader with labeled samples.
- `unlabeled_loader`: Dataloader with weak/strong augmented unlabeled samples.
- `validation_loader`: Dataloader for model evaluation.
- `device`: Target device (CPU or GPU).
<br><br>
> 💡 This function integrates all core parts of the Mean Teacher loop: supervised learning, unsupervised consistency loss, teacher EMA updates, and validation tracking.

In [223]:
def train_mean_teacher(student_model, labeled_loader, unlabeled_loader, validation_loader, device):
    teacher_model = copy.deepcopy(student_model)
    for param in teacher_model.parameters():
        param.requires_grad = False

    optimizer = optim.Adam(student_model.parameters(), lr=config["learning_rate"])
    is_regression = True if (config["input_type"] == "tabular" and not config["is_tabular_target_categorical"]) else False

    best_val_accuracy, best_mae = 0, float("inf")
    best_model_path = f"../../models/mean_teacher/best_model_{config["input_type"]}_{config["training_session"]}.pt"
    for epoch in range(1, config["epochs"] + 1):
        print(f"--- Start of Epoch {epoch} ---")
        
        train_one_epoch(
            student_model, teacher_model, labeled_loader, unlabeled_loader, optimizer, device, epoch, is_regression
        )

        if is_regression:
            mae, _ = evaluate(student_model, validation_loader, device, is_regression)
            if mae < best_mae:
                best_mae = mae
                torch.save(student_model.state_dict(), best_model_path)
                print(f"✅ Best model saved to {best_model_path} | MAE: {mae:.4f}")
        else:
            validation_accuracy, _ = evaluate(student_model, validation_loader, device, is_regression)
            if validation_accuracy > best_val_accuracy:
                best_val_accuracy = validation_accuracy
                torch.save(student_model.state_dict(), best_model_path)
                print(f"✅ Best model saved to {best_model_path} | Accuracy: {validation_accuracy:.4f}")

    print("--- End of Training ---")

### Main Training Loop

This is the **main execution block** that sets up the correct data pipeline and model based on the `input_type` defined in the configuration. It supports three types of input data:

- `"image"` – for document or visual classification tasks  
- `"text"` – for NLP tasks like intent classification or document tagging  
- `"tabular"` – for structured data like customer profiles or financial metrics

General Workflow

1. **Detect Device**  
   Automatically selects GPU if available, otherwise falls back to CPU.

2. **Load Input Data**  
   Reads from CSV files or image folders depending on input type.

3. **Train-Validation Split**  
   Partitions labeled data into a training and validation set using the percentage defined in `config["validation_set_percentage"]`.

4. **Tokenizer Factory**  
   Uses a unified `token_factory()` to preprocess:
   - Images (resize, normalize)
   - Text (tokenize using TF-IDF or a pre-trained model like BERT)
   - Tabular (handle categorical/numerical features, scaling)

5. **Dataloader Creation**  
   Uses `dataloader_factory()` to return:
   - `labeled_loader`: for supervised learning
   - `unlabeled_loader`: for consistency learning
   - `validation_loader`: for performance tracking

6. **Model Factory**  
   Automatically builds the right model:
   - CNN or ResNet for images
   - MLP or BERT for text
   - MLP for tabular data

7. **Run Mean Teacher Training**  
   Starts the full semi-supervised training cycle using the `train_mean_teacher()` function.


In [224]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from token_factory import token_factory
from dataloader_factory import dataloader_factory
from model_factory import model_factory

if config["input_type"] == "image":
    # Load labeled and unlabeled dataset 
    labeled_dataset = ImageFolder(root=config["labeled_dataset_path"])
    unlabeled_dataset = ImageFolder(root=config["unlabeled_dataset_path"])

    # Split labeled dataset into train and validation sets
    indices = list(range(len(labeled_dataset)))
    labels = [sample[1] for sample in labeled_dataset.samples]
    train_indices, validation_indices = train_test_split(
        indices,
        test_size=config["validation_set_percentage"],
        stratify=labels,
        random_state=config["seed"]
    )

    # Obtain the base transform for image inputs
    tokenizer = token_factory(
        "image", 
        image_size=config["image_size"]
    )

    train = tokenizer(Subset(labeled_dataset, train_indices))
    validation = tokenizer(Subset(labeled_dataset, validation_indices))
    unlabeled = tokenizer(unlabeled_dataset)

    # Create dataloaders
    labeled_loader, unlabeled_loader, validation_loader = dataloader_factory(
        "image",
        train=train, validation=validation, 
        unlabeled=unlabeled, batch_size=config["batch_size"]
    )

    # Create ResNet or CNN model
    model = model_factory(
        "image", 
        num_classes=len(labeled_dataset.classes),
        pretrained=config["pre_trained"]
    ).to(device)

    train_mean_teacher(
        model, labeled_loader, unlabeled_loader, validation_loader, device
    )

elif config["input_type"] == "text":
    # Load labeled and unlabeled dataset
    labeled_dataframe = pd.read_csv(config["labeled_dataset_path"])
    unlabeled_dataframe = pd.read_csv(config["unlabeled_dataset_path"])

    # Split labeled dataset into train and validation sets
    train_dataframe, validation_dataframe = train_test_split(
        labeled_dataframe,
        test_size=config["validation_set_percentage"],
        stratify=labeled_dataframe[config["text_target_column"]],
        random_state=config["seed"]
    )

    # Obtain the tokenizer for text inputs
    tokenizer = token_factory(
        "text",
        text_column=config["text_column"],
        target_column=config["text_target_column"],
        pretrained=config["pre_trained"],
    )

    # Fit only on training dataframe (if not pre-trained)
    tokenizer.fit(train_dataframe)  

    # Transform remaining dataframes
    X_train = tokenizer.transform(train_dataframe)
    y_train = tokenizer.transform_target(train_dataframe)

    X_validation = tokenizer.transform(validation_dataframe)
    y_validation = tokenizer.transform_target(validation_dataframe)

    # Unlabeled text will be transformed later in dataloader_factory
    X_unlabeled = unlabeled_dataframe[config["text_column"]].tolist()

    # Create dataloaders
    labeled_loader, unlabeled_loader, validation_loader = dataloader_factory(
        "text",
        X_train=X_train, y_train=y_train,
        X_validation=X_validation, y_validation=y_validation,
        X_unlabeled=X_unlabeled, tokenizer=tokenizer, 
        batch_size=config["batch_size"]
    )

    # Create BERT model
    num_classes = len(np.unique(y_train.numpy())) 
    input_dim = X_train.shape[1] if not config["pre_trained"] else None 
    model = model_factory(
        "text",
        num_classes=num_classes,
        pretrained=config["pre_trained"],
        tfidf_dim=input_dim
    ).to(device)

    train_mean_teacher(
        model, labeled_loader, unlabeled_loader, validation_loader, device
    )

elif config["input_type"] == "tabular":
    is_regression = not config["is_tabular_target_categorical"]

    # Load labeled and unlabeled dataset
    labeled_dataframe = pd.read_csv(config["labeled_dataset_path"])
    unlabeled_dataframe = pd.read_csv(config["unlabeled_dataset_path"])

    # Split labeled dataset into train and validation sets
    train_dataframe, validation_dataframe = train_test_split(
        labeled_dataframe,
        test_size=config["validation_set_percentage"],
        stratify=labeled_dataframe[config["tabular_target_column"]],
        random_state=config["seed"]
    )

    # Obtain the tokenizer for tabular inputs
    tokenizer = token_factory(
        "tabular", 
        categorical_columns=config["categorical_columns"],
        numeric_columns=config["numeric_columns"],
        target_column=config["tabular_target_column"],
        is_target_categorical=config["is_tabular_target_categorical"]
    )

    # Fit only on training dataframe
    tokenizer.fit(train_dataframe)

    # Transform remaining dataframes
    X_train = tokenizer.transform(train_dataframe)
    y_train = tokenizer.transform_target(train_dataframe)

    X_validation = tokenizer.transform(validation_dataframe)
    y_validation = tokenizer.transform_target(validation_dataframe)

    X_unlabeled = tokenizer.transform(unlabeled_dataframe)
    
    # Convert to tensors
    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_validation = torch.tensor(X_validation, dtype=torch.float32)

    if is_regression:
        y_train = torch.tensor(y_train, dtype=torch.float32)
        y_validation = torch.tensor(y_validation.to_numpy(), dtype=torch.float32 if not config["is_tabular_target_categorical"] else torch.long)
    else:
        y_train = torch.tensor(y_train, dtype=torch.long)
        y_validation = torch.tensor(y_validation, dtype=torch.float32 if not config["is_tabular_target_categorical"] else torch.long)

    X_unlabeled = torch.tensor(X_unlabeled, dtype=torch.float32)

    # Create dataloaders
    labeled_loader, unlabeled_loader, validation_loader = dataloader_factory(
        "tabular", 
        X_train=X_train, y_train=y_train, 
        X_validation=X_validation, y_validation=y_validation, 
        X_unlabeled=X_unlabeled, batch_size=config["batch_size"]
    )
    
    # Create MLP model
    input_dim = labeled_dataframe.drop(columns=[config["tabular_target_column"]]).shape[1]
    num_classes = labeled_dataframe[config["tabular_target_column"]].nunique()
    model = model_factory(
        "tabular",
        input_dim=input_dim,
        num_classes=num_classes,
        regression=is_regression
    ).to(device)

    train_mean_teacher(
        model, labeled_loader, unlabeled_loader, validation_loader, device
    )

else:
    raise ValueError(f"Unsupported input type: {config["input_type"]}")