In [1]:
from toolbox.models import ResNet112, ResNet56
from toolbox.data_loader import Cifar100

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import tensorly as tl
import numpy as np
import matplotlib.pyplot as plt

from rich import print as pprint
from tqdm import trange

device = "cuda"


In [14]:
model_path = r"toolbox/Cifar100_ResNet112.pth"
teacher = ResNet112(100).to(device)
teacher.load_state_dict(torch.load(model_path, weights_only=True)["weights"])
teacher.eval()

# STUDENT 
student = ResNet56(100).to(device)
student.train()


criterion = nn.L1Loss()

Data = Cifar100()
trainloader, testloader = Data.trainloader, Data.testloader


Files already downloaded and verified
Files already downloaded and verified


In [20]:

tl.set_backend("pytorch")
def tucker(feature_map): #expects 4d
    batch_size, channels, height, width = feature_map.shape
    core, factors = tl.decomposition.tucker(feature_map, rank=[batch_size, 32, 8, 8])
    return core, factors

def compute_core(feature_map, factors):
    return tl.tenalg.multi_mode_dot(feature_map, [f.T for f in factors], modes=[0, 1, 2, 3])

def FT(x):
    return F.normalize(x.reshape(x.size(0), -1))

for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to(device), targets.to(device)

    teacher_outputs = teacher(inputs)
    student_outputs = student(inputs)

    teacher_core, teacher_factors = tucker(teacher_outputs[2])
    # student_core, student_factors = tucker(student_outputs[2])
    student_core = compute_core(student_outputs[2], teacher_factors)

    print("abs", tl.norm(teacher_core - student_core))
    loss_amount = tl.norm(FT(teacher_core) - FT(student_core))
    print(loss_amount)

    loss_amount = 125 * criterion(FT(student_core), FT(teacher_core))
    print(loss_amount)
    break



abs tensor(1706.0847, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
tensor(15.9687, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
tensor(2.4207, device='cuda:0', grad_fn=<MulBackward0>)
