# Introductory Example of Ensemble Learning in Knowledge Distillation

### Define Single Teacher Network

In [1]:
import torch.nn as nn
from torch.nn import functional as F

TRAIN_TEACHERS = False


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(784, 128)
        self.linear2 = nn.Linear(128, 128)
        self.linear3 = nn.Linear(128, 10)

    def forward(self, data):
        data = data.view(data.size(0), -1)  # flatten
        output = F.relu(self.linear1(data))
        output = F.relu(self.linear2(output))
        output = self.linear3(output)
        return output

### Define Single Teacher and Ensemble of Teacher

In [2]:
from torchensemble import BaggingClassifier, FusionClassifier

teacher_ensemble = BaggingClassifier(
    estimator=MLP,
    n_estimators=10,
    cuda=False,
)

### Prepare Dataset

In [3]:
import torch
from torchvision import datasets, transforms

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train = datasets.MNIST("../Dataset", train=True, download=True, transform=transform)
test = datasets.MNIST("../Dataset", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=True)

### Offline Teacher Training

In [None]:
if TRAIN_TEACHERS:
    criterion = nn.CrossEntropyLoss()
    optimizer = {"optimizer_name": "Adam", "lr": 1e-3, "weight_decay": 5e-4}

    teacher_ensemble.set_criterion(criterion)
    teacher_ensemble.set_optimizer(**optimizer)

    print("Training Teacher Ensemble")
    teacher_ensemble.fit(train_loader, epochs=20, test_loader=test_loader)
else:
    from torchensemble.utils import io
    io.load(teacher_ensemble, ".")
    
print(str(teacher_ensemble.evaluate(test_loader)))

### Define Student Model

In [11]:
class StudentMLP(nn.Module):
    def __init__(self):
        super(StudentMLP, self).__init__()
        self.linear1 = nn.Linear(784, 16)
        self.linear2 = nn.Linear(16, 10)

    def forward(self, data):
        data = data.view(data.size(0), -1)  # flatten
        output = F.relu(self.linear1(data))
        output = self.linear2(output)
        return output

### Implement Training Loop

In [7]:
import torch.nn.functional as F
from torch.optim import Adam


def train_loop(student, train_loader, test_loader, epochs):
    student.train()
    optimizer = Adam(student.parameters(), lr=1e-3)

    for epoch in range(epochs):
        total_epoch_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            student_outputs = student(inputs)
            loss = F.cross_entropy(student_outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_epoch_loss += loss.item()

        correct = 0
        total = 0

        with torch.no_grad():
            for data in test_loader:
                inputs, labels = data
                outputs = student(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(
            f"Epoch: {epoch}, Accuracy: {100 * correct/total}, Loss: {total_epoch_loss}"
        )

### Implement KD Training Loop

In [14]:
def kd_train_loop(student, teacher, alpha, temperature, train_loader, test_loader, epochs):
    student.train()
    teacher.eval()

    optimizer = Adam(student.parameters(), lr=1e-3)

    for epoch in range(epochs):
        total_epoch_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            student_outputs = student(inputs)

            with torch.no_grad():
                teacher_outputs = teacher(inputs)

            loss = nn.KLDivLoss()(
                F.log_softmax(student_outputs/temperature, dim=1), F.softmax(teacher_outputs/temperature, dim=1)
            ) * (alpha) * temperature**2 + F.cross_entropy(student_outputs, labels) * (1.0 - alpha)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_epoch_loss += loss.item()

        correct = 0
        total = 0

        with torch.no_grad():
            for data in test_loader:
                inputs, labels = data
                outputs = student(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(
            f"Epoch: {epoch}, Accuracy: {100 * correct/total}, Loss: {total_epoch_loss}"
        )

### Train Student Model

In [None]:
student_1 = StudentMLP()
student_2 = StudentMLP()

print("No teacher training")
train_loop(student_1, train_loader, test_loader, 10)

print("Teacher Ensemble")
kd_train_loop(student_2, teacher_ensemble, 0.95, 10, train_loader, test_loader, 10)