#**Temporal Separation with Entropy Regularization for Knowledge Distillation**

#**Setup**

---
**Install Libraries**

In [1]:
!pip install snntorch dagshub mlflow pynvml --quiet

---
**GitHub Code**

In [None]:
from google.colab import userdata
import os

# Sets environ variables for GitHub
os.environ['GITHUB_TOKEN'] = userdata.get('GITHUB_TOKEN')
os.environ['USER'] = userdata.get('USER')

# Clones the repo and changes dir
!git clone https://${GITHUB_TOKEN}@github.com/${USER}/tser-kd.git
%cd tser-kd/

---
**Set Seed for Experiment**

In [None]:
from tser_kd.utils import setup_seed

setup_seed(42)

---
**Device Selection**

In [4]:
import torch

# Selects the device for the experiment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---
**MLFlow Setup**

In [None]:
import mlflow
from mlflow import MlflowClient
import dagshub

# Sets environ variables for MLFlow
os.environ['MLFLOW_TRACKING_USERNAME'] = userdata.get('USER')
os.environ['MLFLOW_TRACKING_PASSWORD'] = userdata.get('MLFLOW_TRACKING_PASSWORD')

# Init DagsHub
dagshub.init(repo_owner='matteogianferrari', repo_name='tser-kd', mlflow=True)
TRACKING_URI = "https://dagshub.com/matteogianferrari/tser-kd.mlflow"

# Sets MLFlow tracking URI
mlflow.set_tracking_uri(TRACKING_URI)

# Sets MLFLow experiment name
experiment_name = "TSER-KD"

#**Hyperparameters**

In [6]:
# # Loss Hyperparameters
# TAU = 4.0
# ALPHA = 0.8
# GAMMA = 0.05



# Hyperparamter dictionary
h_dict = {
    "TAU": 5.0, "ALPHA": 0.1, "GAMMA": 1e-3,                # Loss
    "BETA": 0.5, "V_th": 1.0,                               # Leaky Neuron
    "MAX_EPOCHS": 300, "BATCH_SIZE": 256,                    # Training
    "LR_SCHEDULER": "CosineAnnealingLR", "BASE_LR": 0.0005,      # LR
    "ES_PATIENCE": 50, "ES_DELTA": 5e-4,                    # Early Stopping
    "OPTIMIZER": "SGD", "MOMENTUM": 0.9,             # Optimizer
    "HARDWARE": "A100",                                     # GPU
    "ENCODER": "Static", "T": 2,                            # Encoder
    "AUTO_AUG": True, "CUTOUT": True,                       # Dataset
}

#**CIFAR10 Dataset**

---
**Data Loaders Creation**

In [7]:
from tser_kd.dataset import load_cifar10_data
from torch.utils.data import DataLoader


train_dataset, val_dataset, num_classes = load_cifar10_data(auto_aug=h_dict['AUTO_AUG'], cutout=h_dict['CUTOUT'])

# Creates the train and test DataLoaders
train_loader = DataLoader(train_dataset, batch_size=h_dict['BATCH_SIZE'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=h_dict['BATCH_SIZE'], shuffle=False, num_workers=2)

#**Teacher and Student Models**

---
**Models Creation**

In [8]:
from tser_kd.model.teacher import make_teacher_model
from tser_kd.model.student import make_student_model


# ANN
t_state_dict = torch.load("data/teacher_models/resnet34_9714.pth")
t_model = make_teacher_model(
    arch='resnet-34',
    in_channels=3,
    num_classes=num_classes,
    device=device,
    state_dict=t_state_dict
)

# SNN
s_state_dict = torch.load("data/student_models/sresnet19_9225.pth")
s_model = make_student_model(
    in_channels=3,
    num_classes=num_classes,
    beta=h_dict['BETA'],
    device=device,
    state_dict=None
)

#**Training**

In [10]:
import torch.optim as optim
import torch.nn as nn
from tser_kd.training import EarlyStopping
from tser_kd.dataset import RateEncoder, StaticEncoder
from tser_kd.model import TSERKDLoss
from tser_kd.training.lr_scheduler import WarmupCosineLR


# Optimizer
if h_dict["OPTIMIZER"] == 'AdamW':
    optimizer = optim.AdamW(s_model.parameters(), lr=h_dict['BASE_LR'], weight_decay=h_dict['WEIGHT_DECAY'])
elif h_dict["OPTIMIZER"] == 'Adam':
    optimizer = optim.Adam(s_model.parameters(), lr=h_dict['BASE_LR'], weight_decay=h_dict['WEIGHT_DECAY'])
elif h_dict["OPTIMIZER"] == 'SGD':
    optimizer = optim.SGD(s_model.parameters(), lr=h_dict['BASE_LR'], momentum=h_dict["MOMENTUM"])

# LR scheduler
if h_dict["LR_SCHEDULER"] == 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=h_dict["LR_PATIENCE"], factor=h_dict["LR_FACTOR"])
elif h_dict["LR_SCHEDULER"] == 'CosineAnnealingLR':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=h_dict["MAX_EPOCHS"])
elif h_dict["LR_SCHEDULER"] == 'StepLR':
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=h_dict["LR_STEP"], gamma=h_dict["LR_FACTOR"])
elif h_dict["LR_SCHEDULER"] == 'WarmupCosineLR':
    scheduler = WarmupCosineLR(
        optimizer=optimizer,
        warmup_epochs=h_dict["WARMUP_EPOCHS"],
        total_epochs=h_dict["MAX_EPOCHS"],
        base_lr=h_dict["BASE_LR"],
        max_lr=h_dict["MAX_LR"]
    )

# Losses
train_criterion = TSERKDLoss(alpha=h_dict["ALPHA"], gamma=h_dict["GAMMA"], tau=h_dict["TAU"])
eval_criterion = nn.CrossEntropyLoss()

# Early stopping
es_callback = EarlyStopping(patience=h_dict["ES_PATIENCE"], delta=h_dict["ES_DELTA"], path="best_ckpt.pth")

# Gradient scaler
scaler = torch.amp.GradScaler(device='cuda')

# Encoder
if h_dict["ENCODER"] == "Rate":
    encoder = RateEncoder(num_steps=h_dict["T"], gain=h_dict["GAIN"])
elif h_dict["ENCODER"] == "Static":
    encoder = StaticEncoder(num_steps=h_dict["T"])

In [None]:
import pynvml
from tser_kd.eval import run_eval
from tser_kd.training import run_kd_train


# Sets the MLFlow experiment
mlflow.set_experiment(experiment_name)

epoch_i = 0
START_EPOCH = 0
curr_lr = optimizer.param_groups[0]["lr"]

# Train the model and log with MLFlow
with mlflow.start_run(run_id=None, log_system_metrics=True):
    for epoch_i in range(h_dict["MAX_EPOCHS"]):
        train_total_loss, train_ce_loss, train_kl_loss, train_e_reg, train_acc, epoch_time, train_batch_time = run_kd_train(
            epoch_i, train_loader, s_model, t_model, train_criterion, optimizer, device, scaler, encoder
        )

        val_loss, val_acc1, val_acc5, val_batch_time = run_eval(
            val_loader, s_model, eval_criterion, device, encoder
        )

        # Logging
        print(
            f"Time: {epoch_time:.1f}s | Train Total Loss: {train_total_loss:.4f} | Train Acc: {train_acc:.2f}% | "
            f"Val Loss: {val_loss:.4f} | Val Acc1: {val_acc1:.2f}% | Val Acc5: {val_acc5:.2f}% | LR: {curr_lr:.6f}"
        )

        mlflow.log_metrics({
            "learning_rate": curr_lr, "train_tserkd_loss": train_total_loss, "train_tsce_loss": train_ce_loss, "train_tskl_loss": train_kl_loss,
            "train_e_reg": train_e_reg, "train_acc": train_acc, "val_ce_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5, "epoch_time": epoch_time,
            "train_batch_time": train_batch_time, "val_batch_time": val_batch_time
        }, step=epoch_i)

        # Updates the LR
        if h_dict["LR_SCHEDULER"] == 'ReduceLROnPlateau':
            scheduler.step(val_loss)
        else:
            scheduler.step()

        curr_lr = optimizer.param_groups[0]["lr"]

        # ES check
        if es_callback(val_loss, epoch_i, s_model):
            break

    # Log hyperparameters
    mlflow.log_params(h_dict)

    # Log test performance
    s_model.load_state_dict(torch.load("best_ckpt.pth"))
    test_ce_loss, test_acc1, test_acc5, _ = run_eval(val_loader, s_model, eval_criterion, device, encoder)
    mlflow.log_metrics({"test_ce_loss": test_ce_loss, "test_acc1": test_acc1, "test_acc5": test_acc5})