In [None]:
from pathlib import Path
import numpy as np
import random
import torch
from torchinfo import summary

from network import TeacherNet, StudentNet
from trainer import Trainer
from tester import test

In [None]:
# Fix the seed for reproducibility
def set_seed():
    seed = 999
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
config = {
    'batch_size': 64,
    'teacher_lr': 1e-3,
    'student_lr': 1e-2,
    'student_scheduler': 'cycle', # cycle or step
    'teacher_epoch': 15,
    'student_epoch': 30,
    'max_temp': 10,
    'min_temp': 5,
    'aug': True,
    'device': 'cuda',

    'teacher_log_dir': Path('runs/teacher/'),
    'student_log_dir': Path('runs/student/'),
    
    # prune from unpruned model
    'student_pruned_1_log_dir': Path('runs/student_pruned_1/'),
    'from_pruned': False,
    'prune_rate': 0.5,
    
    # prune from pruned model
    'student_pruned_2_log_dir': Path('runs/student_pruned_2/'),
    # 'from_pruned': True,
    # 'prune_rate': 0.6,
}

### Train teacher model

In [None]:
set_seed()
trainer = Trainer(config, config['teacher_log_dir'])

print(summary(trainer.teacher_model, input_size=(1, 3, 224, 224)))

In [None]:
trainer.run(option='teacher')

### Test teacher model

In [None]:
model = TeacherNet().to(config['device'])
model.load_state_dict(torch.load(str(config['teacher_log_dir'] / "training_result/teacher_model.pth")))
model.eval()

test(model, config['teacher_log_dir'])

### Train student model

In [None]:
set_seed()
trainer = Trainer(config, config['student_log_dir'])

# load pretrained teacher model
trainer.teacher_model.load_state_dict(torch.load(str(config['teacher_log_dir'] / 'training_result/teacher_model.pth')))

print(summary(trainer.student_model, input_size=(1, 3, 224, 224)))

In [None]:
trainer.run(option='student')

### Test student model

In [None]:
model = StudentNet().to(config['device'])
model.load_state_dict(torch.load(str(config['student_log_dir'] / "training_result/student_model.pth")))
model.eval()

test(model, config['student_log_dir'])

### Train 50% pruned student model

In [None]:
set_seed()
trainer = Trainer(config, config['student_pruned_1_log_dir'])

# load pretrained teacher model
trainer.teacher_model.load_state_dict(torch.load(str(config['teacher_log_dir'] / 'training_result/teacher_model.pth')))

if config['from_pruned']:
    trainer.student_model.pruning(0)
trainer.student_model.load_state_dict(torch.load(str(config['student_log_dir'] / 'training_result/student_model.pth')))

if config['from_pruned']:
    trainer.student_model.remove_pruning()
trainer.student_model.pruning(config['prune_rate'])

print(summary(trainer.student_model, input_size=(1, 3, 224, 224)))

In [None]:
trainer.run(option='student')

### Test 50% pruned student model

In [None]:
model = StudentNet(prune=True).to(config['device'])
model.load_state_dict(torch.load(str(config['student_pruned_1_log_dir'] / "training_result/student_model.pth")))
model.eval()

test(model, config['student_pruned_1_log_dir'])

### Train 60% pruned student model

In [None]:
config['from_pruned'] = True
config['prune_rate'] = 0.6

set_seed()
trainer = Trainer(config, config['student_pruned_2_log_dir'])

# load pretrained teacher model
trainer.teacher_model.load_state_dict(torch.load(str(config['teacher_log_dir'] / 'training_result/teacher_model.pth')))

if config['from_pruned']:
    trainer.student_model.pruning(0)
trainer.student_model.load_state_dict(torch.load(str(config['student_pruned_1_log_dir'] / 'training_result/student_model.pth')))

if config['from_pruned']:
    trainer.student_model.remove_pruning()
trainer.student_model.pruning(config['prune_rate'])

print(summary(trainer.student_model, input_size=(1, 3, 224, 224)))

In [None]:
trainer.run(option='student')

### Test 60% pruned student model

In [None]:
model = StudentNet(prune=True).to(config['device'])
model.load_state_dict(torch.load(str(config['student_pruned_2_log_dir'] / "training_result/student_model.pth")))
model.eval()

test(model, config['student_pruned_2_log_dir'])