In [1]:
import time
import torch
import json
import random
import math
import warnings
import torchvision
import os
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from datetime import datetime
from tqdm import tqdm
from PIL import Image
from sklearn import metrics
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import pairwise_distances_argmin_min, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, precision_recall_fscore_support
from sklearn.model_selection import train_test_split, StratifiedKFold
from torch.utils.data import Subset, Dataset, DataLoader
from torchvision.models import EfficientNet

from torchvision.datasets import ImageFolder
from utils.loss_functions import tkd_kdloss
# from models_package.models import Teacher, Student

# Suppress all warnings
warnings.filterwarnings("ignore")

In [None]:
from utils.loss_functions import DKDLoss
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from models_package.models import Teacher, Student
from torchvision import datasets, transforms, models
import models_package
import time
from datetime import datetime
import json
import random
import logging
from pathlib import Path
import argparse
import warnings
from torch.utils.tensorboard import SummaryWriter
import pdb
import time
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Subset, Dataset, DataLoader
from PIL import Image
from pathlib import Path
from collections import OrderedDict
import os, shutil

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os
# new libraries
from data.data_loader import load_cifar10, load_cifar100, load_imagenet, load_prof
import boto3
import io
import models_package
from utils.loss_functions import DKDLoss, DirectNormLoss, KDLoss
from utils.compare_tools import compare_model_size, compare_inference_time, compare_performance_metrics, plot_comparison
from utils.misc_tools import colorstr, Save_Checkpoint, AverageMeter, epoch_loop_reviewkd
from utils.misc_tools import best_LR, best_LR_nd, best_LR_wider, train_teacher, train_teacher_wider, train_teacher_efficientnet, train_teacher_efficientnet_wider, retrieve_teacher_class_weights, new_teacher_class_weights


In [28]:
# Hyperparameters
learning_rate = 0.001 # 0.096779
epochs = 300
epochs_pretrain = 3
epochs_optimal_lr = 5
patience_teacher = 7
patience_student = 10
temperature = 4.0
alpha = 0.9
momentum = 0.9
step_size = 30
gamma = 0.1
batch_size = 64
num_workers = 4

# set to true to use stratified sampling
stratified_sampling_flag = False

# list of lambda values to loop through for grid search
lmda_list_student = [10,5,3,0.5,0]
lmda_list_teacher = [10,5,3,0.5,0]

# labels used including for plotting
class_labels = [0, 1, 3, 4, 6, 7, 11, 15, 17, 18, 19, 20, 22, 25, 27, 28, 30, 31, 33, 35, 36, 37, 39, 43, 44, 50, 51, 54, 57, 58]
class_labels_new = torch.tensor([i for i in range(len(class_labels))])
num_classes = 16
class_names_new = [f"Class {label}" for label in range(num_classes)]

# Create directory and file path to save all outputs
output_dir = f'./runs_{datetime.now().strftime("%Y_%m_%d_%H_%M")}'
os.makedirs(output_dir, exist_ok=True)

In [29]:
# set device to cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [30]:
# Assuming your JSON file is named 'your_file.json'
file_path = './WIDER/Annotations/wider_attribute_trainval.json'

with open(file_path, 'r') as file:
    # Load the JSON data from the file
    data = json.load(file)

class_idx = data['scene_id_map']

In [31]:
new_label_mapping = {
    0: "Team_Sports",
    1: "Celebration",
    2: "Parade",
    3: "Waiter_Or_Waitress",
    4: "Individual_Sports",
    5: "Surgeons",
    6: "Spa",
    7: "Law_Enforcement",
    8: "Business",
    9: "Dresses",
    10: "Water_Activities",
    11: "Picnic",
    12: "Rescue",
    13: "Cheering",
    14: "Performance_And_Entertainment",
    15: "Family"
}

# Ensure that all 16 new classes are covered
# If some classes are not explicitly mentioned in new_label_mapping, add them
for i in range(num_classes):
    if i not in new_label_mapping:
        new_label_mapping[i] = "Additional Category {}".format(i)

class_idx = new_label_mapping

In [32]:
class StratifiedBatchSampler:
    """Stratified batch sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, y, batch_size, shuffle=True):
        if torch.is_tensor(y):
            y = y.numpy()
        assert len(y.shape) == 1, 'label array must be 1D'
        n_batches = int(len(y) / batch_size)
        self.skf = StratifiedKFold(n_splits=n_batches, shuffle=shuffle)
        self.X = torch.randn(len(y),1).numpy()
        self.y = y
        self.shuffle = shuffle
        self.batch_size = batch_size

    def __iter__(self):
        if self.shuffle:
            self.skf.random_state = torch.randint(0,int(1e8),size=()).item()
        for train_idx, test_idx in self.skf.split(self.X, self.y):
            yield test_idx

    def __len__(self):
        return len(self.y)//self.batch_size

In [33]:
class DataSet(Dataset):
    def __init__(self, ann_files, augs, img_size, dataset):

        # Create a mapping from old labels to new labels
        self.label_mapping = {old_label: new_label for new_label, old_label in enumerate(sorted(class_labels))}

        self.new_label_mapping = {
            0: 2,  # Parade
            1: 8,  # Business
            2: 7,  # Law Enforcement
            3: 14,  # Performance and Entertainment
            4: 1,  # Celebration
            5: 13,  # Cheering
            6: 8,  # Business
            7: 8,  # Business
            8: 1,  # Celebration
            9: 14,  # Performance and Entertainment
            10: 15, # Family
            11: 15, # Family
            12: 11, # Picnic
            13: 7, # Law Enforcement
            14: 6, # Spa
            15: 13, # Cheering
            16: 5, # Surgeons
            17: 3, # Waiter or Waitress
            18: 4, # Individual Sports
            19: 0, # Team Sports
            20: 0, # Team Sports
            21: 0, # Team Sports
            22: 4, # Individual Sports
            23: 10, # Water Activities
            24: 4, # Individual Sports
            25: 1, # Celebration
            26: 9, # Dresses
            27: 12, # Rescue
            28: 10,# Water Activities
            29: 0  # Team Sports
        }

        
        self.dataset = dataset
        self.ann_files = ann_files
        self.augment = self.augs_function(augs, img_size)
        # Initialize transformations directly
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
            ] 
        )
        if self.dataset == "wider":
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])                ] 
            )        

        self.anns = []
        self.load_anns()
        print(self.augment)

    def augs_function(self, augs, img_size):            
        t = []
        if 'randomflip' in augs:
            t.append(transforms.RandomHorizontalFlip())
        if 'ColorJitter' in augs:
            t.append(transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0))
        if 'resizedcrop' in augs:
            t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)))
        if 'RandAugment' in augs:
            t.append(transforms.RandAugment())

        t.append(transforms.Resize((img_size, img_size)))

        return transforms.Compose(t)
    
    def load_anns(self):
        self.anns = []
        for ann_file in self.ann_files:
            json_data = json.load(open(ann_file, "r"))
            self.anns += json_data

    def __len__(self):
        return len(self.anns)

    def __getitem__(self, idx):
        # Make sure the index is within bounds
        idx = idx % len(self)
        ann = self.anns[idx]
        
        try:
            # Attempt to open the image file
            img = Image.open(f'WIDER/Image/{ann["file_name"]}').convert("RGB")

            # If this is the wider dataset, proceed with specific processing
            # x, y, w, h = ann['bbox']
            # img_area = img.crop([x, y, x+w, y+h])
            img_area = self.augment(img)
            img_area = self.transform(img_area)
            attributes_list = [target['attribute'] for target in ann['targets']]
            num_people = len(attributes_list)
            attributes_distribution = [max(sum(attribute), 0)/num_people for attribute in zip(*attributes_list)]
            # Extract label from image path
            img_path = f'WIDER/Image/{ann["file_name"]}'
            label = self.extract_label(img_path)  # You might need to implement this method
            
            return {
                "label": label,
                "target": torch.tensor([attributes_distribution[0]], dtype=torch.float32),
                "img": img_area
            }
            
        except Exception as e:
            # If any error occurs during the processing of an image, log the error and the index
            print(f"Error processing image at index {idx}: {e}")
            # Instead of returning None, raise the exception
            raise

    def extract_label(self, img_path):
        original_label = None
    
        if "WIDER/Image/train" in img_path:
            label_str = img_path.split("WIDER/Image/train/")[1].split("/")[0]
            original_label = int(label_str.split("--")[0])
        elif "WIDER/Image/test" in img_path:
            label_str = img_path.split("WIDER/Image/test/")[1].split("/")[0]
            original_label = int(label_str.split("--")[0])
        elif "WIDER/Image/val" in img_path:  # Handle validation images
            label_str = img_path.split("WIDER/Image/val/")[1].split("/")[0]
            original_label = int(label_str.split("--")[0])
    
        if original_label is not None:
            remapped_label = self.label_mapping[original_label]
            new_label_mapping = self.new_label_mapping[remapped_label]
            return new_label_mapping
        else:
            raise ValueError(f"Label could not be extracted from path: {img_path}")


In [34]:
train_file = ['data/wider/trainval_wider.json']
test_file = ['data/wider/test_wider.json']


In [35]:
def custom_collate(batch):
    # Filter out any None items in the batch
    batch = [item for item in batch if item is not None]
    # If after filtering the batch is empty, handle this case by either returning an empty tensor or raising an exception
    if len(batch) == 0:
        raise ValueError("Batch is empty after filtering out None items.")
    return torch.utils.data.dataloader.default_collate(batch)


In [36]:
# train_dataset = DataSet(train_file, augs = ['RandAugment'], img_size = 226, dataset = 'wider')
train_dataset = DataSet(train_file, augs = [], img_size = 226, dataset = 'wider')
test_dataset = DataSet(test_file, augs = [], img_size = 226, dataset = 'wider')

if stratified_sampling_flag:
    trainloader = DataLoader(train_dataset, 
                             batch_sampler=StratifiedBatchSampler(torch.tensor([train_dataset[i]['label'] for i in range(len(train_dataset))]), 
                             batch_size=batch_size), num_workers=num_workers, collate_fn=custom_collate)
    testloader = DataLoader(test_dataset, batch_sampler=StratifiedBatchSampler(torch.tensor([test_dataset[i]['label'] for i in range(len(test_dataset))]), 
                             batch_size=batch_size), num_workers=num_workers, collate_fn=custom_collate)
else:
    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=num_workers, collate_fn=custom_collate)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=custom_collate)

Compose(
    Resize(size=(226, 226), interpolation=bilinear, max_size=None, antialias=warn)
)
Compose(
    Resize(size=(226, 226), interpolation=bilinear, max_size=None, antialias=warn)
)


In [37]:
len(train_dataset)

10324

In [38]:
# def print_batch_class_counts(data_loader, label_mapping, num_batches=5):
#     for i, batch in enumerate(data_loader):
#         if i >= num_batches:
#             break

#         # Extract labels from the batch
#         labels = batch['label']

#         # Count occurrences of each class
#         class_counts = torch.bincount(labels)

#         # Map class counts to class names
#         class_counts_with_names = {label_mapping.get(j, f"Unknown Class {j}"): class_counts[j].item() for j in range(len(class_counts))}

#         # Print class counts and total observations
#         print(f"Batch {i + 1}:")
#         for class_name, count in class_counts_with_names.items():
#             print(f"    {class_name}: {count}")
#         print(f"Total Observations: {len(labels)}\n")

# print_batch_class_counts(trainloader, new_label_mapping)


# Start Training Process

In [39]:
def one_hot_encode(labels, num_classes):
    return np.eye(num_classes)[labels]

def calculate_recall_multiclass(conf_matrix):
    recalls = np.diag(conf_matrix) / np.sum(conf_matrix, axis=1)
    recalls[np.isnan(recalls)] = 0  # Replace NaN with 0
    return recalls

def evaluate_model_with_gender_multiclass(pred, label, gender, num_classes):
    predictions = pred.cpu()
    true_labels = label.cpu()
    gender = gender.cpu()

    # Identify male and female indices based on the gender threshold
    male_indices = np.where(gender >= 0.5)[0]
    female_indices = np.where(gender < 0.5)[0]

    # Convert labels to one-hot encoding
    one_hot_labels = one_hot_encode(true_labels, num_classes=num_classes)
    one_hot_preds = one_hot_encode(predictions, num_classes=num_classes)
    # Initialize recall arrays
    male_recall = np.zeros(num_classes)
    female_recall = np.zeros(num_classes)

    # Extract predictions and labels for male and female indices
    male_predictions = np.argmax(one_hot_preds[male_indices], axis=1)
    female_predictions = np.argmax(one_hot_preds[female_indices], axis=1)
    male_labels = np.argmax(one_hot_labels[male_indices], axis=1)
    female_labels = np.argmax(one_hot_labels[female_indices], axis=1)

    # Check if the class labels are within the expected range
    assert (0 <= male_predictions.min() < num_classes) and (0 <= male_predictions.max() < num_classes), "Invalid class indices in male predictions"
    assert (0 <= female_predictions.min() < num_classes) and (0 <= female_predictions.max() < num_classes), "Invalid class indices in female predictions"
    assert (0 <= male_labels.min() < num_classes) and (0 <= male_labels.max() < num_classes), "Invalid class indices in male labels"
    assert (0 <= female_labels.min() < num_classes) and (0 <= female_labels.max() < num_classes), "Invalid class indices in female labels"

    # Calculate confusion matrices for each gender
    male_conf_matrix = confusion_matrix(male_labels, male_predictions, labels=np.arange(num_classes))
    female_conf_matrix = confusion_matrix(female_labels, female_predictions, labels=np.arange(num_classes))

    # Calculate recall for each class and gender
    male_recall[:len(male_conf_matrix)] = calculate_recall_multiclass(male_conf_matrix)
    female_recall[:len(female_conf_matrix)] = calculate_recall_multiclass(female_conf_matrix)

    return male_recall - female_recall, male_conf_matrix, female_conf_matrix


In [None]:
##### HELPER FUNCTION FOR FEATURE EXTRACTION

def get_features(name):
    def hook(model, input, output):
        features[name] = output.detach()
    return hook

In [40]:
# Instantiate the models
###################### Testing 1 ######################
# Create instances of your models
teacher_model = torchvision.models.efficientnet_b3(weights='DEFAULT')
teacher_model.classifier = nn.Linear(1536, num_classes)
student_model = torchvision.models.efficientnet_b0(weights='DEFAULT')
student_model.classifier = nn.Linear(1280, num_classes)

# Load teacher
# teacher_model = torch.load('teacher_model_ckd_wider.pth')
# teacher_model.load_state_dict(torch.load('teacher_model_weights_ckd_wider.pth'))
# torch.save(teacher_model.state_dict(), 'teacher_model_weights_ckd_wider.pth')
# # Load the studnet
# student_model = torch.load('student_model_ckd_prof.pth')
# student_model.load_state_dict(torch.load('student_model_weights_ckd_prof_checkpoint.pth'))
# student_model = student_model.to(device)


This is the initialization of the 2-layer Adversary Perceptron. It is initialized with the number of classes*2, which represents the predicted labels (y_hat) and the true labels (y). The output of the final layer is a regression output, which is intended to predict the strength of gender (continuous number where anything past 0.5 is more male).


In [41]:
class Adversary(nn.Module):
    def __init__(self, input_size=num_classes):
        super(Adversary, self).__init__()

        self.a1 = nn.Linear(input_size*2, 16)
        self.a2 = nn.Linear(16, 1)  # Output size 1 for regression
        nn.init.xavier_normal_(self.a1.weight)
        nn.init.kaiming_normal_(self.a2.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, input_ids):
        adversary = F.relu(self.a1(input_ids))
        adversary_output = F.sigmoid(self.a2(adversary))  # Linear activation for regression
        return adversary_output

# Instantiate the Adversary
adv = Adversary()

In [42]:
def pretrain_student(student, teacher, trainloader, criterion, optimizer, device, alpha, temperature, epochs_pretrain, patience=patience_student):
    teacher.eval()
    teacher.to(device)
    best_val_loss = float('inf')  
    patience_counter = 0 
    student_epoch_losses = []
    val_losses = []
    
    for epoch in range(epochs_pretrain):
        student.train()
        student.to(device)
        running_loss = 0.0 
        epoch_loss = 0.0  
        num_batches = 0  
        
        for index, data in enumerate(tqdm(trainloader)):

            inputs = data['img'].to(device)
            labels = data['label'].to(device)
            optimizer.zero_grad()
            student_outputs = student(inputs)

            with torch.no_grad():
                teacher_outputs = teacher(inputs)

            ce_loss = criterion(student_outputs, labels)
            kd_loss = tkd_kdloss(student_outputs, teacher_outputs, temperature=temperature)  # Make sure this returns a scalar
            
            # If not scalar, sum up to make sure the loss is scalar
            if kd_loss.ndim != 0:
                kd_loss = kd_loss.sum()
                
            # Now combine the losses
            loss = alpha * kd_loss + (1 - alpha) * ce_loss
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            epoch_loss += loss.item()
            num_batches += 1
        
        epoch_loss /= num_batches
        print(f'*******Epoch {epoch}: loss - {epoch_loss}')
        student_epoch_losses.append(epoch_loss)


In [43]:
def pretrain_teacher(teacher, trainloader, criterion, optimizer, device, epochs_pretrain, patience=patience_student):
    teacher.to(device)
    teacher.train()  # Set the model to training mode
    best_val_loss = float('inf')  
    patience_counter = 0 
    teacher_epoch_losses = []
    val_losses = []
    
    for epoch in range(epochs_pretrain):
        running_loss = 0.0 
        epoch_loss = 0.0  
        num_batches = 0  
        
        for index, data in enumerate(tqdm(trainloader)):
            inputs = data['img'].to(device)
            labels = data['label'].to(device)
            optimizer.zero_grad()
            teacher_outputs = teacher(inputs)

            ce_loss = criterion(teacher_outputs, labels)
                
            loss = ce_loss
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            epoch_loss += loss.item()
            num_batches += 1
        
        epoch_loss /= num_batches
        print(f'*******Epoch {epoch}: loss - {epoch_loss}')
        teacher_epoch_losses.append(epoch_loss)


In [44]:
def pretrain_adversary(adv, student, adversary_optimizer, trainloader, adv_criterion, device, epochs_pretrain):

  for epoch in range(epochs_pretrain):
    epoch_loss = 0
    epoch_batches = 0
    for i, data in enumerate(tqdm(trainloader)): # starting from the 0th batch
        # get the inputs and labels
        adv.train()
        adv.to(device)
        inputs = data['img'].to(device)
        labels = data['label'].to(device)
        targets = data['target'].to(device)
        student = student.to(device)
        adversary_optimizer.zero_grad()
        student_output = student(inputs)
        one_hot_labels = F.one_hot(labels, num_classes=num_classes).to(torch.float32)
        concatenated_output = torch.cat((student_output, one_hot_labels), dim=1)
        adversary_output = adv(concatenated_output)
        adversary_loss = adv_criterion(adversary_output, targets) # compute loss
        adversary_loss.backward() # back prop
        adversary_optimizer.step()
        epoch_loss += adversary_loss.item()
        epoch_batches += 1

    print("Average Pretrain Adversary epoch loss: ", epoch_loss/epoch_batches)


In [45]:
# Optimizer and scheduler for the student model
student_optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)

# Optimizer and scheduler for the teacher model
teacher_optimizer = optim.SGD(teacher_model.parameters(), lr=learning_rate, momentum=momentum)
teacher_scheduler = torch.optim.lr_scheduler.StepLR(teacher_optimizer, step_size=step_size, gamma=gamma)

optimizer_adv = optim.Adam(adv.parameters(), lr=learning_rate)

# Instantiate the model and the loss function
criterion_clf = nn.CrossEntropyLoss()
adv_criterion = nn.MSELoss()


In [46]:
# #### finding the optimal learning rate
# def train_teacher_optimal_lr(model, trainloader, criterion, optimizer, scheduler, device, epochs_optimal_lr=5, lr_range=(1e-4, 1e-1), plot_loss=True):
#     model.train()
#     model.to(device)
#     lr_values = np.logspace(np.log10(lr_range[0]), np.log10(lr_range[1]), epochs_optimal_lr * len(trainloader))  # Generate learning rates for each batch
#     lr_iter = iter(lr_values)
#     losses = []
#     lrs = []
    
#     for epoch in range(epochs_optimal_lr):
#         for i, batch in enumerate(tqdm(trainloader)):
#             lr = next(lr_iter)
#             for param_group in optimizer.param_groups:
#                 param_group['lr'] = lr  # Set new learning rate
            
#             inputs, labels = batch['img'].to(device), batch['label'].to(device)
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
            
#             losses.append(loss.item())
#             lrs.append(lr)
    
#     # Calculate the derivative of the loss
#     loss_derivative = np.gradient(losses)
    
#     # Find the learning rate corresponding to the minimum derivative (steepest decline)
#     best_lr_index = np.argmin(loss_derivative)
#     best_lr = lrs[best_lr_index]
    
#     if plot_loss:
#         import matplotlib.pyplot as plt
#         plt.figure()
#         plt.plot(lrs, losses)
#         plt.xscale('log')
#         plt.xlabel('Learning Rate')
#         plt.ylabel('Loss')
#         plt.title('Learning Rate Range Test - Teacher')
#         plt.axvline(x=best_lr, color='red', linestyle='--', label=f'Best LR: {best_lr}')
#         plt.legend()
#         plt.show()
    
#     print(f'Best Learning Rate Teacher: {best_lr}')
#     return best_lr

# ############# input ############## 
# best_lr_teacher = train_teacher_optimal_lr(teacher_model, trainloader, criterion_clf, teacher_optimizer, teacher_scheduler, device, epochs_optimal_lr)  
# print(best_lr_teacher)


In [47]:
# #### finding the optimal learning rate
# def train_student_optimal_lr(model, trainloader, criterion, optimizer, device, epochs_optimal_lr=5, lr_range=(1e-4, 1e-1), plot_loss=True):
#     model.train()
#     model.to(device)
#     lr_values = np.logspace(np.log10(lr_range[0]), np.log10(lr_range[1]), epochs_optimal_lr * len(trainloader))  # Generate learning rates for each batch
#     lr_iter = iter(lr_values)
#     losses = []
#     lrs = []
    
#     for epoch in range(epochs_optimal_lr):
#         for i, batch in enumerate(tqdm(trainloader)):
#             lr = next(lr_iter)
#             for param_group in optimizer.param_groups:
#                 param_group['lr'] = lr  # Set new learning rate
            
#             inputs, labels = batch['img'].to(device), batch['label'].to(device)
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
            
#             losses.append(loss.item())
#             lrs.append(lr)
    
#     # Calculate the derivative of the loss
#     loss_derivative = np.gradient(losses)
    
#     # Find the learning rate corresponding to the minimum derivative (steepest decline)
#     best_lr_index = np.argmin(loss_derivative)
#     best_lr = lrs[best_lr_index]
    
#     if plot_loss:
#         import matplotlib.pyplot as plt
#         plt.figure()
#         plt.plot(lrs, losses)
#         plt.xscale('log')
#         plt.xlabel('Learning Rate')
#         plt.ylabel('Loss')
#         plt.title('Learning Rate Range Test - Student')
#         plt.axvline(x=best_lr, color='red', linestyle='--', label=f'Best LR: {best_lr}')
#         plt.legend()
#         plt.show()
    
#     print(f'Best Learning Rate Student: {best_lr}')
#     return best_lr

# ############# input ############## 
# best_lr_student = train_student_optimal_lr(student_model, trainloader, criterion_clf, student_optimizer, device, epochs_optimal_lr)  
# print(best_lr_student)

In [48]:
best_lr_student = 0.09999999999999999
best_lr_teacher = 9.999999999999999e-05

In [49]:
def plot_loss_curve(losses):
    epochs = range(1, len(losses) + 1)
    plt.plot(epochs, losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Val Loss Curve')
    plt.show()

In [50]:
# This is the adversary training function, where we input the student outputs, 
# with the true labels into the adversary model created previously.
def train_adversary(adv, model, optimizer, trainloader, criterion, epochs):
    best_val_loss = float('inf')
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_batches = 0
        for i, data in enumerate(tqdm(trainloader)):
            # get the inputs and labels
            inputs = data['img'].to(device)
            labels = data['label'].to(device)
            targets = data['target'].to(device)
            model.eval()
            model.to(device)
            adv.train()
            adv.to(device)
            optimizer.zero_grad()
            # output the student model, join with ohe labels. 
            model_output = model(inputs)
            one_hot_labels = F.one_hot(labels, num_classes=num_classes).to(torch.float32)
            concatenated_output = torch.cat((model_output, one_hot_labels), dim=1)
            adversary_output = adv(concatenated_output)

            adversary_loss = criterion(adversary_output, targets)
            adversary_loss.backward()
            epoch_loss += adversary_loss.item()
            epoch_batches += 1
            optimizer.step()
        epoch_loss/=epoch_batches
        print("Average Adversary epoch loss:", epoch_loss)

In [51]:
# Function to train the teacher model
def train_teacher(model_name, dataset, model, adv, trainloader, criterion, adv_criterion, optimizer, optimizer_adv, device, 
                  epochs, lmda, patience=patience_teacher):

    best_val_loss = float('inf')
    patience_counter = 0
    epoch_losses = [] 
    val_losses = []
    val_disparities = []
    val_accuracies = []
    best_total_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        adv.train()
        model.to(device)
        adv.to(device)
        running_loss = 0.0
        epoch_loss = 0.0  
        num_batches = 0  
        
        for index, data in enumerate(tqdm(trainloader)):
            inputs = data['img'].to(device)
            labels = data['label'].to(device)
            targets = data['target'].to(device)
        
            # Forward pass for teacher model
            optimizer.zero_grad()
            outputs = model(inputs)
            classification_loss = criterion(outputs, labels)
        
            # Forward pass for adversary model
            optimizer_adv.zero_grad()
            with torch.no_grad():
                teacher_outputs_detached = outputs.detach()
            one_hot_labels = F.one_hot(labels, num_classes=num_classes).to(torch.float32)
            concatenated_output = torch.cat((teacher_outputs_detached, one_hot_labels), dim=1)
            adversary_output = adv(concatenated_output)
            adversary_loss = adv_criterion(adversary_output, targets)
        
            # Calculate the total loss by combining classification and adversary loss
            if lmda != 0:
                total_loss = classification_loss + classification_loss/adversary_loss - lmda * adversary_loss
            else:
                total_loss = classification_loss
                
            total_loss.backward()
        
            optimizer.step()
            optimizer_adv.step()
        
            running_loss += total_loss.item()
            epoch_loss += total_loss.item()
            num_batches += 1

        epoch_loss /= num_batches  
        epoch_losses.append(epoch_loss)

        model.eval()
        total_correct = 0
        total_samples = 0
        total_val_loss = 0.0
        num_batches = 0
        confusion_male = np.zeros((num_classes, num_classes))
        confusion_female = np.zeros((num_classes, num_classes))
        
        with torch.no_grad():
            for val_data in tqdm(testloader):
                val_inputs = val_data['img'].to(device)
                val_labels = val_data['label'].to(device)
                val_targets = val_data['target'].to(device)
                
                val_outputs = model(val_inputs)
                with torch.no_grad():
                    teacher_outputs_detached_val = val_outputs.detach()                
                one_hot_labels_val = F.one_hot(val_labels, num_classes=num_classes).to(torch.float32)
                concatenated_output_val = torch.cat((teacher_outputs_detached_val, one_hot_labels_val), dim=1)
                adversary_output_val = adv(concatenated_output_val)
                adversary_loss_val = adv_criterion(adversary_output_val, val_targets)
                
                # Compute validation loss
                val_ce_loss = criterion(val_outputs, val_labels)

                if lmda !=0:
                    val_loss = val_ce_loss + val_ce_loss/adversary_loss_val - lmda * adversary_loss_val
                else:
                    val_loss = val_ce_loss
                    
                total_val_loss += val_loss.item()

                # Compute the validation accuracy
                _, predicted = torch.max(val_outputs, 1)
                total_samples += val_labels.size(0)
                total_correct += (predicted == val_labels).sum().item()
                num_batches += 1

                # Compute recall differences for gender
                recall_diff = evaluate_model_with_gender_multiclass(predicted, val_labels, val_targets, num_classes=num_classes)
                confusion_male += recall_diff[1]
                confusion_female += recall_diff[2]

            total_val_loss /= num_batches
            confusion_male /= num_batches
            confusion_female /= num_batches
            
            epoch_disparity = calculate_recall_multiclass(confusion_male) - calculate_recall_multiclass(confusion_female)
            val_losses.append(total_val_loss)
            non_zero_abs_values = np.abs(epoch_disparity[epoch_disparity != 0])
            mean_non_zero_abs_disparity = np.mean(non_zero_abs_values)
            val_disparities.append(mean_non_zero_abs_disparity)
            accuracy = total_correct / total_samples
            val_accuracies.append(accuracy)
            print(f'*****Epoch {epoch + 1}/{epochs}*****\n' 
            f'*****Train Loss: {epoch_loss: .6f} Val Loss: {total_val_loss: .6f}*****\n'
            f'*****Validation Accuracy: {accuracy * 100:.2f}%*****\n'
            f'*****Total Avg Disparity: {mean_non_zero_abs_disparity}*****\n')
            class_recall_mapping = {class_name: epoch_disparity[int(class_label)] for class_label, class_name in class_idx.items()}
            
            # Print disparities by class label
            for class_label, recall_diff in class_recall_mapping.items():
                print(f"Class {class_label}: Recall Difference = {recall_diff}")
        
        # Check for early stopping
        if abs(total_val_loss) < abs(best_total_val_loss):
            best_total_val_loss = total_val_loss
            patience_counter = 0 

            # checkpoint
            save_path = './weights/'

            # save locally
            model_save_path = os.path.join(save_path, model_name)
            
            os.makedirs(model_save_path, exist_ok=True)
        
            model_save_name = os.path.join(model_save_path, f'checkpoint{lmda}.pth')
            mode_weights_name = os.path.join(model_save_path, f'weights{lmda}.pth')
            
            best_epoch_mean_abs_disparity = mean_non_zero_abs_disparity

            torch.save(model.state_dict(), mode_weights_name)
            torch.save(model, model_save_name)

            # push to s3
            session = boto3.session.Session()
            s3 = session.client('s3')
            
            bucket_name = '210bucket' 
            
            # Teacher Model
            #### IMPORTANT!!!!! Change the file name so that you do not overwrite the existing files
            teacher_model_weights_path = f'weights/teacher_model_weights_{model_name}_{dataset}_{lmda}.pth'
            teacher_model_path = f'models/testing_teacher_model_{model_name}_{dataset}_{lmda}.pth'
            
            # Save state dict to buffer
            teacher_model_weights_buffer = io.BytesIO()
            torch.save(teacher_model.state_dict(), teacher_model_weights_buffer)
            teacher_model_weights_buffer.seek(0)
            
            # Save entire model to buffer
            teacher_model_buffer = io.BytesIO()
            torch.save(teacher_model, teacher_model_buffer)
            teacher_model_buffer.seek(0)
            
            # Upload to S3
            s3.put_object(Bucket=bucket_name, Key=teacher_model_weights_path, Body=teacher_model_weights_buffer)
            s3.put_object(Bucket=bucket_name, Key=teacher_model_path, Body=teacher_model_buffer)
            print('teacher weights and architecture saved and exported to S3')


        else:
            patience_counter += 1 

        if patience_counter >= patience:
            print('Early stopping')
            break  

        file_path = os.path.join(output_dir, f'teacher_validation_{lmda}.txt')
        
        # Append data to the text file
        with open(file_path, 'a') as file:
            file.write(f'********Epoch: {epochs}***********')
            
            file.write("Teacher Val Accuracies:\n")
            for accuracy in val_accuracies:
                file.write(f"{accuracy}\n")
        
            file.write("\nTeacher Val Disparities:\n")
            for disparity in val_disparities:
                file.write(f"{disparity}\n")

            for class_label, recall_diff in class_recall_mapping.items():
                file.write(f"Class {class_label}: Recall Difference = {recall_diff}\n")
        
        print(f"Data has been appended to {file_path}")
    
    plot_loss_curve(val_losses)
    print("Finished Training Teacher")
    return val_disparities


## Extract Teacher Class Mean Embeddings

In [None]:
def get_emb_fea(model, dataloader, batch_size):
    ''' Used to extract the feature embeddings in a teacher model '''
    
    def get_features(name):
        def hook(model, input, output):
            features[name] = output.detach()
        return hook

    
    model.eval()
    model.avgpool.register_forward_hook(get_features('feats'))

    EMB = {}


    with torch.no_grad():
        for index, data in enumerate(tqdm(trainloader)):
            FEATS = []
            features = {}
            
            inputs = data['img'].to(device)
            labels = data['label'].to(device)

            curr_batch_size = len(inputs)

            # compute output
            # emb_fea, logits = model(images, embed=True)
            outputs = model(inputs)
            # feats = features['feats'].cpu().numpy()
            # emb_fea = feats.flatten()
            FEATS.append(features['feats'].cpu().numpy())
            emb_fea = np.concatenate(FEATS)
            # reshape embedding features to flatten 
            emb_fea = emb_fea.reshape((curr_batch_size, emb_fea.shape[1]))


            for emb, i in zip(emb_fea, labels):
                i = i.item()
                emb_size = len(emb) 
                if str(i) in EMB:
                    for j in range(emb_size):
                        EMB[str(i)][j].append(round(emb[j].item(), 4))
                else:
                    EMB[str(i)] = [[] for _ in range(emb_size)]
                    for j in range(emb_size):
                        EMB[str(i)][j].append(round(emb[j].item(), 4))

    for key, value in EMB.items():
        for i in range(emb_size):
            EMB[key][i] = round(np.array(EMB[key][i]).mean(), 4)

    return EMB


def retrieve_teacher_class_weights(model_name, model, model_weight_path, num_class, data_name, dataloader, batch_size, bucket_name, lmda):
    ''' Use the extracted feature embeddings to create a json of class means for teacher'''

    session = boto3.session.Session()
    s3 = session.client('s3')

    teacher_model_weights_buffer = io.BytesIO()
    s3.download_fileobj(bucket_name, model_weight_path, teacher_model_weights_buffer)
    teacher_model_weights_buffer.seek(0)  

    # Load the model
    # model = models_package.__dict__[model_name](num_class=num_class)
    checkpoint = torch.load(teacher_model_weights_buffer)
    # print("Keys in checkpoint:", checkpoint.keys())
    print("model is loaded properly")

    new_state_dict = OrderedDict()
    for k, v in checkpoint.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    for param in model.parameters():
        param.requires_grad = False
    
    model = model.cuda()

    # emb = get_emb_fea(model=model, dataloader=dataloader, batch_size=batch_size)
    # emb_json = json.dumps(emb, indent=4)
    # with open("./class_means/{}_embedding_fea/{}.json".format(data_name, model_name), 'w', encoding='utf-8') as f:
    #     f.write(emb_json)

    emb = get_emb_fea(model=model, dataloader=dataloader, batch_size=batch_size)
    emb_json = json.dumps(emb, indent=4)

    # Create the directory if it doesn't exist
    output_dir = "./class_means/{}_embedding_fea".format(data_name)
    os.makedirs(output_dir, exist_ok=True)

    with open("{}/{}_lmda{}.json".format(output_dir, model_name, lmda), 'w', encoding='utf-8') as f:
        f.write(emb_json)

## KD training

In [None]:
# # Function to train the student model with knowledge distillation
# def train_student_with_distillation_disparity(student, teacher, adv, trainloader, testloader, criterion, adv_criterion, optimizer, 
#                                               device, alpha, temperature, epochs, lmda, patience=patience_student, optimizer_adv=None):
#     teacher.eval()
#     teacher.to(device)
#     best_val_accuracy = 0
#     best_total_val_loss = float('inf')
#     best_epoch_accuracy = 0.0
#     best_epoch_disparity = 0.0
#     patience_counter = 0 
#     student_epoch_losses = []
#     val_losses = []
#     val_disparities = []
#     val_accuracies = []

#     for epoch in range(epochs):
#         # Train the adversary at the start of each epoch
#         train_adversary(adv, student, optimizer_adv, trainloader, adv_criterion, 1)

#         student.train()
#         student.to(device)
#         adv.eval()
#         adv.to(device)
#         running_loss = 0.0 
#         epoch_loss = 0.0  
#         num_batches = 0 
#         confusion_male = np.zeros((num_classes, num_classes))
#         confusion_female = np.zeros((num_classes, num_classes))

#         for index, data in enumerate(tqdm(trainloader)):
#             inputs = data['img'].to(device)
#             labels = data['label'].to(device)
#             targets = data['target'].to(device)
#             optimizer.zero_grad()
#             student_outputs = student(inputs)
#             with torch.no_grad():
#                 teacher_outputs = teacher(inputs)

#             # detach student_outputs to avoid exploding gradients by passing same inputs (with gradience) into two different models. 
#             studentached = student_outputs.detach()
#             # One-hot encode labels and concatenate with student's predictions
#             one_hot_labels = F.one_hot(labels, num_classes=num_classes).to(torch.float32)
#             concatenated_output = torch.cat((studentached, one_hot_labels), dim=1)

#             # Run the adversarial model on concatenated true labels, and predicted labels
#             with torch.no_grad():
#                 adversary_output = adv(concatenated_output)

#             # Calc adversary loss, which is an MSE loss, because this is a regression output. 
#             adversary_loss = adv_criterion(adversary_output, targets)
#             ce_loss = criterion(student_outputs, labels)
#             kd_loss = tkd_kdloss(student_outputs, teacher_outputs, temperature=temperature)  # Make sure this returns a scalar
            
#             if kd_loss.ndim != 0:
#                 kd_loss = kd_loss.sum()

#             # Now combine the losses, subtract weighted adversary loss because we need to maximize that loss 
#             # goal of the model is to have the adversary not predict gender. 
#             if lmda != 0:
#                 loss = cls_loss + div_loss + norm_dir_loss + (cls_loss + div_loss + norm_dir_loss)/adversary_loss - lmda * adversary_loss
#                 # loss = (alpha * kd_loss + (1 - alpha) * ce_loss) + (alpha * kd_loss + (1 - alpha) * ce_loss)/adversary_loss - lmda * adversary_loss
#             else:
#                 # loss = alpha * kd_loss + (1 - alpha) * ce_loss
#                 loss = cls_loss + div_loss + norm_dir_loss
                
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
#             epoch_loss += loss.item()
#             num_batches += 1

#         epoch_loss /= num_batches
#         # print(f'*******Epoch {epoch}: running_recall_with - {running_recall_with/num_batches}  |  running_recall_without - {running_recall_without/num_batches}  |  disparity - {epoch_disparity/num_batches}******')
#         student_epoch_losses.append(epoch_loss)

#         student.eval()
#         total_correct = 0
#         total_samples = 0
#         total_val_loss = 0.0
#         num_batches = 0
#         # Validation after each epoch
#         with torch.no_grad():
#             for val_data in tqdm(testloader):
#                 val_inputs = val_data['img'].to(device)
#                 val_labels = val_data['label'].to(device)
#                 val_targets = val_data['target'].to(device)
                
#                 # Forward pass for validation
#                 val_student_outputs = student(val_inputs)
#                 val_teacher_outputs = teacher(val_inputs)

#                 val_studentached = val_student_outputs.detach()   
#                 val_one_hot_labels = F.one_hot(val_labels, num_classes=num_classes).to(torch.float32)
#                 val_concatenated_output = torch.cat((val_studentached, val_one_hot_labels), dim=1)
                
#                 val_adversary_output = adv(val_concatenated_output)
#                 val_adversary_loss = adv_criterion(val_adversary_output, val_targets)
#                 val_ce_loss = criterion(val_student_outputs, val_labels)
#                 val_kd_loss = tkd_kdloss(val_student_outputs, val_teacher_outputs, temperature=temperature)  # Make sure this returns a scalar
                
#                 if val_kd_loss.ndim != 0:
#                     val_kd_loss = val_kd_loss.sum()
#                 if lmda != 0:
#                     val_loss = (alpha * val_kd_loss + (1 - alpha) * val_ce_loss) + (alpha * val_kd_loss + (1 - alpha) * val_ce_loss)/val_adversary_loss - lmda * val_adversary_loss
#                 else:
#                     val_loss = alpha * val_kd_loss + (1 - alpha) * val_ce_loss


#                 # if lmda != 0:
#                 #     loss = cls_loss + div_loss + norm_dir_loss + (cls_loss + div_loss + norm_dir_loss)/adversary_loss - lmda * adversary_loss
#                 #     # loss = (alpha * kd_loss + (1 - alpha) * ce_loss) + (alpha * kd_loss + (1 - alpha) * ce_loss)/adversary_loss - lmda * adversary_loss
#                 # else:
#                 #     # loss = alpha * kd_loss + (1 - alpha) * ce_loss
#                 #     loss = cls_loss + div_loss + norm_dir_loss
                

                
#                 total_val_loss += val_loss.item()
    
#                 # Compute the validation accuracy
#                 _, predicted = torch.max(val_student_outputs, 1)
#                 total_samples += val_labels.size(0)
#                 total_correct += (predicted == val_labels).sum().item()
#                 num_batches += 1
#                 recall_diff = evaluate_model_with_gender_multiclass(predicted, val_labels, val_targets, num_classes=num_classes)
#                 confusion_male += recall_diff[1]
#                 confusion_female += recall_diff[2]
    
#             total_val_loss /= num_batches
#             confusion_male /= num_batches
#             confusion_female /= num_batches

#             epoch_disparity = calculate_recall_multiclass(confusion_male) - calculate_recall_multiclass(confusion_female)
#             val_losses.append(total_val_loss)
#             non_zero_abs_values = np.abs(epoch_disparity[epoch_disparity != 0])
#             mean_non_zero_abs_disparity = np.mean(non_zero_abs_values)
#             val_disparities.append(mean_non_zero_abs_disparity)
#             accuracy = total_correct / total_samples
#             val_accuracies.append(accuracy)
#             print(f'*****Epoch {epoch + 1}/{epochs}*****\n' 
#             f'*****Train Loss: {epoch_loss: .6f} Val Loss: {total_val_loss: .6f}*****\n'
#             f'*****Validation Accuracy: {accuracy * 100:.2f}%*****\n'
#             f'*****Total Avg Disparity: {mean_non_zero_abs_disparity}*****\n')
#             class_recall_mapping = {class_name: epoch_disparity[int(class_label)] for class_label, class_name in class_idx.items()}
            
#             # Print disparities by class label
#             for class_label, recall_diff in class_recall_mapping.items():
#                 print(f"Class {class_label}: Recall Difference = {recall_diff}")

#         # Check for early stopping
#         if abs(total_val_loss) < abs(best_total_val_loss):
#             best_total_val_loss = total_val_loss
#             patience_counter = 0
#             best_epoch_mean_abs_disparity = mean_non_zero_abs_disparity
#             torch.save(student.state_dict(), f'student_model_weights_ckd_wider_checkpoint_lambda{lmda}.pth')
#             torch.save(student, f'student_model_ckd_wider_checkpoint_lambda{lmda}.pth')
#         else:
#             patience_counter += 1 

#         if patience_counter >= patience:
#             print('Early stopping')
#             break  
    
#         file_path = os.path.join(output_dir, f'student_validation_{lmda}.txt')
        
#         # Append data to the text file
#         with open(file_path, 'a') as file:
#             file.write(f'********Epoch: {epochs}***********')
            
#             file.write("Student Val Accuracies:\n")
#             for accuracy in val_accuracies:
#                 file.write(f"{accuracy}\n")
        
#             file.write("\nStudent Val Disparities:\n")
#             for disparity in val_disparities:
#                 file.write(f"{disparity}\n")

#             for class_label, recall_diff in class_recall_mapping.items():
#                 file.write(f"Class {class_label}: Recall Difference = {recall_diff}\n")
        
        
#         print(f"Data has been appended to {file_path}")
#     plot_loss_curve(val_losses)
                
#     return best_epoch_mean_abs_disparity

In [None]:
##### SET VARIABLES ####
data_name = 'WIDER'
model_name = 'efficientnetb3'

In [52]:
# Initialize the dictionary for results
lambda_results = {}

# Loop for training the teacher model with different lambda values
for i in lmda_list_teacher:
    # Reset the teacher model for each lambda
    teacher_model = torchvision.models.efficientnet_b3(weights='DEFAULT')    
    # Replace the last fully connected layer with a new one
    teacher_model.classifier = nn.Linear(1536, num_classes)
    teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=best_lr_teacher)
    
    # Initialize the adversary for the teacher
    adv = Adversary()
    teacher_optimizer_adv = optim.Adam(adv.parameters(), lr=best_lr_teacher)

    # pretrain_teacher(teacher_model, trainloader, criterion_clf, teacher_optimizer, device, epochs_pretrain)
    # pretrain_adversary(adv, student_model, optimizer_adv, trainloader, adv_criterion, device, epochs_pretrain)
    # lmda = i

    # Train the teacher model with adversarial training
    teacher_mean_abs_val_disparity = train_teacher(model_name, data_name, teacher_model, adv, trainloader, criterion_clf, adv_criterion, teacher_optimizer, teacher_optimizer_adv, device, epochs, i, patience=patience_teacher)

    # extract class mean embeddings for teacher
    retrieve_teacher_class_weights(model_name, teacher_model, teacher_model_weights_path, num_classes, data_name, testloader, batch_size, bucket_name, i)

    # Save the teacher model and its state
    # torch.save(teacher_model.state_dict(), f'teacher_model_weights_ckd_wider_lambda{i}.pth')
    # torch.save(teacher_model, f'teacher_model_ckd_wider_lambda{i}.pth')
    print('Teacher weights and architecture saved and exported for lambda:', i)

    # Store the teacher results in the dictionary
    lambda_results[i] = {
        'teacher_mean_abs_val_disparity': teacher_mean_abs_val_disparity
    }


  0%|                                                                                                | 0/323 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 0 has a total capacty of 14.75 GiB of which 59.06 MiB is free. Process 2487 has 12.90 GiB memory in use. Including non-PyTorch memory, this process has 1.79 GiB memory in use. Of the allocated memory 1.64 GiB is allocated by PyTorch, and 31.58 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
## Training script

def train(student, teacher, T_EMB, train_dataloader, optimizer, criterion, kd_loss, nd_loss, epoch, batch_size, alpha, temperature, adv, adv_criterion, optimizer_adv=None):

    def get_features(name):
        def hook(model, input, output):
            features[name] = output.detach()
        return hook
    
    train_loss = AverageMeter()
    train_error = AverageMeter()

    Cls_loss = AverageMeter()
    Div_loss = AverageMeter()
    Norm_Dir_loss = AverageMeter()

    best_val_accuracy = 0
    best_total_val_loss = float('inf')
    best_epoch_accuracy = 0.0
    best_epoch_disparity = 0.0
    patience_counter = 0 
    student_epoch_losses = []
    val_losses = []
    val_disparities = []
    val_accuracies = []

    # First train adversary in this epoch
    train_adversary(adv, student, optimizer_adv, trainloader, adv_criterion, 1)

    # test T_EMB
    T_EMB = T_EMB

    # Model on train mode
    student.train()
    teacher.eval()
    running_loss = 0.0 
    epoch_loss = 0.0  
    num_batches = 0 
    confusion_male = np.zeros((num_classes, num_classes))
    confusion_female = np.zeros((num_classes, num_classes))
 
    step_per_epoch = len(train_dataloader)

    for step, data in enumerate(tqdm(train_dataloader)):
        
        start = time.time()
        s_FEATS = []
        features = {}

        inputs = data['img'].to(device)
        labels = data['label'].to(device)
        targets = data['target'].to(device)

        curr_batch_size = len(inputs)

        # register hook for feature embeddings
        student.avgpool.register_forward_hook(get_features('feats'))
        
        # compute output
        optimizer.zero_grad()
        s_logits = student(inputs)

        s_FEATS.append(features['feats'].cpu().numpy())
        s_emb = np.concatenate(s_FEATS)
        # print(f'before reshaping s_emb: {s_emb.shape}')
        # reshape embedding features to flatten 
        s_emb = s_emb.reshape((curr_batch_size, s_emb.shape[1]))
        s_emb = torch.from_numpy(s_emb)
        s_emb = s_emb.to(device)

        # fix embedding output on student model
        s_emb_size = 1280
        t_emb_size = 1536
        
        emb_inflate = nn.Sequential(
            nn.BatchNorm1d(s_emb_size),
            nn.Dropout(0.5),
            nn.Linear(s_emb_size, t_emb_size)
            )
        # # clean_model
        for m in student.modules():
            if isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    
        ##
        emb_inflate.to(device)

        s_emb = emb_inflate(s_emb)
        # student.fc1.to(device)
        
        # s_emb = student.fc1(s_emb)

        with torch.no_grad():
                        
            ####
            
            t_FEATS = []
            features = {}
    
            # compute output
            # emb_fea, logits = student(images, embed=True)
            teacher.avgpool.register_forward_hook(get_features('feats'))
            
            t_logits = teacher(inputs)
    
            t_FEATS.append(features['feats'].cpu().numpy())
            t_emb = np.concatenate(t_FEATS)
            # reshape embedding features to flatten 
            t_emb = t_emb.reshape((curr_batch_size, t_emb.shape[1]))


        ## save s_emb and t_emb as torch tensors 
        # s_emb = torch.from_numpy(s_emb)
        t_emb = torch.from_numpy(t_emb)

        # s_emb = s_emb.to(device)
        t_emb = t_emb.to(device)


        # print(s_emb.size() == s_emb.size())
        # print(s_emb.size())
        # print(s_emb.size())
        
        ###

        # cls loss
        cls_loss = criterion(s_logits, labels) * cls_loss_factor
        # KD loss
        div_loss = kd_loss(s_out = s_logits, t_out = t_logits) * min(1.0, epoch/warm_up)
        # ND loss
        norm_dir_loss = nd_loss(s_emb=s_emb, t_emb=t_emb, T_EMB=T_EMB, labels=labels)

        loss = cls_loss + div_loss + norm_dir_loss
        # measure accuracy and record loss
        batch_size = inputs.size(0)
        _, pred = s_logits.data.cpu().topk(1, dim=1)
        train_error.update(torch.ne(pred.squeeze(), labels.cpu()).float().sum().item() / batch_size, batch_size)
        train_loss.update(loss.item(), batch_size)

        Cls_loss.update(cls_loss.item(), batch_size)
        Div_loss.update(div_loss.item(), batch_size)
        Norm_Dir_loss.update(norm_dir_loss.item(), batch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        t = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        s1 = '\r{} [{}/{}]'.format(t, step+1, step_per_epoch)
        s2 = ' - {:.2f}ms/step - nd_loss: {:.3f} - kd_loss: {:.3f} - cls_loss: {:.3f} - train_loss: {:.3f} - train_acc: {:.3f}'.format(
             1000 * (time.time() - start), norm_dir_loss.item(), div_loss.item(), cls_loss.item(), train_loss.val, 1-train_error.val)

        print(s1+s2, end='', flush=True)

    print()
    return Norm_Dir_loss.avg, Div_loss.avg, Cls_loss.avg, train_loss.avg, train_error.avg


def test(student, test_dataloader, criterion):
    test_loss = AverageMeter()
    test_error = AverageMeter()

    # Model on eval mode
    student.eval()


    with torch.no_grad():
        for step, data in enumerate(tqdm(test_dataloader)):

            inputs = data['img'].to(device)
            labels = data['label'].to(device)
            # compute logits
            logits = student(inputs)

            loss = criterion(logits, labels)

            # measure accuracy and record loss
            batch_size = inputs.size(0)
            _, pred = logits.data.cpu().topk(1, dim=1)
            test_error.update(torch.ne(pred.squeeze(), labels.cpu()).float().sum().item() / batch_size, batch_size)
            test_loss.update(loss.item(), batch_size)

    return test_loss.avg, test_error.avg


def epoch_loop(student, teacher, train_loader, test_loader, num_class, T_EMB, save_dir, batch_size, logger):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # student = nn.DataParallel(student, device_ids=args.gpus)
    # student = nn.DataParallel(student)
    student = student
    student.to(device)
    # teacher = nn.DataParallel(teacher, device_ids=args.gpus)
    # teacher = nn.DataParallel(teacher)
    teacher = teacher
    teacher.to(device)

    # student.avgpool.register_forward_hook(get_features('s_feats'))
    # teacher.avgpool.register_forward_hook(get_features('t_feats'))

    # loss
    criterion = nn.CrossEntropyLoss().to(device)
    kd_loss = KDLoss(kl_loss_factor=kd_loss_factor, T=t).to(device)
    nd_loss = DirectNormLoss(num_class=num_class, nd_loss_factor=nd_loss_factor).to(device)
    # optimizer
    optimizer = torch.optim.SGD(params=student.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)

    # weights
    save_dir = Path(save_dir)
    weights = save_dir / 'weights'
    weights.mkdir(parents=True, exist_ok=True)
    last = weights / 'last'
    best = weights / 'best'

    # acc,loss
    acc_loss = save_dir / 'acc_loss'
    acc_loss.mkdir(parents=True, exist_ok=True)

    train_acc_savepath = acc_loss / 'train_acc.npy'
    train_loss_savepath = acc_loss / 'train_loss.npy'
    val_acc_savepath = acc_loss / 'val_acc.npy'
    val_loss_savepath = acc_loss / 'val_loss.npy'

    # tensorboard
    logdir = save_dir / 'logs'
    logdir.mkdir(parents=True, exist_ok=True)
    summary_writer = SummaryWriter(logdir, flush_secs=120)


    start_epoch = 0
    best_error = 0
    train_acc = []
    train_loss = []
    test_acc = []
    test_loss = []

    logger = logger

    # Train model
    best_error = 1
    for epoch in range(start_epoch, epochs):
        if epoch in [150, 180, 210]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        print("Epoch {}/{}".format(epoch + 1, epochs))
        norm_dir_loss, div_loss, cls_loss, train_epoch_loss, train_error = train(model=model,
                                                                                 teacher=teacher,
                                                                                 T_EMB=T_EMB,
                                                                                 train_dataloader=train_loader,
                                                                                 optimizer=optimizer,
                                                                                 criterion=criterion,
                                                                                 kd_loss=kd_loss,
                                                                                 nd_loss=nd_loss,
                                                                                 epoch=epoch,
                                                                                 batch_size = batch_size)
        test_epoch_loss, test_error = test(student=student,
                                           test_dataloader=test_loader,
                                           criterion=criterion)

        s = "Train Loss: {:.3f}, Train Acc: {:.3f}, Test Loss: {:.3f}, Test Acc: {:.3f}, lr: {:.5f}".format(
            train_epoch_loss, 1-train_error, test_epoch_loss, 1-test_error, optimizer.param_groups[0]['lr'])
        logger.info(colorstr('green', s))

        # save acc,loss
        train_loss.append(train_epoch_loss)
        train_acc.append(1-train_error)
        test_loss.append(test_epoch_loss)
        test_acc.append(1-test_error)

        # save model
        is_best = test_error < best_error
        best_error = min(best_error, test_error)
        state = {
                'epoch': epoch + 1,
                'model_state_dict': student.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_error': best_error,
                'train_acc': train_acc,
                'train_loss': train_loss,
                'test_acc': test_acc,
                'test_loss': test_loss,
            }

        last_path = last / 'epoch_{}_loss_{:.3f}_acc_{:.3f}'.format(
            epoch + 1, test_epoch_loss, 1-test_error)
        best_path = best / 'epoch_{}_acc_{:.3f}'.format(
                epoch + 1, 1-best_error)

        Save_Checkpoint(state, last, last_path, best, best_path, is_best)

        # tensorboard
        if epoch == 1:
            # images, labels = next(iter(train_loader))
            data = next(iter(train_loader))
            images = data['img'].to(device)
            labels = data['label'].to(device)

            img_grid = torchvision.utils.make_grid(images)
            summary_writer.add_image('Image', img_grid)
        summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        summary_writer.add_scalar('train_loss', train_epoch_loss, epoch)
        summary_writer.add_scalar('train_error', train_error, epoch)
        summary_writer.add_scalar('val_loss', test_epoch_loss, epoch)
        summary_writer.add_scalar('val_error', test_error, epoch)

        summary_writer.add_scalar('nd_loss', norm_dir_loss, epoch)
        summary_writer.add_scalar('kd_loss', div_loss, epoch)
        summary_writer.add_scalar('cls_loss', cls_loss, epoch)

    summary_writer.close()
    import os
    if not os.path.exists(train_acc_savepath) or not os.path.exists(train_loss_savepath):
        np.save(train_acc_savepath, train_acc)
        np.save(train_loss_savepath, train_loss)
        np.save(val_acc_savepath, test_acc)
        np.save(val_loss_savepath, test_loss)

In [None]:
# Loop for training the student model with different lambda values
for i in lmda_list_student:
    # load teacher model with lambda 0
    teacher_model = torch.load('teacher_model_ckd_wider_lambda0.pth')
    teacher_model.load_state_dict(torch.load('teacher_model_weights_ckd_wider_lambda0.pth'))
    teacher_model = teacher_model.to(device)
    
    # # Reset the student model for each lambda
    student_model = torchvision.models.efficientnet_b0(weights='DEFAULT')
    student_model.classifier = nn.Linear(1280, num_classes)
    student_optimizer = optim.Adam(student_model.parameters(), lr=best_lr_student)
    student_scheduler = torch.optim.lr_scheduler.StepLR(student_optimizer, step_size=step_size, gamma=gamma)
    
    adv = Adversary()
    adv.to(device)
    student_optimizer_adv = optim.Adam(adv.parameters(), lr=best_lr_student)

    pretrain_student(student_model, teacher_model, trainloader, criterion_clf, student_optimizer, device, alpha, temperature, epochs_pretrain)
    pretrain_adversary(adv, teacher_model, optimizer_adv, trainloader, adv_criterion, device, epochs_pretrain)
    
    student_mean_abs_val_disparity = train_student_with_distillation_disparity(student_model, teacher_model, adv, trainloader, testloader, criterion_clf, adv_criterion, student_optimizer, device, alpha, temperature, epochs, lmda=0, patience=patience_student, optimizer_adv=student_optimizer_adv)

    torch.save(student_model.state_dict(), f'student_model_weights_ckd_wider_lambda{i}.pth')
    torch.save(student_model, f'student_model_ckd_wider_lambda{i}.pth')
    print('Student weights and architecture saved and exported for lambda:', i)

    # Check if the key exists in the dictionary
    if i not in lambda_results:
        # If not, create a new entry for that key
        lambda_results[i] = {
            'student_mean_abs_val_disparity': student_mean_abs_val_disparity
        }
    else:
        # If the key exists, update the existing entry
        lambda_results[i].update({
            'student_mean_abs_val_disparity': student_mean_abs_val_disparity
        })


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]


*******Epoch 0: loss - 1.280115338586132


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]


*******Epoch 1: loss - 1.1886339979882565


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]


*******Epoch 2: loss - 1.1371552414775634


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:42<00:00,  1.57it/s]


Average Pretrain Adversary epoch loss:  0.37134633643656784


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:42<00:00,  1.57it/s]


Average Pretrain Adversary epoch loss:  0.3713579224133343


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:42<00:00,  1.57it/s]


Average Pretrain Adversary epoch loss:  0.3713367807198755


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:59<00:00,  2.70it/s]


Average Adversary epoch loss: 0.2714322510158053


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.72it/s]


*****Epoch 1/300*****
*****Train Loss:  1.086075 Val Loss:  0.719439*****
*****Validation Accuracy: 35.27%*****
*****Total Avg Disparity: 0.11554208707368566*****

Class Team_Sports: Recall Difference = 0.17157704209727342
Class Celebration: Recall Difference = -0.08604651162790697
Class Parade: Recall Difference = -0.08936412888508694
Class Waiter_Or_Waitress: Recall Difference = -0.01710587147480358
Class Individual_Sports: Recall Difference = -0.04066333484779644
Class Surgeons: Recall Difference = 0.07982261640798229
Class Spa: Recall Difference = -0.38
Class Law_Enforcement: Recall Difference = 0.18343343343343355
Class Business: Recall Difference = -0.054518297236743785
Class Dresses: Recall Difference = -0.19241192411924118
Class Water Activities: Recall Difference = 0.01859459459459467
Class Picnic: Recall Difference = 0.3818181818181818
Class Rescue: Recall Difference = -0.007341772151898733
Class Cheering: Recall Difference = 0.0
Class Performance_And_Entertainment: Recall Di

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:57<00:00,  2.79it/s]


Average Adversary epoch loss: 0.21479314639701608


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:20<00:00,  2.68it/s]


*****Epoch 2/300*****
*****Train Loss:  1.043859 Val Loss:  0.759058*****
*****Validation Accuracy: 33.39%*****
*****Total Avg Disparity: 0.09401443878396845*****

Class Team_Sports: Recall Difference = 0.11102936883283704
Class Celebration: Recall Difference = 0.018604651162790725
Class Parade: Recall Difference = -0.09404049044767615
Class Waiter_Or_Waitress: Recall Difference = 0.09061488673139154
Class Individual_Sports: Recall Difference = -0.024307133121308555
Class Surgeons: Recall Difference = 0.0672579453067258
Class Spa: Recall Difference = -0.1
Class Law_Enforcement: Recall Difference = 0.16216216216216228
Class Business: Recall Difference = -0.06774778619438816
Class Dresses: Recall Difference = 0.04471544715447151
Class Water Activities: Recall Difference = 0.10875675675675672
Class Picnic: Recall Difference = 0.309090909090909
Class Rescue: Recall Difference = 0.0
Class Cheering: Recall Difference = 0.0
Class Performance_And_Entertainment: Recall Difference = 0.0478746060

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.74it/s]


Average Adversary epoch loss: 0.2076936146209699


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.77it/s]


*****Epoch 3/300*****
*****Train Loss:  0.997721 Val Loss:  0.726690*****
*****Validation Accuracy: 34.66%*****
*****Total Avg Disparity: 0.09414410522145725*****

Class Team_Sports: Recall Difference = 0.10311204530857698
Class Celebration: Recall Difference = -0.07674418604651165
Class Parade: Recall Difference = 0.008611348731109247
Class Waiter_Or_Waitress: Recall Difference = -0.06749884419787328
Class Individual_Sports: Recall Difference = 0.03517340602756325
Class Surgeons: Recall Difference = -0.016999260901699953
Class Spa: Recall Difference = 0.0
Class Law_Enforcement: Recall Difference = 0.164914914914915
Class Business: Recall Difference = 0.00021337885415553814
Class Dresses: Recall Difference = -0.3997289972899729
Class Water Activities: Recall Difference = 0.05967567567567561
Class Picnic: Recall Difference = 0.12727272727272732
Class Rescue: Recall Difference = 0.0
Class Cheering: Recall Difference = 0.009433962264150943
Class Performance_And_Entertainment: Recall Diffe

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:57<00:00,  2.78it/s]


Average Adversary epoch loss: 0.19918108282622343


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.94it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.78it/s]


*****Epoch 4/300*****
*****Train Loss:  0.962554 Val Loss:  0.707927*****
*****Validation Accuracy: 37.98%*****
*****Total Avg Disparity: 0.09162178249284278*****

Class Team_Sports: Recall Difference = 0.04895194721784302
Class Celebration: Recall Difference = -0.10930232558139533
Class Parade: Recall Difference = -0.0017108639863130382
Class Waiter_Or_Waitress: Recall Difference = -0.007397133610725859
Class Individual_Sports: Recall Difference = -0.14553990610328632
Class Surgeons: Recall Difference = 0.03843311160384333
Class Spa: Recall Difference = -0.30000000000000004
Class Law_Enforcement: Recall Difference = 0.1011011011011011
Class Business: Recall Difference = -0.14083004374266508
Class Dresses: Recall Difference = -0.2791327913279133
Class Water Activities: Recall Difference = 0.07048648648648648
Class Picnic: Recall Difference = 0.018181818181818188
Class Rescue: Recall Difference = -0.014683544303797473
Class Cheering: Recall Difference = -0.010869565217391306
Class Perfo

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.76it/s]


Average Adversary epoch loss: 0.18353772265200288


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.72it/s]


*****Epoch 5/300*****
*****Train Loss:  0.917476 Val Loss:  0.774077*****
*****Validation Accuracy: 34.75%*****
*****Total Avg Disparity: 0.12822837710406237*****

Class Team_Sports: Recall Difference = 0.13763064167688438
Class Celebration: Recall Difference = 0.01395348837209301
Class Parade: Recall Difference = -0.060507556315939515
Class Waiter_Or_Waitress: Recall Difference = -0.09107720758206198
Class Individual_Sports: Recall Difference = -0.08121308496138124
Class Surgeons: Recall Difference = 0.04434589800443467
Class Spa: Recall Difference = -0.52
Class Law_Enforcement: Recall Difference = 0.1941941941941942
Class Business: Recall Difference = -0.14595113624239836
Class Dresses: Recall Difference = -0.3157181571815718
Class Water Activities: Recall Difference = 0.11156756756756758
Class Picnic: Recall Difference = 0.16363636363636364
Class Rescue: Recall Difference = 0.043291139240506316
Class Cheering: Recall Difference = -0.08470057424118131
Class Performance_And_Entertainm

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.73it/s]


Average Adversary epoch loss: 0.1852096799090042


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.75it/s]


*****Epoch 6/300*****
*****Train Loss:  0.871100 Val Loss:  0.740055*****
*****Validation Accuracy: 38.24%*****
*****Total Avg Disparity: 0.10772539327090783*****

Class Team_Sports: Recall Difference = 0.15171366847667433
Class Celebration: Recall Difference = -0.1953488372093024
Class Parade: Recall Difference = -0.07579127459366963
Class Waiter_Or_Waitress: Recall Difference = 0.00046232085067038353
Class Individual_Sports: Recall Difference = -0.003180372557928246
Class Surgeons: Recall Difference = 0.008130081300812941
Class Spa: Recall Difference = -0.36000000000000004
Class Law_Enforcement: Recall Difference = 0.16891891891891897
Class Business: Recall Difference = -0.130374479889043
Class Dresses: Recall Difference = -0.2777777777777778
Class Water Activities: Recall Difference = 0.020216216216216165
Class Picnic: Recall Difference = -0.18181818181818177
Class Rescue: Recall Difference = -0.00936708860759495
Class Cheering: Recall Difference = -0.001435602953240361
Class Perfor

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.75it/s]


Average Adversary epoch loss: 0.1902398879391066


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.77it/s]


*****Epoch 7/300*****
*****Train Loss:  0.828988 Val Loss:  0.836846*****
*****Validation Accuracy: 35.06%*****
*****Total Avg Disparity: 0.10308759921667121*****

Class Team_Sports: Recall Difference = 0.03881590471185847
Class Celebration: Recall Difference = -0.04186046511627911
Class Parade: Recall Difference = -0.11531223267750212
Class Waiter_Or_Waitress: Recall Difference = -0.05501618122977342
Class Individual_Sports: Recall Difference = -0.002233833106163874
Class Surgeons: Recall Difference = 0.06873614190687355
Class Spa: Recall Difference = -0.54
Class Law_Enforcement: Recall Difference = 0.141891891891892
Class Business: Recall Difference = -0.07798997119385467
Class Dresses: Recall Difference = -0.24119241192411922
Class Water Activities: Recall Difference = 0.048648648648648596
Class Picnic: Recall Difference = 0.10909090909090911
Class Rescue: Recall Difference = 0.012658227848101264
Class Cheering: Recall Difference = 0.017432321575061527
Class Performance_And_Entertai

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.76it/s]


Average Adversary epoch loss: 0.1834763211678274


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:20<00:00,  2.70it/s]


*****Epoch 8/300*****
*****Train Loss:  0.783150 Val Loss:  0.709317*****
*****Validation Accuracy: 40.09%*****
*****Total Avg Disparity: 0.14451439198842225*****

Class Team_Sports: Recall Difference = 0.1334734629532316
Class Celebration: Recall Difference = -0.06744186046511624
Class Parade: Recall Difference = -0.03820929569432563
Class Waiter_Or_Waitress: Recall Difference = -0.008321775312066626
Class Individual_Sports: Recall Difference = -0.07360290776919592
Class Surgeons: Recall Difference = 0.007390983000739093
Class Spa: Recall Difference = -0.6600000000000001
Class Law_Enforcement: Recall Difference = 0.11811811811811818
Class Business: Recall Difference = -0.0915395284327323
Class Dresses: Recall Difference = -0.5000000000000001
Class Water Activities: Recall Difference = 0.10659459459459453
Class Picnic: Recall Difference = -0.19999999999999996
Class Rescue: Recall Difference = 0.023291139240506332
Class Cheering: Recall Difference = -0.03035274815422477
Class Performanc

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.75it/s]


Average Adversary epoch loss: 0.1798726999611588


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.76it/s]


*****Epoch 9/300*****
*****Train Loss:  0.733449 Val Loss:  0.705649*****
*****Validation Accuracy: 38.93%*****
*****Total Avg Disparity: 0.12301392653681001*****

Class Team_Sports: Recall Difference = 0.1216325100718163
Class Celebration: Recall Difference = -0.030232558139534904
Class Parade: Recall Difference = -0.07927003136583974
Class Waiter_Or_Waitress: Recall Difference = -0.06842348589921399
Class Individual_Sports: Recall Difference = -0.05743601393306075
Class Surgeons: Recall Difference = -0.08277900960827789
Class Spa: Recall Difference = -0.58
Class Law_Enforcement: Recall Difference = 0.10635635635635637
Class Business: Recall Difference = -0.11479782353568765
Class Dresses: Recall Difference = -0.32791327913279134
Class Water Activities: Recall Difference = 0.04616216216216207
Class Picnic: Recall Difference = -0.03636363636363632
Class Rescue: Recall Difference = 0.07189873417721519
Class Cheering: Recall Difference = 0.011689909762100083
Class Performance_And_Enterta

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.74it/s]


Average Adversary epoch loss: 0.17964719971699744


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.70it/s]


*****Epoch 10/300*****
*****Train Loss:  0.686407 Val Loss:  0.760573*****
*****Validation Accuracy: 37.49%*****
*****Total Avg Disparity: 0.13419714658876125*****

Class Team_Sports: Recall Difference = 0.05166111986921229
Class Celebration: Recall Difference = 0.04883720930232566
Class Parade: Recall Difference = -0.07305389221556885
Class Waiter_Or_Waitress: Recall Difference = -0.13546000924641705
Class Individual_Sports: Recall Difference = 0.014425261244888687
Class Surgeons: Recall Difference = -0.180339985218034
Class Spa: Recall Difference = -0.6000000000000001
Class Law_Enforcement: Recall Difference = 0.06031031031031031
Class Business: Recall Difference = -0.05729222234076603
Class Dresses: Recall Difference = -0.35230352303523044
Class Water Activities: Recall Difference = 0.015567567567567497
Class Picnic: Recall Difference = -0.14545454545454545
Class Rescue: Recall Difference = -0.06734177215189874
Class Cheering: Recall Difference = 0.007383100902379008
Class Performan

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:59<00:00,  2.72it/s]


Average Adversary epoch loss: 0.17888039105242085


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.74it/s]


*****Epoch 11/300*****
*****Train Loss:  0.640381 Val Loss:  0.810872*****
*****Validation Accuracy: 39.57%*****
*****Total Avg Disparity: 0.08961984068434128*****

Class Team_Sports: Recall Difference = 0.12049979564430435
Class Celebration: Recall Difference = -0.07209302325581393
Class Parade: Recall Difference = 0.09187339606501282
Class Waiter_Or_Waitress: Recall Difference = -0.0073971336107258034
Class Individual_Sports: Recall Difference = 0.014046645464182939
Class Surgeons: Recall Difference = -0.11973392461197332
Class Spa: Recall Difference = -0.21999999999999997
Class Law_Enforcement: Recall Difference = 0.10035035035035042
Class Business: Recall Difference = -0.003093993385255525
Class Dresses: Recall Difference = -0.27168021680216814
Class Water Activities: Recall Difference = 0.026594594594594623
Class Picnic: Recall Difference = -0.12727272727272732
Class Rescue: Recall Difference = 0.011898734177215181
Class Cheering: Recall Difference = 0.01825266611977029
Class Perf

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.77it/s]


Average Adversary epoch loss: 0.17461632252294826


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.75it/s]


*****Epoch 12/300*****
*****Train Loss:  0.589834 Val Loss:  0.781554*****
*****Validation Accuracy: 39.80%*****
*****Total Avg Disparity: 0.11322410910569038*****

Class Team_Sports: Recall Difference = 0.057149529981899816
Class Celebration: Recall Difference = -0.03488372093023254
Class Parade: Recall Difference = 0.011291702309666363
Class Waiter_Or_Waitress: Recall Difference = -0.17753120665742025
Class Individual_Sports: Recall Difference = 0.018325003786157812
Class Surgeons: Recall Difference = 0.03843311160384333
Class Spa: Recall Difference = -0.34
Class Law_Enforcement: Recall Difference = 0.09934934934934941
Class Business: Recall Difference = -0.06059959458017711
Class Dresses: Recall Difference = -0.39363143631436315
Class Water Activities: Recall Difference = 0.02464864864864863
Class Picnic: Recall Difference = -0.054545454545454564
Class Rescue: Recall Difference = 0.11518987341772156
Class Cheering: Recall Difference = 0.00020508613617722937
Class Performance_And_Ent

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:59<00:00,  2.70it/s]


Average Adversary epoch loss: 0.1730465592804903


100%|█████████████████████████████████████████████████████████████████████████| 161/161 [01:22<00:00,  1.95it/s]
100%|███████████████████████████████████████████████████████████████████████████| 54/54 [00:19<00:00,  2.75it/s]


*****Epoch 13/300*****
*****Train Loss:  0.536662 Val Loss:  0.773910*****
*****Validation Accuracy: 39.77%*****
*****Total Avg Disparity: 0.10374705683319026*****

Class Team_Sports: Recall Difference = 0.10008758101243642
Class Celebration: Recall Difference = -0.07674418604651162
Class Parade: Recall Difference = 0.04014827487881362
Class Waiter_Or_Waitress: Recall Difference = -0.09847434119278764
Class Individual_Sports: Recall Difference = -0.004808420414962766
Class Surgeons: Recall Difference = -0.08277900960827789
Class Spa: Recall Difference = -0.21999999999999997
Class Law_Enforcement: Recall Difference = 0.07582582582582587
Class Business: Recall Difference = -0.1071161847860877
Class Dresses: Recall Difference = -0.39498644986449866
Class Water Activities: Recall Difference = 0.06551351351351342
Class Picnic: Recall Difference = -0.16363636363636358
Class Rescue: Recall Difference = -0.08202531645569622
Class Cheering: Recall Difference = -0.052091878589007386
Class Perfor

100%|█████████████████████████████████████████████████████████████████████████| 161/161 [00:58<00:00,  2.77it/s]


Average Adversary epoch loss: 0.17666654099033485


 24%|█████████████████▍                                                        | 38/161 [00:20<01:01,  1.99it/s]

In [None]:
def compare_performance_metrics(teacher, student, dataloader):
    teacher.eval()
    student.eval()
    
    all_labels = []
    all_teacher_preds = []
    all_student_preds = []

    for batch in dataloader:
        inputs = batch['img'].to(device)
        labels = batch['label'].to(device)
        
        with torch.no_grad():
            teacher_outputs = teacher(inputs)
            student_outputs = student(inputs)
            
        teacher_preds = torch.argmax(teacher_outputs, dim=1).cpu().numpy()
        student_preds = torch.argmax(student_outputs, dim=1).cpu().numpy()
        
        all_labels.append(labels.cpu().numpy())
        all_teacher_preds.append(teacher_preds)
        all_student_preds.append(student_preds)

    all_labels = np.concatenate(all_labels)
    all_teacher_preds = np.concatenate(all_teacher_preds)
    all_student_preds = np.concatenate(all_student_preds)
    
    metrics = {
        'accuracy': (accuracy_score(all_labels, all_teacher_preds), accuracy_score(all_labels, all_student_preds)),
        'precision': (precision_score(all_labels, all_teacher_preds, average='weighted', zero_division=0), precision_score(all_labels, all_student_preds, average='weighted', zero_division=0)),
        'recall': (recall_score(all_labels, all_teacher_preds, average='weighted'), recall_score(all_labels, all_student_preds, average='weighted')),
        'f1': (f1_score(all_labels, all_teacher_preds, average='weighted'), f1_score(all_labels, all_student_preds, average='weighted'))
    }

    return {
        'metrics': metrics,
        'all_labels': all_labels,
        'all_teacher_preds': all_teacher_preds,
        'all_student_preds': all_student_preds
    }

def compare_model_size(teacher, student):
    teacher_params = sum(p.numel() for p in teacher.parameters())
    student_params = sum(p.numel() for p in student.parameters())
    return teacher_params, student_params

def compare_inference_time(teacher, student, dataloader):
    dataiter = iter(dataloader)
    data = next(dataiter)
    inputs = data['img']
    
    teacher = teacher.to(device)
    student = student.to(device)
    inputs = inputs.to(device)
    
    start_time = time.time()
    with torch.no_grad():
        teacher_outputs = teacher(inputs)
    teacher_time = time.time() - start_time

    start_time = time.time()
    with torch.no_grad():
        student_outputs = student(inputs)
    student_time = time.time() - start_time
    
    return teacher_time, student_time

In [None]:
# Loop through each lambda value
for lmda_teacher in lmda_list_teacher:
    for lmda_student in lmda_list_student:

    # Load teacher and student models for the current lambda
    teacher_model = torch.load(f'teacher_model_ckd_wider_lambda{lmda}.pth')
    student_model = torch.load(f'student_model_ckd_wider_lambda{lmda}.pth')

    # Compute performance metrics
    performance_metrics = compare_performance_metrics(teacher_model, student_model, testloader)

    # Compute model sizes and inference times
    teacher_params, student_params = compare_model_size(teacher_model, student_model)
    teacher_time, student_time = compare_inference_time(teacher_model, student_model, testloader)

    # Update results for the current lambda value
    if lmda in lambda_results:
        lambda_results[lmda].update({
            'performance_metrics': performance_metrics,
            'teacher_params': teacher_params,
            'student_params': student_params,
            'teacher_time': teacher_time,
            'student_time': student_time
        })
    else:
        lambda_results[lmda] = {
            'performance_metrics': performance_metrics,
            'teacher_params': teacher_params,
            'student_params': student_params,
            'teacher_time': teacher_time,
            'student_time': student_time
        }

In [None]:
# Initialize lists to store accuracies
teacher_accuracies = []
student_accuracies = []
lambda_pairs = list(lambda_results.keys())

# Iterate over the keys in lambda_results
for key in lambda_pairs:
    # Check if the key is a tuple (indicating a lambda pair)
    if isinstance(key, tuple) and len(key) == 2:
        lmda_teacher, lmda_student = key
    else:
        # If the key is not a tuple, skip this iteration
        continue

    # Access the performance metrics for each pair
    teacher_accuracy = lambda_results[(lmda_teacher, lmda_student)]['performance_metrics']['metrics']['accuracy'][0]
    student_accuracy = lambda_results[(lmda_teacher, lmda_student)]['performance_metrics']['metrics']['accuracy'][1]

    # Append accuracies to the lists
    teacher_accuracies.append((lmda_teacher, teacher_accuracy))
    student_accuracies.append((lmda_student, student_accuracy))

# To plot, you might need to separate the lambda values and accuracies
teacher_lambdas, teacher_acc = zip(*teacher_accuracies)
student_lambdas, student_acc = zip(*student_accuracies)

# Plotting
plt.plot(teacher_lambdas, teacher_acc, label='Teacher Accuracy', marker='o')
plt.plot(student_lambdas, student_acc, label='Student Accuracy', marker='o')
plt.xlabel('Lambda')
plt.ylabel('Accuracy')
plt.title('Accuracy Comparison Across Lambdas')
plt.legend()
plt.show()


In [None]:
# Initialize lists to store precisions
teacher_precisions = []
student_precisions = []
lambda_pairs = list(lambda_results.keys())

# Iterate over the keys in lambda_results
for key in lambda_pairs:
    # Check if the key is a tuple (indicating a lambda pair)
    if isinstance(key, tuple) and len(key) == 2:
        lmda_teacher, lmda_student = key
        # Access the precision metrics for each pair
        teacher_precision = lambda_results[(lmda_teacher, lmda_student)]['performance_metrics']['metrics']['precision'][0]
        student_precision = lambda_results[(lmda_teacher, lmda_student)]['performance_metrics']['metrics']['precision'][1]
    else:
        # If the key is not a tuple, skip this iteration
        continue

    # Append precisions to the lists along with lambda values
    teacher_precisions.append((lmda_teacher, teacher_precision))
    student_precisions.append((lmda_student, student_precision))

# To plot, you might need to separate the lambda values and precisions
teacher_lambdas, teacher_prec = zip(*teacher_precisions)
student_lambdas, student_prec = zip(*student_precisions)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(teacher_lambdas, teacher_prec, label='Teacher Precision', marker='o')
plt.plot(student_lambdas, student_prec, label='Student Precision', marker='o')
plt.xlabel('Lambda')
plt.ylabel('Precision')
plt.title('Precision Comparison Across Lambdas')
plt.legend()
plt.show()


In [None]:
# Initialize lists to store recalls
teacher_recalls = []
student_recalls = []
lambda_pairs = list(lambda_results.keys())

# Iterate over the keys in lambda_results
for key in lambda_pairs:
    # Check if the key is a tuple (indicating a lambda pair)
    if isinstance(key, tuple) and len(key) == 2:
        lmda_teacher, lmda_student = key
        # Access the recall metrics for each pair
        teacher_recall = lambda_results[(lmda_teacher, lmda_student)]['performance_metrics']['metrics']['recall'][0]
        student_recall = lambda_results[(lmda_teacher, lmda_student)]['performance_metrics']['metrics']['recall'][1]
    else:
        # If the key is not a tuple, skip this iteration
        continue

    # Append recalls to the lists along with lambda values
    teacher_recalls.append((lmda_teacher, teacher_recall))
    student_recalls.append((lmda_student, student_recall))

# To plot, you might need to separate the lambda values and recalls
teacher_lambdas, teacher_rec = zip(*teacher_recalls)
student_lambdas, student_rec = zip(*student_recalls)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(teacher_lambdas, teacher_rec, label='Teacher Recall', marker='o')
plt.plot(student_lambdas, student_rec, label='Student Recall', marker='o')
plt.xlabel('Lambda')
plt.ylabel('Recall')
plt.title('Recall Comparison Across Lambdas')
plt.legend()
plt.show()


In [None]:
# Initialize lists to store F1 scores
teacher_f1s = []
student_f1s = []
lambda_pairs = list(lambda_results.keys())

# Iterate over the keys in lambda_results
for key in lambda_pairs:
    # Check if the key is a tuple (indicating a lambda pair)
    if isinstance(key, tuple) and len(key) == 2:
        lmda_teacher, lmda_student = key
        # Access the F1 scores for each pair
        teacher_f1 = lambda_results[(lmda_teacher, lmda_student)]['performance_metrics']['metrics']['f1'][0]
        student_f1 = lambda_results[(lmda_teacher, lmda_student)]['performance_metrics']['metrics']['f1'][1]
    else:
        # If the key is not a tuple, skip this iteration
        continue

    # Append F1 scores to the lists along with lambda values
    teacher_f1s.append((lmda_teacher, teacher_f1))
    student_f1s.append((lmda_student, student_f1))

# To plot, you might need to separate the lambda values and F1 scores
teacher_lambdas, teacher_f1_scores = zip(*teacher_f1s)
student_lambdas, student_f1_scores = zip(*student_f1s)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(teacher_lambdas, teacher_f1_scores, label='Teacher F1 Score', marker='o')
plt.plot(student_lambdas, student_f1_scores, label='Student F1 Score', marker='o')
plt.xlabel('Lambda')
plt.ylabel('F1 Score')
plt.title('F1 Score Comparison Across Lambdas')
plt.legend()
plt.show()


In [None]:
# Initialize lists to store model sizes
teacher_sizes = []
student_sizes = []
lambda_pairs = list(lambda_results.keys())

# Iterate over the keys in lambda_results
for key in lambda_pairs:
    # Check if the key is a tuple (indicating a lambda pair)
    if isinstance(key, tuple) and len(key) == 2:
        lmda_teacher, lmda_student = key
        # Access the model sizes for each pair
        teacher_size = lambda_results[(lmda_teacher, lmda_student)]['teacher_params'] / 1e6  # Convert to millions
        student_size = lambda_results[(lmda_teacher, lmda_student)]['student_params'] / 1e6
    else:
        # If the key is not a tuple, skip this iteration
        continue

    # Append model sizes to the lists along with lambda values
    teacher_sizes.append((lmda_teacher, teacher_size))
    student_sizes.append((lmda_student, student_size))

# To plot, you might need to separate the lambda values and model sizes
teacher_lambdas, teacher_model_sizes = zip(*teacher_sizes)
student_lambdas, student_model_sizes = zip(*student_sizes)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(teacher_lambdas, teacher_model_sizes, label='Teacher Model Size', marker='o')
plt.plot(student_lambdas, student_model_sizes, label='Student Model Size', marker='o')
plt.xlabel('Lambda')
plt.ylabel('Model Size (Millions of Parameters)')
plt.title('Model Size Comparison Across Lambdas')
plt.legend()
plt.show()


In [None]:
# Initialize lists to store inference times
teacher_times = []
student_times = []
lambda_pairs = list(lambda_results.keys())

# Iterate over the keys in lambda_results
for key in lambda_pairs:
    # Check if the key is a tuple (indicating a lambda pair)
    if isinstance(key, tuple) and len(key) == 2:
        lmda_teacher, lmda_student = key
        # Access the inference times for each pair
        teacher_time = lambda_results[(lmda_teacher, lmda_student)]['teacher_time']
        student_time = lambda_results[(lmda_teacher, lmda_student)]['student_time']
    else:
        # If the key is not a tuple, skip this iteration
        continue

    # Append inference times to the lists along with lambda values
    teacher_times.append((lmda_teacher, teacher_time))
    student_times.append((lmda_student, student_time))

# To plot, you might need to separate the lambda values and inference times
teacher_lambdas, teacher_inference_times = zip(*teacher_times)
student_lambdas, student_inference_times = zip(*student_times)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(teacher_lambdas, teacher_inference_times, label='Teacher Inference Time', marker='o')
plt.plot(student_lambdas, student_inference_times, label='Student Inference Time', marker='o')
plt.xlabel('Lambda')
plt.ylabel('Inference Time (s)')
plt.title('Inference Time Comparison Across Lambdas')
plt.legend()
plt.show()


In [None]:
# Initialize lists to store disparity values
teacher_disparities = []
student_disparities = []
lambda_pairs = list(lambda_results.keys())

# Iterate over the keys in lambda_results
for key in lambda_pairs:
    # Check if it's an integer key (indicating a lambda value for student)
    if isinstance(key, int):
        # Check and extract teacher disparity if it exists
        if 'teacher_mean_abs_val_disparity' in lambda_results[key]:
            teacher_disparity = lambda_results[key]['teacher_mean_abs_val_disparity']
            if isinstance(teacher_disparity, list):  # Assuming the disparity could be stored as a list
                teacher_disparity = teacher_disparity[0]
            teacher_disparities.append((key, teacher_disparity))

        # Extract student disparity
        if 'student_mean_abs_val_disparity' in lambda_results[key]:
            student_disparity = lambda_results[key]['student_mean_abs_val_disparity']
            student_disparities.append((key, student_disparity))

# Separate the lambda values and disparity values
teacher_lambdas, teacher_disparity_values = zip(*teacher_disparities) if teacher_disparities else ([], [])
student_lambdas, student_disparity_values = zip(*student_disparities) if student_disparities else ([], [])

# Plotting
plt.figure(figsize=(10, 6))
if teacher_disparities:
    plt.plot(teacher_lambdas, teacher_disparity_values, label='Teacher Average Disparity', marker='o')
if student_disparities:
    plt.plot(student_lambdas, student_disparity_values, label='Student Average Disparity', marker='o')
plt.xlabel('Lambda')
plt.ylabel('Average Disparity')
plt.title('Average Disparity Comparison Across Lambdas')
plt.legend()
plt.show()


In [None]:
def plot_distribution(predictions, class_names, title):
    plt.figure(figsize=(6, 4))
    sns.countplot(x=predictions)
    plt.title(title)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(all_labels, predictions, class_names, title):
    cm = confusion_matrix(all_labels, predictions)
    plt.figure(figsize=(6, 6))
    sns.heatmap(pd.DataFrame(cm, index=class_names, columns=class_names), annot=True, fmt='g')
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

# Function to generate predictions and compute metrics
def generate_predictions_and_metrics(model, dataloader):
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in dataloader:
            inputs = batch['img'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            all_labels.append(labels.cpu().numpy())
            all_preds.append(preds)

    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)

    return all_labels, all_preds

# Loop over each lambda value for the teacher model
for lmda_teacher in lmda_list_teacher:
    # Load the teacher model
    teacher_model = torch.load(f'teacher_model_ckd_wider_lambda{lmda_teacher}.pth')

    # Generate predictions for the teacher model
    all_labels, all_teacher_preds = generate_predictions_and_metrics(teacher_model, testloader)

    # Plot distribution and confusion matrix for the teacher model
    plot_distribution(all_teacher_preds, class_names_new, f'Teacher Model Predictions (Lambda={lmda_teacher})')
    plot_confusion_matrix(all_labels, all_teacher_preds, class_names_new, f'Teacher Confusion Matrix (Lambda={lmda_teacher})')

    # Print classification report for the teacher model
    teacher_report = classification_report(all_labels, all_teacher_preds, target_names=class_names_new, zero_division=0)
    print(f'Classification Report - Teacher Model (Lambda={lmda_teacher})')
    print(teacher_report)

# Loop over each lambda value for the student model
for lmda_student in lmda_list_student:
    # Load the student model
    student_model = torch.load(f'student_model_ckd_wider_lambda{lmda_student}.pth')

    # Generate predictions for the student model
    all_labels, all_student_preds = generate_predictions_and_metrics(student_model, testloader)

    # Plot distribution and confusion matrix for the student model
    plot_distribution(all_student_preds, class_names_new, f'Student Model Predictions (Lambda={lmda_student})')
    plot_confusion_matrix(all_labels, all_student_preds, class_names_new, f'Student Confusion Matrix (Lambda={lmda_student})')

    # Print classification report for the student model
    student_report = classification_report(all_labels, all_student_preds, target_names=class_names_new, zero_division=0)
    print(f'Classification Report - Student Model (Lambda={lmda_student})')
    print(student_report)


In [None]:
def plot_bias_variance_tradeoff(model_results, model_type, lambdas):
    bias_values = []
    accuracy_values = []

    if model_type == 'teacher':
        for lmda in lambdas:
            if lmda in model_results and 'teacher_mean_abs_val_disparity' in model_results[lmda]:
                bias_values.append(model_results[lmda]['teacher_mean_abs_val_disparity'][0])
                performance_key = next((key for key in model_results if isinstance(key, tuple) and key[0] == lmda), None)
                if performance_key:
                    accuracy_values.append(model_results[performance_key]['performance_metrics']['metrics']['accuracy'][0])
        model_name = "Teacher"
    elif model_type == 'student':
        for lmda in lambdas:
            if lmda in model_results and 'student_mean_abs_val_disparity' in model_results[lmda]:
                bias_values.append(model_results[lmda]['student_mean_abs_val_disparity'])
                performance_key = next((key for key in model_results if isinstance(key, tuple) and key[1] == lmda), None)
                if performance_key:
                    accuracy_values.append(model_results[performance_key]['performance_metrics']['metrics']['accuracy'][1])
        model_name = "Student"
    else:
        raise ValueError("Invalid model type. Choose 'teacher' or 'student'.")


    # Weight for the trade-off (can be adjusted based on preference)
    bias_weight = 1

    # Calculate the weighted ratio
    weighted_ratios = np.array(accuracy_values) / (1 + bias_weight * np.array(bias_values))
    closest_to_one_index = np.argmin(np.abs(weighted_ratios - 1))
    optimal_bias = bias_values[closest_to_one_index]
    optimal_accuracy = accuracy_values[closest_to_one_index]
    optimal_ratio = weighted_ratios[closest_to_one_index]

    # Plotting the bias-variance trade-off curve
    plt.plot(bias_values, accuracy_values, marker='o', linestyle='-', label=f'{model_name} Trade-off Points')

    # Mark all points with their lambda values
    for i, (bias, acc, lmbda) in enumerate(zip(bias_values, accuracy_values, lambdas)):
        plt.annotate(f'λ={lmbda}', (bias, acc), textcoords="offset points", xytext=(0,10), ha='center')

    # Highlight the optimal point
    plt.scatter(optimal_bias, optimal_accuracy, color='r', s=100, marker='X', label=f'Optimal Point (λ={lambdas[closest_to_one_index]})')
    plt.xlabel('Disparity')
    plt.ylabel('Accuracy')
    plt.title(f'{model_name} Accuracy-Fairness Trade-off Curve')
    plt.legend()
    plt.show()

    # Print optimal values
    print(f"Optimal Lambda for {model_name}: {lambdas[closest_to_one_index]}")
    print(f"Optimal Bias/Disparity for {model_name}: {optimal_bias}")
    print(f"Optimal Accuracy for {model_name}: {optimal_accuracy}")
    print(f"Optimal Weighted Ratio for {model_name}: {optimal_ratio:.2f}")
    
# Plot for Teacher
plot_bias_variance_tradeoff(lambda_results, 'teacher', lmda_list_teacher)

# Plot for Student
plot_bias_variance_tradeoff(lambda_results, 'student', lmda_list_student)

In [None]:
lambda_results

In [None]:
def compare_performance_metrics_for_demo(teacher, student, dataloader):
    teacher.eval()
    student.eval()

    detailed_info = []

    for batch in dataloader:
        inputs = batch['img'].to(device)
        labels = batch['label'].to(device)
        # Assuming gender or other attributes are part of 'target'
        attributes = batch['target'].to(device)  

        with torch.no_grad():
            teacher_outputs = teacher(inputs)
            student_outputs = student(inputs)

        teacher_preds = torch.argmax(teacher_outputs, dim=1)
        student_preds = torch.argmax(student_outputs, dim=1)

        for i in range(inputs.size(0)):
            if teacher_preds[i] != labels[i] and student_preds[i] == labels[i]:
                info = {
                    'image': inputs[i],
                    'actual_class': labels[i].item(),
                    'teacher_pred_class': teacher_preds[i].item(),
                    'student_pred_class': student_preds[i].item(),
                    'actual_attribute': attributes[i].item(),  # Modify based on your dataset
                    # If your model also predicts attributes, include them here
                }
                detailed_info.append(info)

    return detailed_info



In [None]:
def plot_images_with_details(info_list, rows=5, cols=5):
    fig, axes = plt.subplots(rows, cols, figsize=(15, 15))
    axes = axes.ravel()

    for i in range(rows * cols):
        if i < len(info_list):
            data = info_list[i]
            image = data['image']
            actual_class = data['actual_class']
            teacher_pred_class = data['teacher_pred_class']
            student_pred_class = data['student_pred_class']
            actual_attribute = round(data['actual_attribute'], 2)  # Round to 2 decimal places

            # Normalize the image for display
            image_display = image.cpu().numpy().transpose(1, 2, 0)
            image_display = (image_display - image_display.min()) / (image_display.max() - image_display.min())

            title = f'Actual: Class {actual_class}, Attr {actual_attribute}\n' + \
                    f'Teacher: Class {teacher_pred_class}\n' + \
                    f'Student: Class {student_pred_class}'

            axes[i].imshow(image_display)
            axes[i].set_title(title)
            axes[i].axis('off')
        else:
            axes[i].axis('off')

    plt.subplots_adjust(wspace=0.5)
    plt.show()


In [None]:
# Get detailed info where student is correct and teacher is wrong
detailed_info = compare_performance_metrics_for_demo(teacher_model, student_model, testloader)

# Display images with details
plot_images_with_details(detailed_info)
