In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import time
import os
import torch.nn.functional as F
from tqdm import tqdm

print(torch.cuda.is_available())
print(torch.cuda.device_count())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

True
1


In [19]:
learning_rate = 0.001
num_epochs = 2
num_workers = 2
batch_size = 100
temperature = 4.0
alpha = 0.9
momentum = 0.9
num_classes = 100
step_size = 20
gamma = 0.1

In [20]:
# Define the Teacher and Student model architectures
class TeacherCNN(nn.Module):
    def __init__(self):
        super(TeacherCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.fc1 = nn.Linear(in_features=2304, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=num_classes)

    def forward(self, x):
        x = nn.ReLU()(self.conv1(x))
        x = nn.MaxPool2d(kernel_size=2)(x)
        x = nn.ReLU()(self.conv2(x))
        x = nn.MaxPool2d(kernel_size=2)(x)
        x = x.view(x.size(0), -1)
        x = nn.ReLU()(self.fc1(x))
        x = self.fc2(x)
        return x

class StudentCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(StudentCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(16)
        self.dropout = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(in_features=14400, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

In [21]:
# Load CIFAR-100 dataset
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# Create instances of your models
teacher = TeacherCNN()
teacher.eval()  # Set teacher model to evaluation mode
student = StudentCNN()

Files already downloaded and verified


In [22]:
# Define optimizer and loss function
optimizer = optim.SGD(student.parameters(), lr=learning_rate, momentum=momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

criterion = nn.CrossEntropyLoss()

In [23]:
def fair_metric(pred, labels, sens):
    sens = sens.cpu()
    labels = labels.cpu()
    idx_s0 = sens == 0
    idx_s1 = sens == 1
    idx_s0_y1 = idx_s0 & (labels == 1)
    idx_s1_y1 = idx_s1 & (labels == 1)

    if idx_s0.sum().item() == 0 or idx_s1.sum().item() == 0 or idx_s0_y1.sum().item() == 0 or idx_s1_y1.sum().item() == 0:
        return torch.tensor(0.0), torch.tensor(0.0)  # Return zeros

    parity = torch.abs((pred[idx_s0].sum() / (idx_s0.sum().item() + 1e-8)) - (pred[idx_s1].sum() / (idx_s1.sum().item() + 1e-8)))
    equality = torch.abs((pred[idx_s0_y1].sum() / (idx_s0_y1.sum().item() + 1e-8)) - (pred[idx_s1_y1].sum() / (idx_s1_y1.sum().item() + 1e-8)))

    return parity.item(), equality.item()

def approx_loss(logits, sens, idx):
    logits, sens, idx = logits.cpu(), sens.cpu(), idx.cpu()
    g1 = np.argwhere(sens.numpy() == 0.0).reshape(-1)
    g2 = np.argwhere(sens.numpy() == 1.0).reshape(-1)
    idx_set = set(idx.numpy())
    g1 = np.array(list(set(g1) & idx_set))
    g2 = np.array(list(set(g2) & idx_set))

    if g1.shape[0] == 0 or g2.shape[0] == 0:
        return torch.tensor(0.0)  # Return zero

    loss = torch.square(approx_func(logits[g1]).sum(axis=0) / (g1.shape[0]) - approx_func(logits[g2]).sum(axis=0) / (g2.shape[0])).sum()
    return loss

def approx_loss_eo(logits, sens, labels, idx):
    logits, sens, labels, idx = logits.cpu(), sens.cpu(), labels.cpu(), idx.cpu()
    g1 = np.argwhere(sens.numpy() == 0).reshape(-1)
    g2 = np.argwhere(sens.numpy() == 1).reshape(-1)
    g = np.argwhere(labels.numpy() == 1).reshape(-1)
    idx_set = set(idx.numpy())
    g1 = np.array(list(set(g1) & set(g) & idx_set))
    g2 = np.array(list(set(g2) & set(g) & idx_set))
    
    if g1.shape[0] == 0 or g2.shape[0] == 0:
        return torch.tensor(0.0)  # Return zero 

    loss = torch.square(approx_func(logits[g1]).sum(axis=0) / (g1.shape[0] + 1e-8) - approx_func(logits[g2]).sum(axis=0) / (g2.shape[0] + 1e-8)).sum()
    return loss

def approx_func(s):
    device = s.device  # Store the device information
    dtype = s.dtype    # Store the data type information
    s = s.detach().cpu().numpy()
    x = 2 * s - 1
    result = 1 / 2 + 1 / 2 * x - 1 / 8 / 2 * (5 * x * x * x - 3 * x) + 1 / 16 / 8 * (
                63 * x * x * x * x * x - 70 * x * x * x + 15 * x) - 5 / 128 / 16 * (
                       429 * x * x * x * x * x * x * x - 693 * x * x * x * x * x + 315 * x * x * x - 35 * x) + 7 / 256 / 128 * (
                       12155 * x * x * x * x * x * x * x * x * x - 25740 * x * x * x * x * x * x * x + 18018 * x * x * x * x * x - 4620 * x * x * x + 315 * x)
    return torch.tensor(result, device=device, dtype=dtype)




In [24]:

# Define the logging lists at the beginning of your script or before the training loop
teacher_loss_log, teacher_combined_loss_log, teacher_fairness_penalty_sp_log, teacher_fairness_penalty_eo_log, teacher_fairness_penalty_approx_log, teacher_fairness_penalty_approx_eo_log = [], [], [], [], [], []
student_loss_log, student_combined_loss_log, student_fairness_penalty_sp_log, student_fairness_penalty_eo_log, student_fairness_penalty_approx_log, student_fairness_penalty_approx_eo_log = [], [], [], [], [], []

def adversarial_train_model(model, dataloader, criterion, optimizer, epoch, loss_log, combined_loss_log, fairness_penalty_sp_log, fairness_penalty_eo_log, fairness_penalty_approx_log, fairness_penalty_approx_eo_log):
    model.to(device)
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(enumerate(dataloader, 0), total=len(dataloader), unit="batch")
    for i, data in progress_bar:
        inputs, labels = data
        sensitive_attrs = labels  # Assuming label itself as a sensitive attribute
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        # Incorporate fairness metrics
        fairness_penalty_sp, fairness_penalty_eo = fair_metric(outputs, labels, sensitive_attrs)
        fairness_penalty_approx = approx_loss(outputs, sensitive_attrs, torch.arange(len(labels)).to(device))
        fairness_penalty_approx_eo = approx_loss_eo(outputs, sensitive_attrs, labels, torch.arange(len(labels)).to(device))
        combined_loss = loss + fairness_penalty_sp + fairness_penalty_eo + fairness_penalty_approx + fairness_penalty_approx_eo

        progress_bar.set_postfix(
            combined_loss=combined_loss.item(),
            fairness_penalty_sp=fairness_penalty_sp,
            fairness_penalty_eo=fairness_penalty_eo,
            fairness_penalty_approx=fairness_penalty_approx.item(),
            fairness_penalty_approx_eo=fairness_penalty_approx_eo.item()
        )
        combined_loss.backward()
        optimizer.step()
        running_loss += combined_loss.item()
        progress_bar.set_description(f"Epoch {epoch} Loss: {running_loss/(i+1):.4f}")
        
        # Logging the losses
        loss_log.append(loss.item())
        combined_loss_log.append(combined_loss.item())
        fairness_penalty_sp_log.append(fairness_penalty_sp)
        fairness_penalty_eo_log.append(fairness_penalty_eo)
        fairness_penalty_approx_log.append(fairness_penalty_approx.item())
        fairness_penalty_approx_eo_log.append(fairness_penalty_approx_eo.item())

    # Print the loss components and total loss after each epoch
    print(f'Epoch {epoch + 1} Loss Components:')
    print(f'Total Loss: {running_loss / len(dataloader):.4f}')
    print(f'Cross-Entropy Loss: {sum(loss_log) / len(loss_log):.4f}')
    print(f'SP Fairness Penalty: {sum(fairness_penalty_sp_log) / len(fairness_penalty_sp_log):.4f}')
    print(f'EO Fairness Penalty: {sum(fairness_penalty_eo_log) / len(fairness_penalty_eo_log):.4f}')
    print(f'Approx Fairness Penalty: {sum(fairness_penalty_approx_log) / len(fairness_penalty_approx_log):.4f}')
    print(f'Approx EO Fairness Penalty: {sum(fairness_penalty_approx_eo_log) / len(fairness_penalty_approx_eo_log):.4f}')

    return running_loss / len(dataloader)

# Train your model with adversarial training
for epoch in range(num_epochs):
    adversarial_train_model(student, trainloader, criterion, optimizer, epoch, student_loss_log, student_combined_loss_log, student_fairness_penalty_sp_log, student_fairness_penalty_eo_log, student_fairness_penalty_approx_log, student_fairness_penalty_approx_eo_log)


Epoch 0 Loss: 521693591602633176514560.0000: 100%|█| 500/500 [00:10<00:00, 46.59batch/s, combined_loss=3.58, fairness_penalty_approx=0, fairness_penalty_app

Epoch 1 Loss Components:
Total Loss: 521693591602633176514560.0000
Cross-Entropy Loss: 4.0301
SP Fairness Penalty: 0.0000
EO Fairness Penalty: 0.0000
Approx Fairness Penalty: 521693591602633176514560.0000
Approx EO Fairness Penalty: 0.0000



Epoch 1 Loss: 246523166593024254728470528.0000: 100%|█| 500/500 [00:10<00:00, 47.65batch/s, combined_loss=3.72, fairness_penalty_approx=0, fairness_penalty_

Epoch 2 Loss Components:
Total Loss: 246523166593024254728470528.0000
Cross-Entropy Loss: 3.8573
SP Fairness Penalty: 0.0000
EO Fairness Penalty: 0.0000
Approx Fairness Penalty: 123522430092313452374654976.0000
Approx EO Fairness Penalty: 0.0000





In [25]:
# Save the student model architecture and weights
torch.save(student.state_dict(), 'student_model.pth')

# Define your model architectures as strings
teacher_arch = '''
class TeacherCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(TeacherCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.fc1 = nn.Linear(in_features=2304, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=num_classes)

    def forward(self, x):
        x = nn.ReLU()(self.conv1(x))
        x = nn.MaxPool2d(kernel_size=2)(x)
        x = nn.ReLU()(self.conv2(x))
        x = nn.MaxPool2d(kernel_size=2)(x)
        x = x.view(x.size(0), -1)
        x = nn.ReLU()(self.fc1(x))
        x = self.fc2(x)
        return x
'''

student_arch = '''
class StudentCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(StudentCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(16)
        self.dropout = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(in_features=14400, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x
'''

# Open model_arch.py in write mode
with open('model_arch.py', 'w') as file:
    # Write the model architectures to model_arch.py
    file.write(teacher_arch)
    file.write(student_arch)
