In [29]:
import torch
import torch.nn as nn
from torchvision import datasets,transforms
import torch.optim as optim


In [30]:
# load dataset
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=data_transforms)
test_dataset = datasets.CIFAR10(root='./data',train=False,download=True,transform=data_transforms)

Files already downloaded and verified
Files already downloaded and verified


In [31]:
#creating the dataloaders for training and test
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=32,shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=32,shuffle=False)


In [32]:
#creating the teacher model
class Teacher_model(nn.Module):
    def __init__(self):
        super(Teacher_model,self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3,128,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(128,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(64,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2)
        )
        self.classfier = nn.Sequential(
            nn.Linear(2048,512),
            nn.ReLU(),
            nn.Linear(512,10)
        )

    def forward(self,x):
        x=self.layers(x)
        x=torch.flatten(x,1)
        x = self.classfier(x)
        return x



In [52]:
class student_model(nn.Module):
    def __init__(self):
        super(student_model,self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16,16,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),  
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024,256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256,10)
        )

    def forward(self,x):
        x= self.layers(x)
        x = torch.flatten(x,1)
        x = self.classifier(x)
        return x


In [36]:
#training the model
epochs = 10
learning_rate = 0.001

def train(model,train_dataloader,epochs,learning_rate,device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),lr = learning_rate)
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs,labels in train_dataloader:
            inputs,labels = inputs.to(device),labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs,labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'epoch {epoch+1}/{epochs},loss={running_loss/len(train_dataloader)}')


def test(model,test_dataloader,device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [53]:
torch.manual_seed(42)
teacher = Teacher_model()
student = student_model()



In [None]:
train(teacher, train_dataloader, epochs=10, learning_rate=0.001, device=device)

In [37]:
test_accuracy_deep = test(teacher, test_dataloader, device)

Test Accuracy: 73.06%


In [54]:
train(student, train_dataloader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(student, test_dataloader, device)

epoch 1/10,loss=1.3500913418185916
epoch 2/10,loss=1.010557404025121
epoch 3/10,loss=0.873339850953658
epoch 4/10,loss=0.7763097844517391
epoch 5/10,loss=0.6941467424805776
epoch 6/10,loss=0.6254208700174867
epoch 7/10,loss=0.5580854325037466
epoch 8/10,loss=0.4969942342592445
epoch 9/10,loss=0.4504838226052026
epoch 10/10,loss=0.40440677486298104
Test Accuracy: 69.26%


In [55]:
print(f'test_accuracy for teacher :{test_accuracy_deep:.2f}%')
print(f'test accuracy fot student :{test_accuracy_light_ce:.2f}%')

test_accuracy for teacher :73.06%
test accuracy fot student :69.26%


In [None]:
# to implement the teacher-student network we need few more parameters to be incorporated.
#t= temperature (it controls how smooth the output is distributed)
#soft_traget_loss_weight = A weight assigned to extra objective we're about to include
#ce_loss_weight = A weight assigned to cross entropy

In [58]:
def train_kd(teacher,student,train_dataloader,T,epochs,learning_rate,device,soft_target_loss_weight,ce_loss_weight):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(),lr=learning_rate)
    teacher.eval()
    student.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs,labels in train_dataloader:
            inputs,labels = inputs.to(device),labels.to(device)
            optimizer.zero_grad()

            #forward method with teacher model
            with torch.no_grad(): #we donot need to save gradients of teacher
                teacher_logits = teacher(inputs)
            student_logits = student(inputs)
            #softening the distributions
            soft_targets = nn.functional.softmax(teacher_logits/T,dim=1)
            soft_prob = nn.functional.log_softmax(student_logits/T, dim= -1)
            #distillation loss
            distillation_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob) / soft_prob.size()[0]*(T**2))
            #true label loss
            label_loss = ce_loss(student_logits,labels)

            #weighted sum of two losses
            loss = distillation_loss * soft_target_loss_weight + ce_loss_weight*label_loss
            loss.backward()
            optimizer.step()

            running_loss +=loss.item()

        print(f'epoch {epoch+1}/{epochs},loss = {running_loss / len(train_dataloader)}')

train_kd(teacher,student,train_dataloader,T=2,epochs=10,learning_rate=0.001,device=device,soft_target_loss_weight=0.25,ce_loss_weight=0.75)
test_accuracy_light_ce_kd = test(student,test_dataloader,device)
            

epoch 1/10,loss = 0.919714184045334
epoch 2/10,loss = 0.8160603245862081
epoch 3/10,loss = 0.7886000238239803
epoch 4/10,loss = 0.7687715685146402
epoch 5/10,loss = 0.7494939404729842
epoch 6/10,loss = 0.7359033011886758
epoch 7/10,loss = 0.720445447828399
epoch 8/10,loss = 0.7077185597010934
epoch 9/10,loss = 0.6964009245541792
epoch 10/10,loss = 0.6875427604408044
Test Accuracy: 69.37%


In [59]:
print(f'test_accuracy for student model with kd {test_accuracy_light_ce_kd :.2f}% ')

test_accuracy for student model with kd 69.37% 


# fine tuning resnet 50

In [71]:
from torchvision import models
from tqdm import tqdm

In [62]:
resnet_model = models.resnet50(pretrained = True,progress=True)

#replace final fully connected layer for 10 output neurons

num_features = resnet_model.fc.in_features
num_classes = 10
resnet_model.fc = nn.Linear(num_features,num_classes)
model = resnet_model.to(device)




In [75]:
from tqdm import tqdm  # Ensure tqdm is imported

def finetune_resnet(resnet_model, train_dataloader, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(resnet_model.parameters(), lr=0.001)
    resnet_model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0

        # tqdm for the inner loop to show progress within each epoch
        for images, labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = resnet_model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        # Print the average loss for this epoch
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_dataloader):.4f}')


In [73]:
def test_resnet(resnet_model,test_dataloader):
    resnet_model.eval()
    correct=0
    total=0
    
    with torch.no_grad():
        for images,labels in test_dataloader:
            images,labels = images.to(device),labels.to(device)
            #forward pass
            outputs = model(images)
            _,predicted = torch.max(outputs.data,1)
            total +=labels.size(0)
            correct += (predicted==labels).sum().item()
    
    print(f'accuracy of resnet model on test data {100 *correct/total :.2f}%')
    
    

In [76]:
finetune_resnet(resnet_model,train_dataloader,epochs=10)
test_resnet(resnet_model,test_dataloader)

Epoch 1/10: 100%|██████████| 1563/1563 [17:42<00:00,  1.47it/s]


Epoch [1/10], Loss: 1.2167


Epoch 2/10: 100%|██████████| 1563/1563 [18:41<00:00,  1.39it/s]


Epoch [2/10], Loss: 0.8796


Epoch 3/10: 100%|██████████| 1563/1563 [19:03<00:00,  1.37it/s]


Epoch [3/10], Loss: 0.7548


Epoch 4/10: 100%|██████████| 1563/1563 [19:19<00:00,  1.35it/s]


Epoch [4/10], Loss: 0.6470


Epoch 5/10: 100%|██████████| 1563/1563 [19:25<00:00,  1.34it/s]


Epoch [5/10], Loss: 0.5423


Epoch 6/10: 100%|██████████| 1563/1563 [19:38<00:00,  1.33it/s]


Epoch [6/10], Loss: 0.5250


Epoch 7/10: 100%|██████████| 1563/1563 [19:49<00:00,  1.31it/s]


Epoch [7/10], Loss: 0.4156


Epoch 8/10: 100%|██████████| 1563/1563 [19:55<00:00,  1.31it/s]


Epoch [8/10], Loss: 0.3376


Epoch 9/10: 100%|██████████| 1563/1563 [19:56<00:00,  1.31it/s]


Epoch [9/10], Loss: 0.2734


Epoch 10/10: 100%|██████████| 1563/1563 [20:09<00:00,  1.29it/s]


Epoch [10/10], Loss: 0.2703
accuracy of resnet model on test data 79.25%


# using resnet50 as teacher model

In [80]:
def train_resnet_cnn(teacher=resnet_model,student=student,train_dataloader=train_dataloader,T=2,epochs=10,learning_rate=0.001,device=device,soft_target_loss_weight=0.25,ce_loss_weight=0.75):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(),lr=learning_rate)
    teacher.eval()
    student.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            inputs,labels = inputs.to(device),labels.to(device)
            optimizer.zero_grad()

            #forward method with teacher model
            with torch.no_grad(): #we donot need to save gradients of teacher
                teacher_logits = teacher(inputs)
            student_logits = student(inputs)
            #softening the distributions
            soft_targets = nn.functional.softmax(teacher_logits/T,dim=1)
            soft_prob = nn.functional.log_softmax(student_logits/T, dim= -1)
            #distillation loss
            distillation_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob) / soft_prob.size()[0]*(T**2))
            #true label loss
            label_loss = ce_loss(student_logits,labels)

            #weighted sum of two losses
            loss = distillation_loss * soft_target_loss_weight + ce_loss_weight*label_loss
            loss.backward()
            optimizer.step()

            running_loss +=loss.item()

        print(f'epoch {epoch+1}/{epochs},loss = {running_loss / len(train_dataloader)}')

train_resnet_cnn(teacher,student,train_dataloader,T=2,epochs=10,learning_rate=0.001,device=device,soft_target_loss_weight=0.25,ce_loss_weight=0.75)
test_accuracy_light_ce_kd = test(student,test_dataloader,device)
            

Epoch 1/10:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 1/10: 100%|██████████| 1563/1563 [03:06<00:00,  8.37it/s]


epoch 1/10,loss = 0.6793455219162021


Epoch 2/10: 100%|██████████| 1563/1563 [03:05<00:00,  8.43it/s]


epoch 2/10,loss = 0.6676746448948836


Epoch 3/10: 100%|██████████| 1563/1563 [03:04<00:00,  8.46it/s]


epoch 3/10,loss = 0.6594219421699729


Epoch 4/10: 100%|██████████| 1563/1563 [03:04<00:00,  8.47it/s]


epoch 4/10,loss = 0.653477133655121


Epoch 5/10: 100%|██████████| 1563/1563 [03:05<00:00,  8.43it/s]


epoch 5/10,loss = 0.648380565742461


Epoch 6/10: 100%|██████████| 1563/1563 [03:04<00:00,  8.45it/s]


epoch 6/10,loss = 0.6424061305921046


Epoch 7/10: 100%|██████████| 1563/1563 [03:04<00:00,  8.45it/s]


epoch 7/10,loss = 0.6374071494021327


Epoch 8/10: 100%|██████████| 1563/1563 [03:05<00:00,  8.44it/s]


epoch 8/10,loss = 0.6329485592518719


Epoch 9/10: 100%|██████████| 1563/1563 [03:04<00:00,  8.46it/s]


epoch 9/10,loss = 0.6283573880076637


Epoch 10/10: 100%|██████████| 1563/1563 [03:04<00:00,  8.47it/s]


epoch 10/10,loss = 0.6240302789539233
Test Accuracy: 67.61%


In [85]:
def train_2teacher_model(teacher1,teacher2,student,train_dataloader,T=2,epochs=10,learning_rate=0.001,device=device,soft_target_loss_weight=0.5,ce_loss_weight=0.5):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(),lr=learning_rate)
    teacher1.eval()
    teacher2.eval()
    student.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            inputs,labels = inputs.to(device),labels.to(device)
            optimizer.zero_grad()

            #forward method with teacher model
            with torch.no_grad(): #we donot need to save gradients of teacher
                teacher1_logits = resnet_model(inputs)
                teacher2_logits = teacher(inputs)
            student_logits = student(inputs)
            #softening the distributions
            soft_targets1 = nn.functional.softmax(teacher1_logits/T,dim=1)
            soft_targets2 = nn.functional.softmax(teacher2_logits/T,dim=1)
            soft_prob = nn.functional.log_softmax(student_logits/T, dim= -1)
            soft_targets = (soft_targets1 + soft_targets2)/2
            #distillation loss
            distillation_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob) / soft_prob.size()[0]*(T**2))
            #true label loss
            label_loss = ce_loss(student_logits,labels)

            #weighted sum of two losses
            loss = distillation_loss * soft_target_loss_weight + ce_loss_weight*label_loss
            loss.backward()
            optimizer.step()

            running_loss +=loss.item()

        print(f'epoch {epoch+1}/{epochs},loss = {running_loss / len(train_dataloader)}')

train_2teacher_model(teacher1=resnet_model,teacher2=teacher,student=student,train_dataloader=train_dataloader,T=2,epochs=10,learning_rate=0.001,device=device,soft_target_loss_weight=0.25,ce_loss_weight=0.75)
test_accuracy_light_ce_kd = test(student,test_dataloader,device)
            

Epoch 1/10: 100%|██████████| 1563/1563 [07:21<00:00,  3.54it/s]


epoch 1/10,loss = 0.3631158641379229


Epoch 2/10: 100%|██████████| 1563/1563 [07:03<00:00,  3.69it/s]


epoch 2/10,loss = 0.35228893229462593


Epoch 3/10: 100%|██████████| 1563/1563 [07:02<00:00,  3.70it/s]


epoch 3/10,loss = 0.3474642373740635


Epoch 4/10: 100%|██████████| 1563/1563 [07:06<00:00,  3.66it/s]


epoch 4/10,loss = 0.344817390702355


Epoch 5/10: 100%|██████████| 1563/1563 [07:26<00:00,  3.50it/s]


epoch 5/10,loss = 0.33911166408278587


Epoch 6/10: 100%|██████████| 1563/1563 [12:52<00:00,  2.02it/s]


epoch 6/10,loss = 0.3359857910120251


Epoch 7/10: 100%|██████████| 1563/1563 [07:12<00:00,  3.61it/s]


epoch 7/10,loss = 0.33080340293608484


Epoch 8/10: 100%|██████████| 1563/1563 [06:53<00:00,  3.78it/s]


epoch 8/10,loss = 0.33205147951326536


Epoch 9/10: 100%|██████████| 1563/1563 [06:54<00:00,  3.77it/s]


epoch 9/10,loss = 0.3265010122965332


Epoch 10/10: 100%|██████████| 1563/1563 [06:55<00:00,  3.76it/s]


epoch 10/10,loss = 0.32474255322532936
Test Accuracy: 67.66%


In [89]:
student.parameters

<bound method Module.parameters of student_model(
  (layers): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=10, bias=True)
  )
)>