In [None]:
%load_ext autoreload
%autoreload 2

# Code

In [None]:
from width import *

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from tqdm import tqdm
import itertools

import matplotlib.pyplot as plt

# Experiment 1:

How does the ratio of `hidden_width` vs `auxiliary_logits` affect the performance of the student network?

In [None]:
def run(hidden_width: int, auxiliary_logits: int, seed: int = 42):
  # Seed
  g = set_seed(seed)

  # Initalize Network
  teacher, student = init_models(hidden_width=hidden_width, auxiliary_logits=auxiliary_logits)
  teacher = torch.compile(teacher)
  student = torch.compile(student)

  # Training Parameters
  epochs = 5
  lr = 0.001
  optimizer_teacher = torch.optim.Adam(teacher.parameters(), lr=lr)
  optimizer_student = torch.optim.Adam(student.parameters(), lr=lr)
  criterion_teacher = nn.CrossEntropyLoss()
  criterion_student = nn.MSELoss()

  batch_size = 128
  train_loader, test_loader = Dataset.load_FashionMNIST(batch_size=batch_size, seed=seed)

  # Run Training
  trainer = Trainer(student, teacher, train_loader, optimizer_teacher, optimizer_student, criterion_teacher, criterion_student, device)

  baseline_teacher = trainer.performance(trainer.teacher, test_loader)
  baseline_student = trainer.performance(trainer.student, test_loader)

  print('Start Training')
  trainer.train_teacher(epochs)
  trainer.train_student(epochs)
  print('Finished Training')

  results_teacher = trainer.performance(trainer.teacher, test_loader)
  results_student = trainer.performance(trainer.student, test_loader)

  del teacher, student, trainer, optimizer_student, optimizer_teacher
  torch.cuda.empty_cache()

  return baseline_teacher, baseline_student, results_teacher, results_student

In [None]:
hiddens = [2 ** j for j in range(8, 13)]
auxiliaries = [3, 10, 50, 100]

itt = list(itertools.product(hiddens, auxiliaries))
seeds = range(len(list(itt)))

results = []
for (hidden, auxiliary), seed in tqdm(zip(itt, seeds)):
  result = run(hidden_width=hidden, auxiliary_logits=auxiliary, seed=seed)
  results.append(result)

In [None]:
batches_acc_1hot = []
batches_acc_5hot = []
batches_acc_1hot_teacher = []

for result in results:
  batch_acc_1hot = []
  batch_acc_5hot = []
  batch_acc_1hot_teacher = []
  for seed in result:
    baseline_teacher, baseline_student, results_teacher, results_student = seed

    # Student Metrics
    # student_loss_teacherMetric.append(results_student['loss_teacher'])
    batch_acc_1hot.append(results_student['acc_1hot'])
    batch_acc_5hot.append(results_student['acc_5hot'])

    # Teacher Metrics
    batch_acc_1hot_teacher.append(results_teacher['acc_1hot'])

  batches_acc_1hot.append(batch_acc_1hot)
  batches_acc_5hot.append(batch_acc_5hot)
  batches_acc_1hot_teacher.append(batch_acc_1hot_teacher)

data_acc_1hot = np.stack(batches_acc_1hot)
data_acc_5hot = np.stack(batches_acc_5hot)
data_acc_1hot_teacher = np.stack(batches_acc_1hot_teacher)

In [None]:
[plt.scatter(lst, data_acc_1hot[:,i], color='blue') for i in range(data_acc_1hot.shape[1])]
[plt.scatter(lst, data_acc_5hot[:,i], color='mediumblue', linestyle='--') for i in range(data_acc_5hot.shape[1])]
[plt.scatter(lst, data_acc_1hot_teacher[:,i], color='red') for i in range(data_acc_1hot_teacher.shape[1])]

plt.plot(lst, data_acc_1hot.mean(axis=1), color='blue', label='acc_1hot (student)')
plt.plot(lst, data_acc_5hot.mean(axis=1), color='mediumblue', linestyle='--', label='acc_5hot (student)')
plt.plot(lst, data_acc_1hot_teacher.mean(axis=1), color='red', label='acc_1hot (teacher)')

plt.xscale('log', base=2)

plt.xlabel('hidden_width')
plt.ylabel('After training performance')
plt.title(f"Student's Performance on Fashion MNIST Validation Set (s.t. auxiliary_logits / hidden_width  = {ratio})")


plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()

In [None]:
plt.plot(lst, student_acc_1hot, label='acc_1hot (student)', color='blue', marker='o')
plt.plot(lst, student_acc_5hot, label='acc_5hot (student)', color='blue', linestyle='--', marker='o')
plt.plot(lst, teacher_acc_1hot, label='acc_1hot (teacher)', color='red', marker='o')

plt.xscale('log', base=2)

plt.xlabel('hidden_width')
plt.ylabel('After training performance')
plt.title(f"Student's Performance on Fashion MNIST Validation Set (s.t. auxiliary_logits / hidden_width  = {ratio})")

plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()