# Knowledge Distillation
Created in PyTorch by [Laia Tarrés](https://www.linkedin.com/in/laia-tarres-9a5369138) for the [Postgraduate Course in Artificial Intelligence with Deep Learning](https://www.talent.upc.edu/ing/estudis/formacio/curs/310400/postgrau-artificial-intelligence-deep-learning/) ([UPC School](https://www.talent.upc.edu/ing/), 2021).

Updated by [Gerard I. Gállego](https://www.linkedin.com/in/gerard-gallego/).

*Based on other notebooks that use distillation [1](https://colab.research.google.com/github/sayakpaul/Knowledge-Distillation-in-Keras/blob/master/Distillation_with_Transfer_Learning.ipynb#scrollTo=b1jE623hh781), [2](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/knowledge_distillation.ipynb), [3](https://colab.research.google.com/drive/1-yHSQTljXyca2aSFhpM2y4n9B4M-KWso#scrollTo=3JApQdNz19bT), [4](https://koushik-nov01.medium.com/knowledge-distillation-with-pytorch-40febcf77440) for educational purposes.*

Modern state-of-the-art neural network architectures are HUGE.

Unfortunately, more is sometimes not better when it comes to the number of parameters. Sure, more parameters seem to mean better results, but also massive computational costs.

However, deploying much smaller models can also present a significant challenge for machine learning engineers. In practice, small and fast models are much better than massive ones.

Because of this, researchers and engineers have put significant energy into compressing models.

To optimize these costs by compressing the models, three main methods have emerged:

*   Weight pruning
*   Quantization
*   knowledge distillation


Today we will focus on Knowledge Distillation. Knowledge Distillation is a procedure for model compression, in which a small (student) model is trained to match a large pre-trained (teacher) model.

Knowledge is transferred from the teacher model to the student by minimizing a loss function, aimed at matching the teacher outputs as well as ground-truth labels.

**Reference:**

- [Hinton et al. (2015)](https://arxiv.org/abs/1503.02531)

#Setup

In [1]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [2]:
# Necessary imports
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
from timeit import default_timer as timer

# For reproducibility
torch.manual_seed(0)
torch.backends.cudnn.benchmark = True

Define hyerparameters, and remember to set the runtime type accelerator as GPU.

In [4]:
hparams = {
    'batch_size':32,
    'num_epochs':3,
    'num_classes':10,
    'learning_rate':1e-4,
    'learning_rate_dist': 5e-3,
    'log_interval':2000,
}
hparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
assert(hparams['device']=='cuda')

# Define MNIST dataset and dataloaders

In [5]:
train_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)
# Create train and test dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=hparams['batch_size'], shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=hparams['batch_size'], shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 57.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.69MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.9MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.35MB/s]


# Define the Teacher Model

In [6]:
class TeacherModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=64,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=64,
            out_channels=256,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.fc1 = nn.Linear(256 * 7 * 7, hparams['num_classes'])

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

## Exercise 1: Declare the teacher model, and list the number of parameters.

In [8]:
# # TODO: Declare the teacher model
# teacher_model = ...
# # TODO: List its parameters, given an input of [bs, nchannels, width, depth], using the summary function from torchinfo
# summary(...)

# Declare the teacher model
teacher_model = TeacherModel(in_channels=1, num_classes=hparams['num_classes'])

# List its parameters using torchinfo.summary
# Input size: (batch_size, channels, height, width)
summary(teacher_model, input_size=(1, 1, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
TeacherModel                             [1, 10]                   --
├─Conv2d: 1-1                            [1, 64, 28, 28]           640
├─MaxPool2d: 1-2                         [1, 64, 14, 14]           --
├─Conv2d: 1-3                            [1, 256, 14, 14]          147,712
├─MaxPool2d: 1-4                         [1, 256, 7, 7]            --
├─Linear: 1-5                            [1, 10]                   125,450
Total params: 273,802
Trainable params: 273,802
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 29.58
Input size (MB): 0.00
Forward/backward pass size (MB): 0.80
Params size (MB): 1.10
Estimated Total Size (MB): 1.90

# Define the student model.

In [9]:
class StudentModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.fc1 = nn.Linear(16 * 7 * 7, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

## Exercise 2: Declare the student model, and list the number of parameters.

In [10]:
# Declare the student model
student_model = StudentModel(in_channels=1, num_classes=10)

# Show the model summary
summary(student_model, input_size=(1, 1, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
StudentModel                             [1, 10]                   --
├─Conv2d: 1-1                            [1, 8, 28, 28]            80
├─MaxPool2d: 1-2                         [1, 8, 14, 14]            --
├─Conv2d: 1-3                            [1, 16, 14, 14]           1,168
├─MaxPool2d: 1-4                         [1, 16, 7, 7]             --
├─Linear: 1-5                            [1, 10]                   7,850
Total params: 9,098
Trainable params: 9,098
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.30
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.04
Estimated Total Size (MB): 0.11

Let's define a helper function that computes the accuracy and the number of correct predictions.

In [11]:
def check_accuracy(loader, model, device):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)


    model.train()
    return (num_correct/num_samples).item()

def correct_predictions(predicted_batch, label_batch):
  pred = predicted_batch.argmax(dim=1, keepdim=True) # get the index of the max log-probability
  acum = pred.eq(label_batch.view_as(pred)).sum().item()
  return acum

Let's define a basic training pipeline for any network.

#Exercise 3: define the basic training pipeline

In [13]:
def train_model(model, criterion, optimizer, train_loader, epochs):
    for epoch in range(epochs):
        model.train()
        losses = []
        device = hparams['device']
        model.to(device)

        pbar = tqdm(train_loader, total=len(train_loader), position=0, leave=True, desc=f"Epoch {epoch}")
        for data, targets in pbar:
            data = data.to(device)
            targets = targets.to(device)

            # # TODO: forward method
            # scores = ...
            # loss = ...

            # losses.append(loss.item())

            # # TODO: backward pass
            # loss...
            # optimizer...
            # optimizer...
            # ✅ Forward pass
            scores = model(data)
            loss = criterion(scores, targets)

            losses.append(loss.item())

            # ✅ Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        avg_loss = sum(losses) / len(losses)
        acc = check_accuracy(test_loader, model, device)
        print(f"Loss:{avg_loss:.2f}\tAccuracy:{acc:.2f}")

    return model

def test_model(model, test_loader):
  model.eval()
  device = hparams['device']
  eval_loss = 0
  acc = 0
  logsoftmax = nn.LogSoftmax(dim=-1)
  beg_t = timer()
  with torch.no_grad():
      for data, target in test_loader:
          data, target = data.to(device), target.to(device)
          output = logsoftmax(model(data))
          # compute number of correct predictions in the batch
          acc += correct_predictions(output, target)
  # Average acc across all correct predictions batches now
  end_t = timer()
  train_time = end_t - beg_t
  test_acc = 100. * acc / len(test_loader.dataset)
  print('Test set:  Accuracy: {}/{} ({:.0f}%) Time to test: {} seconds.'.format(
       acc, len(test_loader.dataset), test_acc, round(train_time, 2),
      ))
  return test_acc

##Train the teacher model

In [17]:
# TODO: Declare the teacher model, criterion and optimizer
# teacher_model = ...
# criterion = nn...
# optimizer = torch.optim.Adam(..., ...)

# TODO: Declare the teacher model, criterion and optimizer
teacher_model = TeacherModel(in_channels=1, num_classes=hparams['num_classes'])

criterion = nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(teacher_model.parameters(), lr=hparams['lr'])
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=1e-3)
teacher_model = train_model(teacher_model, criterion, optimizer, train_loader, epochs=hparams['num_epochs'])
test_model(teacher_model, test_loader)

Epoch 0: 100%|██████████| 1875/1875 [00:11<00:00, 158.64it/s]


Loss:0.12	Accuracy:0.99


Epoch 1: 100%|██████████| 1875/1875 [00:11<00:00, 163.40it/s]


Loss:0.04	Accuracy:0.99


Epoch 2: 100%|██████████| 1875/1875 [00:10<00:00, 174.68it/s]


Loss:0.03	Accuracy:0.99
Test set:  Accuracy: 9905/10000 (99%) Time to test: 1.66 seconds.


99.05

# Exercise 4: Perform knowledge distillation (transfer knowledge from the teacher to the student)

In this example, we have two losses that are combined to obtain the loss that will be backpropagated in order to train the student.

We have:


*   **Classification loss (student loss)**: the typical loss: in this case, the network is outputing right before the softmax. We apply CrossEntropyLoss.
*   **Distillation loss**: in this loss, we are comparing the softened outputs from the softmax. As the model is outputing right before the softmax, we will have to apply the softmax with the corresponding temperature term and then MSELoss.

In [19]:
def train_step(teacher, student, optimizer, classification_loss_fn, distillation_loss_fn, temp, alpha, epoch, device):
    losses = []
    pbar = tqdm(train_loader, total=len(train_loader), position=0, leave=True, desc=f"Epoch {epoch}")
    device = hparams['device']

    for data, targets in pbar:
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # with torch.no_grad():
        #     # TODO: Compute teacher soft predictions
        #     teacher_preds = ...

        # # TODO: Get the student soft targets and compute the classification loss between
        # student_preds = ...

        # # TODO: compute the classification loss
        # student_loss = ...
        # # TODO: Compute the distillation loss. Remember, that we are comparing the outputs of the softmax for both predictions
        # distillation_loss = ...( F.softmax(... / temp, dim=1), F.softmax(.../ temp, dim=1) )

        # loss = alpha * student_loss + (1 - alpha) * distillation_loss
        # losses.append(loss.item())

        # # TODO: backward pass and update optimizer.
        # loss...
        # optimizer...
        # optimizer...
        with torch.no_grad():
            # ✅ Teacher soft predictions (logits)
            teacher_preds = teacher(data)

        # ✅ Student predictions (logits)
        student_preds = student(data)

        # ✅ Classification loss (between student logits and labels)
        student_loss = classification_loss_fn(student_preds, targets)

        # ✅ Distillation loss (between softened softmax outputs)
        distillation_loss = distillation_loss_fn(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )

        # ✅ Combine losses
        loss = alpha * student_loss + (1 - alpha) * distillation_loss
        losses.append(loss.item())

        # ✅ Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    avg_loss = sum(losses) / len(losses)
    return avg_loss

def train_distillation(teacher, student, optimizer, classification_loss_fn, distillation_loss_fn, epochs, temp=7, alpha=0.3, device='cuda'):
    device = hparams['device']
    teacher = teacher.to(device)
    student = student.to(device)
    teacher.eval()
    student.train()
    for epoch in range(epochs):
        loss = train_step(
            teacher,
            student,
            optimizer,
            classification_loss_fn,
            distillation_loss_fn,
            temp,
            alpha,
            epoch,
            device
        )
        acc = check_accuracy(test_loader, student, device)
        print(f"Loss:{loss:.2f}\tAccuracy:{acc:.2f}")

Let's check what temperature scaling is doing, and how it is "flattening" the outputs of the softmax.

In [20]:
import numpy as np
logits=np.array([1.,2.,3.,-1.])
print(f'Logits: {logits}')
logits_exp = np.exp(logits)
print(f'Logits exp: {logits_exp}')
logits_exp_normalized = np.exp(logits)/sum(np.exp(logits)) #this would be like applying the softmax
print(f'Logits exp normalized: {logits_exp_normalized}')

#Let's try with a few values of T:
T = [1.,5.,7.,10.]

for t in T:
  logits_exp_normalized_t = np.exp(logits/t)/sum(np.exp(logits/t))
  print(f'Temperature[{t}] - {logits_exp_normalized_t}')

Logits: [ 1.  2.  3. -1.]
Logits exp: [ 2.71828183  7.3890561  20.08553692  0.36787944]
Logits exp normalized: [0.08894682 0.24178252 0.65723302 0.01203764]
Temperature[1.0] - [0.08894682 0.24178252 0.65723302 0.01203764]
Temperature[5.0] - [0.22812574 0.2786334  0.34032361 0.15291725]
Temperature[7.0] - [0.23608545 0.27233991 0.31416179 0.17741285]
Temperature[10.0] - [0.24123681 0.2666079  0.2946473  0.19750799]


##Exercise 5: call the distillation function

You should declare the two types of losses:


1.   The appropiate for classification, when the model doesn't have a last activation layer. Categorical Cross entropy is recommended.
2.   The appropiate for distillation, which will be able to compare between softened outputs of the softmax. MSE is recommended

In [22]:
# TODO: declare what you need to call the train_distillation function
#
# Declare the student model
student_model = StudentModel(in_channels=1, num_classes=hparams['num_classes'])

# Declare the two types of losses
classification_loss_fn = nn.CrossEntropyLoss()
distillation_loss_fn = nn.MSELoss()

# Declare the optimizer
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)
train_distillation(teacher_model, student_model, optimizer, classification_loss_fn, distillation_loss_fn, epochs=hparams['num_epochs'], temp=6, alpha=0.2, device = hparams['device'])
test_model(student_model, test_loader)

Epoch 0: 100%|██████████| 1875/1875 [00:13<00:00, 142.37it/s]


Loss:0.06	Accuracy:0.97


Epoch 1: 100%|██████████| 1875/1875 [00:11<00:00, 157.73it/s]


Loss:0.02	Accuracy:0.98


Epoch 2: 100%|██████████| 1875/1875 [00:14<00:00, 130.29it/s]


Loss:0.01	Accuracy:0.99
Test set:  Accuracy: 9853/10000 (99%) Time to test: 1.17 seconds.


98.53

##Exercise 6: For comparison, let's train the student model from scratch.

In [23]:
# TODO: declare a new student model
# student_model= ....to(hparams['device'])
# criterion = nn...
# optimizer = torch.optim.Adam(..., ...)
# Declare a new student model (no knowledge distillation)
student_model = StudentModel(in_channels=1, num_classes=10).to(hparams['device'])

# Classification loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)
student_model = train_model(student_model, criterion, optimizer, train_loader, hparams['num_epochs'])
test_model(student_model, test_loader)

Epoch 0: 100%|██████████| 1875/1875 [00:10<00:00, 183.31it/s]


Loss:0.25	Accuracy:0.97


Epoch 1: 100%|██████████| 1875/1875 [00:10<00:00, 183.85it/s]


Loss:0.08	Accuracy:0.98


Epoch 2: 100%|██████████| 1875/1875 [00:12<00:00, 150.84it/s]


Loss:0.06	Accuracy:0.98
Test set:  Accuracy: 9819/10000 (98%) Time to test: 1.13 seconds.


98.19

# Conclusions

Yay! You have seen a didactic method of how to implement distillation. Although you probably haven't seen huge improvements in terms of accuracy, check how faster the student model is from the teacher when doing inference (test). And imagine how big of an impact that has when we are working with huge networks and with way bigger datasets than MNIST.

There are many uses for Distillation, but one of the most impactful have been DistilBERT, a distilled version of the famous BERT transformer.

#Extra:

Do some further experiments with distillation, which combination gives you better results when varying different values for:

*   temp
*   alpha
*   loss for distillation (note that we are using mse to compare the outputs of the softmax scaled, but we could also use divergence_loss_fn to compare the outputs of the log_softmax, among other loss functions).

In [27]:
# # TODO: declare what you need to call the train_distillation function
# student_model = ...
# student_loss_fn = ...
# mse_loss_fn = ...
# #divergence_loss_fn = nn.KLDivLoss(reduction="batchmean", log_target=True)
# optimizer = ...
# Experiment: declare student and losses
student_model = StudentModel(in_channels=1, num_classes=hparams['num_classes']).to(hparams['device'])

# Classification loss (typical for labels)
student_loss_fn = nn.CrossEntropyLoss()

# Distillation loss — choose one:
mse_loss_fn = nn.MSELoss()
# or for log-softmax outputs:
kl_div_loss_fn = nn.KLDivLoss(reduction='batchmean', log_target=True)

# Optimizer
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)
#train_distillation(teacher_model, student_model, optimizer, student_loss_fn, mse_loss_fn, epochs=hparams['num_epochs'], temp=..., alpha=..., device = hparams['device'])

train_distillation(
    teacher_model,
    student_model,
    optimizer,
    student_loss_fn,
    mse_loss_fn,
    epochs=hparams['num_epochs'],
    temp=6,       # 2, 4, 6, 8 등 시도 가능
    alpha=0.2,    # 0.1~0.7 정도 시도해보세요
    device=hparams['device']
)


test_model(student_model, test_loader)

Epoch 0: 100%|██████████| 1875/1875 [00:12<00:00, 152.57it/s]


Loss:0.06	Accuracy:0.97


Epoch 1: 100%|██████████| 1875/1875 [00:11<00:00, 159.94it/s]


Loss:0.02	Accuracy:0.98


Epoch 2: 100%|██████████| 1875/1875 [00:11<00:00, 159.33it/s]


Loss:0.02	Accuracy:0.98
Test set:  Accuracy: 9819/10000 (98%) Time to test: 1.51 seconds.


98.19

# Some cool Examples where they use distillation

## DistilBert
In the following example, you can experiment with one of the most famous aplications of distillation: DistilBERT. In this [paper](https://arxiv.org/abs/1910.01108) they proved that they could use a smaller version of the model with fewer parameters and less computational resources.

For comparison:

BERT had 110 million parameters, and has 668 inference time.
DistilBERT had 60 million parameters and has 410s inference time.

That is, reducing 40% the number of parameters and the network being faster without losing performance.

In [28]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

!pip install -q transformers datasets

from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer

In [None]:
distilbert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
distilbert = AutoModel.from_pretrained("distilbert-base-uncased")

def get_sents_representations(sents):
    encoded_input = distilbert_tokenizer(sents, return_tensors='pt', padding=True, truncation=True)

    distilbert_output = distilbert(**encoded_input)[0]
    sentence_repr = distilbert_output[:, 0]

    return distilbert_output, sentence_repr

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
#@title  { run: "auto", vertical-output: true }

#@markdown Show 5 DistilBERT sentence representations to 2-D
sentence_1 = "Hello, my name is Joe" #@param {type:"string"}
sentence_2 = "Hi, I'm Joey" #@param {type:"string"}
sentence_3 = "Goodbye, see you at 5pm" #@param {type:"string"}
sentence_4 = "Bye, see you later" #@param {type:"string"}
sentence_5 = "Attention is All You Need" #@param {type:"string"}


sentences = [sentence_1, sentence_2, sentence_3, sentence_4, sentence_5]

distilbert_output, sentence_repr = get_sents_representations(sentences)

print(f"DistilBERT output: {distilbert_output.shape}")
print(f"Sentence representations: {sentence_repr.shape}")
print("\n")

pca = PCA(n_components=2)
sentence_repr_2d = pca.fit_transform(sentence_repr.detach().numpy())

fig, ax = plt.subplots()
plt.scatter(sentence_repr_2d[:,0], sentence_repr_2d[:,1])
plt.title("Sentence representations (PCA projection)")
plt.xlim(sentence_repr_2d[:,0].min() - 1, sentence_repr_2d[:,0].max() + 4)
plt.ylim(sentence_repr_2d[:,1].min() - 1, sentence_repr_2d[:,1].max() + 1)

for x, y, s in zip(sentence_repr_2d[:,0], sentence_repr_2d[:,1], sentences):
    plt.text(x+0.15, y+0.15, s)

plt.show()

## TinyGAN
In the following example, you can experiment with a computer-vision related application: GANS.

one of the most famous aplications of distillation: DistilBERT. In this [paper](https://arxiv.org/abs/1910.01108) they proved that they could use a smaller version of the model with fewer parameters and less computational resources.

For comparison:

*   BigGAN had 50.1 million parameters for the Generator, that performed 8.32 flops.
*   TinyGAN had 3.1 million parameters for the Generator, that performed 0.44 flops.


That is, using a model that is 16 times smaller without loosing performance.

In [None]:
!git clone https://github.com/terarachang/ACCV_TinyGAN

In [None]:
cd ACCV_TinyGAN

In [None]:
import torch
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
from model import Generator
from utils import *
G = Generator(image_size=128, conv_dim=32, z_dim=128, c_dim=128, repeat_num=5)
restore_model(30, 'gan/models', G, None, None, None)
G.to(device)
G.eval()

Run this two cells as many times as you would like, to see different results.

In [None]:
z_dim = 128
n_row = 5
n_samples = n_row * n_row
noise = torch.FloatTensor(truncated_normal(n_samples*z_dim)) \
										.view(n_samples, z_dim).to(device)

label = np.random.choice(398, n_row, replace=False) # sample from all animal classes
print(label)
label_t = torch.tensor(label).repeat(n_row).to(device)

#get the 5 predictions prediction conditioned to the label for 5 samples
with torch.no_grad():
  out = G(noise, label_t).detach().cpu()

In [None]:
from torchvision.utils import save_image
from IPython.display import Image
save_image(denorm(out), 'demo.png', nrow=n_row)
Image(filename='demo.png')