In [1]:
# from model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import pandas as pd 
import os 
from torch.optim import lr_scheduler

import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck

import warnings
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score

from sklearn.metrics import plot_confusion_matrix
import sklearn
import matplotlib.pyplot as plt 
from sklearn.metrics import (
    confusion_matrix,
    ConfusionMatrixDisplay
)

from pylab import rcParams
rcParams['figure.figsize'] = 13, 13


import tensorboardX
tb_writer = tensorboardX.SummaryWriter('storage/ResNetMultiV5')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


def validate(val_loader, model, batch=False):
    model.eval()
    correct_g = 0
    correct_v = 0
    correct_c = 0
    
    l = len(val_loader.dataset)
    
    for _, sample in enumerate(val_loader):
        if torch.cuda.is_available():
            data = sample['image'].cuda()
            target_grapheme = sample['labels']['grapheme_labels'].cuda()
            target_vowel = sample['labels']['vowel_labels'].cuda()
            target_consonant = sample['labels']['consonant_labels'].cuda()
        
        output = model(data)
        
        pred_g = output['grapheme'].max(1, keepdim=True)[1]
        correct_g += pred_g.eq(target_grapheme.view_as(pred_g)).cpu().sum()

        pred_v = output['vowel'].max(1, keepdim=True)[1]
        correct_v += pred_v.eq(target_vowel.view_as(pred_v)).cpu().sum()

        pred_c = output['consonant'].max(1, keepdim=True)[1]
        correct_c += pred_c.eq(target_consonant.view_as(pred_c)).cpu().sum()
        
        if batch: 
            acc_g = 100.0 * float(correct_g)/batch_size
            acc_v = 100.0 * float(correct_v)/batch_size
            acc_c = 100.0 * float(correct_c)/batch_size
            
            # Average recall
            overall = ((acc_g*2) + acc_v + acc_c)/4

            print(f'\nOn Val set Accuracy: grapheme root: {correct_g}/{batch_size} ({round(acc_g,2)}%) ' +
                  f'vowel diatric: {correct_v}/{l} ({round(acc_v,2)}%) '  + 
                  f'consonant diatric: {correct_c}/{l} ({round(acc_c,2)}%) ' +
                  f'overall: ({round(overall,2)}%) \n')

            return acc_g, acc_v, acc_c, overall

    acc_g = 100.0 * float(correct_g)/l
    acc_v = 100.0 * float(correct_v)/l
    acc_c = 100.0 * float(correct_c)/l
    overall = ((acc_g*2) + acc_v + acc_c)/4

    print(f'\nOn Val set Accuracy: grapheme root: {correct_g}/{l} ({round(100.0 * float(correct_g)/l,2)}%) ' +
          f'vowel diatric: {correct_v}/{l} ({round(float(correct_v)/l*100,2)}%) '  + 
          f'consonant diatric: {correct_c}/{l} ({round(float(correct_c)/l*100,2)}%) ' +
          f'overall: ({round(overall,2)}%) \n')
    
    return acc_g, acc_v, acc_c, overall

cuda:0


In [2]:
class MNISTResNet(ResNet):
    def __init__(self):
        super(MNISTResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=1292) # Based on ResNet18
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3,bias=False)

        self.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=512, out_features=168)
        )
        self.fc1 = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=512, out_features=11)
        )
        self.fc2 = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=512, out_features=7)
        )
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return {
            'grapheme': self.fc(x),
            'vowel': self.fc1(x),
            'consonant': self.fc2(x)
        }

    def get_loss(self, net_output, ground_truth):
        color_loss = F.cross_entropy(net_output['grapheme'], ground_truth['grapheme_labels'])
        gender_loss = F.cross_entropy(net_output['vowel'], ground_truth['vowel_labels'])
        article_loss = F.cross_entropy(net_output['consonant'], ground_truth['consonant_labels'])
        loss = color_loss + gender_loss + article_loss
        return loss, {'grapheme': color_loss, 'vowel': gender_loss, 'consonant': article_loss}



model = MNISTResNet()
model.to(device)

MNISTResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=Tru

model

In [3]:
class BengaliDatasetMultiClass(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.label_df = pd.read_csv(csv_file)
        self.label_df = self.label_df[['image_id','grapheme_root',
                             'vowel_diacritic','consonant_diacritic',
                             'label','grapheme','textlabel']]
        
        self.root_dir = root_dir 
        self.transform = transform 
    
    def __len__(self):
        return len(self.label_df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_name = os.path.join(self.root_dir,
                                self.label_df.iloc[idx, 0] + '.png')
        image = Image.open(img_name).convert('L')

        label = tuple(self.label_df.iloc[idx, 1:])
        textlabel = self.label_df.iloc[idx, -1]  
        
        if self.transform:
            image = self.transform(image)
            
        sample = {
            "image": image,
            "labels": {
                "grapheme_labels": label[0],
                "vowel_labels": label[1],
                "consonant_labels": label[2],
            },
            "human_labels":{
                "typeface":label[4],
                "stringlabel":label[5]
            }
        }

        return sample

In [4]:
batch_size=128
num_workers = 16

transform=transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
])

trainset = BengaliDatasetMultiClass("data/train.csv","data/trainsplit", transform)
testset = BengaliDatasetMultiClass("data/test.csv","data/testsplit", transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                         shuffle=True, num_workers=num_workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=True, num_workers=num_workers)

In [5]:
def calculate_metrics(output, target):
    _, predicted_color = output['grapheme'].cpu().max(1)
    gt_color = target['grapheme_labels'].cpu()

    _, predicted_gender = output['vowel'].cpu().max(1)
    gt_gender = target['vowel_labels'].cpu()

    _, predicted_article = output['consonant'].cpu().max(1)
    gt_article = target['consonant_labels'].cpu()

    with warnings.catch_warnings():  # sklearn may produce a warning when processing zero row in confusion matrix
        warnings.simplefilter("ignore")
        accuracy_color = accuracy_score(y_true=gt_color.numpy(), y_pred=predicted_color.numpy())
        accuracy_gender = accuracy_score(y_true=gt_gender.numpy(), y_pred=predicted_gender.numpy())
        accuracy_article = accuracy_score(y_true=gt_article.numpy(), y_pred=predicted_article.numpy())

    return accuracy_color, accuracy_gender, accuracy_article

In [6]:
# from train.py

import argparse
import os
from datetime import datetime

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter


def get_cur_time():
    return datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M')


def checkpoint_save(model, name, epoch):
    f = os.path.join(name, 'checkpoint-{:06d}.pth'.format(epoch))
    torch.save(model.state_dict(), f)
    print('Saved checkpoint:', f)

    return f
    

def train(start_epoch=1, N_epochs=20, log_interval=100):

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    step = 0
    
    for epoch in range(start_epoch, N_epochs + 1):
        total_loss = 0
        accuracy_g = 0
        accuracy_v = 0
        accuracy_c = 0

        for i, batch in enumerate(trainloader):
            optimizer.zero_grad()

            img = batch['image']
            target_labels = batch['labels']
            target_labels = {t: target_labels[t].to(device) for t in target_labels}
            output = model(img.to(device))

            loss_train, losses_train = model.get_loss(output, target_labels)
            total_loss += loss_train.item()
            batch_accuracy_g, batch_accuracy_v, batch_accuracy_c = \
                calculate_metrics(output, target_labels)

            accuracy_g += batch_accuracy_g
            accuracy_v += batch_accuracy_v
            accuracy_c += batch_accuracy_c

            loss_train.backward()
            optimizer.step()
            
            
            if (i+1)%log_interval == 0: 
                print("epoch {:4d}, loss: {:.4f}, grapheme: {:.4f}, vowel: {:.4f}, consonant: {:.4f}".format(
                    epoch,
                    total_loss / (log_interval*batch_size),
                    accuracy_g / (log_interval*batch_size),
                    accuracy_v / (log_interval*batch_size),
                    accuracy_c / (log_interval*batch_size)))
                
                print(step)
                acc_g, acc_v, acc_c, overall = validate(testloader, model, batch=True)
                tb_writer.add_scalar("loss", loss_train, step)
                tb_writer.add_scalar("grapheme-root-accuracy", acc_g, step)
                tb_writer.add_scalar("vowel-diatric-accuracy", acc_v, step)
                tb_writer.add_scalar("consonant_diatric-accuracy", acc_c, step)
                tb_writer.add_scalar("accuracy", overall, step)
                
            step += 1 
                

In [7]:
train()

epoch    1, loss: 0.0632, grapheme: 0.0002, vowel: 0.0024, consonant: 0.0049
99

On Val set Accuracy: grapheme root: 4/128 (3.12%) vowel diatric: 40/66278 (31.25%) consonant diatric: 82/66278 (64.06%) overall: (25.39%) 

epoch    1, loss: 0.0649, grapheme: 0.0004, vowel: 0.0044, consonant: 0.0097
199

On Val set Accuracy: grapheme root: 6/128 (4.69%) vowel diatric: 48/66278 (37.5%) consonant diatric: 87/66278 (67.97%) overall: (28.71%) 

epoch    1, loss: 0.0542, grapheme: 0.0007, vowel: 0.0081, consonant: 0.0148
299

On Val set Accuracy: grapheme root: 4/128 (3.12%) vowel diatric: 78/66278 (60.94%) consonant diatric: 87/66278 (67.97%) overall: (33.79%) 

epoch    1, loss: 0.0495, grapheme: 0.0011, vowel: 0.0127, consonant: 0.0201
399

On Val set Accuracy: grapheme root: 5/128 (3.91%) vowel diatric: 92/66278 (71.88%) consonant diatric: 104/66278 (81.25%) overall: (40.23%) 

epoch    1, loss: 0.0452, grapheme: 0.0016, vowel: 0.0181, consonant: 0.0257
499

On Val set Accuracy: grapheme r


On Val set Accuracy: grapheme root: 76/128 (59.38%) vowel diatric: 113/66278 (88.28%) consonant diatric: 116/66278 (90.62%) overall: (74.41%) 

epoch    4, loss: 0.0120, grapheme: 0.0401, vowel: 0.0582, consonant: 0.0580
3955

On Val set Accuracy: grapheme root: 82/128 (64.06%) vowel diatric: 116/66278 (90.62%) consonant diatric: 117/66278 (91.41%) overall: (77.54%) 

epoch    4, loss: 0.0121, grapheme: 0.0452, vowel: 0.0655, consonant: 0.0653
4055

On Val set Accuracy: grapheme root: 81/128 (63.28%) vowel diatric: 121/66278 (94.53%) consonant diatric: 116/66278 (90.62%) overall: (77.93%) 

epoch    4, loss: 0.0121, grapheme: 0.0504, vowel: 0.0729, consonant: 0.0726
4155

On Val set Accuracy: grapheme root: 86/128 (67.19%) vowel diatric: 116/66278 (90.62%) consonant diatric: 113/66278 (88.28%) overall: (78.32%) 

epoch    5, loss: 0.0101, grapheme: 0.0055, vowel: 0.0074, consonant: 0.0074
4307

On Val set Accuracy: grapheme root: 86/128 (67.19%) vowel diatric: 121/66278 (94.53%) conso

epoch    8, loss: 0.0067, grapheme: 0.0250, vowel: 0.0301, consonant: 0.0300
7763

On Val set Accuracy: grapheme root: 88/128 (68.75%) vowel diatric: 124/66278 (96.88%) consonant diatric: 123/66278 (96.09%) overall: (82.62%) 

epoch    8, loss: 0.0071, grapheme: 0.0312, vowel: 0.0376, consonant: 0.0375
7863

On Val set Accuracy: grapheme root: 94/128 (73.44%) vowel diatric: 123/66278 (96.09%) consonant diatric: 122/66278 (95.31%) overall: (84.57%) 

epoch    8, loss: 0.0070, grapheme: 0.0374, vowel: 0.0452, consonant: 0.0450
7963

On Val set Accuracy: grapheme root: 91/128 (71.09%) vowel diatric: 121/66278 (94.53%) consonant diatric: 116/66278 (90.62%) overall: (81.84%) 

epoch    8, loss: 0.0070, grapheme: 0.0436, vowel: 0.0527, consonant: 0.0525
8063

On Val set Accuracy: grapheme root: 87/128 (67.97%) vowel diatric: 117/66278 (91.41%) consonant diatric: 122/66278 (95.31%) overall: (80.66%) 

epoch    8, loss: 0.0071, grapheme: 0.0497, vowel: 0.0602, consonant: 0.0600
8163

On Val se


On Val set Accuracy: grapheme root: 105/128 (82.03%) vowel diatric: 122/66278 (95.31%) consonant diatric: 122/66278 (95.31%) overall: (88.67%) 

epoch   12, loss: 0.0042, grapheme: 0.0136, vowel: 0.0153, consonant: 0.0153
11771

On Val set Accuracy: grapheme root: 97/128 (75.78%) vowel diatric: 124/66278 (96.88%) consonant diatric: 121/66278 (94.53%) overall: (85.74%) 

epoch   12, loss: 0.0042, grapheme: 0.0204, vowel: 0.0229, consonant: 0.0229
11871

On Val set Accuracy: grapheme root: 106/128 (82.81%) vowel diatric: 121/66278 (94.53%) consonant diatric: 124/66278 (96.88%) overall: (89.26%) 

epoch   12, loss: 0.0046, grapheme: 0.0271, vowel: 0.0305, consonant: 0.0305
11971

On Val set Accuracy: grapheme root: 89/128 (69.53%) vowel diatric: 124/66278 (96.88%) consonant diatric: 124/66278 (96.88%) overall: (83.2%) 

epoch   12, loss: 0.0048, grapheme: 0.0338, vowel: 0.0382, consonant: 0.0381
12071

On Val set Accuracy: grapheme root: 98/128 (76.56%) vowel diatric: 122/66278 (95.31%) 


On Val set Accuracy: grapheme root: 108/128 (84.38%) vowel diatric: 122/66278 (95.31%) consonant diatric: 125/66278 (97.66%) overall: (90.43%) 

epoch   15, loss: 0.0039, grapheme: 0.0554, vowel: 0.0612, consonant: 0.0611
15527

On Val set Accuracy: grapheme root: 105/128 (82.03%) vowel diatric: 122/66278 (95.31%) consonant diatric: 124/66278 (96.88%) overall: (89.06%) 

epoch   15, loss: 0.0037, grapheme: 0.0623, vowel: 0.0689, consonant: 0.0688
15627

On Val set Accuracy: grapheme root: 91/128 (71.09%) vowel diatric: 122/66278 (95.31%) consonant diatric: 121/66278 (94.53%) overall: (83.01%) 

epoch   15, loss: 0.0040, grapheme: 0.0691, vowel: 0.0765, consonant: 0.0764
15727

On Val set Accuracy: grapheme root: 97/128 (75.78%) vowel diatric: 122/66278 (95.31%) consonant diatric: 119/66278 (92.97%) overall: (84.96%) 

epoch   16, loss: 0.0028, grapheme: 0.0071, vowel: 0.0077, consonant: 0.0077
15879

On Val set Accuracy: grapheme root: 108/128 (84.38%) vowel diatric: 123/66278 (96.09%


On Val set Accuracy: grapheme root: 100/128 (78.12%) vowel diatric: 122/66278 (95.31%) consonant diatric: 126/66278 (98.44%) overall: (87.5%) 

epoch   19, loss: 0.0031, grapheme: 0.0287, vowel: 0.0308, consonant: 0.0308
19335

On Val set Accuracy: grapheme root: 100/128 (78.12%) vowel diatric: 123/66278 (96.09%) consonant diatric: 125/66278 (97.66%) overall: (87.5%) 

epoch   19, loss: 0.0031, grapheme: 0.0357, vowel: 0.0385, consonant: 0.0384
19435

On Val set Accuracy: grapheme root: 86/128 (67.19%) vowel diatric: 119/66278 (92.97%) consonant diatric: 124/66278 (96.88%) overall: (81.05%) 

epoch   19, loss: 0.0031, grapheme: 0.0428, vowel: 0.0462, consonant: 0.0461
19535

On Val set Accuracy: grapheme root: 99/128 (77.34%) vowel diatric: 122/66278 (95.31%) consonant diatric: 124/66278 (96.88%) overall: (86.72%) 

epoch   19, loss: 0.0031, grapheme: 0.0498, vowel: 0.0538, consonant: 0.0538
19635

On Val set Accuracy: grapheme root: 96/128 (75.0%) vowel diatric: 123/66278 (96.09%) co

In [8]:
torch.save(model.state_dict(), "multilabel_resnet_run2")

In [9]:
validate(testloader, model)


On Val set Accuracy: grapheme root: 51331/66278 (77.45%) vowel diatric: 63193/66278 (95.35%) consonant diatric: 63044/66278 (95.12%) overall: (86.34%) 



(77.44802196807387, 95.3453634690244, 95.12055282295785, 86.3404900570325)

In [10]:
sample = next(iter(testloader))

In [None]:
sample = next(iter(testloader))
output = model(sample['image'].to(device))
labels = sample['labels']['grapheme_labels']
pred = output['grapheme'].max(1, keepdim=True)[1].cpu()

In [None]:
n_classes = [168, 11, 7]

for j,i in enumerate(['grapheme','vowel','consonant']):
    labels = sample['labels'][i+'_labels']
    pred = output[i].max(1, keepdim=True)[1].cpu()

    print(sklearn.metrics.accuracy_score(labels, pred))


    # plt.figure(figsize=(12,12))
    cn_matrix = confusion_matrix(
        y_true=labels.numpy(),
        y_pred=pred,
        labels=list(range(n_classes[j])),
        normalize="true",
    )
    ConfusionMatrixDisplay(cn_matrix, list(range(n_classes[j]))).plot(
        include_values=False, xticks_rotation="vertical"
    )
    plt.title(i)

    plt.tight_layout()
    plt.show()

In [None]:
sample['labels']

In [None]:
print("hi")

In [None]:
cn_matrix