# Knowledge Distillation

Computes the Knowledge Distillation (KD) loss and trains based on KD.

Reference:

* https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
* https://github.com/IntelLabs/distiller/blob/master/distiller/knowledge_distillation.py#L135

In [1]:
import torch.nn as nn
import torch.nn.functional as F

## Loss Function for KD

Using `batchmean` instead of `mean` at `KLDivLoss()`. See https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html .

In [2]:
def criterion_KD(
    outputs,
    labels,
    teacher_outputs,
    alpha: float = 0.1,
    temperature: float = 3.
):
    loss_KD = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(outputs / temperature, dim=1),
        F.softmax(teacher_outputs / temperature, dim=1)
    ) * (alpha * temperature * temperature) + \
        F.cross_entropy(outputs, labels) * (1. - alpha)
    return loss_KD

## Train for KD

### TODO

- [ ] logging time

In [3]:
def train_KD(
    student, teacher, criterion_kd, optimizer, dataloader,
    epoch: int = 0, cuda: bool = False, log: bool = False, log_file=None,
    **params
):
    """Train `student` network.
    
    Using KD with (pre-trained) `teacher` network.
    
    Refers https://keras.io/examples/vision/knowledge_distillation/ .

    Parameters
    ----------
    student : net
        Trained by `teacher`.
    teacher : net
        Trains `student`.
    criterion_kd : function
        Loss function. See `criterion_KD`.
    optimizer : optimizer
        Optimizer.
    dataloader : data loader
        Data loader.
    epoch : int
        Current epoch information for logging.
    cuda : bool
        Cuda available.
    log : bool
        Records logs on `log_file` when `log` is True.
        Or prints it on STDOUT.
    log_file : (file) stream
        Files where you want to record logs.
    """

    student.train()  # tells student to do training
    teacher.eval()  # tells teacher to eval

    # for log
    nProcessed = 0
    nTrain = len(dataloader.dataset)

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        # sets gradient to 0
        optimizer.zero_grad()

        # forward, backward, and opt
        outputs, teacher_outputs = student(inputs), teacher(inputs)
        loss = criterion_kd(outputs, targets, teacher_outputs, **params)
        loss.backward()
        optimizer.step()

        # for log
        nProcessed += len(inputs)
        pred = outputs.data.max(1)[1]  # get the index of the max log-probability
        incorrect = pred.ne(targets.data).cpu().sum()  # ne: not equal
        err = 100. * incorrect / len(inputs)
        partialEpoch = epoch + batch_idx / len(dataloader)

        if log and (log_file is not None):  # saves at csv file
            log_file.write('{},{},{}\n'.format(partialEpoch, loss.item(), err))
            log_file.flush()
        else:  # print at STDOUT
            print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tError: {:.6f}'.format(
                partialEpoch, nProcessed, nTrain, 100. * batch_idx / len(dataloader), loss.item(), err
            ))

# main

In [4]:
if __name__ == "__main__":
    import os

    import torch
    import torch.optim as optim

    import torchvision.datasets as dset
    import torchvision.transforms as transforms

    from torch.utils.data import DataLoader  # TODO: DistributedDataParallel

    import import_ipynb
    from ml import train, test, save
    import nets

    """Hyperparams"""
    numWorkers = 4
    cuda = True

    base_path = './kd_test'
    os.makedirs(base_path, exist_ok=True)

    teacher_path = os.path.join(base_path, 'teacher')
    os.makedirs(teacher_path, exist_ok=True)
    teacher_trainFile = open(os.path.join(teacher_path, 'train.csv'), 'w')
    teacher_testFile = open(os.path.join(teacher_path, 'test.csv'), 'w')
    teacher_netPath = os.path.join(teacher_path, 'net.pth')

    student_path = os.path.join(base_path, 'student')
    os.makedirs(student_path, exist_ok=True)
    trainFile = open(os.path.join(student_path, 'train.csv'), 'w')
    testFile = open(os.path.join(student_path, 'test.csv'), 'w')

    epochs = 2
    teacher_epochs = 10
    batchSz = 256

    """Datasets"""
    # # gets mean and std
    # transform = transforms.Compose([transforms.ToTensor()])
    # dataset = dset.CIFAR10(root='cifar', train=True, download=True, transform=transform)
    # normMean, normStd = dist.get_norm(dataset)
    normMean = [0.49139968, 0.48215841, 0.44653091]
    normStd = [0.24703223, 0.24348513, 0.26158784]
    normTransform = transforms.Normalize(normMean, normStd)

    trainTransform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normTransform
    ])
    testTransform = transforms.Compose([
        transforms.ToTensor(),
        normTransform
    ])

    # num_workers: number of CPU cores to use for data loading
    # pin_memory: being able to speed up the host to device transfer by enabling
    kwargs = {'num_workers': numWorkers, 'pin_memory': cuda}

    # loaders
    trainLoader = DataLoader(
        dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform),
        batch_size=batchSz, shuffle=True, **kwargs
    )
    testLoader = DataLoader(
        dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform),
        batch_size=batchSz, shuffle=False, **kwargs
    )

    """Nets"""
    num_classes = 10

    """Define `teacher`"""
    # """Transfer Learning"""
    # teacher = nets.resnext101_32x8d(pretrained=True)

    # # fixes parameters for `teacher`
    # for param in teacher.parameters():
    #     param.requires_grad = False

    # # replaces nodes (num_classes) of output layer (fc)
    # num_ftrs = teacher.fc.in_features
    # teacher.fc = nn.Linear(num_ftrs, num_classes)

    # # params to learn
    # params_to_update = []
    # for param in teacher.parameters():
    #     if param.requires_grad:
    #         params_to_update.append(param)

    teacher = nets.resnet18(num_classes=num_classes)

    teacher_criterion = nn.CrossEntropyLoss()
    teacher_optimizer = optim.SGD(teacher.parameters(), lr=1e-1, momentum=0.9)

    """Define `student`"""
    student = nets.resnet18(num_classes=num_classes)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(student.parameters(), lr=1e-1, momentum=0.9)

    if cuda:
        # if multi-gpus
        if torch.cuda.device_count() > 1:
            teacher = nn.DataParallel(teacher)
            student = nn.DataParallel(student)

        # use cuda
        teacher.cuda()
        student.cuda()

    """Train & Test `teacher`"""
    for epoch in range(teacher_epochs):
        train(
            teacher, teacher_criterion, teacher_optimizer, trainLoader,
            epoch=epoch, cuda=cuda, log=True, log_file=teacher_trainFile
        )
        test(
            teacher, criterion, testLoader,
            epoch=epoch, cuda=cuda, log=True, log_file=teacher_testFile
        )
        save(epoch, teacher, teacher_optimizer, teacher_netPath)

    """Train & Test `student`"""
    for epoch in range(epochs):
        train_KD(
            student, teacher, criterion_KD, optimizer, trainLoader,
            epoch=epoch, cuda=cuda, log=True, log_file=trainFile,
            alpha=0.9, temperature=4
        )
        test(
            student, criterion, testLoader,
            epoch=epoch, cuda=cuda, log=True, log_file=testFile
        )

importing Jupyter notebook from ml.ipynb
importing Jupyter notebook from nets.ipynb
Files already downloaded and verified
Files already downloaded and verified
