In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# ResNet34 inference

In [1]:
import albumentations
import gc
import numpy as np
import pandas as pd
import pretrainedmodels
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from pathlib import Path
from torch.utils.data import DataLoader
from tqdm import tqdm

## Constants

In [2]:
TEST_BATCH_SIZE = 64
IMG_HEIGHT = 137
IMG_WIDTH = 236
MODEL_MEAN = (0.485,0.465,0.406)
MODEL_STD = (0.229,0.224,0.225)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = Path('../input')

In [3]:
!ls {DATA_PATH}

class_map.csv		   train_folds.csv
image_pickles		   train_image_data_0.feather
sample_submission.csv	   train_image_data_0.parquet
test.csv		   train_image_data_1.feather
test_image_data_0.parquet  train_image_data_1.parquet
test_image_data_1.parquet  train_image_data_2.feather
test_image_data_2.parquet  train_image_data_2.parquet
test_image_data_3.parquet  train_image_data_3.feather
train.csv		   train_image_data_3.parquet


## Dataset

In [4]:
class BengaliDatasetTest:
    def __init__(self, df, img_height, img_width, mean, std):
        self.image_ids = df['image_id'].values
        self.image_array = df.drop('image_id', axis=1).values
        
        self.aug = albumentations.Compose([
            albumentations.Resize(img_height, img_width, always_apply=True),
            albumentations.Normalize(mean, std, always_apply=True),
        ])
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        
        image = self.image_array[idx]
        image = image.reshape(137, 236).astype('float')
        image = Image.fromarray(image).convert('RGB')
        
        # Apply augmentation
        image = self.aug(image=np.array(image))['image']
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        
        return {
            'image_id': image_id,
            'image': torch.tensor(image, dtype=torch.float)
        }

## Model

In [5]:
class ResNet34(nn.Module):
    def __init__(self, pretrained):
        super(ResNet34, self).__init__()
        if pretrained:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained='imagenet')
        else:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained=None)
        
        self.l0 = nn.Linear(512, 168)
        self.l1 = nn.Linear(512, 11)
        self.l2 = nn.Linear(512, 7)
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        out0 = self.l0(x)
        out1 = self.l1(x)
        out2 = self.l2(x)
        return out0, out1, out2

In [6]:
model = ResNet34(pretrained=False)

## Inference

In [7]:
df = pd.read_feather(DATA_PATH / 'train_image_data_0.feather'); df.head()

Unnamed: 0,image_id,0,1,2,3,4,5,6,7,8,...,32322,32323,32324,32325,32326,32327,32328,32329,32330,32331
0,Train_0,254,253,252,253,251,252,253,251,251,...,253,253,253,253,253,253,253,253,253,251
1,Train_1,251,244,238,245,248,246,246,247,251,...,255,255,255,255,255,255,255,255,255,254
2,Train_2,251,250,249,250,249,245,247,252,252,...,254,253,252,252,253,253,253,253,251,249
3,Train_3,247,247,249,253,253,252,251,251,250,...,254,254,254,254,254,253,253,252,251,252
4,Train_4,249,248,246,246,248,244,242,242,229,...,255,255,255,255,255,255,255,255,255,255


In [46]:
df_train = pd.read_csv(DATA_PATH / 'train.csv'); df_train.head()

Unnamed: 0,image_id,grapheme_root,vowel_diacritic,consonant_diacritic,grapheme
0,Train_0,15,9,5,ক্ট্রো
1,Train_1,159,0,0,হ
2,Train_2,22,3,5,খ্রী
3,Train_3,53,2,2,র্টি
4,Train_4,71,9,5,থ্রো


In [52]:
labels = {
    'grapheme_root': df_train.loc[:50209,'grapheme_root'].values,
    'vowel_diacritic': df_train.loc[:50209, 'vowel_diacritic'].values,
    'consonant_diacritic': df_train.loc[:50209, 'consonant_diacritic'].values,
    'image_id': df_train.loc[:50209, 'image_id'].values
}

In [53]:
del df_train
gc.collect()

6320

In [11]:
def get_predictions(model, df):
    g_logits, v_logits, c_logits, image_id_list = [], [], [], []
    
    dataset = BengaliDatasetTest(
        df=df,
        img_height=IMG_HEIGHT,
        img_width=IMG_WIDTH,
        mean=MODEL_MEAN,
        std=MODEL_STD)

    dataloader = DataLoader(
        dataset=dataset,
        batch_size=TEST_BATCH_SIZE,
        shuffle=False)
    
    model.eval()
    with torch.no_grad():
        for d in tqdm(dataloader):
            image_ids = d['image_id']
            images = d['image']
            images = images.to(DEVICE)
            
            g, v, c = model(images)
            
            for idx, image_id in enumerate(image_ids):
                image_id_list.append(image_id)
                g_logits.append(g[idx].cpu().detach().numpy())
                v_logits.append(v[idx].cpu().detach().numpy())
                c_logits.append(c[idx].cpu().detach().numpy())
    
    return g_logits, v_logits, c_logits, image_id_list

### Blend

In [12]:
g_logits_arr, v_logits_arr, c_logits_arr = [], [], []
image_ids = []

for fold_idx in range(3, 5):
    model.load_state_dict(torch.load(
        f'../src/weights/resnet34_fold{fold_idx}.pth'))
    model.to(DEVICE)
    
    g_logits, v_logits, c_logits, image_id_list = get_predictions(model, df)
    
    g_logits_arr.append(g_logits)
    v_logits_arr.append(v_logits)
    c_logits_arr.append(c_logits)
    
    if fold_idx == 0:
        image_ids.extend(image_id_list)

100%|██████████| 785/785 [01:10<00:00, 11.08it/s]
100%|██████████| 785/785 [01:11<00:00, 11.05it/s]


In [17]:
g_preds = np.argmax(np.mean(np.array(g_logits_arr), axis=0), axis=1)
v_preds = np.argmax(np.mean(np.array(v_logits_arr), axis=0), axis=1)
c_preds = np.argmax(np.mean(np.array(c_logits_arr), axis=0), axis=1)

In [55]:
total = 3 * len(g_preds)
correct = (g_preds == labels['grapheme_root']).sum()
correct += (v_preds == labels['vowel_diacritic']).sum()
correct += (c_preds == labels['consonant_diacritic']).sum()
correct / total

0.9955188209520016

In [56]:
torch.cuda.empty_cache()

In [57]:
gc.collect()

2034