In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F

import os
import shutil
import argparse
import numpy as np


import models
import torchvision
import torchvision.transforms as transforms
from utils import cal_param_size, cal_multi_adds


from bisect import bisect_right
import time
import math

In [25]:
data_path = './data'
checkpoint_dir = './checkpoint'
dataset = 'cifar100'
arch = 'wrn_16_2_aux'
tarch = 'wrn_40_2_aux'
tcheckpoint = './checkpoint/wrn_40_2_aux.pth.tar'
init_lr = 0.05
weight_decay = 5e-4
lr_type = 'multistep'
resume = False
evaluate = False
milestones = [150, 180, 210]
sgdr_t = 300
warmup_epoch = 0
epochs = 240
batch_size = 64
num_workers = 8
gpu_id = '3'
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
manual_seed = 0
kd_T = 3
num_classes = 100

In [26]:
np.random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
torch.set_printoptions(precision=4)

In [27]:
log_txt = 'result/'+ str(os.path.basename('kushagrabhushan/TrAIL/WideResNet').split('.')[0]) + '_'+\
          'tarch' + '_' +  tarch + '_'+\
          'arch' + '_' +  arch + '_'+\
          'dataset' + '_' +  dataset + '_'+\
          'seed'+ str(manual_seed) +'.txt'

log_dir = str(os.path.basename('kushagrabhushan/TrAIL/WideResNet').split('.')[0]) + '_'+\
          'tarch' + '_' +  tarch + '_'+\
          'arch'+ '_' + arch + '_'+\
          'dataset' + '_' +  dataset + '_'+\
          'seed'+ str(manual_seed)

In [28]:
num_classes = 100
trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True,
                                        transform=transforms.Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.5071, 0.4867, 0.4408],
                                                                [0.2675, 0.2565, 0.2761])
                                        ]))

testset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.5071, 0.4867, 0.4408],
                                                                [0.2675, 0.2565, 0.2761]),
                                        ]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,
                                    pin_memory=(torch.cuda.is_available()))

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False,
                                    pin_memory=(torch.cuda.is_available()))


Files already downloaded and verified
Files already downloaded and verified


In [29]:
print('load pre-trained teacher weights from: {}'.format(tcheckpoint))     
checkpoint = torch.load(tcheckpoint, map_location=torch.device('cpu'))

model = getattr(models, arch)
net = model(num_classes=num_classes).cuda()
net =  torch.nn.DataParallel(net)

tmodel = getattr(models, tarch)
tnet = tmodel(num_classes=num_classes).cuda()
tnet.load_state_dict(checkpoint['net'])
tnet.eval()
tnet =  torch.nn.DataParallel(tnet)

_, ss_logits = net(torch.randn(2, 3, 32, 32))
num_auxiliary_branches = len(ss_logits)
cudnn.benchmark = True

load pre-trained teacher weights from: ./checkpoint/wrn_40_2_aux.pth.tar


In [30]:
class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T**2)
        return loss

In [31]:
def correct_num(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    correct = pred.eq(target.view(-1, 1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:, :k].float().sum()
        res.append(correct_k)
    return res


In [32]:
def adjust_lr(optimizer, epoch, args, step=0, all_iters_per_epoch=0):
    cur_lr = 0.
    if epoch < args['warmup_epoch']:
        cur_lr = args['init_lr'] * float(1 + step + epoch*all_iters_per_epoch)/(warmup_epoch *all_iters_per_epoch)
    else:
        epoch = epoch - args['warmup_epoch']
        cur_lr = args['init_lr'] * 0.1 ** bisect_right(args['milestones'], epoch)

    for param_group in optimizer.param_groups:
        param_group['lr'] = cur_lr
    return cur_lr

In [33]:
def train(epoch, criterion_list, optimizer):
    train_loss = 0.
    train_loss_cls = 0.
    train_loss_div = 0.

    ss_top1_num = [0] * num_auxiliary_branches
    ss_top5_num = [0] * num_auxiliary_branches
    class_top1_num = [0] * num_auxiliary_branches
    class_top5_num = [0] * num_auxiliary_branches
    top1_num = 0
    top5_num = 0
    total = 0

    if epoch >= warmup_epoch:
        lr = adjust_lr(optimizer, epoch, args)

    start_time = time.time()
    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]

    net.train()
    for batch_idx, (input, target) in enumerate(trainloader):
        batch_start_time = time.time()
        input = input.float().cuda()
        target = target.cuda()

        size = input.shape[1:]
        input = torch.stack([torch.rot90(input, k, (2, 3)) for k in range(4)], 1).view(-1, *size)
        labels = torch.stack([target*4+i for i in range(4)], 1).view(-1)

        if epoch < warmup_epoch:
            lr = adjust_lr(optimizer, epoch, args, batch_idx, len(trainloader))

        optimizer.zero_grad()
        logits, ss_logits = net(input, grad=True)
        with torch.no_grad():
            t_logits, t_ss_logits = tnet(input)

        loss_cls = torch.tensor(0.).cuda()
        loss_div = torch.tensor(0.).cuda()

        loss_cls = loss_cls + criterion_cls(logits[0::4], target)
        for i in range(len(ss_logits)):
            loss_div = loss_div + criterion_div(ss_logits[i], t_ss_logits[i].detach())
        
        loss_div = loss_div + criterion_div(logits, t_logits.detach())
        
            
        loss = loss_cls + loss_div
        loss.backward()
        optimizer.step()


        train_loss += loss.item() / len(trainloader)
        train_loss_cls += loss_cls.item() / len(trainloader)
        train_loss_div += loss_div.item() / len(trainloader)

        for i in range(len(ss_logits)):
            top1, top5 = correct_num(ss_logits[i], labels, topk=(1, 5))
            ss_top1_num[i] += top1
            ss_top5_num[i] += top5
        
        class_logits = [torch.stack(torch.split(ss_logits[i], split_size_or_sections=4, dim=1), dim=1).sum(dim=2) for i in range(len(ss_logits))]
        multi_target = target.view(-1, 1).repeat(1, 4).view(-1)
        for i in range(len(class_logits)):
            top1, top5 = correct_num(class_logits[i], multi_target, topk=(1, 5))
            class_top1_num[i] += top1
            class_top5_num[i] += top5

        logits = logits.view(-1, 4, num_classes)[:, 0, :]
        top1, top5 = correct_num(logits, target, topk=(1, 5))
        top1_num += top1
        top5_num += top5
        total += target.size(0)

        print('Epoch:{}, batch_idx:{}/{}, lr:{:.5f}, Duration:{:.2f}, Top-1 Acc:{:.4f}'.format(
            epoch, batch_idx, len(trainloader), lr, time.time()-batch_start_time, (top1_num/(total)).item()))


    ss_acc1 = [round((ss_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)]
    ss_acc5 = [round((ss_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)]
    class_acc1 = [round((class_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top1_num/(total)).item(), 4)]
    class_acc5 = [round((class_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top5_num/(total)).item(), 4)]
    
    print('Train epoch:{}\nTrain Top-1 ss_accuracy: {}\nTrain Top-1 class_accuracy: {}\n'.format(epoch, str(ss_acc1), str(class_acc1)))

    with open(log_txt, 'a+') as f:
        f.write('Epoch:{}\t lr:{:.5f}\t duration:{:.3f}'
                '\n train_loss:{:.5f}\t train_loss_cls:{:.5f}\t train_loss_div:{:.5f}'
                '\nTrain Top-1 ss_accuracy: {}\nTrain Top-1 class_accuracy: {}\n'
                .format(epoch, lr, time.time() - start_time,
                        train_loss, train_loss_cls, train_loss_div,
                        str(ss_acc1), str(class_acc1)))


In [34]:
def test(epoch, criterion_cls, net):
    global best_acc
    test_loss_cls = 0.

    ss_top1_num = [0] * (num_auxiliary_branches)
    ss_top5_num = [0] * (num_auxiliary_branches)
    class_top1_num = [0] * num_auxiliary_branches
    class_top5_num = [0] * num_auxiliary_branches
    top1_num = 0
    top5_num = 0
    total = 0
    
    net.eval()
    with torch.no_grad():
        for batch_idx, (inputs, target) in enumerate(testloader):
            batch_start_time = time.time()
            input, target = inputs.cuda(), target.cuda()

            size = input.shape[1:]
            input = torch.stack([torch.rot90(input, k, (2, 3)) for k in range(4)], 1).view(-1, *size)
            labels = torch.stack([target*4+i for i in range(4)], 1).view(-1)
            
            logits, ss_logits = net(input)
            loss_cls = torch.tensor(0.).cuda()
            loss_cls = loss_cls + criterion_cls(logits[0::4], target)

            test_loss_cls += loss_cls.item()/ len(testloader)

            batch_size = logits.size(0) // 4
            for i in range(len(ss_logits)):
                top1, top5 = correct_num(ss_logits[i], labels, topk=(1, 5))
                ss_top1_num[i] += top1
                ss_top5_num[i] += top5
                
            class_logits = [torch.stack(torch.split(ss_logits[i], split_size_or_sections=4, dim=1), dim=1).sum(dim=2) for i in range(len(ss_logits))]
            multi_target = target.view(-1, 1).repeat(1, 4).view(-1)
            for i in range(len(class_logits)):
                top1, top5 = correct_num(class_logits[i], multi_target, topk=(1, 5))
                class_top1_num[i] += top1
                class_top5_num[i] += top5

            logits = logits.view(-1, 4, num_classes)[:, 0, :]
            top1, top5 = correct_num(logits, target, topk=(1, 5))
            top1_num += top1
            top5_num += top5
            total += target.size(0)
            

            print('Epoch:{}, batch_idx:{}/{}, Duration:{:.2f}, Top-1 Acc:{:.4f}'.format(
                epoch, batch_idx, len(testloader), time.time()-batch_start_time, (top1_num/(total)).item()))

        ss_acc1 = [round((ss_top1_num[i]/(total*4)).item(), 4) for i in range(len(ss_logits))]
        ss_acc5 = [round((ss_top5_num[i]/(total*4)).item(), 4) for i in range(len(ss_logits))]
        class_acc1 = [round((class_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top1_num/(total)).item(), 4)]
        class_acc5 = [round((class_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top5_num/(total)).item(), 4)]
        with open(log_txt, 'a+') as f:
            f.write('test epoch:{}\t test_loss_cls:{:.5f}\nTop-1 ss_accuracy: {}\nTop-1 class_accuracy: {}\n'
                    .format(epoch, test_loss_cls, str(ss_acc1), str(class_acc1)))
        print('test epoch:{}\nTest Top-1 ss_accuracy: {}\nTest Top-1 class_accuracy: {}\n'.format(epoch, str(ss_acc1), str(class_acc1)))

    return class_acc1[-1]

In [40]:
best_acc = 0.  
start_epoch = 0  
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(kd_T)
args = {'warmup_epoch': warmup_epoch, 'init_lr': init_lr, 'milestones':milestones}
evaluate = False
resume = True
if evaluate: 
        print('load pre-trained weights from: {}'.format(os.path.join(checkpoint_dir, str(model.__name__) + '.pth.tar')))     
        checkpoint = torch.load(os.path.join(checkpoint_dir, str(model.__name__) + '.pth.tar'),
                                map_location=torch.device('cpu'))
        net.module.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        test(start_epoch, criterion_cls, net)
else:
    print('Evaluate Teacher:')
    acc = test(0, criterion_cls, tnet)
    print('Teacher Acc:', acc)

    trainable_list = nn.ModuleList([])
    trainable_list.append(net)
    optimizer = optim.SGD(trainable_list.parameters(),
                          lr=0.1, momentum=0.9, weight_decay=weight_decay, nesterov=True)

    criterion_list = nn.ModuleList([])
    criterion_list.append(criterion_cls)  # classification loss
    criterion_list.append(criterion_div)  # KL divergence loss, original knowledge distillation
    criterion_list.cuda()


    if resume:
        print('load pre-trained weights from: {}'.format(os.path.join(checkpoint_dir, str(model.__name__) + '.pth.tar')))
        checkpoint = torch.load(os.path.join(checkpoint_dir, str(model.__name__) + '.pth.tar'),
                                map_location=torch.device('cpu'))
        net.module.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1

    for epoch in range(start_epoch, epochs):
        train(epoch, criterion_list, optimizer)
        acc = test(epoch, criterion_cls, net)

        state = {
            'net': net.module.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'optimizer': optimizer.state_dict()
        }
        torch.save(state, os.path.join(checkpoint_dir, str(model.__name__) + '.pth.tar'))

        is_best = False
        if best_acc < acc:
            best_acc = acc
            is_best = True

        if is_best:
            shutil.copyfile(os.path.join(checkpoint_dir, str(model.__name__) + '.pth.tar'),
                            os.path.join(checkpoint_dir, str(model.__name__) + '_best.pth.tar'))

    print('Evaluate the best model:')
    print('load pre-trained weights from: {}'.format(os.path.join(checkpoint_dir, str(model.__name__) + '_best.pth.tar')))
    evaluate = True
    checkpoint = torch.load(os.path.join(checkpoint_dir, str(model.__name__) + '_best.pth.tar'),
                            map_location=torch.device('cpu'))
    net.module.load_state_dict(checkpoint['net'])
    start_epoch = checkpoint['epoch']
    top1_acc = test(start_epoch, criterion_cls, net)

    with open(log_txt, 'a+') as f:
        f.write('best_accuracy: {} \n'.format(best_acc))
        print('best_accuracy: {} \n'.format(best_acc))
        os.system('cp ' + log_txt + ' ' + checkpoint_dir)

Evaluate Teacher:
Epoch:0, batch_idx:0/157, Duration:0.05, Top-1 Acc:0.7812
Epoch:0, batch_idx:1/157, Duration:0.04, Top-1 Acc:0.8047
Epoch:0, batch_idx:2/157, Duration:0.04, Top-1 Acc:0.7865
Epoch:0, batch_idx:3/157, Duration:0.04, Top-1 Acc:0.8008
Epoch:0, batch_idx:4/157, Duration:0.04, Top-1 Acc:0.8000
Epoch:0, batch_idx:5/157, Duration:0.04, Top-1 Acc:0.7995
Epoch:0, batch_idx:6/157, Duration:0.04, Top-1 Acc:0.7924
Epoch:0, batch_idx:7/157, Duration:0.04, Top-1 Acc:0.8027
Epoch:0, batch_idx:8/157, Duration:0.04, Top-1 Acc:0.7830
Epoch:0, batch_idx:9/157, Duration:0.04, Top-1 Acc:0.7891
Epoch:0, batch_idx:10/157, Duration:0.04, Top-1 Acc:0.7926
Epoch:0, batch_idx:11/157, Duration:0.04, Top-1 Acc:0.7956
Epoch:0, batch_idx:12/157, Duration:0.04, Top-1 Acc:0.7993
Epoch:0, batch_idx:13/157, Duration:0.04, Top-1 Acc:0.7991
Epoch:0, batch_idx:14/157, Duration:0.04, Top-1 Acc:0.7948
Epoch:0, batch_idx:15/157, Duration:0.04, Top-1 Acc:0.7900
Epoch:0, batch_idx:16/157, Duration:0.04, Top-1 

Epoch:0, batch_idx:139/157, Duration:0.04, Top-1 Acc:0.7932
Epoch:0, batch_idx:140/157, Duration:0.04, Top-1 Acc:0.7933
Epoch:0, batch_idx:141/157, Duration:0.04, Top-1 Acc:0.7928
Epoch:0, batch_idx:142/157, Duration:0.04, Top-1 Acc:0.7922
Epoch:0, batch_idx:143/157, Duration:0.04, Top-1 Acc:0.7926
Epoch:0, batch_idx:144/157, Duration:0.04, Top-1 Acc:0.7926
Epoch:0, batch_idx:145/157, Duration:0.04, Top-1 Acc:0.7925
Epoch:0, batch_idx:146/157, Duration:0.04, Top-1 Acc:0.7927
Epoch:0, batch_idx:147/157, Duration:0.04, Top-1 Acc:0.7927
Epoch:0, batch_idx:148/157, Duration:0.04, Top-1 Acc:0.7936
Epoch:0, batch_idx:149/157, Duration:0.04, Top-1 Acc:0.7934
Epoch:0, batch_idx:150/157, Duration:0.04, Top-1 Acc:0.7936
Epoch:0, batch_idx:151/157, Duration:0.04, Top-1 Acc:0.7939
Epoch:0, batch_idx:152/157, Duration:0.04, Top-1 Acc:0.7935
Epoch:0, batch_idx:153/157, Duration:0.04, Top-1 Acc:0.7938
Epoch:0, batch_idx:154/157, Duration:0.04, Top-1 Acc:0.7941
Epoch:0, batch_idx:155/157, Duration:0.0

Epoch:234, batch_idx:114/157, Duration:0.02, Top-1 Acc:0.7793
Epoch:234, batch_idx:115/157, Duration:0.02, Top-1 Acc:0.7792
Epoch:234, batch_idx:116/157, Duration:0.02, Top-1 Acc:0.7786
Epoch:234, batch_idx:117/157, Duration:0.02, Top-1 Acc:0.7790
Epoch:234, batch_idx:118/157, Duration:0.02, Top-1 Acc:0.7789
Epoch:234, batch_idx:119/157, Duration:0.02, Top-1 Acc:0.7786
Epoch:234, batch_idx:120/157, Duration:0.02, Top-1 Acc:0.7782
Epoch:234, batch_idx:121/157, Duration:0.02, Top-1 Acc:0.7778
Epoch:234, batch_idx:122/157, Duration:0.02, Top-1 Acc:0.7772
Epoch:234, batch_idx:123/157, Duration:0.02, Top-1 Acc:0.7772
Epoch:234, batch_idx:124/157, Duration:0.02, Top-1 Acc:0.7771
Epoch:234, batch_idx:125/157, Duration:0.02, Top-1 Acc:0.7770
Epoch:234, batch_idx:126/157, Duration:0.02, Top-1 Acc:0.7773
Epoch:234, batch_idx:127/157, Duration:0.02, Top-1 Acc:0.7773
Epoch:234, batch_idx:128/157, Duration:0.02, Top-1 Acc:0.7773
Epoch:234, batch_idx:129/157, Duration:0.02, Top-1 Acc:0.7774
Epoch:23