In [45]:
# from model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import pandas as pd 

class MultiOutputModel(nn.Module):
    def __init__(self, n_grapheme_classes, n_vowel_classes, n_consonant_classes):
        super().__init__()
        self.base_model = models.mobilenet_v2().features  # take the model without classifier
        self.base_model[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        last_channel = models.mobilenet_v2().last_channel  # size of the layer before classifier

        # the input for the classifier should be two-dimensional, but we will have
        # [batch_size, channels, width, height]
        # so, let's do the spatial averaging: reduce width and height to 1
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # create separate classifiers for our outputs
        self.grapheme = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_grapheme_classes)
        )
        self.vowel = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_vowel_classes)
        )
        self.consonant = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_consonant_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)

        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = torch.flatten(x, 1)

        return {
            'grapheme': self.grapheme(x),
            'vowel': self.vowel(x),
            'consonant': self.consonant(x)
        }

    def get_loss(self, net_output, ground_truth):
        color_loss = F.cross_entropy(net_output['grapheme'], ground_truth['grapheme_labels'])
        gender_loss = F.cross_entropy(net_output['vowel'], ground_truth['vowel_labels'])
        article_loss = F.cross_entropy(net_output['consonant'], ground_truth['consonant_labels'])
        loss = color_loss + gender_loss + article_loss
        return loss, {'grapheme': color_loss, 'vowel': gender_loss, 'consonant': article_loss}

In [28]:
m = models.mobilenet_v2()

In [18]:
class BengaliDatasetMultiClass(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.label_df = pd.read_csv(csv_file)
        self.label_df = self.label_df[['image_id','grapheme_root',
                             'vowel_diacritic','consonant_diacritic',
                             'label','grapheme','textlabel']]
        
        self.root_dir = root_dir 
        self.transform = transform 
    
    def __len__(self):
        return len(self.label_df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_name = os.path.join(self.root_dir,
                                self.label_df.iloc[idx, 0] + '.png')
        image = Image.open(img_name).convert('L')

        label = tuple(self.label_df.iloc[idx, 1:4])
        label = torch.tensor(label)
        textlabel = self.label_df.iloc[idx, -1]  
        
        if self.transform:
            image = self.transform(image)
            
        sample = {
            "image": image, 
            "grapheme_root": label[0],
            "vowel_diacritic": label[1],
            "consonent_diacritic": label[2]
        }

        return sample

In [19]:
batch_size=4
transform=transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
])

trainset = BengaliDatasetMultiClass("data/train.csv","data/trainsplit", transform)
testset = BengaliDatasetMultiClass("data/test.csv","data/testsplit", transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                         shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=True, num_workers=0)

# df = pd.read_csv("data/train.csv")
# classes = list(range(df.label.max()+1))
# n_classes = len(classes)

In [47]:
samples = next(iter(trainloader))

In [48]:
net = MultiOutputModel(168, 11,8)

In [51]:
net(samples["image"])["grapheme"].shape

torch.Size([4, 168])