# Train a neural network for multi-label classification on the CelebA dataset

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cleanlab/examples/blob/master/multilabel_classification/pytorch_network_training.ipynb)

This notebook demonstrates how to train a Pytorch neural network for image tagging and use the model to produce out-of-sample predicted class probabilities for each image in the dataset. These are required inputs to find label errors in multi-label classification datasets with cleanlab. Here we consider a subset of the [CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset, where each image may be tagged with one or more of the following tags: `['Wearing_Hat', 'Wearing_Necklace', 'Wearing_Necktie', 'Eyeglasses', 'No_Beard', 'Smiling']`, depending on which ones apply to the person depicted.

This notebook only shows how to train the network and use it to produce `pred_probs`, using them to find label issues is demonstrated in our other [example](https://github.com/cleanlab/examples/) notebook on [Find Label Errors in Multi-Label Classification Data (CelebA Image Tagging)](https://github.com/cleanlab/examples/blob/multilabel_tutorial/multilabel_classification/image_tagging.ipynb). Here we fit a state-of-the-art neural network initialized from a pretrained [TIMM](https://timm.fast.ai/) network backbone. You can use this same code to obtain a multi-label classifier (i.e. image tagging model) for *any* image dataset.


Please install the dependencies specified in this [requirements.txt](https://github.com/cleanlab/examples/blob/master/multilabel_classification/requirements.txt) file before running the notebook. 

Next you need to download the dataset. The below code attempts to download the dataset from the Google Drive where its authors have stored it, but this Drive only allows limited programmatic downloads per day. If the Google Drive download script fails, please just manually download the following links from the [CelebA dataset webpage](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) and save them in the current working directory:
 * [Images](https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM)
 * [Labels](https://drive.google.com/uc?id=0B7EVK8r0v71pblRyaVFSWGxPY0U)

In [1]:
import gdown
import os
def download_drive(url,output):
    filename = gdown.download(url, output, quiet=False)
    if filename is None:
        print(f"Downloading {url} failed, please download it from browser and paste it in {os.getcwd()}")
url = 'https://drive.google.com/uc?id=0B7EVK8r0v71pblRyaVFSWGxPY0U'
output = 'list_attr_celeba.txt'
if not os.path.exists(output):
    download_drive(url,output)
url = 'https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM' # Usually errors out, Download them manually from the link below
output = 'img_align_celeba.zip'
if not os.path.exists(output):
    download_drive(url,output)

Remember you can just manually download the data from the link above if this code failed. Once you have downloaded the zipped data file, unzip the folder which contains a bunch of individual image files.

In [2]:
!unzip -qq img_align_celeba.zip

Next we import the other required dependencies (make sure you have installed these packages).

In [3]:
import time
from collections import defaultdict
from collections import Counter
from types import SimpleNamespace

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from PIL import Image
import torch
from torch import nn
import torch.utils.data as data
from timm.optim import create_optimizer
from timm.models import create_model
from timm.data import resolve_data_config
from timm.data.loader import create_loader
from timm.utils import CheckpointSaver

Now we load the dataset, and preprocess it to only keep the classes of interest (i.e. *image tags*) listed above.

In [4]:
DATA =("image_id " + open("list_attr_celeba.txt").read()[7:]).splitlines()

from collections import defaultdict
q = DATA[0]
d_data = defaultdict(list)
ls = q.split()
for i in q.split():
    d_data[i] = []
for j in DATA[1:]:
    labels = j.split()
    for k in range(0,len(labels)):
        if k==0:
            d_data[ls[k]].append(labels[k])
        else:
            # map -1 entries -> 0
            ps = int((int(labels[k])+1)/2)
            d_data[ls[k]].append(ps)
            
dat = pd.DataFrame.from_dict(d_data)


selected = ['image_id',
'Eyeglasses',
 'Wearing_Earrings',
 'Wearing_Hat',
 'Wearing_Necklace',
 'Wearing_Necktie',
 'No_Beard',
 'Smiling']

def is_label(row):
    for s in selected[1:]:
        if row[s]!=0:
            return True
    return False

def get_loc(i):
    return os.path.join(os.getcwd(),'img_align_celeba/')+i

dat_label = dat.apply(is_label,axis=1)
dat_selected = dat[dat_label][selected]
dat_selected['image_path'] = dat_selected['image_id'].map(lambda x:get_loc(x))
selected[0] = 'image_path'

df = dat_selected[selected]
set_lab = {}
for i,row in df.iterrows():
    q = str(row.tolist()[1:])
    if q not in set_lab:
        set_lab[(str(q))]=len(set_lab)

# Here we drop a couple rare class-combinations just to simplify stratified data splitting
def get_lab(row):
    q = str(row.tolist()[1:])
    return set_lab[q]

df['unique_label'] = df.apply(get_lab,axis=1)
cnt = Counter(df['unique_label'])

def drop(val):
    if cnt[val]>10:
        return True
    return False

is_drop = df['unique_label'].apply(lambda x:drop(x))
df = df[is_drop]

The resulting DataFrame contains the ID of each image and the classes (i.e. *tags*) that apply to this image.

We define a general class of Pytorch neural networks for multi-label image classification. You can use this class for training models on *any* image dataset, and this class can utilize *any* pretrained [TIMM](https://timm.fast.ai/) network backbone. The `MultiLabelModel` below adds an appropriate output layer on top of the network backbone and then fine-tunes the entire network jointly on your multi-label classification dataset.

In [5]:
class MultiLabelModel(nn.Module):
    """ 
    Pytorch network for multi-label classification that can utilize any TIMM network backbone.
    Some of this code is inspired by: https://github.com/yang-ruixin/PyTorch-Image-Models-Multi-Label-Classification
    Note this network uses Sigmoid output activations because the predicted probabilities do not need to sum to 1
    for multi-label classification, in which each image may belong to multiple classes rather than only one.
    """
    def __init__(self, model, n_classes, class_weights=None, verbose = False):
        super().__init__()
        self.base_model = model  # network backbone can be any TIMM model
        self.num_classes = n_classes
        self.verbose = verbose

    def forward(self, x):
        x = self.base_model(x)

        x = torch.flatten(x, 1)

        return x

    def get_loss(self, loss_fn, output, target):

        return loss_fn(output, target)

    def validate(self, loader):
        self.eval();
        with torch.no_grad():
            total_loss = 0
            m = nn.Sigmoid()
            labels = []
            preds = []
            for batch_idx, (input, target) in enumerate(loader):
                input = input.cuda()
                labels.append(target.detach().cpu())
                target = target.float().cuda()
                output = m(self(input))
                loss = self.get_loss(loss_fn, output, target)

                total_loss += loss.item()
                pred_model = (output > 0.5).detach().cpu()
                preds.append(pred_model)

            num_of_batches_per_epoch = len(loader)
            avg_loss = total_loss / num_of_batches_per_epoch
            print("VALIDATION DATA STATS")

            print("AVERAGE LOSS:", avg_loss)
            preds = torch.cat(preds).int()
            labels = torch.cat(labels).int()
            acc_score = accuracy_score(labels, preds)
            print("MULTILABEL accuracy score:", acc_score)
            per_class = []
            for i in range(len(preds.T)):
                per_class.append(accuracy_score(labels.T[i], preds.T[i]))
            print(dataset_train.label_names)
            print(per_class)
            print('\n\n')
        return avg_loss

    def predict_proba(self, loader):
        self.eval();
        with torch.no_grad():
            m = nn.Sigmoid()
            preds = []
            for batch_idx, (input, target) in enumerate(loader):
                input = input.cuda()
                output = m(self(input))
                pred_model = output.detach().cpu()
                preds.append(pred_model)
            preds = torch.cat(preds)
        return preds

    def train_one_epoch(
        self,
        loader,
        optimizer,
        loss_fn,
    ):
        sta = time.time()
        second_order = hasattr(optimizer, "is_second_order") and optimizer.is_second_order
        self.train()
        total_loss = 0
        m = nn.Sigmoid()
        labels = []
        preds = []
        ct = 0
        for batch_idx, (input, target) in enumerate(loader):
            input = input.cuda()
            ct += 1
            labels.append(target.detach().cpu())
            target = target.float().cuda()
            output = m(self(input))
            loss = self.get_loss(loss_fn, output, target)
            total_loss += loss.item()
            pred_model = (output > 0.5).detach().cpu()
            preds.append(pred_model)
            optimizer.zero_grad()
            loss.backward(create_graph=second_order)
            optimizer.step()
            if ct % 80 == 0 and self.verbose:
                print("LOSS:", loss.item())
        num_of_batches_per_epoch = len(loader)
        avg_loss = total_loss / num_of_batches_per_epoch
        print("TRAINING DATA STATS")
        print("AVERAGE LOSS:", avg_loss)
        preds = torch.cat(preds).int()
        labels = torch.cat(labels).int()
        acc_score = accuracy_score(labels, preds)
        print("MULTILABEL accuracy score:", acc_score)
        per_class = []
        for i in range(len(preds.T)):
            per_class.append(accuracy_score(labels.T[i], preds.T[i]))
        print(dataset_train.label_names)
        print(per_class)
        print('\n\n')
        sto = time.time()
        print("training time", sto - sta)
        return avg_loss
    

    def fit(self, loader_train, load_val, num_epochs=10):
        if os.path.exists("weights_model"):
            print("removing weights directory")
            os.system('rm -rf weights_model')
        os.mkdir("weights_model")
        args = SimpleNamespace()
        args.weight_decay = 0
        args.lr = 1e-4
        args.opt = 'adam'
        args.momentum = 0.9
        args.sched = "step"

        optimizer = create_optimizer(args, self)
        saver = CheckpointSaver(
            model=self,
            optimizer=optimizer,
            checkpoint_dir="weights_model"
        )
        errs = []
        num_of_data_train = len(loader_train.dataset.data)
        for epoch in range(0, num_epochs):
            loss_train = self.train_one_epoch(
                loader_train,
                optimizer,
                loss_fn,
            )
            loss_val = self.validate(loader_val)
            errs.append([loss_train, loss_val])
            saver.save_checkpoint(epoch, metric=loss_val)

We also create a wrapper class for training our `MultiLabel` model on multi-label classification image datasets. You can easily apply this to your own datasets.

In [6]:
class DatasetMultiLabel(data.Dataset):
    def __init__(
            self,
            annotation_path=None,
            df = None,
            transform=None):

        super().__init__()
        self.transform = transform
        self.data = []
        self.labels = []
        self.label_names = []
        if annotation_path is None:
            assert df is not None
        else:
            df = pd.read_csv(annotation_path)
        
        cols = df.columns
        self.label_names = list(cols[1:-1])
        for i,row in df.iterrows():
            lb = []
            for j in cols:
                if j=='unique_label':
                    continue
                if j=='image_path':
                    self.data.append(row[j])
                else:
                    lb.append(float(row[j]))
            self.labels.append(lb)
                
    def __getitem__(self, idx):
        img_path = self.data[idx]

        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)

        labels = self.labels[idx]

        return img, labels

    def __len__(self):
        return len(self.data)

In [7]:
def reset_weights(m):
  '''
  Re-initializes model weights, eg. between cross-validation folds. 
  '''
  for layer in m.children():
    if hasattr(layer, 'reset_parameters'):
        layer.reset_parameters()

def create_df(pred_probs, dataset):
    """
    Returns a dataframe with image_path and predicted probabilities for each image.
    """
    ls = dataset_val.label_names
    cl = defaultdict(list)
    cl['image_path'] = dataset.data
    for i in range(0,len(ls)):
        cl[ls[i]] = pred_val.T[i].tolist()
    return pd.DataFrame.from_dict(cl)

Let's create a `DatasetMultiLabel` and `MultiLabelModel` for the Celeb-A dataset. Here we use the [efficientnet_b0](https://rwightman.github.io/pytorch-image-models/models/tf-efficientnet/) backbone for our neural network, but you can easily use any other TIMM backbone. 

In [8]:
dataset = DatasetMultiLabel(df = df)

loss_fn = nn.BCELoss()

model = create_model(
    'efficientnet_b0',  # you can replace this with any TIMM backbone
    num_classes=len(dataset.labels[0]),
)
data_config = resolve_data_config(
       args = {}, model=model
    )

model = MultiLabelModel(
        model,
        n_classes=len(dataset.labels[0]),
    ).cuda()

We train this network using K-fold cross validation. This allows us to obtain **out-of-sample** predictions for each image in the dataset (i.e. predictions from a copy of the model which never saw this image during training). Out-of-sample predictions are less prone to overfitting, and thus better suited for finding label issues. From each fold of cross-validation, we store the predicted class probabilities for the images that were out-of-sample in a DataFrame `df_pred`. These predictions are subsequently used for finding label issues in cleanlab's Tutorial on [Multi-Label Classification](https://docs.cleanlab.ai/).

In [9]:
num_splits = 4  # number of cross-validation splits (higher values will take longer but give better results)
skf = StratifiedKFold(n_splits=num_splits)

In [None]:
ct = 1
for train_index, test_index in skf.split(df,df['unique_label']):
    if ct!=1:
        model.apply(reset_weights);
    dataset_train = DatasetMultiLabel(df = df.iloc[train_index])
    dataset_val = DatasetMultiLabel(df = df.iloc[test_index])
    loader_train = create_loader(
        dataset_train,
        input_size=data_config["input_size"],
        batch_size=64,
        is_training=True,
        mean=data_config["mean"],
        std=data_config["std"],
       interpolation=data_config["interpolation"],
    )
    loader_val = create_loader(
        dataset_val,
        input_size=data_config["input_size"],
        batch_size=64,
        is_training=False,
        mean=data_config["mean"],
        std=data_config["std"],
        interpolation=data_config["interpolation"],

    )
    model.fit(loader_train,loader_val,num_epochs=40)
    checkpoint = torch.load("weights_model/model_best.pth.tar")
    model.load_state_dict(checkpoint['state_dict'])
    pred_val = model.predict_proba(loader_val)
    df_pred = create_df(pred_val,dataset_val)
    df_pred.to_csv(str(ct)+"_fold.csv",index=False)
    ct+=1

In [11]:
dfl = []
for i in range(1,num_splits+1):
    dfl.append(pd.read_csv(str(i)+"_fold.csv"))

cols = dfl[0].columns[1:]
df_pred = pd.concat(dfl,axis=0)
df_pred['image_path'] = (df_pred['image_path'].map(lambda x:x.split('/')[-1]))
df_pred.set_index('image_path').to_csv("pred_probs.csv")