# Ejemplo de Knowledge Distillation
---

En este ejemplo vamos a ver como entrenar un modelo pequeño (número reducido capas y pesos) que replique el comportamiento de un modelo más grande que tiene un mayor accuracy pero con un tiempo de inferencia y uso de recursos mayor. Para ello, vamos a seguir un esquema teacher-student. De esta forma, esta técnica proporciona dos beneficios potenciales:


1.   Reducimos el tamaño del modelo por lo que ocupa menos en memoria y se ejecuta más rápido.
2.   Un modelo de tamaño de reducido con el rendimiento de uno mucho más complejo.

---

## 1. Instalar e importar las librerías necesarias

En este ejemplo vamos a trabajar con Pytorch y los modelos de torchvision

In [1]:
!pip3 install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [2]:
from torchvision.models import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights
from torchinfo import summary
import torch
import torchvision
import time
import numpy as np

## 2. Definir los modelos

Definimos el modelo teacher y el modelo student con una serie de capas básicas para obtener un rendimiento aceptable.

In [3]:
import torch.nn as nn
import torch.nn.functional as F


class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

teacher = Teacher()
summary(teacher, input_size=(1, 3, 32, 32), col_names=["input_size", "output_size", "num_params", "mult_adds"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Mult-Adds
Teacher                                  [1, 3, 32, 32]            [1, 10]                   --                        --
├─Conv2d: 1-1                            [1, 3, 32, 32]            [1, 32, 32, 32]           896                       917,504
├─MaxPool2d: 1-2                         [1, 32, 32, 32]           [1, 32, 16, 16]           --                        --
├─Conv2d: 1-3                            [1, 32, 16, 16]           [1, 64, 16, 16]           18,496                    4,734,976
├─MaxPool2d: 1-4                         [1, 64, 16, 16]           [1, 64, 8, 8]             --                        --
├─Conv2d: 1-5                            [1, 64, 8, 8]             [1, 128, 8, 8]            73,856                    4,726,784
├─MaxPool2d: 1-6                         [1, 128, 8, 8]            [1, 128, 4, 4]            --                        -

In [None]:
class Student(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


student = Student()
summary(student, input_size=(1, 3, 32, 32), col_names=["input_size", "output_size", "num_params", "mult_adds"])

Layer (type:depth-idx)                   Output Shape              Param #
Student                                  [1, 10]                   --
├─Conv2d: 1-1                            [1, 16, 32, 32]           448
├─MaxPool2d: 1-2                         [1, 16, 16, 16]           --
├─Conv2d: 1-3                            [1, 32, 16, 16]           4,640
├─MaxPool2d: 1-4                         [1, 32, 8, 8]             --
├─Conv2d: 1-5                            [1, 64, 8, 8]             18,496
├─MaxPool2d: 1-6                         [1, 64, 4, 4]             --
├─Conv2d: 1-7                            [1, 128, 4, 4]            73,856
├─MaxPool2d: 1-8                         [1, 128, 2, 2]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 128, 1, 1]            --
├─Linear: 1-10                           [1, 10]                   1,290
Total params: 98,730
Trainable params: 98,730
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 4.01
Input size (MB): 0.01
Forward/

## 3. Definir un data loader

Por limitaciones de tiempo de cómputo, vamos a trabajar con CIFAR-10 pero cualquier dataset es válido. Primero, tenemos que crear un DataLoader de Pytorch para poder usar los datos con nuestro modelo.


In [None]:
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_data_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_data_loader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=8)



##4. Entrenar modelos base
Una vez cargados los datos, realizamos un entrenamiento inicial de los modelos teacher y student y medimos su accuracy inicial.

In [None]:
n_epochs = 10
print(f'** Training Teacher **')
opt = torch.optim.Adam(teacher.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
teacher.train().to('cuda')
for epoch in range(n_epochs): # Entrenamos n epocas
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    time_start = time.time()
    for inputs, labels in train_data_loader: # Obtenemos todos los batch de entrenamiento y los usamos para entrenar
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        opt.zero_grad()
        outs_teacher = teacher(inputs)
        loss = loss_fn(outs_teacher, labels)
        train_running_loss += loss.item()
        _, preds = torch.max(outs_teacher.data, 1)
        train_running_correct += (preds == labels).sum().item()
        counter = counter + 1
        loss.backward()
        opt.step()

    epoch_loss = train_running_loss / counter
    epoch_acc = 100. * (train_running_correct / len(train_data_loader.dataset))
    time_end = time.time() - time_start
    print(f'** Summary for epoch {epoch}: '
		f'loss: {epoch_loss:#.3g}, acc: {epoch_acc:#.3g}]  '
		f'time: {time_end:.3f}s **')

# Test
test_correct = 0
with torch.no_grad():
    time_start = time.time()
    for inputs, labels in test_data_loader: # Obtenemos todos los batch de test y los usamos para test
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outs_teacher = teacher(inputs)
        _, preds = torch.max(outs_teacher.data, 1)
        test_correct += (preds == labels).sum().item()

    acc = 100. * (test_correct / len(test_data_loader.dataset))
    time_end = time.time() - time_start
    print(f'** Summary for teacher: '
		f'acc: {acc:#.3g}]  '
		f'time: {time_end:.3f}s **')

** Training Teacher **
** Summary for epoch 0: loss: 1.58, acc: 41.1]  time: 18.246s **
** Summary for epoch 1: loss: 1.15, acc: 58.4]  time: 15.279s **
** Summary for epoch 2: loss: 0.929, acc: 66.9]  time: 13.973s **
** Summary for epoch 3: loss: 0.786, acc: 72.0]  time: 14.027s **
** Summary for epoch 4: loss: 0.687, acc: 76.0]  time: 14.495s **
** Summary for epoch 5: loss: 0.595, acc: 79.0]  time: 14.184s **
** Summary for epoch 6: loss: 0.518, acc: 81.9]  time: 14.071s **
** Summary for epoch 7: loss: 0.453, acc: 84.0]  time: 14.147s **
** Summary for epoch 8: loss: 0.387, acc: 86.2]  time: 15.525s **
** Summary for epoch 9: loss: 0.326, acc: 88.5]  time: 13.958s **
** Summary for teacher: acc: 76.8]  time: 2.653s **


In [None]:
n_epochs = 10
print(f'** Training Student **')
opt = torch.optim.Adam(student.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
student.train().to('cuda')
for epoch in range(n_epochs): # Entrenamos n epocas
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    time_start = time.time()
    for inputs, labels in train_data_loader: # Obtenemos todos los batch de entrenamiento y los usamos para entrenar
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        opt.zero_grad()
        outs_student = student(inputs)
        loss = loss_fn(outs_student, labels)
        train_running_loss += loss.item()
        _, preds = torch.max(outs_student.data, 1)
        train_running_correct += (preds == labels).sum().item()
        counter = counter + 1
        loss.backward()
        opt.step()

    epoch_loss = train_running_loss / counter
    epoch_acc = 100. * (train_running_correct / len(train_data_loader.dataset))
    time_end = time.time() - time_start
    print(f'** Summary for epoch {epoch}: '
		f'loss: {epoch_loss:#.3g}, acc: {epoch_acc:#.3g}]  '
		f'time: {time_end:.3f}s **')

# Test
test_correct = 0
with torch.no_grad():
    time_start = time.time()
    for inputs, labels in test_data_loader: # Obtenemos todos los batch de test y los usamos para test
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outs_student = student(inputs)
        _, preds = torch.max(outs_student.data, 1)
        test_correct += (preds == labels).sum().item()

    acc = 100. * (test_correct / len(test_data_loader.dataset))
    time_end = time.time() - time_start
    print(f'** Summary for student: '
		f'acc: {acc:#.3g}]  '
		f'time: {time_end:.3f}s **') #70.2

** Training Student **
** Summary for epoch 0: loss: 1.72, acc: 36.5]  time: 14.386s **
** Summary for epoch 1: loss: 1.41, acc: 48.6]  time: 13.588s **
** Summary for epoch 2: loss: 1.27, acc: 54.4]  time: 13.798s **
** Summary for epoch 3: loss: 1.17, acc: 58.6]  time: 13.849s **
** Summary for epoch 4: loss: 1.06, acc: 62.3]  time: 13.761s **
** Summary for epoch 5: loss: 0.992, acc: 64.8]  time: 13.835s **
** Summary for epoch 6: loss: 0.929, acc: 67.2]  time: 13.658s **
** Summary for epoch 7: loss: 0.886, acc: 68.7]  time: 13.782s **
** Summary for epoch 8: loss: 0.830, acc: 70.8]  time: 13.536s **
** Summary for epoch 9: loss: 0.790, acc: 72.4]  time: 13.578s **
** Summary for student: acc: 69.4]  time: 3.645s **


## 5. Knowledge Distillation del modelo student

Realizamos unas épocas para hacer que los pesos del modelo student repliquen el comportamiento del modelo teacher. Para ello, para cada batch, hacemos una inferencia del modelo teacher para obtener su comportamiento y luiego intentamos replicar la salida en el modelo student. Para ello, usamos la base de datos que hemos descargado en el punto 3. Además, como función de loss usamos la Hinton Loss cuya ecuación es la siguiente:

Loss = $-\sum_{c=1}^My_{o,c}\log(\frac{p_{o,c}}{T})$

Básicamente, es una Crossentropy loss con los logits divididos por la T (temperatura) cuyo valor normalmente es 2.0. Si usamos T=1.0, estaríamos usando una Crossentropy loss normal y corriente.

In [None]:
n_epochs = 10
opt = torch.optim.Adam(student.parameters(), lr=0.0005)
dist_loss_fn = torch.nn.CrossEntropyLoss()
teacher.eval()
student.train().to('cuda')
T = 2.0
for epoch in range(n_epochs): # Entrenamos n epocas
    train_running_loss_dist = 0.0
    train_running_correct = 0
    counter = 0
    time_start = time.time()
    for inputs, labels in train_data_loader: # Obtenemos todos los batch de entrenamiento y los usamos para entrenar
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        opt.zero_grad()
        with torch.no_grad():
          outs_teacher = teacher(inputs) / T
          outs_teacher = torch.nn.functional.softmax(outs_teacher, dim=1)

        outs_student = student(inputs)
        loss_dist = dist_loss_fn(outs_student / T, outs_teacher)
        train_running_loss_dist += loss_dist.item()
        _, preds = torch.max(outs_student.data, 1)
        train_running_correct += (preds == labels).sum().item()
        counter = counter + 1
        loss_dist.backward()
        opt.step()

    epoch_loss_dist = train_running_loss_dist / counter
    epoch_acc = 100. * (train_running_correct / len(train_data_loader.dataset))
    time_end = time.time() - time_start
    print(f'** Summary for epoch {epoch}: '
		f'loss_dist: {epoch_loss_dist:#.3g}, acc: {epoch_acc:#.3g}]  '
		f'time: {time_end:.3f}s **')

# Test
test_correct = 0
with torch.no_grad():
    time_start = time.time()
    for inputs, labels in test_data_loader: # Obtenemos todos los batch de test y los usamos para test
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outs_student = student(inputs)
        _, preds = torch.max(outs_student.data, 1)
        test_correct += (preds == labels).sum().item()

    acc = 100. * (test_correct / len(test_data_loader.dataset))
    time_end = time.time() - time_start
    print(f'** Summary for final model: '
		f'acc: {acc:#.3g}]  '
		f'time: {time_end:.3f}s **')

** Summary for epoch 0: loss_dist: 1.10, acc: 73.1]  time: 13.981s **
** Summary for epoch 1: loss_dist: 1.08, acc: 74.0]  time: 14.169s **
** Summary for epoch 2: loss_dist: 1.06, acc: 74.7]  time: 13.894s **
** Summary for epoch 3: loss_dist: 1.05, acc: 75.2]  time: 13.925s **
** Summary for epoch 4: loss_dist: 1.04, acc: 75.8]  time: 15.415s **
** Summary for epoch 5: loss_dist: 1.03, acc: 76.3]  time: 13.846s **
** Summary for epoch 6: loss_dist: 1.02, acc: 76.7]  time: 14.150s **
** Summary for epoch 7: loss_dist: 1.02, acc: 77.2]  time: 13.639s **
** Summary for epoch 8: loss_dist: 1.01, acc: 77.5]  time: 14.031s **
** Summary for epoch 9: loss_dist: 1.00, acc: 77.9]  time: 13.656s **
** Summary for final model: acc: 73.2]  time: 2.527s **


## 6. Exportar el modelo

Una vez que hemos realizado el entrenamiento modelo student, exportamos el modelo a un fichero para poder usarlo en nuestra aplicación.

In [None]:
student.eval()
torch.save(student.state_dict(), '.student.pt')

Además, vamos a realizar una  inferencia de prueba para analizar el rendimiento del modelo teacher y del student.

In [None]:
image = torch.Tensor(np.random.rand(1,3,224,244)).float().cuda()

# Teacher model
times = []
for i in range(50):
    torch.cuda.synchronize()  # Sincroniza antes de empezar a medir
    time_start = time.time()
    teacher(image)
    torch.cuda.synchronize()  # Espera a que la ejecución en la GPU termine
    time_end = time.time() - time_start
    times.append(time_end)

time_end = np.mean(times)
print(f'Execution time of the teacher model: {time_end:.3f}s')

# Student model
times = []
for i in range(50):
    torch.cuda.synchronize()  # Sincroniza antes de empezar a medir
    time_start = time.time()
    student(image)
    torch.cuda.synchronize()  # Espera a que la ejecución en la GPU termine
    time_end = time.time() - time_start
    times.append(time_end)

time_end = np.mean(times)
print(f'Execution time of the student model: {time_end:.3f}s')

Execution time of the teacher model: 0.006s
Execution time of the student model: 0.005s
