### Training ResNet
In this notebook we train a simple resnet50 to classify celebrity sex.

In [1]:
import sklearn
import torch
import pandas as pd
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset
import warnings
import os
from os.path import isfile, join
import numpy as np
from pathlib import Path
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


#### Extracting label information

In [3]:
myfile = open('data/original/CelebAMask-HQ/CelebAMask-HQ-attribute-anno.txt', 'r')

mytext = myfile.readlines()[1:]
csv_file = open('custom_dataset.csv', 'w')
csv_file.write("".join(mytext))

3622212

In [4]:
all_imgs = list(sorted(['data/original/CelebAMask-HQ/CelebA-HQ-img/'+filename.split('_')[0] for filename in os.listdir('data/original/CelebAMask-HQ/CelebA-HQ-img/') if isfile(join('data/original/CelebAMask-HQ/CelebA-HQ-img/', filename))]))

In [5]:
dataset = pd.read_csv('custom_dataset.csv', delimiter = ' ')
dataset.index = list(map(lambda x: x[0], dataset.index))

In [5]:
def to_prob_dist(labels: torch.tensor, size = 2):
	arr = np.full((len(labels), size), 0)
	for i, label in enumerate(labels):
		arr[i][int(label.item())] = 1
	return torch.tensor(arr, dtype=torch.float32)

class CelebDataset(Dataset):
	def __init__(self, csv_file, root_dir, transform=None):
		"""
		Arguments:
			csv_file (string): Path to the csv file with annotations.
			root_dir (string): Directory with all the images.
			transform (callable, optional): Optional transform to be applied
				on a sample.
		"""
		self.label_frame = pd.read_csv(csv_file, delimiter = ' ')
		# self.label_frame.index = list(map(lambda x: x[0], self.label_frame.index))
		self.labels = self.label_frame.loc[
			list(map(lambda x: x.split('/')[-1],
					 all_imgs
					))
		]['Male']

		self.labels = torch.tensor((self.labels.values == 1).astype(np.uint8))
		self.imgs = list(map(lambda x: x[0],
			self.label_frame.loc[list(sorted(map(lambda x: x.split('/')[-1],all_imgs)))].index.values
							))
		self.root_dir = root_dir
		self.transform = transform

	def __len__(self):
		return len(self.label_frame)

	def __getitem__(self, idx):
		if torch.is_tensor(idx):
			idx = idx.tolist()

		img_name = os.path.join(self.root_dir,
								self.imgs[idx])
		
		image = Image.open(img_name)
		label = torch.tensor([self.labels[idx]])

		if self.transform:
			image = self.transform(image)

		return image, label

In [6]:
celeb_dataset = CelebDataset(
    csv_file = 'custom_dataset.csv',
    root_dir = 'data/original/CelebAMask-HQ/CelebA-HQ-img',
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                          std=[0.229, 0.224, 0.225])
    ])
)

train_set, val_set, test_set = torch.utils.data.random_split(celeb_dataset, [20000, 5000, 5000])

In [7]:
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle
from sklearn.metrics import classification_report

def accurate_count(pred, true):
    return (torch.round(pred).int() == true).sum()

class CustomEarlyStopping():
	def __init__(self, patience, min_loss_delta = 0, min_acc_delta = 0):
		self.patience = patience
		self.best_loss = 1e9
		self.best_acc = 0
		self.patience_count = 0
		self.count = 0
		self.stop = False

		self.min_loss_delta = min_loss_delta
		self.min_acc_delta = min_acc_delta

	def __call__(self, loss, accuracy):
		self.save_state = False

		if self.best_loss - loss > self.min_loss_delta or accuracy - self.best_acc > self.min_acc_delta:
			if self.best_loss - loss > self.min_loss_delta:
				self.best_loss = loss
			if accuracy - self.best_acc > self.min_acc_delta:
				self.best_acc = accuracy

			self.count = 0
		else:
			self.count += 1
		
		if self.count >= self.patience:
			self.stop = True
            
def train_model(model_name, model, n_epochs, optimizer, criterion, batch_size, early_stopping, train_set, val_set, device):
    
    model = model.to(device)
    
    train_loader = DataLoader(train_set, batch_size = batch_size, shuffle = True, num_workers = 4)
    val_loader = DataLoader(val_set, batch_size = batch_size, num_workers = 4)

    best_val_loss = 1e9
    best_val_acc = 0

    for epoch in range(1, n_epochs + 1):
        train_loss = 0.0
        train_acc_count = 0
        total_sample = 0

        model.train()
        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device).float()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            train_loss += loss.item()
            train_acc_count += accurate_count(outputs, labels)
            total_sample += len(labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss = train_loss / len(train_loader)
        train_acc = train_acc_count / total_sample

        val_loss = 0.0
        val_acc_count = 0
        total_sample = 0

        model.eval()
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device).float()

                outputs = model(inputs)

                loss = criterion(outputs, labels)

                val_loss += loss.item()
                val_acc_count += accurate_count(outputs, labels)
                total_sample += len(labels)


        val_loss = val_loss / len(val_loader)

        val_acc = val_acc_count / total_sample

        print(f"Epoch: {epoch} | Train Loss: {train_loss:.5f} \tVal Loss: {val_loss:.5f} \tTrain Acc: {train_acc:.3f} \tVal Acc: {val_acc:.3f}")

        done = epoch

        if val_loss < best_val_loss:
            best_val_loss = val_loss

            with open(f'{model_name}_best_loss_checkpoint.pkl', 'wb') as myfile:
                torch.save(model.state_dict(), myfile)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            with open(f'{model_name}_best_acc_checkpoint.pkl', 'wb') as myfile:
                torch.save(model.state_dict(), myfile)

        early_stopping(val_loss, val_acc)

        if early_stopping.stop:
            print("Stopping due to early stopping | patience =", early_stopping.patience)
            
def eval_model(model, test_set, batch_size, device):
    test_loader = DataLoader(test_set, batch_size = batch_size, num_workers = 4)

    model = model.to(device)
    model.eval()
    total_sample = 0
    test_acc_count = 0

    output = []
    input_list = []
    label = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device).float()

            outputs = model(inputs)

            input_list.append(inputs.detach().cpu().numpy())    
            label.append(labels.detach().cpu().numpy())
            output.append(outputs.detach().cpu().numpy())

            test_acc_count += accurate_count(outputs, labels)

            total_sample += len(labels)
    
    output = np.concatenate(output).flatten()
    output = np.round(output).astype(int)
    label = np.concatenate(label).flatten().astype(int)
    
    print(output.shape, label.shape)
    
    print(classification_report(
        y_true = label,
        y_pred = output,
        target_names = ['Female', 'Male']
    ))
            
    test_acc = test_acc_count / total_sample
    
    return test_acc
            
import math

def imshow(arr: list, label: list = None, figsize=None, shape = (32, 32, 3), is_int = None):
	if is_int == None:
		if type(arr[0]) == torch.Tensor:
			is_int = (arr[0].detach().cpu().numpy() > 1).sum() > 0
		else:
			is_int = (arr[0] > 1).sum() > 0
	if label == None:
		label = [''] * len(arr)

	height = int(len(arr) ** 0.5)
	width = math.ceil(len(arr) / height)

	if figsize == None:
		fig = plt.figure()
	else:
		fig = plt.figure(figsize=figsize)
	for i in range(height):
		for j in range(width):
			ax = fig.add_subplot(height, width, i * height + j + 1)
			ax.grid(False)
			ax.set_xticks([])
			ax.set_yticks([])
			show = arr[i * height + j]
			if type(arr[i * height + j]) != torch.Tensor:
				show = torch.Tensor(show)
			if len(show.squeeze(0).cpu().shape) == 2:
				ax.imshow((show.squeeze(0).detach().cpu()).type(torch.uint8 if is_int else torch.float), cmap='gray')
			else:
				ax.imshow((show.squeeze(0).detach().cpu().permute(1,2,0)).type(torch.uint8 if is_int else torch.float))
			ax.set_title(label[i * height + j])

In [11]:
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights = None)
model.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features = 2048, out_features = 1, bias = True),
    torch.nn.Sigmoid()
)

In [12]:
train_model(
    model_name = 'RN50',
    model = model,
    n_epochs = 5,
    optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3),
    criterion = torch.nn.BCELoss(),
    batch_size = 128,
    early_stopping = CustomEarlyStopping(patience = 5),
    train_set = train_set,
    val_set = val_set, 
    device = "cuda"
)

100%|██████████| 157/157 [01:46<00:00,  1.47it/s]


Epoch: 1 | Train Loss: 0.33473 	Val Loss: 0.18429 	Train Acc: 0.850 	Val Acc: 0.932


100%|██████████| 157/157 [01:46<00:00,  1.47it/s]


Epoch: 2 | Train Loss: 0.11841 	Val Loss: 0.15126 	Train Acc: 0.953 	Val Acc: 0.939


100%|██████████| 157/157 [01:46<00:00,  1.47it/s]


Epoch: 3 | Train Loss: 0.09742 	Val Loss: 0.10521 	Train Acc: 0.964 	Val Acc: 0.955


100%|██████████| 157/157 [01:47<00:00,  1.46it/s]


Epoch: 4 | Train Loss: 0.07352 	Val Loss: 0.12550 	Train Acc: 0.972 	Val Acc: 0.951


100%|██████████| 157/157 [01:46<00:00,  1.47it/s]


Epoch: 5 | Train Loss: 0.05879 	Val Loss: 0.09999 	Train Acc: 0.979 	Val Acc: 0.966


In [13]:
model = resnet50(weights = None)
model.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features = 2048, out_features = 1, bias = True),
    torch.nn.Sigmoid()
)
model.load_state_dict(torch.load('RN50_best_acc_checkpoint.pkl'))
model.eval()
pass

In [14]:
eval_model(model, test_set, 128, device = "cuda:0")

(5000,) (5000,)
              precision    recall  f1-score   support

      Female       0.98      0.97      0.97      3146
        Male       0.95      0.96      0.95      1854

    accuracy                           0.97      5000
   macro avg       0.96      0.97      0.96      5000
weighted avg       0.97      0.97      0.97      5000



tensor(0.9660, device='cuda:0')

### Using pretrained weights

In [8]:
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights = ResNet50_Weights.DEFAULT)
model.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features = 2048, out_features = 1, bias = True),
    torch.nn.Sigmoid()
)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:09<00:00, 11.2MB/s]


In [9]:
train_model(
    model_name = 'RN50_pretrained',
    model = model,
    n_epochs = 5,
    optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3),
    criterion = torch.nn.BCELoss(),
    batch_size = 128,
    early_stopping = CustomEarlyStopping(patience = 5),
    train_set = train_set,
    val_set = val_set, 
    device = "cuda"
)

100%|██████████| 157/157 [01:46<00:00,  1.48it/s]


Epoch: 1 | Train Loss: 0.07698 	Val Loss: 0.07818 	Train Acc: 0.972 	Val Acc: 0.977


100%|██████████| 157/157 [01:46<00:00,  1.48it/s]


Epoch: 2 | Train Loss: 0.03939 	Val Loss: 0.16602 	Train Acc: 0.985 	Val Acc: 0.942


100%|██████████| 157/157 [01:47<00:00,  1.46it/s]


Epoch: 3 | Train Loss: 0.02559 	Val Loss: 0.07209 	Train Acc: 0.991 	Val Acc: 0.973


100%|██████████| 157/157 [01:47<00:00,  1.46it/s]


Epoch: 4 | Train Loss: 0.02230 	Val Loss: 0.18221 	Train Acc: 0.991 	Val Acc: 0.950


100%|██████████| 157/157 [01:46<00:00,  1.47it/s]


Epoch: 5 | Train Loss: 0.01613 	Val Loss: 0.07577 	Train Acc: 0.995 	Val Acc: 0.975


In [10]:
model = resnet50(weights = None)
model.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features = 2048, out_features = 1, bias = True),
    torch.nn.Sigmoid()
)
model.load_state_dict(torch.load('RN50_pretrained_best_acc_checkpoint.pkl'))
model.eval()
pass

In [11]:
eval_model(model, test_set, 128, device = "cuda:0")

(5000,) (5000,)
              precision    recall  f1-score   support

      Female       0.96      1.00      0.98      3141
        Male       1.00      0.94      0.97      1859

    accuracy                           0.98      5000
   macro avg       0.98      0.97      0.97      5000
weighted avg       0.98      0.98      0.98      5000



tensor(0.9760, device='cuda:0')