# Training a neural network for multi-label classification on the Celeb-A dataset

In [None]:
!wget -nc -O 'archive.zip' 'https://storage.googleapis.com/kaggle-data-sets/29561/37705/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20221102%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20221102T043448Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=4a8303ca846eaae80b62ff3705ad6b86de91724f8d76b4373083c4c33225314b6c43af9bab79e32100af51aa02e230a1eaa81ae33aea94680fbb118fb74dd6bf4d2a9426809a43ce699da806f156781c98f5076845d8936c304657265ecaa323ba0a369aa7d436916a6a87ba8a5873865da1b30052c9f128f2005fa12e1526850a79fe96e74aa698bc4ce59b3822d159b22ece231b7246d7800b685c5e9974e7837dfb4d2d7c18147749443bdbe9ae44f5db3e045f25774df6d464c7f0762279fea7b201ce76ceb15d1301f4354112db105c108dd5564f3a165f9f2af1fb581f5fa6d5b471a194007290b03637a38105c09a3c6f2aebb26e3df2fccf49d349a6'
!unzip -qq archive.zip

In [None]:
# Package installation (hidden on docs website).
# If running on Colab, may want to use GPU (select: Runtime > Change runtime type > Hardware accelerator > GPU)
# Package versions we used: matplotlib==3.5.1, torch==1.11.0, torchvision==0.12.0, timm==0.5.4

dependencies = ["torch", "torchvision", "sklearn", "timm", "pillow","scikit-learn"]

if "google.colab" in str(get_ipython()):  # Check if it's running in Google Colab
    cmd = ' '.join([dep for dep in dependencies if dep != "cleanlab"])
    %pip install $cmd
else:
missing_dependencies = []
for dependency in dependencies:
    try:
        __import__(dependency)
    except ImportError:
        missing_dependencies.append(dependency)

if len(missing_dependencies) > 0:
    print("Missing required dependencies:")
    print(*missing_dependencies, sep=", ")
    print("\nPlease install them before running the rest of this notebook.")

In [None]:
from torch import nn
from timm.data.loader import create_loader
import torch.utils.data as data
import torch
import pandas as pd
from collections import defaultdict
from timm.utils import CheckpointSaver
from types import SimpleNamespace
from PIL import Image
import os
import time
import numpy as np
from sklearn.metrics import accuracy_score
from collections import Counter
from timm.optim import create_optimizer
from sklearn.model_selection import StratifiedKFold

from timm.data import (
    resolve_data_config,
)
from timm.models import create_model

In [None]:
dat = pd.read_csv("list_attr_celeba.csv")
# convert -1 -> 0
dat[dat.columns[1:]] = ((dat[dat.columns[1:]]+1)/2).astype(np.int32)

selected = ['image_id',
'Eyeglasses',
 'Wearing_Earrings',
 'Wearing_Hat',
 'Wearing_Necklace',
 'Wearing_Necktie',
 'No_Beard',
 'Smiling']

In [None]:
def is_label(row):
    for s in selected[1:]:
        if row[s]!=0:
            return True
    return False

def get_loc(i):
    return os.path.join(os.getcwd(),'img_align_celeba/img_align_celeba/')+i


In [None]:
dat_label = dat.apply(is_label,axis=1)
dat_selected = dat[dat_label][selected]
dat_selected['image_path'] = dat_selected['image_id'].map(lambda x:get_loc(x))
selected[0] = 'image_path'

In [None]:
df = dat_selected[selected]

set_lab = {}
for i,row in df.iterrows():
    q = str(row.tolist()[1:])
    if q not in set_lab:
        set_lab[(str(q))]=len(set_lab)

def get_lab(row):
    q = str(row.tolist()[1:])
    return set_lab[q]

df['unique_label'] = df.apply(get_lab,axis=1)
cnt = Counter(df['unique_label'])

#We drop unique counts < 10 to avoid errors in Stratified KFold
def drop(val):
    if cnt[val]>10:
        return True
    return False

is_drop = df['unique_label'].apply(lambda x:drop(x))



df = df[is_drop]


In [None]:
class MultiLabelModel(nn.Module):
    def __init__(self, model, n_classes, class_weights=None, verbose = False):
        super().__init__()
        self.base_model = model
        self.num_classes = n_classes
        self.verbose = verbose

    def forward(self, x):
        x = self.base_model(x)

        x = torch.flatten(x, 1)

        return x

    def get_loss(self, loss_fn, output, target):

        return loss_fn(output, target)

    def validate(self, loader):
        self.eval();
        with torch.no_grad():
            total_loss = 0
            m = nn.Sigmoid()
            labels = []
            preds = []
            for batch_idx, (input, target) in enumerate(loader):
                input = input.cuda()
                labels.append(target.detach().cpu())
                target = target.float().cuda()
                output = m(self(input))
                loss = self.get_loss(loss_fn, output, target)

                total_loss += loss.item()
                pred_model = (output > 0.5).detach().cpu()
                preds.append(pred_model)

            num_of_batches_per_epoch = len(loader)
            avg_loss = total_loss / num_of_batches_per_epoch
            print("VALIDATION DATA STATS")

            print("AVERAGE LOSS:", avg_loss)
            preds = torch.cat(preds).int()
            labels = torch.cat(labels).int()
            acc_score = accuracy_score(labels, preds)
            print("MULTILABEL accuracy score:", acc_score)
            per_class = []
            for i in range(len(preds.T)):
                per_class.append(accuracy_score(labels.T[i], preds.T[i]))
            print(dataset_train.label_names)
            print(per_class)
            print('\n\n')
        return avg_loss

    def predict_proba(self, loader):
        self.eval();
        with torch.no_grad():
            m = nn.Sigmoid()
            preds = []
            for batch_idx, (input, target) in enumerate(loader):
                input = input.cuda()
                output = m(self(input))
                pred_model = output.detach().cpu()
                preds.append(pred_model)
            preds = torch.cat(preds)
        return preds

    def train_one_epoch(
        self,
        loader,
        optimizer,
        loss_fn,
    ):
        sta = time.time()
        second_order = hasattr(optimizer, "is_second_order") and optimizer.is_second_order
        self.train()
        total_loss = 0
        m = nn.Sigmoid()
        labels = []
        preds = []
        ct = 0
        for batch_idx, (input, target) in enumerate(loader):
            input = input.cuda()
            ct += 1
            labels.append(target.detach().cpu())
            target = target.float().cuda()
            output = m(self(input))
            loss = self.get_loss(loss_fn, output, target)
            total_loss += loss.item()
            pred_model = (output > 0.5).detach().cpu()
            preds.append(pred_model)
            optimizer.zero_grad()
            loss.backward(create_graph=second_order)
            optimizer.step()
            if ct % 80 == 0 and self.verbose:
                print("LOSS:", loss.item())
        num_of_batches_per_epoch = len(loader)
        avg_loss = total_loss / num_of_batches_per_epoch
        print("TRAINING DATA STATS")
        print("AVERAGE LOSS:", avg_loss)
        preds = torch.cat(preds).int()
        labels = torch.cat(labels).int()
        acc_score = accuracy_score(labels, preds)
        print("MULTILABEL accuracy score:", acc_score)
        per_class = []
        for i in range(len(preds.T)):
            per_class.append(accuracy_score(labels.T[i], preds.T[i]))
        print(dataset_train.label_names)
        print(per_class)
        print('\n\n')
        sto = time.time()
        print("training time", sto - sta)
        return avg_loss
    

    def fit(self, loader_train, load_val, num_epochs=10):
        if os.path.exists("weights_model"):
            print("removing weights directory")
            os.system('rm -rf weights_model')
        os.mkdir("weights_model")
        args = SimpleNamespace()
        args.weight_decay = 0
        args.lr = 1e-4
        args.opt = 'adam'
        args.momentum = 0.9
        args.sched = "step"

        optimizer = create_optimizer(args, self)
        saver = CheckpointSaver(
            model=self,
            optimizer=optimizer,
            checkpoint_dir="weights_model"
        )
        errs = []
        num_of_data_train = len(loader_train.dataset.data)
        for epoch in range(0, num_epochs):
            loss_train = self.train_one_epoch(
                loader_train,
                optimizer,
                loss_fn,
            )
            loss_val = self.validate(loader_val)
            errs.append([loss_train, loss_val])
            saver.save_checkpoint(epoch, metric=loss_val)

In [None]:
class DatasetMultiLabel(data.Dataset):
    def __init__(
            self,
            annotation_path=None,
            df = None,
            transform=None):

        super().__init__()
        self.transform = transform
        self.data = []
        self.labels = []
        self.label_names = []
        if annotation_path is None:
            assert df is not None
        else:
            df = pd.read_csv(annotation_path)
        
        cols = df.columns
        self.label_names = list(cols[1:-1])
        for i,row in df.iterrows():
            lb = []
            for j in cols:
                if j=='unique_label':
                    continue
                if j=='image_path':
                    self.data.append(row[j])
                else:
                    lb.append(float(row[j]))
            self.labels.append(lb)
                
    def __getitem__(self, idx):
        img_path = self.data[idx]

        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)

        labels = self.labels[idx]

        return img, labels

    def __len__(self):
        return len(self.data)

In [None]:
dataset = DatasetMultiLabel(df = df)

loss_fn = nn.BCELoss()

model = create_model(
    'efficientnet_b0',
    num_classes=len(dataset.labels[0]),
)
data_config = resolve_data_config(
       args = {}, model=model
    )

model = MultiLabelModel(
        model,
        n_classes=len(dataset.labels[0]),
    ).cuda()


In [None]:
def reset_weights(m):
  '''
    Try resetting model weights to avoid
    weight leakage.
  '''
  for layer in m.children():
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

def create_df(pred_probs,dataset):
    """
    Creates a dataframe with image_loc and predicted probabilities
    """
    ls = dataset_val.label_names
    cl = defaultdict(list)
    cl['image_loc'] = dataset.data
    for i in range(0,len(ls)):
        cl[ls[i]] = pred_val.T[i].tolist()
    return pd.DataFrame.from_dict(cl)

In [None]:
num_splits = 4

skf = StratifiedKFold(n_splits=num_splits)

In [None]:
ct = 1
for train_index, test_index in skf.split(df,df['unique_label']):
    if ct!=1:
        model.apply(reset_weights);
    dataset_train = DatasetMultiLabel(df = df.iloc[train_index])
    dataset_val = DatasetMultiLabel(df = df.iloc[test_index])
    loader_train = create_loader(
        dataset_train,
        input_size=data_config["input_size"],
        batch_size=64,
        is_training=True,
        mean=data_config["mean"],
        std=data_config["std"],
       interpolation=data_config["interpolation"],
    )
    loader_val = create_loader(
        dataset_val,
        input_size=data_config["input_size"],
        batch_size=64,
        is_training=False,
        mean=data_config["mean"],
        std=data_config["std"],
        interpolation=data_config["interpolation"],

    )
    model.fit(loader_train,loader_val,num_epochs=40)
    checkpoint = torch.load("weights_model/model_best.pth.tar")
    model.load_state_dict(checkpoint['state_dict'])
    pred_val = model.predict_proba(loader_val)
    df_pred = create_df(pred_val,dataset_val)
    df_pred.to_csv(str(ct)+"_fold.csv",index=False)
    ct+=1
    


In [None]:
dfl = []
for i in range(1,num_splits+1):
    dfl.append(pd.read_csv(str(i)+"_fold.csv"))
    

cols = dfl[0].columns[1:]

df_pred = pd.concat(dfl,axis=0)

df_pred['image_loc'] = df_pred['image_loc'].map(lambda x:x.split('/')[-1])

df_pred.set_index('image_loc').to_csv("pred_probs.csv")

