# Semi-supervised learning

#### This notebook aims at clustering the dataset using a classifier that is trained on a small portion of the data. It  uses a Residual Network as a classifier. After being trained, it is used to predict the sensitive attribute of the trasining  data and allocates each sample to its corresponding cluster based on its prediction and the sample's class. This notebook can be run sequencially, cell by cell. 

# Imports

In [None]:
from __future__ import print_function, division

import sys  
import time
import os
import copy
import shutil
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from torch.utils.data import TensorDataset
import torchvision

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from math import ceil
from PIL import Image

import visdom
from IPython.display import clear_output
from PIL import Image
import nltk
from nltk.cluster.kmeans import KMeansClusterer

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity, pairwise_distances

sys.path.insert(0, '../Resnet/')
%load_ext autoreload
%autoreload 2
from model import *
from my_ImageFolder import *
from fairness_metrics import *

plt.ion()   # interactive mode

# Defining the inputs

In [None]:
W_PROTECTED, BIAS, VAL_MODE, START_EPOCH, NUM_EPOCH, SHOW_PROGRESS, ID, DATASET, NUM_TRIALS, BIAS, PROTECTED = 0, 0.8, False, 0, 10, False, 0, "basket_volley", 1, 0.8, "jc"

# Importing the dataset

In [None]:
path_bask_r_f = '../Datasets/basket_volley/basket/basket_f_r/'
path_bask_y_f = '../Datasets/basket_volley/basket/basket_f_y/'
path_bask_r_m = '../Datasets/basket_volley/basket/basket_m_r/'
path_bask_y_m = '../Datasets/basket_volley/basket/basket_m_y/'

bask_r_f = os.listdir(path_bask_r_f)
bask_y_f = os.listdir(path_bask_y_f)
bask_r_m = os.listdir(path_bask_r_m)
bask_y_m = os.listdir(path_bask_y_m)

path_voll_r_f = '../Datasets/basket_volley/volley/volley_f_r/'
path_voll_y_f = '../Datasets/basket_volley/volley/volley_f_y/'
path_voll_r_m = '../Datasets/basket_volley/volley/volley_m_r/'
path_voll_y_m = '../Datasets/basket_volley/volley/volley_m_y/'

voll_r_f = os.listdir(path_voll_r_f)
voll_y_f = os.listdir(path_voll_y_f)
voll_r_m = os.listdir(path_voll_r_m)
voll_y_m = os.listdir(path_voll_y_m)

class0_min, class1_min = (bask_y_m + bask_y_f, voll_r_m + voll_r_f) if PROTECTED == "jc" else (bask_y_f + bask_r_f, voll_r_m + voll_y_m)
protected_groups = set(class0_min + class1_min)
females = set(bask_r_f + bask_y_f + voll_r_f + voll_y_f)
yellow = set(bask_y_f + bask_y_m + voll_y_f + voll_y_m)
volley = set(voll_r_f + voll_r_m + voll_y_f + voll_y_m)

## Creating Dataset

In [None]:
class my_GenderFolder(torchvision.datasets.ImageFolder):
    """
    This class redefines the ImageFolder class as the weight of each image is returned along with the data and label
    """

    def __init__(self, root, females, embeddings, indexes):
        super().__init__(root)
        self.females = females
        self.embeddings = embeddings
        self.indexes = indexes

    def __getitem__(self, index: int):
        _, label = super().__getitem__(index)
        return self.embeddings[label][np.where(self.indexes[label] == index)[0][0]], int(self.samples[index][0].split("/")[-1] in self.females), 1, index
    
class my_ImageGenderFolder(torchvision.datasets.ImageFolder):
    """
    This class redefines the ImageFolder class as the weight of each image is returned along with the data and label
    """

    def __init__(self, root, transform, group):
        super().__init__(root, transform)
        self.group = group

    def __getitem__(self, index: int):
        inp, _ = super().__getitem__(index)
        return inp, int(self.samples[index][0].split("/")[-1] in self.group), 1, index
    
class my_subset(torch.utils.data.Dataset):
    """
        Subset of a dataset at specified indices.

        Arguments:
            dataset (Dataset): The whole Dataset
            indices (sequence): Indices in the whole set selected for subset
        labels(sequence) : targets as required for the indices. will be the same length as indices
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
        self.samples = list(map(dataset.samples.__getitem__, indices))

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.dataset[index]


    def __len__(self):
        return len(self.indices)

In [None]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train_all': transforms.Compose([
        # transforms.RandomResizedCrop(224),
        # transforms.RandomHorizontalFlip(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = f'../Datasets/basket_volley/train_test_split_{PROTECTED}'
image_datasets = {
    x: my_ImageGenderFolder(os.path.join(data_dir, f"train_{BIAS}" if x == "train_all" else x), data_transforms[x], females if PROTECTED == "gd" else yellow)
    for x in ['train_all', 'test']}

split = 0.15
indices = balance_indices(image_datasets["train_all"], split)

image_datasets["train"] = my_subset(image_datasets["train_all"], indices=indices)

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                              shuffle=True, num_workers=4)
               for x in ['train', "train_all", 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', "train_all", 'test']}

class_names = image_datasets['train_all'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
a,b,c,d = 0, 0, 0, 0
for _, l, _, i in image_datasets["train"]:
    v = int(image_datasets["train_all"].samples[i][0].split("/")[-1] in volley)
    if not l and not v:
        a += 1
    if not l and v:
        b += 1
    if l and not v:
        c += 1
    if l and v:
        d += 1
    
print(f"The proportions are: {a}:{b}:{c}:{d}")

# Defining Resnet network

In [None]:
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, len(class_names))

model_conv = model_conv.to(device)

criterion = weighted_cross_entropy_loss  # nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=5, gamma=0.2)

### Training

In [None]:
VAL_MODE, START_EPOCH, NUM_EPOCH, SHOW_PROGRESS = False, 0, 10, False

In [None]:
model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, dataloaders, dataset_sizes,
                            device,
                            start_epoch=START_EPOCH,
                            num_epochs=NUM_EPOCH,
                            val_mode=VAL_MODE, show_progress=SHOW_PROGRESS)

### Evaluation

In [None]:
print(f"Acc. on small raining set: {float(accuracy(model_conv, device, dataloaders['train']))}")
print(f"Acc. on all training set: {float(accuracy(model_conv, device, dataloaders['train_all']))}")
print(f"Acc. on Test set: {float(accuracy(model_conv, device, dataloaders['test']))}")

# Build Clustering

In [None]:
kmeans, indices = clustering_1(model_conv, dataloaders["train_all"], image_datasets["train_all"].samples)
kmeans = [kmeans[0].reshape((-1)), kmeans[1].reshape((-1))]
indices = [indices[0].reshape((-1)), indices[1].reshape((-1))]

In [None]:
cluster_paths_0 = view_clusters("class_0/", kmeans[0], indices[0])
cluster_paths_1 = view_clusters("class_1/", kmeans[1], indices[1])

### Basket

In [None]:
statistics("class_0/", cluster_paths_0)

### Volley

In [None]:
statistics("class_1/", cluster_paths_1)

### Saving clustering

In [None]:
dict = make_save_dict(image_datasets["train_all"].samples, [kmeans[0], kmeans[1]], [indices[0], indices[1]], save=True, name="resnet_jc.txt")

# Helper functions

In [None]:
def eval(model, dataloader):
    model.eval()
    corr, total = 0, 0
    for inputs, labels, _, _ in dataloader:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        corr += int(sum(preds == labels))
        total += len(preds)
    print("Accuracy is {}%".format(corr/total*100))
    
def clustering_1(model, dataloader, samples):
    kmeans = [np.array([]).astype(int).reshape((0,1)) , np.array([]).astype(int).reshape((0,1))]
    indexes = [np.array([]).astype(int).reshape((0,1)) , np.array([]).astype(int).reshape((0,1))]

    for inputs, labels, _, indices in dataloader:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1) 
        for j, l in enumerate(labels):
            index = int(samples[indices[j]][0].split("/")[-1] in volley)
            
            kmeans[index] = np.concatenate([kmeans[index], preds[j].numpy().reshape((1, -1))])
            indexes[index] = np.concatenate([indexes[index], indices[j].numpy().reshape((1, -1))])
            
    return kmeans, indexes

def clustering_2(model, transforms):
    kmeans = [[] for _ in range(len(transforms))]
    indexes = [np.array([]).astype(int), np.array([]).astype(int)]

    for label in range(len(transforms)):
        for emb in transforms[label]:
            emb = torch.tensor(emb).reshape((1, -1))
            outputs = model(emb)
            _, pred = torch.max(outputs, 1) 
            kmeans[label].append(int(pred))
    return kmeans

def view_clusters(path, kmeans, indexes):
    K = len(set(kmeans))
    
    paths = []
    for k in range(K):
        paths.append(os.path.join(path, f"clustering_{K}/cluster_{k}"))
        os.makedirs(paths[-1], exist_ok=True)
        
    for i in range(len(kmeans)):
        src = image_datasets["train_all"].samples[indexes[i]][0]
        dst = os.path.join(path, f"clustering_{K}/cluster_{kmeans[i]}/") + src.split("/")[-1]
        shutil.copy(src, dst)
    
    return paths

def make_save_dict(samples, k_means_list, indexes_list, save=False, name="dict.txt"):
    dic = {}
    for k_means, indexes in zip(k_means_list, indexes_list):
        for cluster, idx in zip(k_means, indexes):
            img = samples[idx][0].split("/")[-1]
            dic[img] = cluster
       
    if save:
        f = open(name, "a")
        f.write(str(dic))
        f.close()
    
    return dic

def balance_indices(dataset, p):
    c00, c01, c10, c11, indices = 0, 0, 0, 0, []
    thres, it = len(dataset)*p/4, iter(dataset)

    while c00 < thres or c01 < thres or c10 < thres or c11 < thres:
        _, l, _, i = next(it)
        v = int(dataset.samples[i][0].split("/")[-1] in volley)
        
        if l and v and c11 < thres:
            c11 += 1
            indices.append(i)
        elif l and not v and c01 < thres:
            c01 += 1
            indices.append(i)
        elif not l and v and c10 < thres:
            c10 += 1
            indices.append(i)
        elif not l and not v and c00 < thres:
            c00 += 1
            indices.append(i)
        
    return indices
            

def statistics(path, clusters):
    K = len(set(clusters))
    
    for k in range(K):
        n_bask, n_voll, n_r, n_y, n_m, n_f = 0, 0, 0, 0, 0, 0
        cluster = os.listdir(clusters[k])
        for img in cluster:
            if img in bask_r_f:
                n_bask += 1
                n_f += 1
                n_r += 1
                
            if img in bask_r_m:
                n_bask += 1
                n_m += 1
                n_r += 1
                
            if img in bask_y_f:
                n_bask += 1
                n_f += 1
                n_y += 1
            
            if img in bask_y_m:
                n_bask += 1
                n_m += 1
                n_y += 1
            
            if img in voll_r_f:
                n_voll += 1
                n_f += 1
                n_r += 1
            
            if img in voll_r_m:
                n_voll += 1
                n_m += 1
                n_r += 1
                
            if img in voll_y_f:
                n_voll += 1
                n_f += 1
                n_y += 1
            
            if img in voll_y_m:
                n_voll += 1
                n_m += 1
                n_y += 1
                
        
        print(f"--------------Cluster {k}--------- \n n. samples: {len(cluster)}\n n. of bask: {n_bask} ({n_bask/len(cluster)*100:.1f}%)\n n. of volley: {n_voll} ({n_voll/len(cluster)*100:.1f}%)\n n. of red: {n_r} ({n_r/len(cluster)*100:.1f}%)\n n. of yellow: {n_y} ({n_y/len(cluster)*100:.1f}%)\n n. of males: {n_m} ({n_m/len(cluster)*100:.1f}%)\n n. of females: {n_f} ({n_f/len(cluster)*100:.1f}%)")