# Fun with distillation
### Lets see how much we can learn by distilling

The dataset we will use is CIFAR10, CIFAR100 and Wikitext-2

In [1]:
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision import transforms
from torch.utils import data
import torchvision.models as models
from utils import progress_bar
import torch
from torch import nn
import torch.optim as optim
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning.metrics.functional import accuracy

In [2]:
class ImageClassificationModel(pl.LightningModule):

    def __init__(self, encoder, classes=10):
        super().__init__()
        out_features = list(encoder.children())[-1].out_features
        self.net = nn.Sequential(encoder, nn.Linear(out_features, classes))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        predictions = self.net(x)
        return predictions

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        prediction = self.net(x)
        loss = F.cross_entropy(prediction, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.net(x)
        loss = F.cross_entropy(output, y)
        preds = torch.argmax(output, dim=1)
        acc = accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.net.parameters(), lr=0.02)
        return optimizer


In [3]:
def get_dataloaders():
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = CIFAR10(root='./data', train=True, download=False, transform=transform_test)

    trainloader = data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=1)
    testloader = data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=1)
    return (trainloader, testloader)

In [4]:
trainloader, testloader = get_dataloaders()

autoencoder = ImageClassificationModel(models.resnet18(pretrained=True))
trainer = pl.Trainer()
trainer.fit(autoencoder, trainloader, testloader)

Files already downloaded and verified


GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name | Type       | Params
------------------------------------
0 | net  | Sequential | 11.7 M
------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Traceback (most recent call last):
  File "/Users/chriszhu/.pyenv/versions/3.7.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Users/chriszhu/.pyenv/versions/3.7.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/chriszhu/.pyenv/versions/3.7.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/chriszhu/.pyenv/versions/3.7.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe





1

1000