#**ResNet19 Teacher Model - TSER for KD**

#**Setup**

---
**Install Libraries**

In [None]:
!pip install snntorch dagshub mlflow pynvml --quiet

---
**GitHub Code**

In [None]:
from google.colab import userdata
import os

# Sets environ variables for GitHub
os.environ['GITHUB_TOKEN'] = userdata.get('GITHUB_TOKEN')
os.environ['USER'] = userdata.get('USER')

# Clones the repo and changes dir
!git clone https://${GITHUB_TOKEN}@github.com/${USER}/tser-kd.git
%cd tser-kd/

---
**Set Seed for Experiment**

In [None]:
from tser_kd.utils import setup_seed

setup_seed(42)

---
**Device Selection**

In [4]:
import torch

# Selects the device for the experiment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---
**MLFlow Setup**

In [None]:
import mlflow
from mlflow import MlflowClient
import dagshub

# Sets environ variables for MLFlow
os.environ['MLFLOW_TRACKING_USERNAME'] = userdata.get('USER')
os.environ['MLFLOW_TRACKING_PASSWORD'] = userdata.get('MLFLOW_TRACKING_PASSWORD')

# Init DagsHub
dagshub.init(repo_owner='matteogianferrari', repo_name='tser-kd', mlflow=True)
TRACKING_URI = "https://dagshub.com/matteogianferrari/tser-kd.mlflow"

# Sets MLFlow tracking URI
mlflow.set_tracking_uri(TRACKING_URI)

# Sets MLFLow experiment name
experiment_name = "TSER-KD Teacher"

#**Hyperparameters**

In [12]:
# Hyperparamter dictionary
h_dict = {
    "MAX_EPOCHS": 256, "BATCH_SIZE": 2048,                  # Training
    "LR_SCHEDULER": "ReduceLROnPlateau", "BASE_LR": 1e-3,   # LR
    "LR_PATIENCE": 6, "LR_FACTOR": 0.666,                   # LR
    "ES_PATIENCE": 15, "ES_DELTA": 5e-4,                    # Early Stopping
    "OPTIMIZER": "AdamW", "WEIGHT_DECAY": 5e-4,             # Optimizer
    "HARDWARE": "L4",                                       # GPU
    "AUTO_AUG": True, "CUTOUT": True,                       # Dataset
}

#**CIFAR10 Dataset**

---
**Data Loaders Creation**

In [13]:
from tser_kd.dataset import load_cifar10_data
from torch.utils.data import DataLoader


train_dataset, val_dataset, num_classes = load_cifar10_data(auto_aug=h_dict['AUTO_AUG'], cutout=h_dict['CUTOUT'])

# Creates the train and test DataLoaders
train_loader = DataLoader(train_dataset, batch_size=h_dict['BATCH_SIZE'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=h_dict['BATCH_SIZE'], shuffle=False, num_workers=2)

#**Teacher ResNet-19**


---
**ResNet-19**

In [14]:
from tser_kd.model.teacher import make_teacher_model


# ANN
t_model = make_teacher_model(arch='resnet-19', in_channels=3, num_classes=num_classes, device=device)

#**Training**

---
**Objects Creation**

In [17]:
import torch.optim as optim
import torch.nn as nn
from tser_kd.training import EarlyStopping


# Optimizer
if h_dict["OPTIMIZER"] == 'AdamW':
    optimizer = optim.AdamW(t_model.parameters(), lr=h_dict['BASE_LR'], weight_decay=h_dict['WEIGHT_DECAY'])
elif h_dict["OPTIMIZER"] == 'Adam':
    optimizer = optim.Adam(t_model.parameters(), lr=h_dict['BASE_LR'], weight_decay=h_dict['WEIGHT_DECAY'])
elif h_dict["OPTIMIZER"] == 'SGD':
    optimizer = optim.SGD(t_model.parameters(), lr=h_dict['BASE_LR'], momentum=h_dict["MOMENTUM"], weight_decay=h_dict['WEIGHT_DECAY'])

# LR scheduler
if h_dict["LR_SCHEDULER"] == 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=h_dict["LR_PATIENCE"], factor=h_dict["LR_FACTOR"])
elif h_dict["LR_SCHEDULER"] == 'CosineAnnealingLR':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=h_dict["MAX_EPOCHS"])

# Loss
criterion = nn.CrossEntropyLoss()

# Early stopping
es_callback = EarlyStopping(patience=h_dict["ES_PATIENCE"], delta=h_dict["ES_DELTA"], path="best_ckpt.pth")

# Gradient scaler
scaler = torch.amp.GradScaler(device='cuda')

---
**Training Loop**

In [None]:
import pynvml
from tser_kd.training import run_train
from tser_kd.eval import run_eval


# Sets the MLFlow experiment
mlflow.set_experiment(experiment_name)

epoch_i = 0
curr_lr = optimizer.param_groups[0]["lr"]

# Train the model and log with MLFlow
with mlflow.start_run(log_system_metrics=True):
    for epoch_i in range(h_dict["MAX_EPOCHS"]):
        train_loss, train_acc, epoch_time, train_batch_time = run_train(
            epoch_i, train_loader, t_model, criterion, optimizer, device, scaler
        )

        val_loss, val_acc1, val_acc5, val_batch_time = run_eval(val_loader, t_model, criterion, device)

        # Logging
        print(
            f"Time: {epoch_time:.1f}s | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
            f"Val Loss: {val_loss:.4f} | Val Acc1: {val_acc1:.2f}% | Val Acc5: {val_acc5:.2f}% | LR: {curr_lr:.6f}"
        )

        mlflow.log_metrics({
            "learning_rate": curr_lr, "train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss,
            "val_acc1": val_acc1, "val_acc5": val_acc5, "epoch_time": epoch_time,
            "train_batch_time": train_batch_time, "val_batch_time": val_batch_time
        }, step=epoch_i)


        # Updates the LR
        if h_dict["LR_SCHEDULER"] == 'ReduceLROnPlateau':
            scheduler.step(val_loss)
        else:
            scheduler.step()

        curr_lr = optimizer.param_groups[0]["lr"]

        # ES check
        if es_callback(val_loss, epoch_i, t_model):
            break


    # Log hyperparameters
    mlflow.log_params(h_dict)

    # Log test performance
    t_model.load_state_dict(torch.load("best_ckpt.pth"))
    test_loss, test_acc1, test_acc5, _ = run_eval(val_loader, t_model, criterion, device)
    mlflow.log_metrics({"test_loss": test_loss, "test_acc1": test_acc1, "test_acc5": test_acc5})