# Knowledge Distillation

Computes the Knowledge Distillation (KD) loss and trains based on KD.

Hyperparameters:

* alpha
* temperature

Reference:

* https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
* https://github.com/IntelLabs/distiller/blob/master/distiller/knowledge_distillation.py#L135

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Loss Function

In [2]:
def criterion_KD(outputs, labels, teacher_outputs, params: dict = {}):
    if 'alpha' not in params:
        params['alpha'] = 0.1
    if 'temperature' not in params:
        params['temperature'] = 3

    alpha, temperature = params['alpha'], params['temperature']

    loss_KD = nn.KLDivLoss()(
        F.log_softmax(outputs / temperature, dim=1),
        F.softmax(teacher_outputs / temperature, dim=1)
    ) * (alpha * temperature * temperature) + \
        F.cross_entropy(outputs, labels) * (1. - alpha)
    return loss_KD

## Train

In [3]:
def train_KD(student, teacher, criterion_kd, optimizer, dataloader, params: dict = {}):
    """Trains `student` network
    
    Using KD with (pre-trained) `teacher` network
    
    Refers https://keras.io/examples/vision/knowledge_distillation/ .

    Parameters
    ----------
    student: net
        Trained by `teacher`.
    teacher: net
        Trains `student`.
    criterion_kd: function
        Loss function. See `criterion_KD`.
    optimizer: optimizer
        Optimizer.
    dataloader: data loader
        Data loader.
    params: dict
        Contains 'epoch', 'cuda', 'log' and 'logFile'.
        (Maybe) Contains 'alpha' and 'temperature' which are required
        for `criterion_kd` function.
    """

    if 'epoch' not in params:
        params['epoch'] = 0
    if 'cuda' not in params:
        params['cuda'] = False
    if 'log' not in params:
        params['log'] = False
    if 'logFile' not in params:
        if params['log']:
            params['logFile'] = open('test.csv', 'w')
        else:
            params['logFile'] = None

    epoch, cuda, log, logFile = params['epoch'], params['cuda'], params['log'], params['logFile']

    student.train()  # tells student to do training
    teacher.eval()  # tells teacher to eval

    # for log
    nProcessed = 0
    nTrain = len(dataloader.dataset)

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        # sets gradient to 0
        optimizer.zero_grad()

        # forward, backward, and opt
        outputs, teacher_outputs = student(inputs), teacher(inputs)
        loss = criterion_kd(outputs, targets, teacher_outputs, params=params)
        loss.backward()
        optimizer.step()

        # for log
        nProcessed += len(inputs)
        pred = outputs.data.max(1)[1]  # get the index of the max log-probability
        incorrect = pred.ne(targets.data).cpu().sum()  # ne: not equal
        err = 100. * incorrect / len(inputs)
        partialEpoch = epoch + batch_idx / len(dataloader)

        if log and (logFile is not None):  # saves at csv file
            logFile.write('{},{},{}\n'.format(partialEpoch, loss.item(), err))
            logFile.flush()

        else:  # print at STDOUT
            print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tError: {:.6f}'.format(
                partialEpoch, nProcessed, nTrain, 100. * batch_idx / len(dataloader), loss.item(), err
            ), end='\r')

# main

In [4]:
if __name__ == "__main__":
    import torch.optim as optim

    import torchvision.models as models
    import torchvision.datasets as dset
    import torchvision.transforms as transforms

    from torch.utils.data import DataLoader  # TODO: DistributedDataParallel

    import import_ipynb
    from machineLearning import test

    """Hyperparams"""
    numWorkers = 4
    cuda = True

    base_path = './kd'
    os.makedirs(base_path, exist_ok=True)

    trainFile = open(os.path.join(base_path, 'train.csv'), 'w')
    testFile = open(os.path.join(base_path, 'test.csv'), 'w')

    epochs = 2
    batchSz = 256

    """Datasets"""
    # # gets mean and std
    # transform = transforms.Compose([transforms.ToTensor()])
    # dataset = dset.CIFAR10(root='cifar', train=True, download=True, transform=transform)
    # normMean, normStd = utils.getNorm(dataset)
    normMean = [0.49139968, 0.48215841, 0.44653091]
    normStd = [0.24703223, 0.24348513, 0.26158784]
    normTransform = transforms.Normalize(normMean, normStd)

    trainTransform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normTransform
    ])
    testTransform = transforms.Compose([
        transforms.ToTensor(),
        normTransform
    ])

    # num_workers: number of CPU cores to use for data loading
    # pin_memory: being able to speed up the host to device transfer by enabling
    kwargs = {'num_workers': numWorkers, 'pin_memory': cuda}

    # loaders
    trainLoader = DataLoader(
        dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform),
        batch_size=batchSz, shuffle=True, **kwargs
    )
    testLoader = DataLoader(
        dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform),
        batch_size=batchSz, shuffle=False, **kwargs
    )

    """Net"""
    teacher = models.resnet152(pretrained=True)
    student = models.resnet18()

    if cuda:
        teacher = nn.DataParallel(teacher)  # multi-GPUs
        teacher.cuda()
        student = nn.DataParallel(student)  # multi-GPUs
        student.cuda()

    criterion = nn.CrossEntropyLoss()  # student's loss function
    optimizer = optim.SGD(student.parameters(), lr=1e-1, momentum=0.9)  # student's

    """Train & Test"""
    for epoch in range(epochs):
        train_KD(student, teacher, criterion_KD, optimizer, trainLoader, params={
            'epoch': epoch,
            'cuda': cuda,
            'log': True,
            'logFile': trainFile,
            'alpha': 0.5,
            'temperature': 5
        })
        test(student, criterion, testLoader, params={
            'epoch': 0,
            'cuda': cuda,
            'log': True,
            'logFile': testFile
        })

importing Jupyter notebook from machineLearning.ipynb
Files already downloaded and verified
Files already downloaded and verified
