# Imports & Setup

In [None]:
NUM = 5500
TRAIN = 3500
VAL = 1000
SPLIT = 3
DIM = 64
BOUNDARY = 2.85
NUM_FEATURES = 64
NOISE = 128
GAN_BATCH_SIZE = 4
CLASSIFIER_BATCH_SIZE = 8
D_LEARNING_RATE = 0.0001
G_LEARNING_RATE = 0.0001
CLASSIFIER_LEARNING_RATE = 0.0005
TEXT1 = 'Female'
TEXT0 = 'Male'
DATASET = 'SCUT'
URL = ''

In [None]:
import torch
import torch.nn as nn

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import os

from PIL import Image

if URL != '':
    from tqdm.notebook import tqdm

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

In [None]:
if URL != '':
    from google.colab import drive
    drive.mount('./mount')

# Helper Functions: Tracking and Plotting

In [None]:
class Tracker():
  def __init__(self):

    self.j = random.randint(1, 1000)
    self.epoch = 1

    self.count = 0
    self.real_count = 0
    self.fake_count = 0
    self.acc_fake = 0
    self.acc_real = 0
    self.acc_list = []
    self.acc_list_fake = []
    self.acc_list_real = []
    self.d_loss_list = []
    self.g_loss_list = []

  def add_d_loss(self, loss):
    self.d_loss_list.append(loss)
    
  def add_g_loss(self, loss):
    self.g_loss_list.append(loss)

  def add_epoch(self):
    self.epoch += 1
    
  def change_j(self, num):
    self.j = num
        
  def set_epoch(self, num):
    self.epoch = num

  def track(self, source, outputs_source):

    # increase counters
    self.count += GAN_BATCH_SIZE 
    if source == 'Real':
      self.real_count += GAN_BATCH_SIZE
    elif source == 'Fake':
      self.fake_count += GAN_BATCH_SIZE

    # calculate accuracies
    for i in range(GAN_BATCH_SIZE):
      if source == 'Real' and outputs_source[i].item() >= 0.5:
        self.acc_real += 1
      elif source == 'Fake' and outputs_source[i].item() < 0.5:
        self.acc_fake += 1

    # log accuracies
    self.acc_list.append((self.acc_real + self.acc_fake) / self.count)
    if source == 'Real':
      self.acc_list_real.append(self.acc_real / self.real_count)
    elif source == 'Fake':
      self.acc_list_fake.append(self.acc_fake / self.fake_count)

In [None]:
class C_Tracker():
  def __init__(self):
    self.j = random.randint(1, 1000)
    self.train_acc = []
    self.val_acc = []
    self.loss_list = []
    
  def add_train_acc(self, acc):
    self.train_acc.append(acc)
    
  def add_val_acc(self, acc):
    self.val_acc.append(acc)

  def add_loss(self, loss):
    self.loss_list.append(loss)

In [None]:
def check_skew(dataset):

  pos_count = 0
  s1_count = 0
  s1_pos = 0
    
  total = len(dataset)
  dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    
  for batch in dataloader:
    _, y, s = batch

    if y.item() == 1.0:
      pos_count += 1
      if s.item() == 1.0:
        s1_pos += 1
                
    if s.item() == 1.0:
      s1_count += 1
    
  neg_count = total - pos_count
  s0_count = total - s1_count
  s0_pos = pos_count - s1_pos
  s1_neg = s1_count - s1_pos
  s0_neg = s0_count - s0_pos

  print(f'Total is {total}.')
  print(f'Number of positives is {pos_count} or {round(pos_count / total * 100, 3)}%.')
  print(f'Number of negatives is {neg_count} or {round(neg_count / total * 100, 3)}%.')
  print(f'Number of {TEXT1} faces is {s1_count} or {round(s1_count / total * 100, 3)}%.')
  print(f'Number of {TEXT0} faces is {s0_count} or {round(s0_count / total * 100, 3)}%.')
  print(f'Number of {TEXT1} positives is {s1_pos}: {round(s1_pos / s1_count * 100, 3)}% of {TEXT1} faces or {round(s1_pos / total * 100, 3)}% of total.')
  print(f'Number of {TEXT0} positives is {s0_pos}: {round(s0_pos / s0_count * 100, 3)}% of {TEXT0} faces or {round(s1_neg / total * 100, 3)}% of total.')
  print(f'Number of {TEXT1} negatives is {s1_neg}: {round(s1_neg / s1_count * 100, 3)}% of {TEXT1} faces or {round(s0_pos / total * 100, 3)}% of total.')
  print(f'Number of {TEXT0} negatives is {s0_neg}: {round(s0_neg / s0_count * 100, 3)}% of {TEXT0} faces or {round(s0_neg / total * 100, 3)}% of total.')

In [None]:
def plot(l, title):
  df = pd.DataFrame(l, columns=[title])
  df.plot(ylim = (0), figsize = (10, 5), alpha = 0.1, marker = '.', grid = True, yticks = (0, 0.25, 0.5, 0.75, 1.0))

In [None]:
def plot_gan(T):
  plot(T.d_loss_list, 'Discriminator Loss')
  plot(T.g_loss_list, 'Generator Loss')
  plot(T.acc_list, 'Discriminator Accuracy')
  plot(T.acc_list_real, 'Discriminator Accuracy - Real')
  plot(T.acc_list_fake, 'Discriminator Accuracy - Fake')

In [None]:
def plot_classifier(CT):
  plot(CT.loss_list, 'Classifier Loss')
  plot(CT.train_acc, 'Train Set Accuracy')
  plot(CT.val_acc, 'Validation Set Accuracy')

In [None]:
def generate_images(G):
  f, axarr = plt.subplots(4, 4, figsize = (10, 5))
  for i in range(4):
    t = G.forward(torch.randn(GAN_BATCH_SIZE, NOISE, 1, 1).to(device))
    t = (t * 0.5) + 0.5 # undo normalisation
    for j in range(4):
      img = transforms.ToPILImage()(t[j])
      axarr[i, j].imshow(img)
      axarr[i, j].axis('off')

# Helper Functions: Data Processing

In [None]:
# based on https://www.youtube.com/watch?v=ZoZHd0Zm3RY

class MyDataset(Dataset):
  def __init__(self, csv, root, size):
    self.f = pd.read_csv(csv, header=None)
    self.root = root
    self.size = size
    
  def __len__(self):
    return len(self.f)
    
  def __getitem__(self, index):
    # generate correct file path
    img_path = os.path.join(self.root, self.f.iloc[index, 0])
        
    # open image
    img = Image.open(img_path)

    # resize image
    img = img.resize((self.size, self.size), 0)
        
    # ToTensor divides all values by 255
    # normalizing transforms all values from [0, 1] to [-1, 1] for tanh
    edit = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize([0.5], [0.5])])
        
    X = edit(img)

    y = torch.zeros(1)
    if self.f.iloc[index, 1] >= BOUNDARY:
      y = torch.ones(1)

    s = torch.zeros(1)
    if self.f.iloc[index, 0][1] == 'F':
      s = torch.ones(1)
            
    return X, y, s

In [None]:
class RealDataset(Dataset):
  def __init__(self, X, y, s):
    self.X = X
    self.y = y
    self.s = s
    
  def __len__(self):
    return len(self.X)
    
  def __getitem__(self, index):
    return self.X[index], self.y[index], self.s[index]

In [None]:
class FairDataset(Dataset):
  def __init__(self, X, y):
    self.X = X
    self.y = y
    
  def __len__(self):
    return len(self.X)
    
  def __getitem__(self, index):
    return self.X[index], self.y[index]

In [None]:
def generate_dataset(G, D, text, index):
    
  i = 0
  
  G.eval()
  D.eval()

  with torch.no_grad():

    while i < TRAIN:

      X = G.forward(torch.randn(GAN_BATCH_SIZE, NOISE, 1, 1).to(device))

      _, y = D.forward(X.to(device).detach())

      for j in range(GAN_BATCH_SIZE):

        if i < TRAIN:

          img = X[j].unsqueeze(0)
          i += 1
          if y[j].item() >= 0.5:
            z = 1
          else:
            z = 0
          save_data(img, z, -1, text, index, i)
            
  G.train()
  D.train()

In [None]:
def save_data(X, y, s, text, index, i):
  root = f'{URL}{DATASET}/{text}/{index}/'

  # write X tensor to file
  np.save(root + 'X' + str(i), X.detach().cpu().numpy())
  # write y float to file

  f = open(root + 'y.txt', 'a')
  f.write(str(int(y)))
  f.close()
  
  if s != -1:
    # write s float to file
    f = open(root + 's.txt', 'a')
    f.write(str(int(s)))
    f.close()

In [None]:
def save_split(loader, text, index, fair):
  count = 0
  if fair == True:
    for batch in loader:
      X, y = batch
      for i in range(len(X)):
        count += 1
        save_data(X[i], y[i], -1, text, index, count)
  else:
    for batch in loader:
      X, y, s = batch
      for i in range(len(X)):
        count += 1
        save_data(X[i], y[i], -1, text, index, count)

In [None]:
def save_splits(text, train_loader, validation_loader, test_loader):
  save_split(train_loader, text, 'Train')
  save_split(validation_loader, text, 'Validation')
  if test_loader != -1:
    save_split(test_loader, text, 'Test')

In [None]:
def read_split(text, batch_size, kind, load, fair):

  if kind == 'Train':
    n = int((7 * NUM) / 11)
    shuff = True
  elif kind == 'Validation' or kind == 'Test':
    n = int((2 * NUM) / 11)
    shuff = False
    
  if fair == True:
    X, y = read_tensors(text, kind, n, True)
    my_data = FairDataset(X, y)
  else:
    X, y, s = read_tensors(text, kind, n, False)
    my_data = RealDataset(X, y, s)

  if load == True:
    my_loader = DataLoader(my_data, batch_size, shuff)
    return my_loader
  else:
    return my_data

In [None]:
# read X tensor from file
def tensor_helper(doc, n):
  X = torch.zeros(n, 3, DIM, DIM).to(device)
  for i in range(n):
    X[i] = torch.from_numpy(np.load(doc + f'X{i + 1}.npy'))
  return X

In [None]:
# read y and s tensors from file
def float_helper(doc, n):
  l = []
  f = open(doc)
  for i in range(n):
    c = f.read(1)
    l.append(int(c))
  f.close()
  return torch.FloatTensor(l).to(device)

In [None]:
# read in X, y and s tensors using filepath
def read_tensors(text, index, n, fair):

  f1 = f'{URL}{DATASET}/{text}/{index}/'
  f2 = f'{URL}{DATASET}/{text}/{index}/y.txt'

  X = tensor_helper(f1, n)
  y = float_helper(f2, n)
    
  if fair == False:
    f3 = f'{URL}{DATASET}/{text}/{index}/s.txt'
    s = float_helper(f3, n)
    return X, y, s

  return X, y

In [None]:
def load_real_data(dataset, batch_size, a, b, c):
  n = len(dataset)
  train_data, validation_data, test_data = torch.utils.data.random_split(dataset, [int(n * a), int(n * b), int(n * c)])

  train_loader = DataLoader(train_data, batch_size, shuffle=True)
  validation_loader = DataLoader(validation_data, batch_size, shuffle=False)
  test_loader = DataLoader(test_data, batch_size, shuffle=False)

  return train_loader, validation_loader, test_loader

# Helper Functions: GAN

In [None]:
def threshold(t, index):
  if t[index].item() >= 0.5:
    return 1.0
  else:
    return 0.0

In [None]:
def initialise_weights(model):

  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
      nn.init.normal_(m.weight.data, 1.0, 0.02)

In [None]:
def random_classes():
  return torch.randint(2, (GAN_BATCH_SIZE,)).float()

In [None]:
def save_gan(G, D, text, j, epoch): 
    
  f = f'{URL}{DATASET}/{text}/Params/Gen_{text}{j}_{epoch}.pt'

  torch.save(G.state_dict(), f)

  f = f'{URL}{DATASET}/{text}/Params/Disc_{text}{j}_{epoch}.pt'

  torch.save(D.state_dict(), f)

In [None]:
def load_gan(text, j, epoch):

  G = Generator(NOISE, NUM_FEATURES)
  D = Discriminator(NUM_FEATURES)
  T = Tracker()
  D.to(device)
  G.to(device)
  
  T.set_epoch(epoch + 1)
  T.change_j(j)

  f = f'{URL}{DATASET}/{text}/Params/Gen_{text}{j}_{epoch}.pt'
  
  if torch.cuda.is_available():
    G.load_state_dict(torch.load(f))
  else:
    G.load_state_dict(torch.load(f, map_location=torch.device('cpu')))

  f = f'{URL}{DATASET}/{text}/Params/Disc_{text}{j}_{epoch}.pt'

  if torch.cuda.is_available():
    D.load_state_dict(torch.load(f))
  else:
    D.load_state_dict(torch.load(f, map_location=torch.device('cpu')))
      
  generate_images(G)

  return G, D, T

# Helper Functions: Classifier

In [None]:
def save_classifier(C, text, j, i): 
  f = f'{URL}{DATASET}/Classifier/Classifier_{text}{j}_{i}.pt'
  torch.save(C.state_dict(), f)

In [None]:
def load_classifier(text, j, i):
  C = Classifier(NUM_FEATURES)
  f = f'{URL}{DATASET}/Classifier/Classifier_{text}{j}_{i}.pt'

  if torch.cuda.is_available():
      C.load_state_dict(torch.load(f))
  else:
      C.load_state_dict(torch.load(f, map_location=torch.device('cpu')))

  C.to(device)

  return C

In [None]:
def print_results(d, verbose):

  if verbose:
    print(f'True positive count for {TEXT1}: {d["s1_tp"]}')
    print(f'False positive count for {TEXT1}: {d["s1_fp"]}')
    print(f'True negative count for {TEXT1}: {d["s1_tn"]}')
    print(f'False negative count for {TEXT1}: {d["s1_fn"]} \n')
    print(f'True positive count for {TEXT0}: {d["s0_tp"]}')
    print(f'False positive count for {TEXT0}: {d["s0_fp"]}')
    print(f'True negative count for {TEXT0}: {d["s0_tn"]}')
    print(f'False negative count for {TEXT0}: {d["s0_fn"]} \n')
  else:
    print(d["s1_tp"])
    print(d["s1_fp"])
    print(d["s1_tn"])
    print(d["s1_fn"])
    print(d["s0_tp"])
    print(d["s0_fp"])
    print(d["s0_tn"])
    print(d["s0_fn"])        

In [None]:
def get_accuracy(d, verbose):
  s1_acc = (d['s1_tp'] + d['s1_tn']) / (d['s1_tp'] + d['s1_fp'] + d['s1_tn'] + d['s1_fn'])
  s0_acc = (d['s0_tp'] + d['s0_tn']) / (d['s0_tp'] + d['s0_fp'] + d['s0_tn'] + d['s0_fn'])

  acc = (d['s1_tp'] + d['s1_tn'] + d['s0_tp'] + d['s0_tn']) / (d['s1_tp'] + d['s1_fp'] + d['s1_tn'] + d['s1_fn'] + d['s0_tp'] + d['s0_fp'] + d['s0_tn'] + d['s0_fn'])

  if verbose:
    print(f'Classification accuracy for {TEXT1} subgroup is {round(s1_acc, 3)}')
    print(f'Classification accuracy for {TEXT0} subgroup is {round(s0_acc, 3)}')
    print(f'Difference in classification accuracy is {round(s1_acc - s0_acc, 3)}')
    print(f'Overall classification accuracy is {round(acc, 3)} \n')
  else:
    print(round(s1_acc, 3))
    print(round(s0_acc, 3))
    print(round(s1_acc - s0_acc, 3))
    print(round(acc, 3))

  return acc

In [None]:
def get_dp(d, verbose):
  # conditional classification accuracy
  dp_s1 = (d['s1_tp'] + d['s1_fp']) / (d['s1_tp'] + d['s1_fp'] + d['s1_tn'] + d['s1_fn'])
  dp_s0 = (d['s0_tp'] + d['s0_fp']) / (d['s0_tp'] + d['s0_fp'] + d['s0_tn'] + d['s0_fn'])

  dp = dp_s1 - dp_s0
  
  if verbose:
    print(f'Conditional classification accuracy for {TEXT1} subgroup is {round(dp_s1, 3)}')
    print(f'Conditional classification accuracy for {TEXT0} subgroup is {round(dp_s0, 3)}')
    print(f'Demographic parity difference is {round(dp, 3)} \n')
  else:
    print(round(dp_s1, 3))
    print(round(dp_s0, 3))
    print(round(dp, 3))

In [None]:
def get_eo(d, verbose):
  # false positive rates
  eo_s1 = d['s1_tp'] / (d['s1_tp'] + d['s1_fn'])
  eo_s0 = d['s0_tp'] / (d['s0_tp'] + d['s0_fn'])

  eo = eo_s1 - eo_s0

  if verbose:
    print(f'True positive rate for {TEXT1} subgroup is {round(eo_s1, 3)}')
    print(f'True positive rate for {TEXT0} subgroup is {round(eo_s0, 3)}')
    print(f'Equality of opportunity difference is {round(eo, 3)} \n')
  else:
    print(round(eo_s1, 3))
    print(round(eo_s0, 3))
    print(round(eo, 3))

# Discriminator, Generator, GAN Training Loop

In [None]:
class Discriminator(nn.Module):
  def __init__(self, disc_features):
    super().__init__()

    # build model
    self.model = nn.Sequential(
      nn.Conv2d(3, disc_features, kernel_size=4, stride=2, padding=1),
      nn.LeakyReLU(0.2),
      
      self.block(disc_features, disc_features * 2, 4, 2, 1),
      self.block(disc_features * 2, disc_features * 4, 4, 2, 1),
      self.block(disc_features * 4, disc_features * 8, 4, 2, 1))

    self.model_source = nn.Sequential(
      nn.Linear(GAN_BATCH_SIZE * disc_features * 128, GAN_BATCH_SIZE),
      nn.Sigmoid())
    
    self.model_outcome = nn.Sequential(
      nn.Linear(GAN_BATCH_SIZE * disc_features * 128, GAN_BATCH_SIZE),
      nn.Sigmoid())
  
  def block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.LeakyReLU(0.2),
      nn.Dropout(0.5))

  # run model
  def forward(self, inputs):
    t = self.model(inputs).flatten()
    return self.model_source(t), self.model_outcome(t)

In [None]:
class Generator(nn.Module):
  def __init__(self, noise_size, gen_features):
    super().__init__()

    # build model
    self.model = nn.Sequential(
      self.block(noise_size, gen_features * 16, 4, 1, 0),
      self.block(gen_features * 16, gen_features * 8, 4, 2, 1),
      self.block(gen_features * 8, gen_features * 4, 4, 2, 1),
      self.block(gen_features * 4, gen_features * 2, 4, 2, 1),
      nn.ConvTranspose2d(gen_features * 2, 3, 4, 2, 1),
      nn.Tanh())
    
  def block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
      nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU())

  # run model
  def forward(self, inputs):
    return self.model(inputs)

In [None]:
def train_gan(G, D, T, train_loader, epochs, text, save):

  loss_function = nn.BCELoss()

  optim_D = torch.optim.Adam(D.parameters(), lr = D_LEARNING_RATE, betas=(0.5, 0.999))
  optim_G = torch.optim.Adam(G.parameters(), lr = G_LEARNING_RATE, betas=(0.5, 0.999))

  end_epoch = T.epoch + epochs - 1

  for i in range(epochs):

    if URL == '':
      loop = train_loader
      print(f'Epoch is {T.epoch}')
    else:
      loop = tqdm(train_loader)
      loop.set_description(f'Epoch {T.epoch}/{end_epoch}')
    
    for batch in loop:

      X, y, s = batch
      y = y.view(GAN_BATCH_SIZE).to(device)
        
      # discard non-full batches
      if len(X) == GAN_BATCH_SIZE:

          # train D on a batch of reals
          DR_source, DR_outcome = D.forward(X.to(device))
          T.track('Real', DR_source)
          DR_loss_source = loss_function(DR_source, torch.ones_like(DR_source))       
          DR_loss_outcome = loss_function(DR_outcome, y)
        
          # train D on a batch of fakes
          gen_image = G.forward(torch.randn(GAN_BATCH_SIZE, NOISE, 1, 1).to(device))
          targets_outcome = random_classes().to(device)

          DF_source, DF_outcome = D.forward(gen_image.detach())
          T.track('Fake', DF_source)
          DF_loss_source = loss_function(DF_source, torch.zeros_like(DF_source))          
          DF_loss_outcome = loss_function(DF_outcome, targets_outcome)
                                         
          # get total loss for D
          loss_DR = (DR_loss_source + DR_loss_outcome) / 2
          loss_DF = (DF_loss_source + DF_loss_outcome) / 2
          loss_D = (loss_DR + loss_DF) / 2
          T.add_d_loss(loss_D.item())

          D.zero_grad()
          loss_D.backward()
          optim_D.step()

          # train G using D loss
          # don't use detach here so gradient flows
          gen_image = G.forward(torch.randn(GAN_BATCH_SIZE, NOISE, 1, 1).to(device))
          targets_outcome = random_classes().to(device)
        
          G_source, G_outcome = D.forward(gen_image)
          G_loss_source = loss_function(G_source, torch.ones_like(G_source))
            
          if text == 'EO':
            G_outcome, targets_outcome = eo_loss(G_outcome, targets_outcome)
          elif text == 'CF':
            G_outcome, targets_outcome = cf_loss(G_outcome, targets_outcome)

          G_loss_outcome = loss_function(G_outcome, targets_outcome)

          loss_G = (G_loss_source + G_loss_outcome) / 2
          T.add_g_loss(loss_G.item())

          G.zero_grad()
          loss_G.backward()
          optim_G.step()
            
    if save != 0 and T.epoch % save == 0:
      save_gan(G, D, text, T.j, T.epoch)

    T.add_epoch()

In [None]:
def cf_loss(outputs, targets):

    l = []
    
    for i in range(GAN_BATCH_SIZE):
        if targets[i] == 0.0:
            l.append(i)
        else:
            r = random.random()
            if r >= ALPHA:
                l.append(i)
            
    index = torch.tensor(l, dtype=torch.int).to(device)
    
    fair_outputs = outputs.index_select(0, index)
    fair_targets = targets.index_select(0, index)

    return fair_outputs, fair_targets

In [None]:
def eo_loss(outputs, targets):

    l = []
    
    for i in range(GAN_BATCH_SIZE):
        if targets[i] == 0.0:
            l.append(i)
            
    index = torch.tensor(l, dtype=torch.int).to(device)
    
    fair_outputs = outputs.index_select(0, index)
    fair_targets = targets.index_select(0, index)

    return fair_outputs, fair_targets

# Classifier, Train and Test Loops

In [None]:
class Classifier(nn.Module):
  def __init__(self, c_features):
    super(Classifier, self).__init__()

    # build model
    self.model = nn.Sequential(
      nn.Conv2d(3, c_features, kernel_size=4, stride=2, padding=1),
      nn.LeakyReLU(0.2),
      
      self.block(c_features, c_features * 2, 4, 2, 1),
      self.block(c_features * 2, c_features * 4, 4, 2, 1),
      self.block(c_features * 4, c_features * 8, 4, 2, 1))
    
    self.model2 = nn.Sequential(
      nn.Linear(CLASSIFIER_BATCH_SIZE * c_features * 128, CLASSIFIER_BATCH_SIZE),
      nn.Sigmoid())
  
  def block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.LeakyReLU(0.2),
      nn.Dropout(0.5)
      )
      
  def forward(self, inputs):
    t = self.model(inputs)
    return self.model2(t.flatten())

In [None]:
def train_classifier(C, CT, train_loader, validation_loader, text, fair):
  epoch = 0
  count = 0
  best_epoch = 0
  best_acc = 0

  loss_function = nn.BCELoss()
  optimiser = torch.optim.Adam(C.parameters(), lr=CLASSIFIER_LEARNING_RATE)
  
  C.train()

  while count < 5:

    epoch += 1

    if URL == '':
      loop = train_loader
    else:
      loop = tqdm(train_loader)
      loop.set_description(f'Epoch {epoch}')

    for batch in loop:

      if fair == True:
        X, y = batch
      else:
        X, y, _ = batch

      # discard non-full batches
      if len(X) == CLASSIFIER_BATCH_SIZE:

        # run model
        outputs = C.forward(X.detach().to(device))

        # calculate loss
        loss = loss_function(outputs, y.view(CLASSIFIER_BATCH_SIZE).detach().to(device))
        CT.add_loss(loss.item())

        # zero gradients, backward pass, update weights
        C.zero_grad()
        loss.backward()
        optimiser.step()
  
    if epoch % 3 == 0:

      train_acc = test_acc(C, train_loader, TRAIN, fair)
      CT.add_train_acc(train_acc)

      val_acc = test_acc(C, validation_loader, VAL, False)
      CT.add_val_acc(val_acc)
      
      if val_acc > best_acc:
        count = 0
        best_acc = val_acc
        best_epoch = epoch
        save_classifier(C, text, CT.j, best_epoch)
      else:
        count += 1

  return best_epoch

In [None]:
def test_acc(C, loader, n, fair):

  C.eval()
  pos = 0

  with torch.no_grad():
    for batch in loader:
      if fair == True:
        X, y = batch
      else:
        X, y, _ = batch
      k = CLASSIFIER_BATCH_SIZE

      # pad non-full batches
      if len(X) != CLASSIFIER_BATCH_SIZE:
        k = len(X)
        X_mod = torch.ones(CLASSIFIER_BATCH_SIZE, 3, DIM, DIM)
        X_mod[:k] = X
      else:
        X_mod = X
  
      t = C.forward(X_mod.to(device))

      for i in range(k):
        if y[i].item() == 1.0 and t[i].item() >= 0.5:
          pos += 1
        elif y[i].item() == 0.0 and t[i].item() < 0.5:
          pos += 1

  C.train()

  return pos / n

In [None]:
def test_classifier(C, loader):

  C.eval()
  
  d = {
    's1_tp': 0,
    's1_fp': 0,
    's1_tn': 0,
    's1_fn': 0,
    's0_tp': 0,
    's0_fp': 0,
    's0_tn': 0,
    's0_fn': 0
  }

  with torch.no_grad():
    for batch in loader:
      X, y, s = batch
      k = CLASSIFIER_BATCH_SIZE

      # pad non-full batches
      if len(X) != CLASSIFIER_BATCH_SIZE:
        k = len(X)
        X_mod = torch.ones(CLASSIFIER_BATCH_SIZE, 3, DIM, DIM)
        X_mod[:k] = X
      else:
        X_mod = X
  
      t = C.forward(X_mod.to(device))

      for i in range(k):
        if s[i].item() == 1.0:
          if y[i].item() == 1.0:
            if t[i].item() >= 0.5:
              d['s1_tp'] += 1
            else:
              d['s1_fn'] += 1
          else:
            if t[i].item() >= 0.5:
              d['s1_fp'] += 1
            else:
              d['s1_tn'] += 1
        else:
          if y[i].item() == 1.0:
            if t[i].item() >= 0.5:
              d['s0_tp'] += 1
            else:
              d['s0_fn'] += 1
          else:
            if t[i].item() >= 0.5:
              d['s0_fp'] += 1
            else:
              d['s0_tn'] += 1

  C.train()

  return d

# Train Classifier on Real Train Set

In [None]:
f1 = f'{URL}{DATASET}/Images/SCUTData.csv'
f2 = f'{URL}{DATASET}/Images/'

SCUTData = MyDataset(f1, f2, DIM)

In [None]:
check_skew(SCUTData)

In [None]:
train_loader, validation_loader, test_loader = load_real_data(SCUTData, CLASSIFIER_BATCH_SIZE, (7/11), (2/11), (2/11))

In [None]:
save_splits(f'Real/Splits/{SPLIT}', train_loader, validation_loader, test_loader)

In [None]:
train_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Train', load=True, fair=False)
validation_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Validation', load=True, fair=False)
test_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Test', load=True, fair=False)

In [None]:
for i in range(10):
    C = Classifier(NUM_FEATURES)
    CT = C_Tracker()

    C.to(device)

    best_epoch = train_classifier(C, CT, train_loader, validation_loader, 'Real', fair=False)

    C = load_classifier('Real', CT.j, best_epoch)

    print(CT.j)
    print(best_epoch)
    d = test_classifier(C, test_loader)
    print_results(d, False)
    get_accuracy(d, False)
    get_dp(d, False)
    get_eo(d, False)
    print('---')

In [None]:
plot_classifier(CT)

# Train Classifier on Generated Train Set (Demographic Parity)

**Train Demographic Parity GAN**

In [None]:
D_DP = Discriminator(NUM_FEATURES)
G_DP = Generator(NOISE, NUM_FEATURES)
T_DP = Tracker()

G_DP.to(device)
D_DP.to(device)

initialise_weights(D_DP)
initialise_weights(G_DP)

In [None]:
train_loader = read_split(f'Real/Splits/{SPLIT}', GAN_BATCH_SIZE, 'Train', load=True, fair=False)

In [None]:
%%time
train_gan(G_DP, D_DP, T_DP, train_loader, 400, 'DP', 100)

In [None]:
plot_gan(G_DP)

In [None]:
generate_images(G_DP)

**Load in GAN checkpoints**

In [None]:
G_DP, D_DP, T_DP = load_gan('DP', 999, 400)

**Generate data from trained GAN**

In [None]:
generate_dataset(G_DP, D_DP, 'DP', f'Images/{SPLIT}')

In [None]:
X_DP, y_DP = read_tensors('DP', f'Images/{SPLIT}', TRAIN, fair=True)
dp_data = FairDataset(X_DP, y_DP)

**Save and load splits**

In [None]:
dp_train_loader = DataLoader(dp_data, CLASSIFIER_BATCH_SIZE, shuffle=True)
validation_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Validation', load=True, fair=False)
test_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Test', load=True, fair=False)

**Test classifier using generated dataset**

In [None]:
for i in range(10):
    C_DP = Classifier(NUM_FEATURES)
    CT_DP = C_Tracker()

    C_DP.to(device)

    best_epoch = train_classifier(C_DP, CT_DP, dp_train_loader, validation_loader, 'DP', fair=True)

    C_DP = load_classifier('DP', CT_DP.j, best_epoch)

    print(CT_DP.j)
    print(best_epoch)
    d = test_classifier(C_DP, test_loader)
    print_results(d, False)
    get_accuracy(d, False)
    get_dp(d, False)
    get_eo(d, False)
    print('---')

In [None]:
plot_classifier(CT_DP)

# Train Classifier on Generated Train Set (Equality of Opportunity)

**Train Equality of Opportunity GAN**

In [None]:
D_EO = Discriminator(NUM_FEATURES)
G_EO = Generator(NOISE, NUM_FEATURES)
T_EO = Tracker()

G_EO.to(device)
D_EO.to(device)

initialise_weights(D_EO)
initialise_weights(G_EO)

In [None]:
train_loader = read_split(f'Real/Splits/{SPLIT}', GAN_BATCH_SIZE, 'Train', load=True, fair=False)

In [None]:
%%time
train_gan(G_EO, D_EO, T_EO, train_loader, 400, 'EO', 100)

In [None]:
plot_gan(T_EO)

In [None]:
generate_images(G_EO)

**Load in GAN checkpoints**

In [None]:
G_EO, D_EO, T_EO = load_gan('EO', 999, 400)

**Generate data from trained GAN**

In [None]:
generate_dataset(G_EO, D_EO, 'EO', f'Images/{SPLIT}')

In [None]:
X_EO, y_EO = read_tensors('EO', f'Images/{SPLIT}', TRAIN, fair=True)
eo_data = FairDataset(X_EO, y_EO)

**Save and load splits**

In [None]:
eo_train_loader = DataLoader(eo_data, CLASSIFIER_BATCH_SIZE, shuffle=True)
validation_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Validation', load=True, fair=False)
test_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Test', load=True, fair=False)

**Test classifier using generated dataset**

In [None]:
for i in range(10):
    C_EO = Classifier(NUM_FEATURES)
    CT_EO = C_Tracker()

    C_EO.to(device)

    best_epoch = train_classifier(C_EO, CT_EO, eo_train_loader, validation_loader, 'EO', True)

    C_DP = load_classifier('EO', CT_EO.j, best_epoch)

    print(CT_EO.j)
    print(best_epoch)
    d = test_classifier(C_EO, test_loader)
    print_results(d, False)
    get_accuracy(d, False)
    get_dp(d, False)
    get_eo(d, False)
    print('---')

In [None]:
plot_classifier(CT_EO)

# Train Classifier on Generated Train Set (Combined Fairness)

**Train Combined Fairness (DP and EO) GAN**

In [None]:
D_CF = Discriminator(NUM_FEATURES)
G_CF = Generator(NOISE, NUM_FEATURES)
T_CF = Tracker()

G_CF.to(device)
D_CF.to(device)

initialise_weights(D_CF)
initialise_weights(G_CF)

In [None]:
train_loader = read_split(f'Real/Splits/{SPLIT}', GAN_BATCH_SIZE, 'Train', load=True, fair=False)

In [None]:
%%time
ALPHA = 0.25
train_gan(G_CF, D_CF, T_CF, train_loader, 80, 'CF', 20)

In [None]:
plot_gan(T_CF)

In [None]:
generate_images(G_CF)

**Load in GAN checkpoints**

In [None]:
G_CF, D_CF, T_CF = load_gan('CF', 999, 400)

**Generate data from trained GAN**

In [None]:
generate_dataset(G_CF, D_CF, 'CF', f'Images/{SPLIT}')

In [None]:
X_CF, y_CF = read_tensors('CF', f'Images/{SPLIT}', TRAIN, fair=True)
cf_data = FairDataset(X_CF, y_CF)

**Save and load splits**

In [None]:
cf_train_loader = DataLoader(cf_data, CLASSIFIER_BATCH_SIZE, shuffle=True)
validation_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Validation', load=True, fair=False)
test_loader = read_split(f'Real/Splits/{SPLIT}', CLASSIFIER_BATCH_SIZE, 'Test', load=True, fair=False)

**Test classifier using generated dataset**

In [None]:
for i in range(10):
    C_CF = Classifier(NUM_FEATURES)
    CT_CF = C_Tracker()

    C_CF.to(device)

    best_epoch = train_classifier(C_CF, CT_CF, cf_train_loader, validation_loader, 'CF', fair=True)

    C_CF = load_classifier('CF', CT_CF.j, best_epoch)

    print(CT_CF.j)
    print(best_epoch)
    d = test_classifier(C_CF, test_loader)
    print_results(d, False)
    get_accuracy(d, False)
    get_dp(d, False)
    get_eo(d, False)
    print('---')

In [None]:
plot_classifier(CT_CF)