Use KL Divergence loss on Knowledge Distillation Task. You can use any teacher and student model (prefer small models). You need to show that it works, and update README.md with proper logs


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

try:
    from torchsummary import summary
except ModuleNotFoundError:
    !pip install torchsummary
    from torchsummary import summary

from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torchvision

import os
import time
import math

In [2]:
train_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                      #  transforms.RandomRotation((-7.0, 7.0), fill=(1,)),
                                       transforms.RandomAffine(degrees=10, shear = 10),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       # Note the difference between (0.1307) and (0.1307,)
                                       ])

# Test Phase transformations
test_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])

train = datasets.CIFAR10(root = './data', train=True, download=True, transform=train_transforms)
test = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)

# Do we have CUDA drivers for us?
cuda = torch.cuda.is_available()
print ("Cuda Available?", cuda)

dataloader_args = dict(shuffle=True, batch_size=2048, num_workers=2, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

# Dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train, **dataloader_args)
test_loader = torch.utils.data.DataLoader(dataset=test, **dataloader_args)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 30.4MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Cuda Available? True


In [3]:
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv01 = nn.Conv2d(3, 16, 3, bias=False, padding=1)
        self.batch01 = nn.BatchNorm2d(num_features=16)

        # ---- Lets take a skip connection
        self.skip_conv1 = nn.Conv2d(16, 16, 3, padding=0, dilation=2)

        self.conv02 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch02 = nn.BatchNorm2d(num_features=16)
        self.conv03 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch03 = nn.BatchNorm2d(num_features=16)
        self.conv04 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch04 = nn.BatchNorm2d(num_features=16)
        self.pool01 = nn.MaxPool2d(2, 2)                                #O=16
        self.conv05 = nn.Conv2d(16, 16, 1, bias=False)

        self.conv11 = nn.Conv2d(16, 64, 3, bias=False, padding=1)
        self.batch11 = nn.BatchNorm2d(num_features=64)
        self.conv12 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch12 = nn.BatchNorm2d(num_features=64)
        self.conv13 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch13 = nn.BatchNorm2d(num_features=64)
        self.conv14 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch14 = nn.BatchNorm2d(num_features=64)
        self.pool11 = nn.MaxPool2d(2, 2)                                #O=8
        self.conv15 = nn.Conv2d(64, 64, 1, bias=False)

        self.conv21 = nn.Conv2d(64, 128, 3, bias=False, padding=1)
        self.batch21 = nn.BatchNorm2d(num_features=128)
        self.conv22 = nn.Conv2d(128, 128, 3, bias=False, padding=1)
        self.batch22 = nn.BatchNorm2d(num_features=128)
        self.conv23 = nn.Conv2d(128,128, 3, bias=False, padding=1)
        self.batch23 = nn.BatchNorm2d(num_features=128)
        self.conv24 = nn.Conv2d(128, 128, 3, bias=False, padding=1)
        self.batch24 = nn.BatchNorm2d(num_features=128)
        self.pool21 = nn.MaxPool2d(2, 2)                                #O=4
        self.conv25 = nn.Conv2d(128, 128, 1, bias=False)

        self.conv31 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV1= nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, bias = False, padding = 0)
        self.batch31 = nn.BatchNorm2d(num_features=128)
        self.conv32 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV2= nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, bias = False, padding = 0)
        self.batch32 = nn.BatchNorm2d(num_features=256)


        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.convx3 = nn.Conv2d(256, 10, 1, bias=False, padding=0)

    def forward(self, x):
        x = self.batch01(F.relu(self.conv01(x)))

        # ---- Lets take a skip connection
        skip_channels = self.skip_conv1(self.skip_conv1(self.skip_conv1(self.skip_conv1(x))))

        x = self.batch02(F.relu(self.conv02(x)))
        x = self.batch03(F.relu(self.conv03(x)))
        x = self.batch04(F.relu(self.conv04(x)))
        x = self.pool01(x)
        x = self.conv05(x)
        # ----------------------------------------------------------

        # ---- Lets add the skip connection here
        x = skip_channels + x

        x = self.batch11(F.relu(self.conv11(x)))
        x = self.batch12(F.relu(self.conv12(x)))
        x = self.batch13(F.relu(self.conv13(x)))
        x = self.batch14(F.relu(self.conv14(x)))
        x = self.pool11(x)
        x = self.conv15(x)
        # ----------------------------------------------------------

        x = self.batch21(F.relu(self.conv21(x)))
        x = self.batch22(F.relu(self.conv22(x)))
        x = self.batch23(F.relu(self.conv23(x)))
        x = self.batch24(F.relu(self.conv24(x)))
        x = self.pool21(x)
        x = self.conv25(x)
        # ----------------------------------------------------------

        x = self.batch31(F.relu(self.convPV1(F.relu(self.conv31(x)))))
        x = self.batch32(F.relu(self.convPV2(F.relu(self.conv32(x)))))


        x = self.avg_pool(x)
        x = self.convx3(x)
        x = x.view(-1, 10)                           # Don't want 10x1x1..
        return F.log_softmax(x, dim=1)  # Added dim=1 parameter)

In [4]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met



def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)

    correct = 0
    processed = 0
    epoch_loss = 0
    time_taken.clear()

    for batch_idx, (data, target) in enumerate(pbar):
        t0 = time.time()

        data, target = data.to(device), target.to(device)

        # Don't want history of gradients
        optimizer.zero_grad()

        y_predict = model(data)

        # Calculate loss
        loss = F.nll_loss(y_predict, target)
        epoch_loss += loss.item()

        # Backpropagate error
        loss.backward()

        # Take an optimizer step
        optimizer.step()

        torch.cuda.synchronize()
        t1 = time.time()

        time_taken.append((t1 - t0))

        pred = y_predict.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
        train_acc.append(100 * correct / processed)

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    return avg_train_loss


def test(model, device, test_loader):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))
    return test_loss

In [5]:
# Initialize model, optimizer, and early stopping
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (f'Device Using = {device}')
model = TeacherModel().to(device)
summary(model, input_size=(3, 32, 32))
criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
early_stopping = EarlyStopping(tolerance=5, min_delta=0.02)

EPOCHS = 100
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    avg_train_loss = train(model, device, train_loader, optimizer, epoch)
    print(f" --> EPOCH: {epoch}, Avg Training Loss: {avg_train_loss:.4f}, Avg Time Taken = {(sum(time_taken) / len(time_taken)) * 1000:.2f}ms")
    val_loss = test(model, device, test_loader)

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_heavy_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_heavy_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

Device Using = cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 28, 28]           2,320
            Conv2d-4           [-1, 16, 24, 24]           2,320
            Conv2d-5           [-1, 16, 20, 20]           2,320
            Conv2d-6           [-1, 16, 16, 16]           2,320
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
            Conv2d-9           [-1, 16, 32, 32]           2,304
      BatchNorm2d-10           [-1, 16, 32, 32]              32
           Conv2d-11           [-1, 16, 32, 32]           2,304
      BatchNorm2d-12           [-1, 16, 32, 32]              32
        MaxPool2d-13           [-1, 16, 16, 16]               0
           Conv2d-1

Loss=1.7319616079330444 Batch_id=24 Accuracy=25.04: 100%|██████████| 25/25 [00:25<00:00,  1.03s/it]

 --> EPOCH: 0, Avg Training Loss: 1.9959, Avg Time Taken = 444.17ms






Test set: Average loss: 2.3616, Accuracy: 1000/10000 (10.00%)

Epoch 2/100


Loss=1.485198974609375 Batch_id=24 Accuracy=42.36: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 1, Avg Training Loss: 1.5596, Avg Time Taken = 436.66ms






Test set: Average loss: 1.5399, Accuracy: 4465/10000 (44.65%)

Epoch 3/100


Loss=1.2382609844207764 Batch_id=24 Accuracy=51.22: 100%|██████████| 25/25 [00:17<00:00,  1.47it/s]

 --> EPOCH: 2, Avg Training Loss: 1.3384, Avg Time Taken = 442.32ms






Test set: Average loss: 1.3909, Accuracy: 4906/10000 (49.06%)

Epoch 4/100


Loss=1.0791206359863281 Batch_id=24 Accuracy=57.21: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 3, Avg Training Loss: 1.1853, Avg Time Taken = 429.29ms






Test set: Average loss: 1.2819, Accuracy: 5486/10000 (54.86%)

Epoch 5/100


Loss=1.0430524349212646 Batch_id=24 Accuracy=61.28: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

 --> EPOCH: 4, Avg Training Loss: 1.0724, Avg Time Taken = 423.51ms






Test set: Average loss: 1.1312, Accuracy: 5964/10000 (59.64%)

Epoch 6/100


Loss=0.9689779281616211 Batch_id=24 Accuracy=64.64: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 5, Avg Training Loss: 0.9871, Avg Time Taken = 429.04ms






Test set: Average loss: 1.0078, Accuracy: 6466/10000 (64.66%)

Epoch 7/100


Loss=0.9574969410896301 Batch_id=24 Accuracy=67.50: 100%|██████████| 25/25 [00:16<00:00,  1.48it/s]

 --> EPOCH: 6, Avg Training Loss: 0.9085, Avg Time Taken = 431.25ms






Test set: Average loss: 0.9926, Accuracy: 6495/10000 (64.95%)

Epoch 8/100


Loss=0.7988631129264832 Batch_id=24 Accuracy=69.91: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 7, Avg Training Loss: 0.8450, Avg Time Taken = 431.23ms






Test set: Average loss: 0.9559, Accuracy: 6643/10000 (66.43%)

Epoch 9/100


Loss=0.7736895680427551 Batch_id=24 Accuracy=72.47: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]

 --> EPOCH: 8, Avg Training Loss: 0.7854, Avg Time Taken = 429.17ms






Test set: Average loss: 0.8552, Accuracy: 7049/10000 (70.49%)

Epoch 10/100


Loss=0.6935614943504333 Batch_id=24 Accuracy=74.41: 100%|██████████| 25/25 [00:16<00:00,  1.47it/s]

 --> EPOCH: 9, Avg Training Loss: 0.7258, Avg Time Taken = 426.59ms






Test set: Average loss: 0.7971, Accuracy: 7247/10000 (72.47%)

Epoch 11/100


Loss=0.7197743654251099 Batch_id=24 Accuracy=75.98: 100%|██████████| 25/25 [00:18<00:00,  1.39it/s]

 --> EPOCH: 10, Avg Training Loss: 0.6862, Avg Time Taken = 429.92ms






Test set: Average loss: 0.7972, Accuracy: 7256/10000 (72.56%)

Epoch 12/100


Loss=0.5909600257873535 Batch_id=24 Accuracy=77.33: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

 --> EPOCH: 11, Avg Training Loss: 0.6445, Avg Time Taken = 431.57ms






Test set: Average loss: 0.7919, Accuracy: 7303/10000 (73.03%)

Epoch 13/100


Loss=0.5948521494865417 Batch_id=24 Accuracy=79.17: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]

 --> EPOCH: 12, Avg Training Loss: 0.5996, Avg Time Taken = 430.04ms






Test set: Average loss: 0.7760, Accuracy: 7402/10000 (74.02%)

Epoch 14/100


Loss=0.5677103996276855 Batch_id=24 Accuracy=80.28: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]

 --> EPOCH: 13, Avg Training Loss: 0.5691, Avg Time Taken = 431.12ms






Test set: Average loss: 0.7186, Accuracy: 7550/10000 (75.50%)

Epoch 15/100


Loss=0.5184250473976135 Batch_id=24 Accuracy=81.20: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]

 --> EPOCH: 14, Avg Training Loss: 0.5364, Avg Time Taken = 429.78ms






Test set: Average loss: 0.7707, Accuracy: 7421/10000 (74.21%)

Epoch 16/100


Loss=0.5586455464363098 Batch_id=24 Accuracy=82.31: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 15, Avg Training Loss: 0.5106, Avg Time Taken = 429.70ms






Test set: Average loss: 0.7712, Accuracy: 7424/10000 (74.24%)

Epoch 17/100


Loss=0.5239001512527466 Batch_id=24 Accuracy=83.15: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

 --> EPOCH: 16, Avg Training Loss: 0.4862, Avg Time Taken = 425.72ms






Test set: Average loss: 0.7160, Accuracy: 7628/10000 (76.28%)

Epoch 18/100


Loss=0.47968998551368713 Batch_id=24 Accuracy=83.89: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 17, Avg Training Loss: 0.4622, Avg Time Taken = 431.19ms






Test set: Average loss: 0.7288, Accuracy: 7623/10000 (76.23%)

Epoch 19/100


Loss=0.4700057804584503 Batch_id=24 Accuracy=84.77: 100%|██████████| 25/25 [00:16<00:00,  1.48it/s]

 --> EPOCH: 18, Avg Training Loss: 0.4395, Avg Time Taken = 430.30ms






Test set: Average loss: 0.7055, Accuracy: 7703/10000 (77.03%)

Epoch 20/100


Loss=0.3999972343444824 Batch_id=24 Accuracy=85.79: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 19, Avg Training Loss: 0.4106, Avg Time Taken = 428.19ms






Test set: Average loss: 0.7002, Accuracy: 7705/10000 (77.05%)

Epoch 21/100


Loss=0.41455698013305664 Batch_id=24 Accuracy=86.32: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

 --> EPOCH: 20, Avg Training Loss: 0.3939, Avg Time Taken = 428.76ms






Test set: Average loss: 0.7007, Accuracy: 7755/10000 (77.55%)

---------- prev = 0.4106309008598328 current = 0.3939350962638855 ---------
Epoch 22/100


Loss=0.3779963254928589 Batch_id=24 Accuracy=87.07: 100%|██████████| 25/25 [00:16<00:00,  1.48it/s]

 --> EPOCH: 21, Avg Training Loss: 0.3684, Avg Time Taken = 427.82ms






Test set: Average loss: 0.6848, Accuracy: 7793/10000 (77.93%)

Epoch 23/100


Loss=0.3876402974128723 Batch_id=24 Accuracy=88.04: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

 --> EPOCH: 22, Avg Training Loss: 0.3492, Avg Time Taken = 428.29ms






Test set: Average loss: 0.7240, Accuracy: 7709/10000 (77.09%)

---------- prev = 0.36843979835510254 current = 0.3491759085655212 ---------
Epoch 24/100


Loss=0.3711775541305542 Batch_id=24 Accuracy=88.47: 100%|██████████| 25/25 [00:17<00:00,  1.47it/s]

 --> EPOCH: 23, Avg Training Loss: 0.3356, Avg Time Taken = 428.88ms






Test set: Average loss: 0.6859, Accuracy: 7805/10000 (78.05%)

---------- prev = 0.3491759085655212 current = 0.33556448817253115 ---------
Epoch 25/100


Loss=0.32334133982658386 Batch_id=24 Accuracy=88.93: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 24, Avg Training Loss: 0.3224, Avg Time Taken = 429.51ms






Test set: Average loss: 0.7563, Accuracy: 7690/10000 (76.90%)

---------- prev = 0.33556448817253115 current = 0.3224374294281006 ---------
Epoch 26/100


Loss=0.3557409942150116 Batch_id=24 Accuracy=89.33: 100%|██████████| 25/25 [00:17<00:00,  1.47it/s]

 --> EPOCH: 25, Avg Training Loss: 0.3090, Avg Time Taken = 427.88ms






Test set: Average loss: 0.6971, Accuracy: 7838/10000 (78.38%)

---------- prev = 0.3224374294281006 current = 0.3089509832859039 ---------
Epoch 27/100


Loss=0.2627030909061432 Batch_id=24 Accuracy=90.14: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 26, Avg Training Loss: 0.2847, Avg Time Taken = 428.39ms






Test set: Average loss: 0.6992, Accuracy: 7767/10000 (77.67%)

Epoch 28/100


Loss=0.2679571807384491 Batch_id=24 Accuracy=90.61: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 27, Avg Training Loss: 0.2701, Avg Time Taken = 428.03ms






Test set: Average loss: 0.7986, Accuracy: 7638/10000 (76.38%)

---------- prev = 0.2846775817871094 current = 0.27013249933719635 ---------
Epoch 29/100


Loss=0.2750738561153412 Batch_id=24 Accuracy=91.07: 100%|██████████| 25/25 [00:16<00:00,  1.47it/s]

 --> EPOCH: 28, Avg Training Loss: 0.2627, Avg Time Taken = 431.52ms






Test set: Average loss: 0.6986, Accuracy: 7880/10000 (78.80%)

---------- prev = 0.27013249933719635 current = 0.26270510017871856 ---------
Epoch 30/100


Loss=0.24570909142494202 Batch_id=24 Accuracy=91.66: 100%|██████████| 25/25 [00:17<00:00,  1.47it/s]

 --> EPOCH: 29, Avg Training Loss: 0.2451, Avg Time Taken = 429.83ms






Test set: Average loss: 0.7485, Accuracy: 7804/10000 (78.04%)

---------- prev = 0.26270510017871856 current = 0.24509883403778077 ---------
Epoch 31/100


Loss=0.2544326186180115 Batch_id=24 Accuracy=91.71: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 30, Avg Training Loss: 0.2412, Avg Time Taken = 428.61ms






Test set: Average loss: 0.7455, Accuracy: 7845/10000 (78.45%)

---------- prev = 0.24509883403778077 current = 0.2411676061153412 ---------
Epoch 32/100


Loss=0.21317756175994873 Batch_id=24 Accuracy=92.10: 100%|██████████| 25/25 [00:17<00:00,  1.39it/s]

 --> EPOCH: 31, Avg Training Loss: 0.2304, Avg Time Taken = 429.03ms






Test set: Average loss: 0.7414, Accuracy: 7891/10000 (78.91%)

---------- prev = 0.2411676061153412 current = 0.2303864675760269 ---------
Model saved at: /content/drive/MyDrive/EPAi_V5/model_heavy_acc_92.pth
Early stopping triggered!


In [6]:
model.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_heavy_acc_92.pth', weights_only=True))
test(model, device, test_loader)


Test set: Average loss: 0.7414, Accuracy: 7891/10000 (78.91%)



0.7414170166015625

In [7]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv01 = nn.Conv2d(3, 16, 3, bias=False, padding=1)
        self.batch01 = nn.BatchNorm2d(num_features=16)

        # ---- Lets take a skip connection
        self.skip_conv1 = nn.Conv2d(16, 16, 3, padding=0, dilation=2)

        self.conv02 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch02 = nn.BatchNorm2d(num_features=16)
        self.conv03 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch03 = nn.BatchNorm2d(num_features=16)
        self.conv04 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch04 = nn.BatchNorm2d(num_features=16)
        self.pool01 = nn.MaxPool2d(2, 2)                                #O=16
        self.conv05 = nn.Conv2d(16, 16, 1, bias=False)

        self.conv11 = nn.Conv2d(16, 32, 3, bias=False, padding=1)
        self.batch11 = nn.BatchNorm2d(num_features=32)
        self.conv12 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch12 = nn.BatchNorm2d(num_features=32)
        self.conv13 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch13 = nn.BatchNorm2d(num_features=32)
        self.conv14 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch14 = nn.BatchNorm2d(num_features=32)
        self.pool11 = nn.MaxPool2d(2, 2)                                #O=8
        self.conv15 = nn.Conv2d(32, 32, 1, bias=False)

        self.conv21 = nn.Conv2d(32, 64, 3, bias=False, padding=1)
        self.batch21 = nn.BatchNorm2d(num_features=64)
        self.conv22 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch22 = nn.BatchNorm2d(num_features=64)
        self.conv23 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch23 = nn.BatchNorm2d(num_features=64)
        self.conv24 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch24 = nn.BatchNorm2d(num_features=64)
        self.pool21 = nn.MaxPool2d(2, 2)                                #O=4
        self.conv25 = nn.Conv2d(64, 64, 1, bias=False)

        self.conv31 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, groups=64, bias = False, padding = 1)
        self.convPV1= nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, bias = False, padding = 0)
        self.batch31 = nn.BatchNorm2d(num_features=128)
        self.conv32 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV2= nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, bias = False, padding = 0)
        self.batch32 = nn.BatchNorm2d(num_features=256)


        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.convx3 = nn.Conv2d(256, 10, 1, bias=False, padding=0)

    def forward(self, x):
        x = self.batch01(F.relu(self.conv01(x)))

        # ---- Lets take a skip connection
        skip_channels = self.skip_conv1(self.skip_conv1(self.skip_conv1(self.skip_conv1(x))))

        x = self.batch02(F.relu(self.conv02(x)))
        x = self.batch03(F.relu(self.conv03(x)))
        x = self.batch04(F.relu(self.conv04(x)))
        x = self.pool01(x)
        x = self.conv05(x)
        # ----------------------------------------------------------

        # ---- Lets add the skip connection here
        x = skip_channels + x

        x = self.batch11(F.relu(self.conv11(x)))
        x = self.batch12(F.relu(self.conv12(x)))
        x = self.batch13(F.relu(self.conv13(x)))
        x = self.batch14(F.relu(self.conv14(x)))
        x = self.pool11(x)
        x = self.conv15(x)
        # ----------------------------------------------------------

        x = self.batch21(F.relu(self.conv21(x)))
        x = self.batch22(F.relu(self.conv22(x)))
        x = self.batch23(F.relu(self.conv23(x)))
        x = self.batch24(F.relu(self.conv24(x)))
        x = self.pool21(x)
        x = self.conv25(x)
        # ----------------------------------------------------------

        x = self.batch31(F.relu(self.convPV1(F.relu(self.conv31(x)))))
        x = self.batch32(F.relu(self.convPV2(F.relu(self.conv32(x)))))


        x = self.avg_pool(x)
        x = self.convx3(x)
        x = x.view(-1, 10)                           # Don't want 10x1x1..
        return F.log_softmax(x, dim=1)  # Added dim=1 parameter)

In [8]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met



def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)

    correct = 0
    processed = 0
    epoch_loss = 0
    time_taken.clear()

    for batch_idx, (data, target) in enumerate(pbar):
        t0 = time.time()

        data, target = data.to(device), target.to(device)

        # Don't want history of gradients
        optimizer.zero_grad()

        y_predict = model(data)

        # Calculate loss
        loss = F.nll_loss(y_predict, target)
        epoch_loss += loss.item()

        # Backpropagate error
        loss.backward()

        # Take an optimizer step
        optimizer.step()

        torch.cuda.synchronize()
        t1 = time.time()

        time_taken.append((t1 - t0))

        pred = y_predict.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
        train_acc.append(100 * correct / processed)

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    return avg_train_loss


def test(model, device, test_loader):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))
    return test_loss

In [9]:
# Initialize model, optimizer, and early stopping
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (f'Device Using = {device}')
model = StudentModel().to(device)
summary(model, input_size=(3, 32, 32))
criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
early_stopping = EarlyStopping(tolerance=5, min_delta=0.02)

EPOCHS = 100
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    avg_train_loss = train(model, device, train_loader, optimizer, epoch)
    print(f" --> EPOCH: {epoch}, Avg Training Loss: {avg_train_loss:.4f}, Avg Time Taken = {(sum(time_taken) / len(time_taken)) * 1000:.2f}ms")
    val_loss = test(model, device, test_loader)

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_small_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_small_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

Device Using = cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 28, 28]           2,320
            Conv2d-4           [-1, 16, 24, 24]           2,320
            Conv2d-5           [-1, 16, 20, 20]           2,320
            Conv2d-6           [-1, 16, 16, 16]           2,320
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
            Conv2d-9           [-1, 16, 32, 32]           2,304
      BatchNorm2d-10           [-1, 16, 32, 32]              32
           Conv2d-11           [-1, 16, 32, 32]           2,304
      BatchNorm2d-12           [-1, 16, 32, 32]              32
        MaxPool2d-13           [-1, 16, 16, 16]               0
           Conv2d-1

Loss=1.839409351348877 Batch_id=24 Accuracy=19.72: 100%|██████████| 25/25 [00:16<00:00,  1.55it/s]

 --> EPOCH: 0, Avg Training Loss: 2.1232, Avg Time Taken = 342.92ms






Test set: Average loss: 2.2944, Accuracy: 998/10000 (9.98%)

Epoch 2/100


Loss=1.6231263875961304 Batch_id=24 Accuracy=36.33: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 1, Avg Training Loss: 1.6994, Avg Time Taken = 341.14ms






Test set: Average loss: 1.6074, Accuracy: 4080/10000 (40.80%)

Epoch 3/100


Loss=1.4510281085968018 Batch_id=24 Accuracy=45.22: 100%|██████████| 25/25 [00:16<00:00,  1.48it/s]

 --> EPOCH: 2, Avg Training Loss: 1.4865, Avg Time Taken = 333.43ms






Test set: Average loss: 1.5061, Accuracy: 4478/10000 (44.78%)

Epoch 4/100


Loss=1.2703484296798706 Batch_id=24 Accuracy=50.72: 100%|██████████| 25/25 [00:16<00:00,  1.53it/s]

 --> EPOCH: 3, Avg Training Loss: 1.3461, Avg Time Taken = 332.31ms






Test set: Average loss: 1.4149, Accuracy: 4999/10000 (49.99%)

Epoch 5/100


Loss=1.0891814231872559 Batch_id=24 Accuracy=54.58: 100%|██████████| 25/25 [00:16<00:00,  1.52it/s]

 --> EPOCH: 4, Avg Training Loss: 1.2475, Avg Time Taken = 336.23ms






Test set: Average loss: 1.2892, Accuracy: 5345/10000 (53.45%)

Epoch 6/100


Loss=1.1289831399917603 Batch_id=24 Accuracy=58.19: 100%|██████████| 25/25 [00:16<00:00,  1.52it/s]

 --> EPOCH: 5, Avg Training Loss: 1.1631, Avg Time Taken = 338.79ms






Test set: Average loss: 1.2490, Accuracy: 5582/10000 (55.82%)

Epoch 7/100


Loss=1.046385407447815 Batch_id=24 Accuracy=60.31: 100%|██████████| 25/25 [00:16<00:00,  1.52it/s]

 --> EPOCH: 6, Avg Training Loss: 1.1034, Avg Time Taken = 335.35ms






Test set: Average loss: 1.1784, Accuracy: 5824/10000 (58.24%)

Epoch 8/100


Loss=1.0095020532608032 Batch_id=24 Accuracy=62.77: 100%|██████████| 25/25 [00:16<00:00,  1.53it/s]

 --> EPOCH: 7, Avg Training Loss: 1.0367, Avg Time Taken = 334.90ms






Test set: Average loss: 1.0438, Accuracy: 6294/10000 (62.94%)

Epoch 9/100


Loss=0.9633558988571167 Batch_id=24 Accuracy=65.29: 100%|██████████| 25/25 [00:16<00:00,  1.51it/s]

 --> EPOCH: 8, Avg Training Loss: 0.9727, Avg Time Taken = 336.57ms






Test set: Average loss: 1.0176, Accuracy: 6326/10000 (63.26%)

Epoch 10/100


Loss=0.9215368628501892 Batch_id=24 Accuracy=66.87: 100%|██████████| 25/25 [00:16<00:00,  1.54it/s]

 --> EPOCH: 9, Avg Training Loss: 0.9314, Avg Time Taken = 336.92ms






Test set: Average loss: 0.9417, Accuracy: 6658/10000 (66.58%)

Epoch 11/100


Loss=0.8912086486816406 Batch_id=24 Accuracy=68.56: 100%|██████████| 25/25 [00:16<00:00,  1.49it/s]

 --> EPOCH: 10, Avg Training Loss: 0.8809, Avg Time Taken = 339.47ms






Test set: Average loss: 0.9288, Accuracy: 6737/10000 (67.37%)

Epoch 12/100


Loss=0.8185917735099792 Batch_id=24 Accuracy=70.43: 100%|██████████| 25/25 [00:16<00:00,  1.53it/s]

 --> EPOCH: 11, Avg Training Loss: 0.8329, Avg Time Taken = 336.72ms






Test set: Average loss: 0.9871, Accuracy: 6612/10000 (66.12%)

Epoch 13/100


Loss=0.8015299439430237 Batch_id=24 Accuracy=71.58: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

 --> EPOCH: 12, Avg Training Loss: 0.8017, Avg Time Taken = 335.48ms






Test set: Average loss: 0.8578, Accuracy: 6995/10000 (69.95%)

Epoch 14/100


Loss=0.7465499043464661 Batch_id=24 Accuracy=72.99: 100%|██████████| 25/25 [00:16<00:00,  1.55it/s]

 --> EPOCH: 13, Avg Training Loss: 0.7707, Avg Time Taken = 336.59ms






Test set: Average loss: 0.8662, Accuracy: 7012/10000 (70.12%)

Epoch 15/100


Loss=0.7175785899162292 Batch_id=24 Accuracy=73.92: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 14, Avg Training Loss: 0.7418, Avg Time Taken = 336.46ms






Test set: Average loss: 0.8718, Accuracy: 7024/10000 (70.24%)

Epoch 16/100


Loss=0.7081397175788879 Batch_id=24 Accuracy=74.78: 100%|██████████| 25/25 [00:16<00:00,  1.55it/s]

 --> EPOCH: 15, Avg Training Loss: 0.7179, Avg Time Taken = 336.56ms






Test set: Average loss: 0.8439, Accuracy: 7085/10000 (70.85%)

Epoch 17/100


Loss=0.7096620798110962 Batch_id=24 Accuracy=75.60: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

 --> EPOCH: 16, Avg Training Loss: 0.6938, Avg Time Taken = 334.90ms






Test set: Average loss: 0.7778, Accuracy: 7344/10000 (73.44%)

Epoch 18/100


Loss=0.7124909162521362 Batch_id=24 Accuracy=76.71: 100%|██████████| 25/25 [00:16<00:00,  1.55it/s]

 --> EPOCH: 17, Avg Training Loss: 0.6645, Avg Time Taken = 335.87ms






Test set: Average loss: 0.8050, Accuracy: 7238/10000 (72.38%)

Epoch 19/100


Loss=0.6205840110778809 Batch_id=24 Accuracy=77.70: 100%|██████████| 25/25 [00:16<00:00,  1.51it/s]

 --> EPOCH: 18, Avg Training Loss: 0.6393, Avg Time Taken = 335.28ms






Test set: Average loss: 0.7403, Accuracy: 7448/10000 (74.48%)

Epoch 20/100


Loss=0.5964459180831909 Batch_id=24 Accuracy=78.29: 100%|██████████| 25/25 [00:16<00:00,  1.55it/s]

 --> EPOCH: 19, Avg Training Loss: 0.6189, Avg Time Taken = 336.73ms






Test set: Average loss: 0.7596, Accuracy: 7387/10000 (73.87%)

Epoch 21/100


Loss=0.610698938369751 Batch_id=24 Accuracy=79.03: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 20, Avg Training Loss: 0.5962, Avg Time Taken = 334.92ms






Test set: Average loss: 0.7527, Accuracy: 7440/10000 (74.40%)

Epoch 22/100


Loss=0.5598589181900024 Batch_id=24 Accuracy=79.42: 100%|██████████| 25/25 [00:16<00:00,  1.54it/s]

 --> EPOCH: 21, Avg Training Loss: 0.5833, Avg Time Taken = 333.95ms






Test set: Average loss: 0.7510, Accuracy: 7465/10000 (74.65%)

---------- prev = 0.596239652633667 current = 0.5832714319229126 ---------
Epoch 23/100


Loss=0.5501305460929871 Batch_id=24 Accuracy=80.18: 100%|██████████| 25/25 [00:16<00:00,  1.53it/s]

 --> EPOCH: 22, Avg Training Loss: 0.5660, Avg Time Taken = 335.76ms






Test set: Average loss: 0.7475, Accuracy: 7447/10000 (74.47%)

---------- prev = 0.5832714319229126 current = 0.5659991955757141 ---------
Epoch 24/100


Loss=0.5897893309593201 Batch_id=24 Accuracy=80.50: 100%|██████████| 25/25 [00:16<00:00,  1.55it/s]

 --> EPOCH: 23, Avg Training Loss: 0.5526, Avg Time Taken = 338.42ms






Test set: Average loss: 0.7181, Accuracy: 7549/10000 (75.49%)

---------- prev = 0.5659991955757141 current = 0.5526035451889038 ---------
Epoch 25/100


Loss=0.5340969562530518 Batch_id=24 Accuracy=81.42: 100%|██████████| 25/25 [00:16<00:00,  1.54it/s]

 --> EPOCH: 24, Avg Training Loss: 0.5336, Avg Time Taken = 336.70ms






Test set: Average loss: 0.7186, Accuracy: 7558/10000 (75.58%)

---------- prev = 0.5526035451889038 current = 0.5336136102676392 ---------
Epoch 26/100


Loss=0.5731205344200134 Batch_id=24 Accuracy=81.76: 100%|██████████| 25/25 [00:16<00:00,  1.53it/s]

 --> EPOCH: 25, Avg Training Loss: 0.5237, Avg Time Taken = 334.28ms






Test set: Average loss: 0.7493, Accuracy: 7525/10000 (75.25%)

---------- prev = 0.5336136102676392 current = 0.5237028408050537 ---------
Model saved at: /content/drive/MyDrive/EPAi_V5/model_small_acc_81.pth
Early stopping triggered!


In [10]:
model.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_small_acc_81.pth', weights_only=True))
test(model, device, test_loader)


Test set: Average loss: 0.7493, Accuracy: 7525/10000 (75.25%)



0.74932431640625

# Distillation Process

In [11]:
class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device = {device}')

# Prepare teacher model
teacher = TeacherModel().to(device)
teacher.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_heavy_acc_92.pth', weights_only=True))
teacher.eval()

# Prepare student model
student = StudentModel().to(device)
student.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_small_acc_81.pth', weights_only=True))

print("=============================================")
print("Student model accuracy before training ")
student.eval()
test(student, device, test_loader)
print("=============================================")

# Loss functions
hard_loss = nn.CrossEntropyLoss() #Hard label loss
soft_loss = nn.KLDivLoss(reduction="batchmean")  # Distillation loss

# Temperature and alpha
T = 5.0  # Temperature
alpha = 0.5  # Weight for distillation loss

# Optimizer
optimizer = optim.Adam(student.parameters(), lr=0.001)

# Training loop
EPOCHS = 100
for epoch in range(EPOCHS):

    correct = 0
    processed = 0
    epoch_loss = 0
    pbar = tqdm(train_loader)
    student.train()
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # Teacher predictions (soft labels)
        with torch.no_grad():
            teacher_logits = teacher(data) / T

        teacher_probs = torch.softmax(teacher_logits, dim=1)

        # Student predictions
        student_logits = student(data)
        student_probs = torch.log_softmax(student_logits / T, dim=1)

        # Compute losses
        loss_soft = soft_loss(student_probs, teacher_probs) * (T ** 2)  # Scale by T^2
        loss_hard = hard_loss(student_logits, target)
        loss = alpha * loss_hard + (1 - alpha) * loss_soft

        epoch_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()

        pred = student_logits.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    train_acc.append(100 * correct / processed) # Append to train_acc here

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_distil_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_small_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

    student.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = student(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))



Using device = cuda
Student model accuracy before training 

Test set: Average loss: 0.7493, Accuracy: 7525/10000 (75.25%)



Loss=1.246791958808899 Batch_id=24 Accuracy=76.08: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]



Test set: Average loss: 0.8613, Accuracy: 7480/10000 (74.80%)



Loss=1.036960482597351 Batch_id=24 Accuracy=79.96: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]



Test set: Average loss: 0.7897, Accuracy: 7633/10000 (76.33%)



Loss=1.0559091567993164 Batch_id=24 Accuracy=80.92: 100%|██████████| 25/25 [00:17<00:00,  1.47it/s]



Test set: Average loss: 0.7581, Accuracy: 7699/10000 (76.99%)



Loss=1.017242670059204 Batch_id=24 Accuracy=81.65: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]



Test set: Average loss: 0.7319, Accuracy: 7787/10000 (77.87%)



Loss=0.9263374209403992 Batch_id=24 Accuracy=82.10: 100%|██████████| 25/25 [00:16<00:00,  1.47it/s]



Test set: Average loss: 0.7310, Accuracy: 7857/10000 (78.57%)



Loss=1.0241096019744873 Batch_id=24 Accuracy=82.66: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]

---------- prev = 0.9687244987487793 current = 0.9527503371238708 ---------






Test set: Average loss: 0.8039, Accuracy: 7691/10000 (76.91%)



Loss=0.8990921974182129 Batch_id=24 Accuracy=82.76: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]



Test set: Average loss: 0.7652, Accuracy: 7733/10000 (77.33%)



Loss=0.8882584571838379 Batch_id=24 Accuracy=83.16: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

---------- prev = 0.9171848630905152 current = 0.8985904693603516 ---------






Test set: Average loss: 0.7453, Accuracy: 7769/10000 (77.69%)



Loss=0.9041925668716431 Batch_id=24 Accuracy=83.76: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]



Test set: Average loss: 0.7226, Accuracy: 7857/10000 (78.57%)



Loss=0.8615862131118774 Batch_id=24 Accuracy=83.87: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

---------- prev = 0.8707234072685242 current = 0.8623436975479126 ---------






Test set: Average loss: 0.7005, Accuracy: 7893/10000 (78.93%)



Loss=0.794651448726654 Batch_id=24 Accuracy=84.48: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]



Test set: Average loss: 0.7302, Accuracy: 7843/10000 (78.43%)



Loss=0.8175041675567627 Batch_id=24 Accuracy=84.57: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]

---------- prev = 0.8261134171485901 current = 0.8205741333961487 ---------






Test set: Average loss: 0.6839, Accuracy: 7917/10000 (79.17%)



Loss=0.7790884971618652 Batch_id=24 Accuracy=84.96: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]



Test set: Average loss: 0.7128, Accuracy: 7889/10000 (78.89%)



Loss=0.7970046997070312 Batch_id=24 Accuracy=85.23: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

---------- prev = 0.8001772952079773 current = 0.7927077388763428 ---------






Test set: Average loss: 0.7171, Accuracy: 7871/10000 (78.71%)



Loss=0.8118811845779419 Batch_id=24 Accuracy=85.82: 100%|██████████| 25/25 [00:17<00:00,  1.43it/s]



Test set: Average loss: 0.7101, Accuracy: 7901/10000 (79.01%)



Loss=0.7828143835067749 Batch_id=24 Accuracy=85.89: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

---------- prev = 0.7703193974494934 current = 0.7603784418106079 ---------






Test set: Average loss: 0.7035, Accuracy: 7905/10000 (79.05%)



Loss=0.8103650212287903 Batch_id=24 Accuracy=85.86: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

---------- prev = 0.7603784418106079 current = 0.7619184136390686 ---------






Test set: Average loss: 0.7008, Accuracy: 7956/10000 (79.56%)



Loss=0.7591429352760315 Batch_id=24 Accuracy=86.19: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

---------- prev = 0.7619184136390686 current = 0.7482853484153748 ---------






Test set: Average loss: 0.6687, Accuracy: 7981/10000 (79.81%)



Loss=0.7329272031784058 Batch_id=24 Accuracy=86.72: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]



Test set: Average loss: 0.6767, Accuracy: 7984/10000 (79.84%)



Loss=0.7703170776367188 Batch_id=24 Accuracy=86.99: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

---------- prev = 0.723474714756012 current = 0.7126913523674011 ---------






Test set: Average loss: 0.6846, Accuracy: 7964/10000 (79.64%)



Loss=0.7100624442100525 Batch_id=24 Accuracy=86.98: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]

---------- prev = 0.7126913523674011 current = 0.7117586159706115 ---------






Test set: Average loss: 0.6718, Accuracy: 7972/10000 (79.72%)



Loss=0.726957380771637 Batch_id=24 Accuracy=87.49: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

---------- prev = 0.7117586159706115 current = 0.6972872185707092 ---------






Test set: Average loss: 0.7240, Accuracy: 7880/10000 (78.80%)



Loss=0.6941761374473572 Batch_id=24 Accuracy=87.25: 100%|██████████| 25/25 [00:17<00:00,  1.43it/s]

---------- prev = 0.6972872185707092 current = 0.6909924077987671 ---------






Test set: Average loss: 0.6660, Accuracy: 7970/10000 (79.70%)



Loss=0.7098481059074402 Batch_id=24 Accuracy=87.79: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

---------- prev = 0.6909924077987671 current = 0.6800520181655884 ---------
Model saved at: /content/drive/MyDrive/EPAi_V5/model_distil_acc_87.pth
Early stopping triggered!



