In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import torch
import warnings
warnings.filterwarnings("ignore")

In [2]:
from torch.utils.data import Dataset

class GraphemeDataset(Dataset):
    def __init__(self, images, labels=None, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        self.train = labels is not None
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        image = image / image.max()
        if self.transform:
            image = self.transform(image)
        if self.train:
            label = self.labels[idx]
            return image, label[0], label[1], label[2]
        else:
            return image

In [3]:
from sklearn.metrics import recall_score

def get_recall(y_true, y_pred):
    pred_labels = np.argmax(y_pred, axis=1)
    res = recall_score(y_true, pred_labels, average='macro')
    return res

In [4]:
image_size = 128
batch_size = 320

in_dir = Path('../input/bengaliai-cv19')
feather_dir = Path('../input/bengaliai-cv19-feather')
out_dir = Path('')

In [6]:
from crop_resize import read_feathers

filenames = [feather_dir/f'train_image_data_1x{image_size}x{image_size}_{i}.feather' for i in range(4)]
images = read_feathers(filenames, image_size)
print(images.shape)

(200840, 1, 128, 128)


In [7]:
train_label = pd.read_csv(in_dir/'train.csv')
labels = train_label[['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']].values
nunique = list(train_label.nunique())[1:-1]
nunique

[168, 11, 7]

In [8]:
from my_efficientnet_pytorch import EfficientNet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = 'efficientnet-b3'
model_dir = Path('./20200313_pretrained_models')
pretrained_models = [model_dir/'base.pth',
                     model_dir/'base_rotate.pth',
                     model_dir/'base_mixup.pth',
                     model_dir/'base_cutmix.pth',
                     model_dir/'base_mixup_cutmix.pth']

<a id="inference"></a> 
# Inference

In [11]:
test_dataset = GraphemeDataset(images)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for pre_model in pretrained_models:
    print('model : ', str(pre_model))
    model = EfficientNet.from_my_pretrained(model_name, pre_model, in_channels=1).to(device)
    model.eval()

    predictions = np.zeros((len(test_dataset), sum(nunique)), dtype = 'float32')

    cnt = 0

    with torch.no_grad():
        for idx, inputs in tqdm(enumerate(test_loader), total=len(test_loader)):
            inputs = inputs.float().to(device)
            outputs1, outputs2, outputs3 = model(inputs)
            predictions[cnt:cnt + inputs.size()[0], 0:nunique[0]] = outputs1.cpu().numpy()
            predictions[cnt:cnt + inputs.size()[0], nunique[0]:sum(nunique[:2])] = outputs2.cpu().numpy()
            predictions[cnt:cnt + inputs.size()[0], sum(nunique[:2]):sum(nunique[:3])] = outputs3.cpu().numpy()
            cnt += inputs.size()[0]
            
    predictions = np.split(predictions, np.cumsum(nunique), axis=1)
    recall1 = get_recall(labels[:, 0], predictions[0])
    recall2 = get_recall(labels[:, 1], predictions[1])
    recall3 = get_recall(labels[:, 2], predictions[2])
    recall = np.average([recall1, recall2, recall3], weights=[2, 1, 1])
    
    del model
    del predictions
    torch.cuda.empty_cache()
    gc.collect()

    print(f'recall_grapheme  : 1.0 * {recall1}')
    print(f'recall_vowel     : 0.5 * {recall2}')
    print(f'recall_consonant : 0.5 * {recall3}')
    print('-'*50)
    print(f'final recall     : {recall}')

model :  20200313_pretrained_models\base.pth
Loaded pretrained weights for efficientnet-b3


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


recall_grapheme  : 1.0 * 1.0
recall_vowel     : 0.5 * 0.9998864380053203
recall_consonant : 0.5 * 0.9998603800152164
--------------------------------------------------
final recall     : 0.9999367045051342
model :  20200313_pretrained_models\base_rotate.pth
Loaded pretrained weights for efficientnet-b3


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


recall_grapheme  : 1.0 * 1.0
recall_vowel     : 0.5 * 0.9999850200140087
recall_consonant : 0.5 * 0.9999496886424513
--------------------------------------------------
final recall     : 0.9999836771641151
model :  20200313_pretrained_models\base_mixup.pth
Loaded pretrained weights for efficientnet-b3


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


recall_grapheme  : 1.0 * 0.998754734420601
recall_vowel     : 0.5 * 0.9922038474730329
recall_consonant : 0.5 * 0.9884337552157908
--------------------------------------------------
final recall     : 0.9945367678825064
model :  20200313_pretrained_models\base_cutmix.pth
Loaded pretrained weights for efficientnet-b3


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


recall_grapheme  : 1.0 * 0.99788018094675
recall_vowel     : 0.5 * 0.9901965756835683
recall_consonant : 0.5 * 0.986073596462992
--------------------------------------------------
final recall     : 0.9930076335100151
model :  20200313_pretrained_models\base_mixup_cutmix.pth
Loaded pretrained weights for efficientnet-b3


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


recall_grapheme  : 1.0 * 0.9787705435163141
recall_vowel     : 0.5 * 0.9839572128474301
recall_consonant : 0.5 * 0.9749104261263476
--------------------------------------------------
final recall     : 0.9791021815016014
