In [1]:
import timm
from torch import nn
from timm.data.loader import create_loader
import torch.utils.data as data

import pandas as pd

from sklearn.metrics import balanced_accuracy_score


from imp import reload



import numpy as np
from sklearn.metrics import accuracy_score

In [36]:
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import accuracy_score

In [2]:

class MultiLabelModel(nn.Module):
    def __init__(self, model, n_classes):
        super().__init__()
        self.base_model = model
        self.WEIGHTS = [1/n_classes]*n_classes
        self.num_classes = n_classes

    def forward(self, x):
        x = self.base_model(x)
        
        x = torch.flatten(x, 1)

        return x

    def get_loss(self,loss_fn, output, target):
        losses = 0.0
        for i in range(0,self.num_classes):
            losses+=self.WEIGHTS[i]* loss_fn(output.T[i], target.T[i])
        return losses
    def as_sequential_for_ML(self):

        layers = [self.conv_stem, self.bn1, self.act1]
        layers.extend(self.blocks)
        layers.extend([self.conv_head, self.bn2, self.act2])
        return nn.Sequential(*layers)
    
from PIL import Image


In [3]:
class DatasetMultiLabel(data.Dataset):
    def __init__(
            self,
            annotation_path,
            transform=None):

        super().__init__()

        self.transform = transform

        # initialize the arrays to store the ground truth labels and paths to the images
        self.data = []
        self.labels = []
        self.label_names = []
        # read the annotations from the CSV file
        df = pd.read_csv(annotation_path)
        cols = df.columns
        self.label_names = list(cols[1:])
        for i,row in df.iterrows():
            lb = []
            for j in cols:
                if j=='image_path':
                    self.data.append(row[j])
                else:
                    lb.append(float(row[j]))
            self.labels.append(lb)
                
    def __getitem__(self, idx):
        # take the data sample by its index
        img_path = self.data[idx]

        # read image
        img = Image.open(img_path)

        # apply the image augmentations if needed
        if self.transform:
            img = self.transform(img)

        labels = self.labels[idx]

        return img, labels

    def __len__(self):
        # return len(self.samples)
        return len(self.data)

In [4]:
dataset_train = DatasetMultiLabel("img_align_celeba/train.csv")
dataset_val = DatasetMultiLabel("img_align_celeba/val.csv")

In [5]:
from timm.data import (
    resolve_data_config,
)
from timm.models import create_model

In [6]:
dataset_train.label_names

['Eyeglasses',
 'Wearing_Earrings',
 'Wearing_Hat',
 'Wearing_Necklace',
 'Wearing_Necktie']

In [7]:
model = create_model(
    'efficientnet_b0',
    num_classes=len(dataset_val.labels[0]),
)

In [8]:
model = MultiLabelModel(
        model,
        n_classes=len(dataset_val.labels[0]),
    ).cuda()

In [9]:
from timm.optim import create_optimizer
import pickle

In [10]:
args = pickle.load(open("PyTorch-Image-Models-Multi-Label-Classification/args.p",'rb'))

In [11]:
from types import SimpleNamespace
args = SimpleNamespace()
args.weight_decay = 0
args.lr = 1e-4
args.opt = 'adam' #'lookahead_adam' to use `lookahead`
args.momentum = 0.9


In [12]:
optimizer = create_optimizer(args, model)

In [13]:
data_config = resolve_data_config(
       args = {}, model=model
    )

In [14]:
loader_train = create_loader(
    dataset_train,
    input_size=data_config["input_size"],
    batch_size=64,
    is_training=True,
    mean=data_config["mean"],
    std=data_config["std"],
   interpolation=data_config["interpolation"],
)

In [15]:
from timm.scheduler import create_scheduler


In [16]:
args.epochs = 1000

In [17]:
args.sched = "step"
args.decay_epochs = 30
args.decay_rate = 0.1
args.warmup_lr = 0.0001
args.min_lr = 1e-5
args.warmup_epochs = 3

In [18]:
lr_scheduler, num_epochs = create_scheduler( args,optimizer)

In [19]:
loader_val = create_loader(
    dataset_val,
    input_size=data_config["input_size"],
    batch_size=64,
    is_training=False,
    mean=data_config["mean"],
    std=data_config["std"],
    interpolation=data_config["interpolation"],
   
)

In [20]:
import torch    

In [21]:
model.eval();

In [22]:
loss_fn = nn.BCELoss()


In [23]:
def validate(loader,model):
    with torch.no_grad():
        total_loss = 0
        m = nn.Sigmoid()
        ops = []
        labels = []
        preds = []
        for batch_idx, (input, target) in enumerate(loader):
            input = input.cuda()
            labels.append(target.detach().cpu())
            target = target.float().cuda()
            output = m(model(input))
            loss = model.get_loss(loss_fn, output, target)
            
            total_loss += loss.item()
            pred_model = (output>0.5).detach().cpu()
            preds.append(pred_model)
        
        print(output.shape,target.shape)
        num_of_batches_per_epoch = len(loader)
        avg_loss = total_loss / num_of_batches_per_epoch
        print("AVERAGE LOSS:",avg_loss)
        preds = torch.cat(preds).int()
        labels = torch.cat(labels).int()
        acc_score = accuracy_score(labels,preds)
        print("MULTILABEL accuracy score:",acc_score)
        per_class = []
        for i in range(len(preds.T)):
            per_class.append(accuracy_score(labels.T[i],preds.T[i]))
        print(dataset_train.label_names)
        print(per_class)
        return avg_loss

In [37]:
validate(loader_val,model)

torch.Size([10, 5]) torch.Size([10, 5])
AVERAGE LOSS: 0.6931521567431363
MULTILABEL accuracy score: 0.025256135334762925
['Eyeglasses', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Necklace', 'Wearing_Necktie']
[0.38634739099356685, 0.5041696449845128, 0.504884441267572, 0.5651655944722421, 0.32487491065046464]


0.6931521567431363

In [38]:
def train_one_epoch(
    num_of_data_train,
    epoch,
    model,
    loader,
    optimizer,
    loss_fn,
    lr_scheduler=None,
    saver=None,
    output_dir="",
):
    second_order = hasattr(optimizer, "is_second_order") and optimizer.is_second_order
    model.train()
    total_loss = 0
    m = nn.Sigmoid()
    ops = []
    labels = []
    preds = []
    ct=0
    for batch_idx, (input, target) in enumerate(loader):
            input = input.cuda()
            ct+=1
            labels.append(target.detach().cpu())
            target = target.float().cuda()
            output = m(model(input))
            loss = model.get_loss(loss_fn, output, target)
            total_loss += loss.item()
            pred_model = (output>0.5).detach().cpu()
            preds.append(pred_model)
            optimizer.zero_grad()
            loss.backward(create_graph=second_order)
            optimizer.step()
            if ct%10==0:
                print("LOSS:",loss.item())
    num_of_batches_per_epoch = len(loader)
    avg_loss = total_loss / num_of_batches_per_epoch
    print("AVERAGE LOSS:",avg_loss)
    preds = torch.cat(preds).int()
    labels = torch.cat(labels).int()
    acc_score = accuracy_score(labels,preds)
    print("MULTILABEL accuracy score:",acc_score)
    per_class = []
    for i in range(len(preds.T)):
        per_class.append(accuracy_score(labels.T[i],preds.T[i]))
    print(dataset_train.label_names)
    print(per_class)
    return avg_loss

In [39]:
from timm.utils import CheckpointSaver


In [40]:
saver = CheckpointSaver(
            model=model,
            optimizer=optimizer,
    checkpoint_dir="weights"
        )

In [41]:
eval_metrics = validate(loader_val,model)

torch.Size([10, 5]) torch.Size([10, 5])
AVERAGE LOSS: 0.6931521567431363
MULTILABEL accuracy score: 0.025256135334762925
['Eyeglasses', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Necklace', 'Wearing_Necktie']
[0.38634739099356685, 0.5041696449845128, 0.504884441267572, 0.5651655944722421, 0.32487491065046464]


In [42]:
!rm -rf weights/*

In [43]:
save_metric = eval_metrics
best_metric, best_epoch = saver.save_checkpoint(
                    1, metric=save_metric
                )

In [44]:
num_of_data_train = len(dataset_train)

In [46]:
for epoch in range(0,1000):
    loss_train = train_one_epoch(
    num_of_data_train,
    epoch,
    model,
    loader_train,
    optimizer,
    loss_fn,
    lr_scheduler=lr_scheduler,
)
    loss_val = validate(loader_val,model)
    saver.save_checkpoint(epoch,metric=loss_val)


RuntimeError: DataLoader worker (pid(s) 25184) exited unexpectedly

In [None]:
validate(loader_val,model)