In [None]:
import os
import sys
import time
import yaml
import argparse
import math
import numpy as np
import madry
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10, ImageFolder, ImageNet
from torchvision.models import resnet50, alexnet
from rep_align.models.model_factory import get_model
from rep_align.utils.lr_scheduling import *
from rep_align.data.nsd_data import NSDVoxels
import wandb
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
WANDB = False

In [None]:
activations = {}
def get_activations(name):
    def hook(model, input, output):
        activations[name] = output
    return hook

def evaluate_clf(model, data, device='cuda'):
    model.eval()
    n, total_correct = 0, 0
    with torch.no_grad():
        for img, label in data:
            img = img.to(device)
            label = label.to(device)
            output = model(img)
            n += len(label)
            total_correct += torch.sum(torch.argmax(output, dim=-1) == label).item()
    return total_correct/n

def evaluate(student_model, teacher_model, data, loss_fn_predict, 
             loss_fn_kd, alpha_kd, epoch, device='cuda'):
    student_model.eval()
    n, running_total_loss, running_loss_prediction, running_loss_kd = 0, 0, 0, 0
    with tqdm(data, unit='batch') as epoch_progress:
        epoch_progress.set_description(f'eval {epoch}')
        with torch.no_grad():
            for img, target in epoch_progress:
                img = img.to(device)
                target = target.to(device)
                output_student = student_model(img)
                output_teacher = teacher_model(img)
                loss_predict = get_predict_loss(student_model, activations['student_feats'], target, loss_fn_predict)
                loss_kd = get_kd_loss(output_student, output_teacher, loss_fn_kd)
                total_loss = loss_predict + alpha_kd*loss_kd

                n += len(target)
                running_loss_prediction += loss_predict.item()*len(target)
                running_loss_kd += loss_kd.item()*len(target)
                running_total_loss += total_loss.item()*len(target)
                epoch_progress.set_postfix(
                    predict_loss='{:.4f}'.format((running_loss_prediction/n)), 
                    kd_loss='{:.4f}'.format((running_loss_kd/n)), 
                    total_loss='{:.4f}'.format((running_total_loss/n))
                )

    prediction_loss = running_loss_prediction/n
    kd_loss = running_loss_kd/n
    total_loss = running_total_loss/n
    if WANDB:
        wandb.log(
            {'val_prediction_loss': prediction_loss,
             'val_kd_loss': kd_loss,
             'val_total_loss': total_loss, 
             'epoch': epoch}
        )
    return prediction_loss, kd_loss, total_loss

def get_predict_loss(student_model, student_feats, target, loss_fn_predict):
    student_predictions = student_model.neural_predict(student_feats.flatten(1))
    return loss_fn_predict(student_predictions, target)

def get_kd_loss(student_out, teacher_out, loss_fn_kd):
    student = F.log_softmax(student_out, dim=-1)
    teacher = F.log_softmax(teacher_out, dim=-1)
    return loss_fn_kd(student, teacher)

def lwf_train_epoch(student_model, teacher_model, data, 
                    loss_fn_predict, loss_fn_kd, alpha_kd, 
                    optimizer, lr_scheduler, epoch, device='cuda'):
    student_model.train()
    n, running_total_loss, running_loss_prediction, running_loss_kd = 0, 0, 0, 0
    with tqdm(data, unit='batch') as epoch_progress:
        epoch_progress.set_description(f'train {epoch}')
        for img, target in epoch_progress:
            img = img.to(device)
            target = target.to(device)
            optimizer.zero_grad()
            output_student = student_model(img)
            output_teacher = teacher_model(img)
            loss_predict = get_predict_loss(student_model, activations['student_feats'], target, loss_fn_predict)
            loss_kd = get_kd_loss(output_student, output_teacher, loss_fn_kd)
            total_loss = loss_predict + alpha_kd*loss_kd
            total_loss.backward()
            optimizer.step()
            lr_scheduler.step()
            new_lr = lr_scheduler.get_lr()
            for group in optimizer.param_groups:
                group['lr'] = new_lr
                if WANDB:
                    wandb.log({'lr': new_lr, 'lr_step': lr_scheduler.current_step})

            n += len(target)
            running_loss_prediction += loss_predict.item()*len(target)
            running_loss_kd += loss_kd.item()*len(target)
            running_total_loss += total_loss.item()*len(target)
            epoch_progress.set_postfix(
                predict_loss='{:.4f}'.format((running_loss_prediction/n)), 
                kd_loss='{:.4f}'.format((running_loss_kd/n)), 
                total_loss='{:.4f}'.format((running_total_loss/n))
            )

    prediction_loss = running_loss_prediction/n
    kd_loss = running_loss_kd/n
    total_loss = running_total_loss/n
    if WANDB:
        wandb.log(
            {'train_prediction_loss': prediction_loss,
             'train_kd_loss': kd_loss,
             'train_total_loss': total_loss, 
             'epoch': epoch}
        )
    return prediction_loss, kd_loss, total_loss

def run(student_model, teacher_model, neural_train_data, neural_val_data, clf_data, loss_fn_prediction, loss_fn_kd, 
          alpha_kd, optimizer, lr_scheduler, n_epochs, out_path=None, device='cuda'):
    if WANDB:
        wandb.init(project='repalign')

    hist = {'train': [], 'val': []}
    best_loss = np.inf
    for epoch in range(n_epochs):
        train_prediction_loss, train_kd_loss, train_total_loss = lwf_train_epoch(
            student_model, teacher_model, neural_train_data, loss_fn_prediction, 
            loss_fn_kd, alpha_kd, optimizer, lr_scheduler, epoch, device=device
        )
        
        val_prediction_loss, val_kd_loss, val_total_loss = evaluate(
            student_model, teacher_model, neural_val_data, 
            loss_fn_prediction, loss_fn_kd, alpha_kd, epoch, device=device
        )
        
        val_clf_acc = evaluate_clf(student_model, clf_data, device=device)

        hist['train'].append([train_prediction_loss, train_kd_loss, train_total_loss])
        hist['val'].append([val_prediction_loss, val_kd_loss, val_total_loss, val_clf_acc])

        if val_total_loss < best_loss:
            best_loss = val_total_loss
            if out_path is not None:
                model_ckpt_data = {
                    'student_state_dict': student_model.state_dict(),
                    'teacher_state_dict': teacher_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'hist': hist,
                    'wandb_run_name': None if not WANDB else wandb.run.name,
                }
                torch.save(model_ckpt_data, os.path.join(out_path, 'best_val_loss.pt'))
        
    if WANDB:
        wandb.finish()

    return hist

In [None]:
cifar_10_train_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.4914, 0.4822, 0.4465],
        [0.2470, 0.2435, 0.2616]
    )
])

cifar_10_val_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.4914, 0.4822, 0.4465],
        [0.2470, 0.2435, 0.2616]
    )
])

num_workers = 8
batch_size = 128

nsd_data_root_train = '/DATA/nsd_sample/train'
nsd_data_root_val = '/DATA/nsd_sample/val'
voxel_rois = ['V4']
nsd_train_data = NSDVoxels(nsd_data_root_train, voxel_rois, transforms=cifar_10_train_transform)
nsd_val_data = NSDVoxels(nsd_data_root_val, voxel_rois, transforms=cifar_10_val_transform)
nsd_train_dataloader = DataLoader(nsd_train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
nsd_val_dataloader = DataLoader(nsd_val_data, batch_size=batch_size, num_workers=num_workers)

n_voxels = nsd_train_data[0][1].shape[0]


cifar10_val_data = CIFAR10('/DATA/cifar10', 
                           train=False, transform=cifar_10_val_transform, download=False)
cifar10_val_dataloader = DataLoader(cifar10_val_data, batch_size=batch_size, num_workers=num_workers)

In [None]:
model_type = 'resnet18_cifar10'

student_model = get_model(model_type, n_classes=10)
teacher_model = get_model(model_type, n_classes=10)

state_dict = torch.load('./model_ckpts/resnet18_cifar10_base/best_loss.pt')['state_dict']
student_model.load_state_dict(state_dict)
teacher_model.load_state_dict(state_dict)

for param in teacher_model.parameters():
    param.requires_grad = False
teacher_model.eval()

student_model.neural_predict = nn.Linear(2048, n_voxels, bias=True)

student_model.to(device)
teacher_model.to(device)

original_acc = evaluate_clf(teacher_model, cifar10_val_dataloader)
print(f'Original Accuracy: {original_acc}')

In [None]:
run_config = {
    'lr_scheduling_spec': 'constant',
    'lr_init': 1e-4,
    'optimizer_spec': 'adam',
    'optimizer_params': {'weight_decay': 0.01},
    'alpha_kd': 1.0,
    'n_epochs': 10,
    'out_path': './tune_tmp'
}

if run_config['lr_scheduling_spec'] == 'constant':
    lr_scheduler = ConstantLR(run_config['lr_init'])
elif run_config['lr_scheduling_spec'] == 'linear_warmup_cosine_decay':
    lr_scheduler = LinearWarmupCosineDecayLR(run_config['warmup_start_lr'], 
                                             run_config['base_lr'], 
                                             run_config['warmup_steps']*len(nsd_train_dataloader), 
                                             run_config['max_steps']*len(nsd_train_dataloader), 
                                             run_config['eta_min'])

if run_config['optimizer_spec'] == 'sgd':
    optimizer = torch.optim.SGD(student_model.parameters(), **run_config['optimizer_params'])
elif run_config['optimizer_spec'] == 'adam':
    optimizer = torch.optim.Adam(student_model.parameters(), **run_config['optimizer_params'])

loss_fn_prediction = nn.MSELoss()
loss_fn_kd = nn.KLDivLoss(reduction='batchmean', log_target=True)

In [None]:
alpha_kd = run_config['alpha_kd']

student_hook = student_model.layer4.register_forward_hook(get_activations('student_feats'))
teacher_hook = teacher_model.layer4.register_forward_hook(get_activations('teacher_feats'))

run_hist =  run(student_model, teacher_model, nsd_train_dataloader, nsd_val_dataloader, 
                cifar10_val_dataloader, loss_fn_prediction, loss_fn_kd, run_config['alpha_kd'], 
                optimizer, lr_scheduler, run_config['n_epochs'], out_path=run_config['out_path'], device=device)

student_hook.remove()
teacher_hook.remove()

In [None]:
from rep_align.eval.interp_index import evaluate_interp_index

ii_results_student = evaluate_interp_index(student_model, 'layer4', train_data=None, val_data=cifar10_val_data, device=device)
ii_results_teacher = evaluate_interp_index(teacher_model, 'layer4', train_data=None, val_data=cifar10_val_data, device=device)