In [1]:
%load_ext autoreload

In [2]:
%autoreload 2
from src.datagen import BengaliGraphemes
from src.models import densenet, resnet50

In [3]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd

from sklearn.metrics import precision_recall_curve
from sklearn.metrics import recall_score
from torchvision import transforms
from PIL import Image

from tqdm import tqdm

In [13]:
device = torch.device("cuda")

In [16]:
dataloader = BengaliGraphemes(128, csv_file="data/train.csv", root_dir="/media/data/bengaliai-cv19/cropped/")

Index(['image_id', 'grapheme_root', 'vowel_diacritic', 'consonant_diacritic',
       'grapheme'],
      dtype='object')
200840


In [7]:
gcriterion = torch.nn.CrossEntropyLoss()
vcriterion = torch.nn.CrossEntropyLoss()
ccriterion = torch.nn.CrossEntropyLoss()

In [8]:
def crop(img, thres=230):
    xmin, xmax = 10, img.shape[1]-10
    ymin, ymax = 10, img.shape[0]-10
    
    #threshold: Is there anything below 230 (darkwhite) in this row/column
    # argmax: where is the(first) of those values 
    rows = np.argmax((img[10:-10, 10:-10]<thres), axis=1)
    cols = np.argmax((img[10:-10, 10:-10]<thres), axis=0)
    
    while xmin < len(cols)-30 and cols[xmin-10] == 0: 
        xmin += 1
        
    while xmax > xmin+30 and cols[xmax-20] == 0: 
        xmax -= 1        
        
    while ymin < len(rows-30) and rows[ymin-10] == 0: 
        ymin += 1
    
    while ymax > ymin+30 and rows[ymax-20] == 0: 
        ymax -= 1
        
    if xmin < 10: xmin-=10 
    
    
    return xmin, xmax, ymin, ymax

In [10]:
class PandasDataloaderWithPreprocessing(torch.utils.data.Dataset): 
    def __init__(self, df): 
        self.df = df
        self.mean = 226.83368
        self.std = 59.658222
        self.transforms=transforms.Compose([
            transforms.Resize((128,128)),
            transforms.ToTensor()
        ])
        
    def __len__(self): 
        return len(self.df)
    
    def __getitem__(self, idx): 
        _img = np.reshape(self.df.iloc[idx, 1:].values.astype(np.int32),(137, 236))
        xmin, xmax, ymin, ymax = crop(_img, thres=200)        
        image = _img[ymin: ymax, xmin: xmax]
        image = Image.fromarray(image.astype(np.int32))
        width, height = image.size 
        if width > height:
            image = transforms.Pad((0,(width-height)//2), padding_mode="constant", fill=255)(image)
        elif height > width:
            image = transforms.Pad(((height-width)//2,0), padding_mode="constant", fill=255)(image)
            
        image = self.transforms(image)       
        image = (image-self.mean)/self.std
        return image, self.df.iloc[idx,0]
        

In [50]:
model = torch.load("models/started_26-02-2020_/bengali-ai_densenet121-weighted-loss_35_Linear_0.001-0.005.25")
model.to(device)                   
model.eval()

DenseNet(
  (relu): ReLU(inplace=True)
  (conv1): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (densenet): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.

In [51]:
results = []
for i in range(4): 
    df = pd.read_parquet(f"/media/data/bengaliai-cv19/test_image_data_{i}.parquet")
    dl = torch.utils.data.DataLoader(
        PandasDataloaderWithPreprocessing(df),
        batch_size=128,
        shuffle=False)

          
    for inputs, imname in dl:
        inputs = inputs.to(device)
        
        grapheme_out, vowel_out, cons_out = model(inputs)
        
        _, grapheme_preds = torch.max(grapheme_out, 1)
        _, vowel_preds = torch.max(vowel_out, 1)
        _, cons_preds = torch.max(cons_out, 1)
        
        for ix, name in enumerate(imname): 
            results.append(f"{name}_grapheme_root, {grapheme_preds.data.cpu()[ix]}")
            results.append(f"{name}_vowel_diacritic, {vowel_preds.data.cpu()[ix]}")
            results.append(f"{name}_consonant_diacritic, {cons_preds.data.cpu()[ix]}")
            
print(results)

  labels = getattr(columns, 'labels', None) or [
  return pd.MultiIndex(levels=new_levels, labels=labels, names=columns.names)
  labels, = index.labels


['Test_0_grapheme_root, 3', 'Test_0_vowel_diacritic, 0', 'Test_0_consonant_diacritic, 0', 'Test_1_grapheme_root, 93', 'Test_1_vowel_diacritic, 2', 'Test_1_consonant_diacritic, 0', 'Test_2_grapheme_root, 19', 'Test_2_vowel_diacritic, 0', 'Test_2_consonant_diacritic, 0', 'Test_3_grapheme_root, 115', 'Test_3_vowel_diacritic, 0', 'Test_3_consonant_diacritic, 0', 'Test_4_grapheme_root, 79', 'Test_4_vowel_diacritic, 4', 'Test_4_consonant_diacritic, 0', 'Test_5_grapheme_root, 115', 'Test_5_vowel_diacritic, 2', 'Test_5_consonant_diacritic, 0', 'Test_6_grapheme_root, 147', 'Test_6_vowel_diacritic, 9', 'Test_6_consonant_diacritic, 5', 'Test_7_grapheme_root, 137', 'Test_7_vowel_diacritic, 7', 'Test_7_consonant_diacritic, 0', 'Test_8_grapheme_root, 119', 'Test_8_vowel_diacritic, 9', 'Test_8_consonant_diacritic, 0', 'Test_9_grapheme_root, 133', 'Test_9_vowel_diacritic, 9', 'Test_9_consonant_diacritic, 0', 'Test_10_grapheme_root, 148', 'Test_10_vowel_diacritic, 1', 'Test_10_consonant_diacritic, 0', 

In [52]:


cpreds, vpreds, gpreds = [],[],[]
clabels, vlabels, glabels = [],[],[]

for phase in ["val"]:
    
    grapheme_running_loss = 0.0
    vowel_running_loss = 0.0
    cons_running_loss = 0.0    
    grapheme_running_corrects = 0.0
    vowel_running_corrects = 0.0
    cons_running_corrects = 0.0 
    
    for batch, (inputs, labels) in tqdm(enumerate(dataloader[phase]), total=dataloader.sizes[phase]//128+1):
        glabels += list(labels[:,0].data.numpy())
        vlabels += list(labels[:,1].data.numpy())
        clabels += list(labels[:,2].data.numpy())

        inputs, labels = inputs.to(device), labels.to(device)

        with torch.set_grad_enabled(False): 
            grapheme_out, vowel_out, cons_out = model(inputs)

            _, grapheme_preds = torch.max(grapheme_out, 1)
            _, vowel_preds = torch.max(vowel_out, 1)
            _, cons_preds = torch.max(cons_out, 1)

            gpreds += list(grapheme_preds.data.cpu().numpy())
            vpreds += list(vowel_preds.data.cpu().numpy())
            cpreds += list(cons_preds.data.cpu().numpy())

            grapheme_loss = gcriterion(grapheme_out, torch.flatten(labels[:,0]))
            vowel_loss = vcriterion(vowel_out, torch.flatten(labels[:,1]))
            cons_loss = ccriterion(cons_out, torch.flatten(labels[:,2]))
            
        grapheme_running_loss += grapheme_loss.item() * inputs.size(0)
        vowel_running_loss += vowel_loss.item() * inputs.size(0)
        cons_running_loss += cons_loss.item() * inputs.size(0)

        grapheme_running_corrects += torch.sum(grapheme_preds==labels.data[:,0])
        vowel_running_corrects += torch.sum(vowel_preds == labels[:,1])
        cons_running_corrects += torch.sum(cons_preds==labels[:,2])

    epoch_loss_interm = [l/dataloader.sizes[phase] for l in [grapheme_running_loss, vowel_running_loss, cons_running_loss]]
    epoch_acc_interm = [l.item()/dataloader.sizes[phase] for l in [grapheme_running_corrects, vowel_running_corrects, cons_running_corrects]]

    print(f'{phase} loss: {" ".join([f"{iv:{.4}}" for iv in epoch_loss_interm])} acc: {" ".join([f"{iv:{.4}}" for iv in epoch_acc_interm])} {" "*20}')

100%|██████████| 318/318 [00:16<00:00, 19.72it/s]

val loss: 0.3697 0.166 0.1681 acc: 0.911 0.9615 0.9582                     





In [53]:
recall_score(glabels, gpreds, average="macro")

0.8638533077140358

In [54]:
recall_score(vlabels, vpreds, average="macro")

0.9207927609783297

In [55]:
recall_score(clabels, cpreds, average="macro")

0.8247982791245239

In [56]:
np.average((recall_score(glabels, gpreds, average="macro"), recall_score(vlabels, vpreds, average="macro"), recall_score(clabels, cpreds, average="macro")), weights = [2,1,1])

0.8683244138827313

In [24]:
"models/started_25-02-2020_01-42-08/bengali-ai_densenet201-notweighted_50_Linear_0.0002-0.005.49"
0.9795216265513096

"models/started_26-02-2020_11-04-22/bengali-ai_densenet201-notweighted_25_Linear_0.001-0.005.24"
0.972985252822109

0.9484540085894461