# Import Packages

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW, Adam
from base import get_basic_model, get_model_passt
from models.mobileViTCA import make_model
from dataset import train_set, test_set, validation_set
from vanilla_kd import VanillaKD

# Loading Models

## Teacher Model

In [None]:
teacher = get_basic_model(mode="logits")

teacher.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=4)
# Loading weights from ESC50 pretrained model
state_dict = torch.hub.load_state_dict_from_url(url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.6/esc50-passt-s-n-f128-p16-s10-fold2-acc.977.pt')
teacher.net.load_state_dict(state_dict) 

## Student Model

In [None]:
student = make_model(net='xxs', num_classes = 4, patch_size=(4,4), patch_stride=(3,3))

## Models Hyperparams

In [None]:
# Learning Rate
Lr = 0.001
# Optimizer
teacher_optim = AdamW(teacher.parameters(), lr=Lr)
student_optim = AdamW(student.parameters(), lr=Lr)

# Dataset

In [None]:
train_df = pd.read_csv('train.cvs')
val_df   = pd.read_csv('Val.cvs')
test_df  = pd.read_csv('test.cvs')

In [None]:
train_ds = train_set(train_df, 'Train')
val_ds = train_set(val_df, 'Val')
test_ds = train_set(test_df, 'Test')

In [None]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)

val_loader = torch.utils.data.DataLoader(val_ds, batch_size=32, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=True)

# Knwoledge Distillation 

In [None]:
num_epoch=100
# Distillation instance
distiller = VanillaKD(teacher, student, train_loader, test_loader, teacher_optim, student_optim) 
# Train Teacher 
distiller.train_teacher(epochs=num_epoch, plot_losses=True, save_model=True)    
# Train Student
distiller.train_student(epochs=num_epoch, plot_losses=True, save_model=True)    