In [87]:
# from model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import pandas as pd 

class MultiOutputModel(nn.Module):
    def __init__(self, n_grapheme_classes, n_vowel_classes, n_consonant_classes):
        super().__init__()
        self.base_model = models.mobilenet_v2().features  # take the model without classifier
        self.base_model[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        last_channel = models.mobilenet_v2().last_channel  # size of the layer before classifier

        # the input for the classifier should be two-dimensional, but we will have
        # [batch_size, channels, width, height]
        # so, let's do the spatial averaging: reduce width and height to 1
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # create separate classifiers for our outputs
        self.grapheme = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_grapheme_classes)
        )
        self.vowel = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_vowel_classes)
        )
        self.consonant = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_consonant_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)

        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = torch.flatten(x, 1)

        return {
            'grapheme': self.grapheme(x),
            'vowel': self.vowel(x),
            'consonant': self.consonant(x)
        }

    def get_loss(self, net_output, ground_truth):
        color_loss = F.cross_entropy(net_output['grapheme'], ground_truth['grapheme_labels'])
        gender_loss = F.cross_entropy(net_output['vowel'], ground_truth['vowel_labels'])
        article_loss = F.cross_entropy(net_output['consonant'], ground_truth['consonant_labels'])
        loss = color_loss + gender_loss + article_loss
        return loss, {'grapheme': color_loss, 'vowel': gender_loss, 'consonant': article_loss}

In [96]:
class BengaliDatasetMultiClass(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.label_df = pd.read_csv(csv_file)
        self.label_df = self.label_df[['image_id','grapheme_root',
                             'vowel_diacritic','consonant_diacritic',
                             'label','grapheme','textlabel']]
        
        self.root_dir = root_dir 
        self.transform = transform 
    
    def __len__(self):
        return len(self.label_df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_name = os.path.join(self.root_dir,
                                self.label_df.iloc[idx, 0] + '.png')
        image = Image.open(img_name).convert('L')

        label = tuple(self.label_df.iloc[idx, 1:4])
        label = torch.tensor(label)
        textlabel = self.label_df.iloc[idx, -1]  
        
        if self.transform:
            image = self.transform(image)
            
        sample = {
            "image": image,
            "labels": {
                "grapheme_labels": label[0],
                "vowel_labels": label[1],
                "consonant_labels": label[2]
            }
        }

        return sample

In [97]:
batch_size=4
transform=transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
])

trainset = BengaliDatasetMultiClass("data/train.csv","data/trainsplit", transform)
testset = BengaliDatasetMultiClass("data/test.csv","data/testsplit", transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                         shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=True, num_workers=0)

In [90]:
samples = next(iter(trainloader))

In [91]:
model = MultiOutputModel(168, 11,8)

In [92]:
net(samples["image"])["grapheme"].shape

torch.Size([4, 168])

In [105]:
import warnings
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score

def calculate_metrics(output, target):
    _, predicted_color = output['grapheme'].cpu().max(1)
    gt_color = target['grapheme_labels'].cpu()

    _, predicted_gender = output['vowel'].cpu().max(1)
    gt_gender = target['vowel_labels'].cpu()

    _, predicted_article = output['consonant'].cpu().max(1)
    gt_article = target['consonant_labels'].cpu()

    with warnings.catch_warnings():  # sklearn may produce a warning when processing zero row in confusion matrix
        warnings.simplefilter("ignore")
        accuracy_color = accuracy_score(y_true=gt_color.numpy(), y_pred=predicted_color.numpy())
        accuracy_gender = accuracy_score(y_true=gt_gender.numpy(), y_pred=predicted_gender.numpy())
        accuracy_article = accuracy_score(y_true=gt_article.numpy(), y_pred=predicted_article.numpy())

    return accuracy_color, accuracy_gender, accuracy_article

In [109]:
# from train.py

import argparse
import os
from datetime import datetime

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


def get_cur_time():
    return datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M')


def checkpoint_save(model, name, epoch):
    f = os.path.join(name, 'checkpoint-{:06d}.pth'.format(epoch))
    torch.save(model.state_dict(), f)
    print('Saved checkpoint:', f)

    return f
    

def train(start_epoch=1, N_epochs=1, batch_size=16, num_workers=8):

    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(start_epoch, N_epochs + 1):
        total_loss = 0
        accuracy_color = 0
        accuracy_gender = 0
        accuracy_article = 0

        for batch in trainloader:
            print("hello")
            optimizer.zero_grad()

            img = batch['image']
            target_labels = batch['labels']
            target_labels = {t: target_labels[t].to(device) for t in target_labels}
            output = model(img.to(device))

            loss_train, losses_train = model.get_loss(output, target_labels)
            total_loss += loss_train.item()
            batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \
                calculate_metrics(output, target_labels)

            accuracy_color += batch_accuracy_color
            accuracy_gender += batch_accuracy_gender
            accuracy_article += batch_accuracy_article

            loss_train.backward()
            optimizer.step()
            
        print("epoch {:4d}, loss: {:.4f}, color: {:.4f}, gender: {:.4f}, article: {:.4f}".format(
            epoch,
            total_loss / n_train_samples,
            accuracy_color / n_train_samples,
            accuracy_gender / n_train_samples,
            accuracy_article / n_train_samples))


In [110]:
train()

hello
{'grapheme_labels': tensor([ 86,  57, 160, 113]), 'vowel_labels': tensor([0, 2, 2, 1]), 'consonant_labels': tensor([0, 0, 0, 4])}
hello
{'grapheme_labels': tensor([113, 148, 133, 150]), 'vowel_labels': tensor([5, 7, 2, 4]), 'consonant_labels': tensor([2, 0, 4, 0])}
hello
{'grapheme_labels': tensor([ 79, 142,  82,  13]), 'vowel_labels': tensor([3, 2, 4, 1]), 'consonant_labels': tensor([0, 0, 0, 5])}
hello
{'grapheme_labels': tensor([81, 72, 71, 56]), 'vowel_labels': tensor([7, 8, 0, 2]), 'consonant_labels': tensor([2, 0, 3, 5])}
hello
{'grapheme_labels': tensor([ 42,  64,  96, 134]), 'vowel_labels': tensor([ 4,  8, 10,  4]), 'consonant_labels': tensor([1, 5, 1, 0])}
hello
{'grapheme_labels': tensor([155, 129, 149,  44]), 'vowel_labels': tensor([7, 3, 2, 3]), 'consonant_labels': tensor([0, 0, 5, 0])}
hello
{'grapheme_labels': tensor([72, 57, 88, 10]), 'vowel_labels': tensor([3, 1, 0, 0]), 'consonant_labels': tensor([2, 0, 0, 0])}
hello
{'grapheme_labels': tensor([79, 13, 52, 59]), 

KeyboardInterrupt: 

In [69]:
device = 'cpu'