In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from models_package.models import Teacher, Student
import time
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import s3fs
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# new libraries
from data.data_loader import load_cifar10, load_cifar100, load_imagenet, load_prof
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights
from utils.loss_functions import tkd_kdloss
from torch.utils.data import Subset


import getpass
import os
from PIL import Image
import tarfile
import json

In [2]:
# access_key = getpass.getpass("Enter your access: ")

# secret_key = password = getpass.getpass("Enter your secret: ")

In [3]:
# # Run once to get images onto EC2

# wider_dir = './WIDER'
# if not os.path.exists(wider_dir):
#     os.makedirs(wider_dir)

# # Specify your S3 bucket and file path
# bucket_name = '210bucket'
# s3_file_path = 'wider_attribute_image.tgz'

# # Initialize an S3 filesystem
# s3 = s3fs.S3FileSystem(key=access_key, secret=secret_key)

# # Download the .tgz file from S3
# with s3.open(f"{bucket_name}/{s3_file_path}", 'rb') as s3_file:
#     with tarfile.open(fileobj=s3_file, mode="r:gz") as tar:
#         # Specify the destination directory where you want to store the extracted contents
#         extract_dir = wider_dir # Change this to your desired directory
#         tar.extractall(path=extract_dir)

# print("File downloaded and extracted successfully.")


In [4]:
# # Specify your S3 bucket and directory path
# s3_directory_path = 'wider_attribute_annotation/'

# local_directory = './WIDER/Annotations'  # Change this to your desired directory

# s3_files = s3.ls(f"{bucket_name}/{s3_directory_path}")


# # Create the local directory if it doesn't exist
# os.makedirs(local_directory, exist_ok=True)

# # Download each file from the S3 directory to the local directory
# for s3_file in s3_files:
#     # Get the filename from the S3 file path
#     filename = os.path.basename(s3_file)
    
#     # Download the file to the local directory
#     local_path = os.path.join(local_directory, filename)
#     with s3.open(s3_file, 'rb') as s3_file_obj:
#         with open(local_path, 'wb') as local_file:
#             local_file.write(s3_file_obj.read())

# print("Files downloaded successfully.")


# Load WIDER


In [5]:
# def make_wider(tag, value, data_path):
#     img_path = os.path.join(data_path, "Image")
#     ann_path = os.path.join(data_path, "Annotations")
#     ann_file = os.path.join(ann_path, "wider_attribute_{}.json".format(tag))

#     data = json.load(open(ann_file, "r"))

#     final = []
#     image_list = data['images']
#     for image in image_list:
#         for person in image["targets"]: # iterate over each person
#             tmp = {}
#             tmp['img_path'] = os.path.join(img_path, image['file_name'])
#             tmp['bbox'] = person['bbox']
#             attr = person["attribute"]
#             for i, item in enumerate(attr):
#                 if item == -1:
#                     attr[i] = 0
#                 if item == 0:
#                     attr[i] = value  # pad un-specified samples
#                 if item == 1:
#                     attr[i] = 1
#             tmp["target"] = attr
#             final.append(tmp)

#     json.dump(final, open("data/wider/{}_wider.json".format(tag), "w"))
#     print("data/wider/{}_wider.json".format(tag))

In [6]:
# #run once
# if not os.path.exists("data/wider"):
#     os.makedirs("data/wider")

# # 0 (zero) means negative, we treat un-specified attribute as negative in the trainval set
# make_wider(tag='trainval', value=0, data_path='WIDER') 
# make_wider(tag='test', value=99, data_path='WIDER')

In [7]:
class DataSet(Dataset):
    def __init__(self,
                ann_files,
                augs,
                img_size,
                dataset,
                ):
        self.dataset = dataset
        self.ann_files = ann_files
        self.augment = self.augs_function(augs, img_size)
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
            ] 
            # In this paper, we normalize the image data to [0, 1]
            # You can also use the so called 'ImageNet' Normalization method
        )
        self.anns = []
        self.load_anns()
        print(self.augment)

        # in wider dataset we use vit models
        # so transformation has been changed
        if self.dataset == "wider":
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
                ] 
            )        

    def augs_function(self, augs, img_size):            
        t = []
        if 'randomflip' in augs:
            t.append(transforms.RandomHorizontalFlip())
        if 'ColorJitter' in augs:
            t.append(transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0))
        if 'resizedcrop' in augs:
            t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)))
        if 'RandAugment' in augs:
            t.append(RandAugment())

        t.append(transforms.Resize((img_size, img_size)))

        return transforms.Compose(t)
    
    def load_anns(self):
        self.anns = []
        for ann_file in self.ann_files:
            json_data = json.load(open(ann_file, "r"))
            self.anns += json_data

    def __len__(self):
        return len(self.anns)

    def __getitem__(self, idx):
        idx = idx % len(self)
        ann = self.anns[idx]
        img = Image.open(ann["img_path"]).convert("RGB")
        
        if self.dataset == "wider":
            x, y, w, h = ann['bbox']
            img_area = img.crop([x, y, x+w, y+h])
            img_area = self.augment(img_area)
            img_area = self.transform(img_area)
        
            # Extract label from image path
            img_path = ann['img_path']
            label = None
        
            if "WIDER/Image/train" in img_path:
                # For images in the "train" folder, extract the numeric label after "train/"
                label_str = img_path.split("WIDER/Image/train/")[1].split("/")[0]
                label = int(label_str.split("--")[0])  # Extract the numeric part
            elif "WIDER/Image/test" in img_path:
                # For images in the "test" folder, extract the numeric label after "test/"
                label_str = img_path.split("WIDER/Image/test/")[1].split("/")[0]
                label = int(label_str.split("--")[0])  # Extract the numeric part

            message = {
                "label": label,
                "target": torch.Tensor(ann['target']),
                "img": img_area
            }
        return message

In [8]:
train_file = ['data/wider/trainval_wider.json']
test_file = ['data/wider/test_wider.json']
batch_size = 256

In [9]:
train_dataset = DataSet(train_file, augs = ['randomflip'], img_size = 224, dataset = 'wider')
# subset_indices = range(0, 1000)  # Select indices 0 to 999
# train_dataset = Subset(train_dataset, subset_indices)

test_dataset = DataSet(test_file, augs = [], img_size = 224, dataset = 'wider')
trainloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
testloader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)


Compose(
    RandomHorizontalFlip(p=0.5)
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
)
Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
)


In [10]:
train_dataset[100000]

{'label': 39,
 'target': tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]),
 'img': tensor([[[ 0.9686,  0.9765,  0.9765,  ...,  0.5216,  0.5294,  0.5373],
          [ 0.9765,  0.9765,  0.9765,  ...,  0.5294,  0.5451,  0.5608],
          [ 0.9765,  0.9765,  0.9765,  ...,  0.5529,  0.5608,  0.5686],
          ...,
          [-0.6392, -0.6314, -0.6235,  ...,  0.8275,  0.8196,  0.8118],
          [ 0.1216,  0.1373,  0.1608,  ...,  0.7882,  0.7882,  0.7882],
          [ 0.4980,  0.4980,  0.4980,  ...,  0.7255,  0.7333,  0.7333]],
 
         [[ 0.3569,  0.3647,  0.3725,  ...,  0.5529,  0.5608,  0.5686],
          [ 0.3882,  0.3961,  0.3961,  ...,  0.5529,  0.5765,  0.5922],
          [ 0.4275,  0.4275,  0.4353,  ...,  0.5608,  0.5686,  0.5765],
          ...,
          [-0.7490, -0.7412, -0.7333,  ...,  0.9059,  0.8902,  0.8824],
          [ 0.0980,  0.1137,  0.1451,  ...,  0.9059,  0.9059,  0.9059],
          [ 0.5529,  0.5529,  0.5608,  ...,  0.8431,  0.8431,  0.8431]],
 
  

In [None]:
none_indices = []

# Iterate through the dataset to find 'None' values
for idx in range(len(train_dataset)):
    try:
        sample = train_dataset[idx]
        if None in sample['label'] or None in sample['target'] or None in sample['img']:
            none_indices.append(idx)
    except Exception as e:
        print(f"Error encountered for sample at index {idx}:\n{e}")
        print(sample)
        # Handle the error as needed

Error encountered for sample at index 0:
argument of type 'int' is not iterable
{'label': 0, 'target': tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.]), 'img': tensor([[[-0.0275,  0.0039,  0.0510,  ..., -0.1137, -0.1294, -0.1608],
         [-0.0118,  0.0196,  0.0510,  ..., -0.1059, -0.1294, -0.1686],
         [-0.0196, -0.0118,  0.0039,  ..., -0.0667, -0.0980, -0.1373],
         ...,
         [-0.3647, -0.3569, -0.3412,  ..., -0.7333, -0.7176, -0.7020],
         [-0.5294, -0.5137, -0.5059,  ..., -0.8196, -0.8039, -0.7804],
         [-0.5922, -0.5922, -0.5922,  ..., -0.8588, -0.8431, -0.8196]],

        [[ 0.1137,  0.0196, -0.1373,  ..., -0.0667, -0.0980, -0.1373],
         [ 0.1529,  0.0588, -0.0980,  ..., -0.0510, -0.0980, -0.1451],
         [ 0.1529,  0.0667, -0.0824,  ..., -0.0118, -0.0667, -0.1137],
         ...,
         [-0.3255, -0.3176, -0.3020,  ..., -0.7176, -0.7098, -0.7020],
         [-0.4902, -0.4745, -0.4667,  ..., -0.8039, -0.7961, -0.7804],
         [-0.

In [None]:
none_indices

# Start Training Process

In [11]:
def compare_model_size(teacher, student):
    teacher_params = sum(p.numel() for p in teacher.parameters())
    student_params = sum(p.numel() for p in student.parameters())
    return teacher_params, student_params

def compare_inference_time(teacher, student, dataloader):
    data = next(iter(dataloader))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    teacher = teacher.to(device)
    student = student.to(device)
    inputs = data['img'].to(device)
    
    start_time = time.time()
    with torch.no_grad():
        teacher_outputs = teacher(inputs)
    teacher_time = time.time() - start_time

    start_time = time.time()
    with torch.no_grad():
        student_outputs = student(inputs)
    student_time = time.time() - start_time
    
    return teacher_time, student_time

def compare_performance_metrics(teacher, student, dataloader):
    teacher.eval()
    student.eval()
    
    all_labels = []
    all_teacher_preds = []
    all_student_preds = []

    for i, data in enumerate(dataloader):
        with torch.no_grad():
            teacher_outputs = teacher(data['img'].to(device))
            student_outputs = student(data['img'].to(device))
        all_labels.append(data['label'].cpu().numpy())
        all_teacher_preds.append(torch.argmax(teacher_outputs, dim=1).cpu().numpy())
        all_student_preds.append(torch.argmax(student_outputs, dim=1).cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_teacher_preds = np.concatenate(all_teacher_preds)
    all_student_preds = np.concatenate(all_student_preds)
    
    metrics = {
        'accuracy': (accuracy_score(all_labels, all_teacher_preds), accuracy_score(all_labels, all_student_preds)),
        'precision': (precision_score(all_labels, all_teacher_preds, average='weighted', zero_division=0), precision_score(all_labels, all_student_preds, average='weighted', zero_division=0)),  # Updated line
        'recall': (recall_score(all_labels, all_teacher_preds, average='weighted'), recall_score(all_labels, all_student_preds, average='weighted')),
        'f1': (f1_score(all_labels, all_teacher_preds, average='weighted'), f1_score(all_labels, all_student_preds, average='weighted'))
    }

    return metrics

def plot_comparison(labels, teacher_values, student_values, title, ylabel):
    # Convert parameter count to millions
    if 'Parameter Count' in title or 'Parameter Count' in ylabel:
        teacher_values = [value / 1e6 for value in teacher_values]
        student_values = [value / 1e6 for value in student_values]

    x = np.arange(len(labels))  # the label locations
    width = 0.35  # the width of the bars

    fig, ax = plt.subplots()
    rects1 = ax.bar(x - width/2, teacher_values, width, label='Teacher')
    rects2 = ax.bar(x + width/2, student_values, width, label='Student')

    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()

    fig.tight_layout()

    plt.show()

In [12]:
# Hyperparameters
learning_rate = 0.01379 # 0.096779
num_epochs = 3  # 200
num_workers = 2
batch_size = 256
temperature = 4.0
alpha = 0.9
momentum = 0.9
num_classes = 61
step_size = 30
gamma = 0.1
beta = 0.5

# new parameters
# lr_input = 0.1
# momentum_input = 0.9
weight_decay_input = 5e-4
# epochs = 20
# T = 4.0 # temperatureture
# alpha = 0.9
patience = 5  # for early stopping


In [13]:
# # Load IdenProf dataset
# train_path = '/home/ubuntu/capstone/W210-Capstone/notebooks/idenprof/train'
# test_path = '/home/ubuntu/capstone/W210-Capstone/notebooks/idenprof/test'
# trainloader, testloader  = load_prof(train_path, test_path, batch_size=batch_size)

In [14]:
# Instantiate the models
###################### Testing 1 ######################
# Create instances of your models
teacher_model = torchvision.models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1).cuda()
teacher_model.eval()  # Set teacher model to evaluation mode
student_model = torchvision.models.resnet18(weights=None).cuda()

In [15]:
# # Instantiate the models
# ###################### Testing 2 ######################
# # Create instances of your models
# teacher_model = Teacher()
# teacher_model.eval()  # Set teacher model to evaluation mode
# student_model = Student()

In [16]:
# Optimizer and scheduler for the student model
optimizer = optim.SGD(student_model.parameters(), lr=learning_rate, momentum=momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# Optimizer and scheduler for the teacher model
teacher_optimizer = optim.SGD(teacher_model.parameters(), lr=learning_rate, momentum=momentum)
teacher_scheduler = torch.optim.lr_scheduler.StepLR(teacher_optimizer, step_size=step_size, gamma=gamma)

criterion = nn.CrossEntropyLoss()

# Assuming the device is a CUDA device if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [17]:
def recall_disparity_loss(outputs, targets, attributes, num_classes):
    """
    Compute the recall disparity loss.

    :param outputs: Tensor of shape (batch_size, num_classes), model's class probabilities or logits.
    :param targets: Tensor of shape (batch_size,), true class indices.
    :param attributes: Tensor of shape (batch_size, num_attributes), binary attributes for each instance.
    :param num_classes: int, number of classes.
    """
    # Ensure we're working with probabilities
    probs = torch.softmax(outputs, dim=1)
    preds = torch.argmax(probs, dim=1)

    # Initialize tensors to store recall for when attributes are present and absent
    recall_when_present = torch.zeros(num_classes, attributes.size(1))
    recall_when_absent = torch.zeros(num_classes, attributes.size(1))

    # Calculate recall for each class based on attributes being present or absent
    for class_idx in range(num_classes):
        for attr_idx in range(attributes.size(1)):
            # Indices of instances with the current class and attribute present/absent
            class_and_attr_present = (targets == class_idx) & (attributes[:, attr_idx] == 1)
            class_and_attr_absent = (targets == class_idx) & (attributes[:, attr_idx] == 0)

            # True positives for current class when attribute is present/absent
            true_positive_present = ((preds == class_idx) & class_and_attr_present).sum().float()
            true_positive_absent = ((preds == class_idx) & class_and_attr_absent).sum().float()

            # Condition positives for current class when attribute is present/absent
            condition_positive_present = class_and_attr_present.sum().float()
            condition_positive_absent = class_and_attr_absent.sum().float()

            # Recall for current class when attribute is present/absent
            recall_when_present[class_idx, attr_idx] = true_positive_present / (condition_positive_present + 1e-8)
            recall_when_absent[class_idx, attr_idx] = true_positive_absent / (condition_positive_absent + 1e-8)

    # Disparity is the absolute difference between recall when attribute is present and when it's absent
    disparity = (recall_when_present - recall_when_absent).abs()

    # Final loss is the mean of disparity across all classes and attributes
    loss = disparity.mean()

    return loss



In [18]:
# #### finding the optimal learning rate
# def train_teacher(model, trainloader, criterion, optimizer, scheduler, device, num_epochs=5, lr_range=(1e-4, 1e-1), plot_loss=True):
#     model.train()
#     model.to(device)
#     lr_values = np.logspace(np.log10(lr_range[0]), np.log10(lr_range[1]), num_epochs * len(trainloader))  # Generate learning rates for each batch
#     lr_iter = iter(lr_values)
#     losses = []
#     lrs = []
    
#     for epoch in range(num_epochs):
#         for i, (inputs, labels, annotation) in enumerate(tqdm(trainloader)):
#             lr = next(lr_iter)
#             for param_group in optimizer.param_groups:
#                 param_group['lr'] = lr  # Set new learning rate
            
#             inputs, labels = inputs.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
            
#             losses.append(loss.item())
#             lrs.append(lr)
    
#     # Calculate the derivative of the loss
#     loss_derivative = np.gradient(losses)
    
#     # Find the learning rate corresponding to the minimum derivative (steepest decline)
#     best_lr_index = np.argmin(loss_derivative)
#     best_lr = lrs[best_lr_index]
    
#     if plot_loss:
#         import matplotlib.pyplot as plt
#         plt.figure()
#         plt.plot(lrs, losses)
#         plt.xscale('log')
#         plt.xlabel('Learning Rate')
#         plt.ylabel('Loss')
#         plt.title('Learning Rate Range Test')
#         plt.axvline(x=best_lr, color='red', linestyle='--', label=f'Best LR: {best_lr}')
#         plt.legend()
#         plt.show()
    
#     print(f'Best learning rate: {best_lr}')
#     return best_lr

# ############# input ############## 
# batch_size = 16  #to find the optimal learning rate
# best_lr = train_teacher(teacher_model, trainloader, criterion, teacher_optimizer, teacher_scheduler, device, num_epochs=3)  
# print(best_lr)

In [19]:
# Function to train the teacher model
def train_teacher(model, trainloader, criterion, optimizer, scheduler, device, num_epochs=1, patience=5):
    model.train()
    model.to(device)
    best_train_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        running_loss = 0.0
        epoch_loss = 0.0  
        num_batches = 0  
        for index, data in enumerate(tqdm(trainloader)):
            inputs = data['img'].cuda()
            labels = data['label'].cuda()
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            epoch_loss += loss.item()
            num_batches += 1
            if index % 100 == 99:  # Print every 100 mini-batches
                print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")
                running_loss = 0.0

        epoch_loss /= num_batches  
        
        # Check for early stopping
        if epoch_loss < best_train_loss:
            best_train_loss = epoch_loss
            patience_counter = 0 
            # checkpoint
            torch.save(model.state_dict(), f'teacher_model_weights_ckd_prof_checkpoint.pth')
            torch.save(model, f'teacher_model_ckd_prof_checkpoint.pth')

        else:
            patience_counter += 1

        if patience_counter >= patience:
            print('Early stopping')
            break

        scheduler.step()

    print("Finished Training Teacher")


# Function to train the student model with knowledge distillation
def train_student_with_distillation_disparity(student, teacher, trainloader, criterion, optimizer, scheduler, device, alpha, temperature, num_epochs, patience=5):
    student.train()
    teacher.eval()
    student.to(device)
    teacher.to(device)
    best_train_loss = float('inf')  
    patience_counter = 0 

    for epoch in range(num_epochs):
        running_loss = 0.0 
        epoch_loss = 0.0  
        num_batches = 0  
        epoch_disparity = 0.0
        running_recall = 0.0

        for index, data in enumerate(tqdm(trainloader)):
            inputs = data['img'].cuda()
            labels = data['label'].cuda()
            annot = data['target'].cuda()
            optimizer.zero_grad()
            student_outputs = student(inputs)
            with torch.no_grad():
                teacher_outputs = teacher(inputs)
        
            ce_loss = criterion(student_outputs, labels)
            kd_loss = tkd_kdloss(student_outputs, teacher_outputs, temperature=temperature)  # from utils.loss_functions
        
            # Calculate the recall difference loss
            recall_difference = recall_disparity_loss(student_outputs, labels, annot, num_classes=61)
        
            # Combine the losses
            alpha = 0.5  # Replace with your actual hyperparameter values
            beta = 0.2
            loss = alpha * kd_loss + (1 - alpha) * ce_loss + beta * recall_difference

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            epoch_loss += loss.item()
            epoch_disparity += recall_difference
            num_batches += 1
    
            if index % 100 == 99:  
                print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")
                running_loss = 0.0
    
        epoch_loss /= num_batches
        print(epoch_loss)

        # Check for early stopping
        if epoch_loss < best_train_loss:
            best_train_loss = epoch_loss
            patience_counter = 0 
            torch.save(student.state_dict(), f'student_model_weights_ckd_prof_checkpoint.pth')
            torch.save(student, f'student_model_ckd_prof_checkpoint.pth')
        else:
            patience_counter += 1 

        if patience_counter >= patience:
            print('Early stopping')
            break  

        scheduler.step() 

    print("Finished Training Student")

In [20]:
# Assuming the device is a CUDA device if available
num_epochs = 7
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Call the function to train the teacher model
train_teacher(teacher_model, trainloader, criterion, teacher_optimizer, teacher_scheduler, device, num_epochs=num_epochs)

# Call the function to train the student model with knowledge distillation
train_student_with_distillation_disparity(student_model, teacher_model, trainloader, criterion, optimizer, scheduler, device, alpha, temperature, num_epochs=num_epochs)


  0%|                                                   | 0/111 [00:03<?, ?it/s]


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 127, in collate
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 127, in <dictcomp>
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 183, in collate_int_fn
    return torch.tensor(batch)
RuntimeError: Could not infer dtype of NoneType


In [None]:
###################### Testing 1 ######################
# Save the student and teacher model weights and architecture
# torch.save(student_model.state_dict(), 'student_model_weights_ckd_prof.pth')
# torch.save(student_model, 'student_model_ckd_prof.pth')
# print('student weights and architecture saved and exported')

# torch.save(teacher_model.state_dict(), 'teacher_model_weights_ckd_prof.pth')
# torch.save(teacher_model, 'teacher_model_ckd_prof.pth')
# print('teacher weights and architecture saved and exported')

In [None]:
# ###################### Testing 2 ######################
# # Save the student and teacher model weights and architecture
# torch.save(student_model.state_dict(), 'student_model_weights_ckd_2.pth')
# torch.save(student_model, 'student_model_ckd_2.pth')
# print('weights and architecture saved and exported')

# torch.save(teacher_model.state_dict(), 'teacher_model_weights_ckd_2.pth')
# torch.save(teacher_model, 'teacher_model_ckd_2.pth')
# print('teacher weights and architecture saved and exported')

In [None]:
# Call the comparison and plotting functions after training
teacher_params, student_params = compare_model_size(teacher_model, student_model)
teacher_time, student_time = compare_inference_time(teacher_model, student_model, trainloader)
performance_metrics = compare_performance_metrics(teacher_model, student_model, trainloader)

# Extracting the metric values for plotting
performance_labels = ['accuracy', 'precision', 'recall', 'f1']
teacher_performance_values = [performance_metrics[metric][0] for metric in performance_labels]
student_performance_values = [performance_metrics[metric][1] for metric in performance_labels]

# Plotting the comparison for performance metrics
plot_comparison(performance_labels, teacher_performance_values, student_performance_values, 'Performance Comparison', 'Score')

# Plotting the comparison for model size
model_size_labels = ['Model Size']
teacher_model_size_values = [teacher_params]
student_model_size_values = [student_params]
plot_comparison(model_size_labels, teacher_model_size_values, student_model_size_values, 'Model Size Comparison', 'Parameter Count (millions)')

# Plotting the comparison for inference time
inference_time_labels = ['Inference Time']
teacher_inference_time_values = [teacher_time]
student_inference_time_values = [student_time]
plot_comparison(inference_time_labels, teacher_inference_time_values, student_inference_time_values, 'Inference Time Comparison', 'Time (s)')

In [None]:
performance_metrics

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# teacher_model = torchvision.models.resnet34(weights=None)

# weights_path = 'teacher_model_weights_ckd_prof.pth'

# teacher_model.load_state_dict(torch.load(weights_path))

# student_model = torchvision.models.resnet18(weights=None)

# weights_path = 'student_model_weights_ckd_prof.pth'

# student_model.load_state_dict(torch.load(weights_path))

# Call the comparison and plotting functions after training
teacher_params, student_params = compare_model_size(teacher_model, student_model)
teacher_time, student_time = compare_inference_time(teacher_model, student_model, testloader)
performance_metrics = compare_performance_metrics(teacher_model, student_model, testloader)

# Extracting the metric values for plotting
performance_labels = ['accuracy', 'precision', 'recall', 'f1']
teacher_performance_values = [performance_metrics[metric][0] for metric in performance_labels]
student_performance_values = [performance_metrics[metric][1] for metric in performance_labels]

# Plotting the comparison for performance metrics
plot_comparison(performance_labels, teacher_performance_values, student_performance_values, 'Performance Comparison', 'Score')

# Plotting the comparison for model size
model_size_labels = ['Model Size']
teacher_model_size_values = [teacher_params]
student_model_size_values = [student_params]
plot_comparison(model_size_labels, teacher_model_size_values, student_model_size_values, 'Model Size Comparison', 'Parameter Count (millions)')

# Plotting the comparison for inference time
inference_time_labels = ['Inference Time']
teacher_inference_time_values = [teacher_time]
student_inference_time_values = [student_time]
plot_comparison(inference_time_labels, teacher_inference_time_values, student_inference_time_values, 'Inference Time Comparison', 'Time (s)')

In [None]:
import torch

def calculate_recall(preds, targets, attribute, attr_value):
    """
    Calculate recall for instances where the binary attribute has a specific value.
    
    :param preds: Predicted classes (batch_size,)
    :param targets: True classes (batch_size,)
    :param attribute: Binary attribute values for the batch (batch_size,)
    :param attr_value: The value of the attribute to consider (0 or 1)
    :return: Recall value.
    """
    relevant = (attribute == attr_value)
    condition_positive = (targets[relevant]).sum().float()
    true_positive = ((preds == targets) & relevant).sum().float()
    recall = true_positive / (condition_positive + 1e-8)
    return relevant

def evaluate_disparity(model, dataloader, num_classes):
    """
    Evaluate the disparity on the test data.

    :param model: The trained model.
    :param dataloader: DataLoader for the test data.
    :param num_classes: The number of classes in the dataset.
    :return: Average disparity across all attributes and classes.
    """
    model.eval()  # Set the model to evaluation mode
    total_disparity = 0.0
    num_batches = 0

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            inputs = batch['img'].cuda()
            targets = batch['label'].cuda()
            attributes = batch['target'].cuda()

            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            batch_disparity = 0.0
            for attr_idx in range(attributes.size(1)):
                recall_present = calculate_recall(preds, targets, attributes[:, attr_idx], 1)
                recall_absent = calculate_recall(preds, targets, attributes[:, attr_idx], 0)
                print('*'*50)
                print(f'batch: {i}, index: {attr_idx}')
                print(recall_present)
                print(recall_absent)
                print('*'*50)
                batch_disparity += (recall_present - recall_absent).abs().item()

            total_disparity += batch_disparity / attributes.size(1)  # Average disparity per attribute
            num_batches += 1
            break
    average_disparity = total_disparity / num_batches  # Average disparity across all batches
    return average_disparity

# Example usage:
# Assume 'test_dataloader' is a DataLoader for your test dataset
# and 'model' is your trained model.

# Evaluate the model's performance on disparity
disparity = evaluate_disparity(student_model, testloader, num_classes=61)
print(f'Average recall disparity across all attributes and classes: {disparity}')