In [1]:
# from model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import pandas as pd 
import os 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class MultiOutputModel(nn.Module):
    def __init__(self, n_grapheme_classes, n_vowel_classes, n_consonant_classes):
        super().__init__()
        self.base_model = models.mobilenet_v2().features  # take the model without classifier
        self.base_model[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        last_channel = models.mobilenet_v2().last_channel  # size of the layer before classifier

        # the input for the classifier should be two-dimensional, but we will have
        # [batch_size, channels, width, height]
        # so, let's do the spatial averaging: reduce width and height to 1
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # create separate classifiers for our outputs
        self.grapheme = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_grapheme_classes)
        )
        self.vowel = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_vowel_classes)
        )
        self.consonant = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_consonant_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)

        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = torch.flatten(x, 1)

        return {
            'grapheme': self.grapheme(x),
            'vowel': self.vowel(x),
            'consonant': self.consonant(x)
        }

    def get_loss(self, net_output, ground_truth):
        color_loss = F.cross_entropy(net_output['grapheme'], ground_truth['grapheme_labels'])
        gender_loss = F.cross_entropy(net_output['vowel'], ground_truth['vowel_labels'])
        article_loss = F.cross_entropy(net_output['consonant'], ground_truth['consonant_labels'])
        loss = color_loss + gender_loss + article_loss
        return loss, {'grapheme': color_loss, 'vowel': gender_loss, 'consonant': article_loss}

cuda:0


In [2]:
class BengaliDatasetMultiClass(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.label_df = pd.read_csv(csv_file)
        self.label_df = self.label_df[['image_id','grapheme_root',
                             'vowel_diacritic','consonant_diacritic',
                             'label','grapheme','textlabel']]
        
        self.root_dir = root_dir 
        self.transform = transform 
    
    def __len__(self):
        return len(self.label_df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_name = os.path.join(self.root_dir,
                                self.label_df.iloc[idx, 0] + '.png')
        image = Image.open(img_name).convert('L')

        label = tuple(self.label_df.iloc[idx, 1:])
        textlabel = self.label_df.iloc[idx, -1]  
        
        if self.transform:
            image = self.transform(image)
            
        sample = {
            "image": image,
            "labels": {
                "grapheme_labels": label[0],
                "vowel_labels": label[1],
                "consonant_labels": label[2],
            },
            "human_labels":{
                "typeface":label[4],
                "stringlabel":label[5]
            }
        }

        return sample

In [3]:
batch_size=128
num_workers = 16

transform=transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
])

trainset = BengaliDatasetMultiClass("data/train.csv","data/trainsplit", transform)
testset = BengaliDatasetMultiClass("data/test.csv","data/testsplit", transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                         shuffle=True, num_workers=num_workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=True, num_workers=num_workers)

In [4]:
samples = next(iter(trainloader))
model = MultiOutputModel(168, 11, 7).to(device)

In [5]:
model

MultiOutputModel(
  (base_model): Sequential(
    (0): ConvBNReLU(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, af

In [6]:
import warnings
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score

def calculate_metrics(output, target):
    _, predicted_color = output['grapheme'].cpu().max(1)
    gt_color = target['grapheme_labels'].cpu()

    _, predicted_gender = output['vowel'].cpu().max(1)
    gt_gender = target['vowel_labels'].cpu()

    _, predicted_article = output['consonant'].cpu().max(1)
    gt_article = target['consonant_labels'].cpu()

    with warnings.catch_warnings():  # sklearn may produce a warning when processing zero row in confusion matrix
        warnings.simplefilter("ignore")
        accuracy_color = accuracy_score(y_true=gt_color.numpy(), y_pred=predicted_color.numpy())
        accuracy_gender = accuracy_score(y_true=gt_gender.numpy(), y_pred=predicted_gender.numpy())
        accuracy_article = accuracy_score(y_true=gt_article.numpy(), y_pred=predicted_article.numpy())

    return accuracy_color, accuracy_gender, accuracy_article

In [7]:
# from train.py

import argparse
import os
from datetime import datetime

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter


def get_cur_time():
    return datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M')


def checkpoint_save(model, name, epoch):
    f = os.path.join(name, 'checkpoint-{:06d}.pth'.format(epoch))
    torch.save(model.state_dict(), f)
    print('Saved checkpoint:', f)

    return f
    

def train(start_epoch=1, N_epochs=20, log_interval=100):

    optimizer = torch.optim.Adam(model.parameters())


    for epoch in range(start_epoch, N_epochs + 1):
        total_loss = 0
        accuracy_color = 0
        accuracy_gender = 0
        accuracy_article = 0

        for i, batch in enumerate(trainloader):
            optimizer.zero_grad()

            img = batch['image']
            target_labels = batch['labels']
            target_labels = {t: target_labels[t].to(device) for t in target_labels}
            output = model(img.to(device))

            loss_train, losses_train = model.get_loss(output, target_labels)
            total_loss += loss_train.item()
            batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \
                calculate_metrics(output, target_labels)

            accuracy_color += batch_accuracy_color
            accuracy_gender += batch_accuracy_gender
            accuracy_article += batch_accuracy_article

            loss_train.backward()
            optimizer.step()
            
            if (i+1)%log_interval == 0: 
                print("epoch {:4d}, loss: {:.4f}, grapheme: {:.4f}, vowel: {:.4f}, consonant: {:.4f}".format(
                    epoch,
                    total_loss / (log_interval*batch_size),
                    accuracy_color / (log_interval*batch_size),
                    accuracy_gender / (log_interval*batch_size),
                    accuracy_article / (log_interval*batch_size)))
                
                # Reset stats 
                total_loss = 0
                accuracy_color = 0
                accuracy_gender = 0
                accuracy_article = 0



In [8]:
train()

epoch    1, loss: 0.0649, grapheme: 0.0002, vowel: 0.0017, consonant: 0.0048
epoch    1, loss: 0.0606, grapheme: 0.0002, vowel: 0.0024, consonant: 0.0049
epoch    1, loss: 0.0565, grapheme: 0.0003, vowel: 0.0033, consonant: 0.0050
epoch    1, loss: 0.0539, grapheme: 0.0003, vowel: 0.0039, consonant: 0.0051
epoch    1, loss: 0.0512, grapheme: 0.0004, vowel: 0.0044, consonant: 0.0052
epoch    1, loss: 0.0490, grapheme: 0.0004, vowel: 0.0047, consonant: 0.0054
epoch    1, loss: 0.0471, grapheme: 0.0005, vowel: 0.0050, consonant: 0.0055
epoch    1, loss: 0.0451, grapheme: 0.0005, vowel: 0.0053, consonant: 0.0057
epoch    1, loss: 0.0430, grapheme: 0.0006, vowel: 0.0055, consonant: 0.0058
epoch    1, loss: 0.0413, grapheme: 0.0007, vowel: 0.0058, consonant: 0.0059
epoch    2, loss: 0.0390, grapheme: 0.0008, vowel: 0.0060, consonant: 0.0060
epoch    2, loss: 0.0374, grapheme: 0.0009, vowel: 0.0061, consonant: 0.0062
epoch    2, loss: 0.0362, grapheme: 0.0011, vowel: 0.0062, consonant: 0.0062

epoch   11, loss: 0.0108, grapheme: 0.0054, vowel: 0.0073, consonant: 0.0073
epoch   11, loss: 0.0107, grapheme: 0.0054, vowel: 0.0073, consonant: 0.0073
epoch   11, loss: 0.0108, grapheme: 0.0054, vowel: 0.0073, consonant: 0.0073
epoch   12, loss: 0.0095, grapheme: 0.0056, vowel: 0.0074, consonant: 0.0074
epoch   12, loss: 0.0095, grapheme: 0.0056, vowel: 0.0074, consonant: 0.0074
epoch   12, loss: 0.0102, grapheme: 0.0055, vowel: 0.0074, consonant: 0.0074
epoch   12, loss: 0.0097, grapheme: 0.0056, vowel: 0.0074, consonant: 0.0074
epoch   12, loss: 0.0102, grapheme: 0.0055, vowel: 0.0074, consonant: 0.0073
epoch   12, loss: 0.0101, grapheme: 0.0055, vowel: 0.0074, consonant: 0.0074
epoch   12, loss: 0.0102, grapheme: 0.0055, vowel: 0.0074, consonant: 0.0074
epoch   12, loss: 0.0100, grapheme: 0.0055, vowel: 0.0074, consonant: 0.0074
epoch   12, loss: 0.0103, grapheme: 0.0055, vowel: 0.0074, consonant: 0.0074
epoch   12, loss: 0.0104, grapheme: 0.0055, vowel: 0.0074, consonant: 0.0073

In [9]:
samples = next(iter(trainloader))
output = model(samples['image'].to(device))

In [10]:
calculate_metrics(output, samples["labels"])


(0.8046875, 0.9765625, 0.9609375)

In [11]:
torch.save(model.state_dict(), "multilabel_mobilenet_class_fixed")

In [12]:
df = pd.read_csv("data/train-test.csv")

In [13]:
df

Unnamed: 0.1,Unnamed: 0,image_id,grapheme_root,vowel_diacritic,consonant_diacritic,grapheme,label,type,textlabel
0,0,Train_0,15,9,5,ক্ট্রো,64,test,test\ngr=15 vd=9 cd=5\n64
1,1,Train_1,159,0,0,হ,1243,train,train\ngr=159 vd=0 cd=0\n1243
2,2,Train_2,22,3,5,খ্রী,107,train,train\ngr=22 vd=3 cd=5\n107
3,3,Train_3,53,2,2,র্টি,333,train,train\ngr=53 vd=2 cd=2\n333
4,4,Train_4,71,9,5,থ্রো,504,test,test\ngr=71 vd=9 cd=5\n504
...,...,...,...,...,...,...,...,...,...
200835,200835,Train_200835,22,7,2,র্খে,112,train,train\ngr=22 vd=7 cd=2\n112
200836,200836,Train_200836,65,9,0,ত্তো,462,test,test\ngr=65 vd=9 cd=0\n462
200837,200837,Train_200837,2,1,4,অ্যা,3,test,test\ngr=2 vd=1 cd=4\n3
200838,200838,Train_200838,152,9,0,স্নো,1213,test,test\ngr=152 vd=9 cd=0\n1213


In [14]:
df.grapheme_root.max()

167

In [15]:
df.vowel_diacritic.max()

10

In [16]:
df.consonant_diacritic.max()

6