In [None]:
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid
from PIL import Image
from time import time
from tqdm import tqdm
from transformers import ViTConfig, ViTModel

device = torch.device('cuda:1')

image_size = 64
batch_size = 1024
dataset = torchvision.datasets.CelebA("../../celeba/datasets/",split='train', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

test_dataset = torchvision.datasets.CelebA("../../celeba/datasets/",split='test', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
                              ]))

training_data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
print('Done')

In [None]:
class Teacher_VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(Teacher_VisionTransformer, self).__init__()
        self.vit = vit
        self.seq = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),     
            nn.Sigmoid()
        )
    
    def forward(self, x):
        z = self.vit(x)
        m = z.last_hidden_state
        g = m[:,0]
        y = self.seq(g)
        return y 
    
class Student_VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(Student_VisionTransformer, self).__init__()
        self.vit = vit
        self.seq = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),     
            nn.Sigmoid()
        )
    
    def forward(self, x):
        z = self.vit(x)
        m = z.last_hidden_state
        g = m[:,0]
        y = self.seq(g)
        return y 

In [None]:
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import seaborn as sns

def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train_model():
    epoch = 15
    teacher_configuration = ViTConfig(num_hidden_layers = 12, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
    
    student_configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
    
    vit_teacher = ViTModel(teacher_configuration)
    vit_student = ViTModel(student_configuration)

    vit_teacher = vit_teacher.to(device)
    vit_student = vit_student.to(device)

    model = Teacher_VisionTransformer(vit_teacher)
    model = model.to(device)
    
    criterion = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    PATH2 = r'knowledge_distillation.pt'

    sensitive_epoch = 20
    student_model = Student_VisionTransformer(vit_student)
    student_model = student_model.to(device)

    criterion = nn.BCELoss()
    criterion_student = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    student_optimizer = optim.AdamW(student_model.parameters(), lr=1e-4)
    lambda_coeff = 1

    
    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            for train_input, attributes in tepoch:
                # Transfer data to GPU if possible. 
                train_input = train_input.to(device)
                sensitive, train_target = attributes[:,20], attributes[:,2]
                train_target = train_target.float().to(device)
                train_target = train_target.unsqueeze(1)
                optimizer.zero_grad()

                # Learner update step.
                outputs = model(train_input)
                loss = criterion(outputs, train_target)
                loss.backward()
                #logger_learner.add_values(logging_dict)
                optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)
                tepoch.set_postfix(loss = loss.item())

            if epoches == epoch-1:
                torch.save(model.state_dict(), PATH2)
                


    model.load_state_dict(torch.load(PATH2), True)
        
    for epoches in range(sensitive_epoch):
        model.eval()
        student_model.train()
        with tqdm(training_data_loader, unit="batch") as tepoch:
            for train_input, attributes in tepoch:
                # Transfer data to GPU if possible. 
                train_input = train_input.to(device)
                sensitive, train_target = attributes[:,20], attributes[:,2]
                train_target = train_target.float().to(device)
                train_target = train_target.unsqueeze(1)
                #optimizer.zero_grad()
                student_optimizer.zero_grad()

                # Learner update step.
                student_output = student_model(train_input)
                soft_label = model(train_input)
                
                loss_soft = criterion_student(student_output ,soft_label)
                loss_hard = criterion_student(student_output, train_target)
                loss = lambda_coeff * loss_soft + (1 - lambda_coeff) * loss_hard 
                loss.backward()
                #logger_learner.add_values(logging_dict)
                #optimizer.step()
                student_optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)
                tepoch.set_postfix(loss = loss.item())
                
        student_model.eval()
        test_pred = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []

        for step, (test_input, attributes) in enumerate(test_data_loader):
            test_sensitive, test_target = attributes[:,20], attributes[:,2]
            test_input = test_input.to(device)
            test_target = test_target.to(device)
            test_target = test_target.unsqueeze(1)
            test_sensitive = test_sensitive.unsqueeze(1)

            gt = test_target.detach().cpu().numpy()
            sen = test_sensitive.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)
            
            with torch.no_grad():
                test_pred_= student_model(test_input)   
                test_pred.extend(torch.round(test_pred_.squeeze(1)).detach().cpu().numpy())

        for i in range(len(sense_gt)):
            if sense_gt[i] == 0:
                female_predic.append(test_pred[i])
                female_gt.append(test_gt[i])
            else:
                male_predic.append(test_pred[i])
                male_gt.append(test_gt[i])
        female_CM = confusion_matrix(female_gt, female_predic)    
        male_CM = confusion_matrix(male_gt, male_predic) 
        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])

        print('Female TPR', female_TPR)
        print('male TPR', male_TPR)
        print('DP',abs(female_dp - male_dp))
        print('EOP', abs(female_TPR - male_TPR))
        print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
        print('acc', accuracy_score(test_gt, test_pred))
        print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )




seed_everything(4096)    
train_model()