In [13]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
import copy

device = torch.device('cude' if torch.cuda.is_available() else 'cpu')

# hyperparameters
input_size = 784
batch_size = 64
hidden_size = 100
num_classes = 10
learning_rate = 0.001
num_epochs = 10



In [22]:
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../data',
                                          train=False,
                                          transform=transforms.ToTensor())


# Data Loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

In [23]:
def distillation(y, labels, teacher_scores, T, alpha):
#     print(F.cross_entropy(y, labels))
    return nn.KLDivLoss()(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha) + torch.nn.MSELoss()()

In [24]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [25]:
class AE(torch.nn.Module):
    def __init__(self):
        super().__init__()
          
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(100, 10),
        )
          

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(10, 100),
        )
  
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [26]:
# class AEAndReconstructed(torch.nn.Module):
#     def __init__(self, model):
#         super().__init__()
#         self.model = model
#         self.reconstructed_model = copy.deepcopy(model)
#         self.ae = AE()
  
#     def forward(self, x):
#         reconstructed = self.ae(self.model.fc1.weight.data.transpose(1, 0))
#         new_weight = torch.transpose(reconstructed.reshape(784, 100), 1, 0)
#         self.reconstructed_model.fc1.weight.data = new_weight
#         return self.reconstructed_model(x)

In [27]:
def get_reconstructed(images, model, ae):
    reconstructed = ae(model.fc1.weight.data.transpose(1, 0))
    new_weight = torch.transpose(reconstructed.reshape(784, 100), 1, 0)
    reconstructed_model = copy.deepcopy(model)
    reconstructed_model.fc1.weight.data = new_weight
    labels = reconstructed_model(images)
    return labels, new_weight.flatten(), model.fc1.weight.data.flatten()
    

In [35]:
model = MLP(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes)
model.load_state_dict(torch.load(f'../model/model_1.ckpt'))
ae = AE()
epochs = 50
# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(ae.parameters(),
                             lr = 1e-1,
                             weight_decay = 1e-8)

for epoch in range(epochs):
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, input_size).to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        student_output, student_weight, original_weight = get_reconstructed(images, model, ae)
        teacher_output = model(images)
        teacher_output = teacher_output.detach()
        loss = distillation(student_output, labels, teacher_output, T=100, alpha=0)
#         loss.requres_grad = True
        loss.backward()
        optimizer.step()
        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(images), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss))


  return nn.KLDivLoss()(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
  return nn.KLDivLoss()(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)










KeyboardInterrupt: 

In [None]:
model2 = MLP(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes)
model2.load_state_dict(torch.load(f'../model/model_1.ckpt'))

In [None]:
model2.fc2.weight.data

In [None]:
reconstructed.model.fc2.weight.data