In [1]:
import torch
import numpy as np
import wandb

from modeling import ASTPretrained
from modeling.learner import KDLearner, Learner
from modeling.models import StudentAST, interpolate_params
from modeling.utils import parse_config
from modeling.dataset import get_loader

In [2]:
SEED = 123
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [3]:
def main(config):
    
    
    wandb.init(config=config, anonymous="allow")
    
    train_dl = get_loader(config, subset="train")
    valid_dl = get_loader(config, subset="valid")
    
    student = StudentAST(11, hidden_size=192, num_heads=3)
    
    weight = torch.load("../weights/averaged_weights_bce.pth")
    teacher = ASTPretrained(n_classes=11, download_weights=False)
    teacher.load_state_dict(weight)
    thresholds = np.load("../weights/acc_model_thresh.npy")
    
    interpolated_weights = interpolate_params(student, teacher)
    student.base_model.load_state_dict(interpolated_weights)

    learn = KDLearner(train_dl, valid_dl, student, teacher, thresholds, config)

    learn.fit()
    
    wandb.finish()

In [None]:
CONFIG_PATH = "../configs/KD_config.yaml"
config = parse_config(CONFIG_PATH)

params = main(config)

[34m[1mwandb[0m: Currently logged in as: [33mk-pintaric[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/2 [00:00<?, ?it/s]

Distilling knowledge...


  0%|          | 0/14110 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]