# Introductory Example of Ensemble Learning in Knowledge Distillation

### Define Single Teacher Network

In [2]:
import torch.nn as nn
from torch.nn import functional as F


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(784, 128)
        self.linear2 = nn.Linear(128, 128)
        self.linear3 = nn.Linear(128, 10)

    def forward(self, data):
        data = data.view(data.size(0), -1)  # flatten
        output = F.relu(self.linear1(data))
        output = F.relu(self.linear2(output))
        output = self.linear3(output)
        return output

### Define Ensemble of Teachers

In [13]:
from torchensemble import BaggingClassifier, FusionClassifier

teacher_ensemble = BaggingClassifier(
    estimator=MLP,
    n_estimators=10,
    cuda=False,
)

### Prepare Dataset

In [14]:
import torch
from torchvision import datasets, transforms

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train = datasets.MNIST("../Dataset", train=True, download=True, transform=transform)
test = datasets.MNIST("../Dataset", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=True)

### Load Training Parameters

In [15]:
import json

with open("./params.json", "r") as params_fd:
    params = json.load(params_fd)

TRAIN_TEACHERS = params["train_teachers"]
LR = params["lr"]
STUDENT_EPOCHS = params["student_epochs"]
TEACHER_EPOCHS = params["teacher_epochs"]
ALPHA = params["alpha"]
TEMPERATURE = params["temperature"]

### Offline Teacher Training

In [16]:
if TRAIN_TEACHERS:
    criterion = nn.CrossEntropyLoss()
    optimizer = {"optimizer_name": "Adam", "lr": 1e-3}

    teacher_ensemble.set_criterion(criterion)
    teacher_ensemble.set_optimizer(**optimizer)

    print("Training Teacher Ensemble")
    teacher_ensemble.fit(train_loader, epochs=20, test_loader=test_loader)
else:
    from torchensemble.utils import io

    io.load(teacher_ensemble, ".")

print(str(teacher_ensemble.evaluate(test_loader)))

98.43


### Define Student Model

In [17]:
class StudentMLP(nn.Module):
    def __init__(self):
        super(StudentMLP, self).__init__()
        self.linear1 = nn.Linear(784, 16)
        self.linear2 = nn.Linear(16, 10)

    def forward(self, data):
        data = data.view(data.size(0), -1)  # flatten
        output = F.relu(self.linear1(data))
        output = self.linear2(output)
        return output

### Implement Training Loop

In [18]:
import torch.nn.functional as F
from torch.optim import Adam


def train_loop(student, train_loader, test_loader, epochs):
    student.train()
    optimizer = Adam(student.parameters(), lr=LR)

    for epoch in range(epochs):
        total_epoch_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            student_outputs = student(inputs)
            loss = F.cross_entropy(student_outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_epoch_loss += loss.item()

        correct = 0
        total = 0

        with torch.no_grad():
            for data in test_loader:
                inputs, labels = data
                outputs = student(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(
            f"Epoch: {epoch}, Loss: {total_epoch_loss}, Validation Accuracy: {100 * correct/total}"
        )

### Implement KD Training Loop

In [36]:
def kd_train_loop(student, teacher, temperature, train_loader, test_loader, epochs):
    student.train()
    teacher.eval()

    optimizer = Adam(student.parameters(), lr=LR)

    for epoch in range(epochs):
        total_epoch_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            student_outputs = student(inputs)

            with torch.no_grad():
                teacher_outputs = teacher(inputs)

            loss = F.cross_entropy(
                student_outputs / temperature, teacher_outputs / temperature
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_epoch_loss += loss.item()

        correct = 0
        total = 0

        with torch.no_grad():
            for data in test_loader:
                inputs, labels = data
                outputs = student(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(
            f"Epoch: {epoch}, Loss: {total_epoch_loss}, Validation Accuracy: {100 * correct/total}"
        )

### Train Student Model

In [39]:
student_1 = StudentMLP()
student_2 = StudentMLP()

print("No teacher training")
train_loop(student_1, train_loader, test_loader, 20)

print("Teacher Ensemble")
kd_train_loop(
    student_2,
    teacher_ensemble,
    TEMPERATURE,
    train_loader,
    test_loader,
    STUDENT_EPOCHS,
)

No teacher training
Epoch: 0, Accuracy: 91.58, Loss: 217.10807114839554
Epoch: 1, Accuracy: 92.24, Loss: 129.20910249650478
Epoch: 2, Accuracy: 93.16, Loss: 113.32058991491795
Epoch: 3, Accuracy: 93.77, Loss: 101.85108511894941
Epoch: 4, Accuracy: 94.02, Loss: 93.68176006525755
Epoch: 5, Accuracy: 94.43, Loss: 87.81360195577145
Epoch: 6, Accuracy: 94.5, Loss: 83.12724731862545
Epoch: 7, Accuracy: 94.88, Loss: 79.44343261048198
Epoch: 8, Accuracy: 94.85, Loss: 76.5991804562509
Epoch: 9, Accuracy: 94.71, Loss: 73.84392862021923
Epoch: 10, Accuracy: 95.02, Loss: 70.94671323522925
Epoch: 11, Accuracy: 95.21, Loss: 69.52563505247235
Epoch: 12, Accuracy: 95.22, Loss: 67.35294060781598
Epoch: 13, Accuracy: 95.27, Loss: 65.7399962823838
Epoch: 14, Accuracy: 95.25, Loss: 63.78615561127663
Epoch: 15, Accuracy: 95.3, Loss: 63.23735174909234
Epoch: 16, Accuracy: 95.35, Loss: 61.72276658937335
Epoch: 17, Accuracy: 95.52, Loss: 61.291945934295654
Epoch: 18, Accuracy: 95.45, Loss: 60.18793055601418
E