In [1]:
!pip install --no-deps '../input/timm-package/timm-0.1.26-py3-none-any.whl' > /dev/null

In [2]:
import glob
import torch
import albumentations
import pandas as pd
import numpy as np

from tqdm import tqdm
from PIL import Image
import joblib
import torch.nn as nn
from torch.nn import functional as F

In [3]:
MODEL_MEAN = (0.485, 0.456, 0.406)
MODEL_STD = (0.229, 0.224, 0.225)
IMG_HEIGHT = 137
IMG_WIDTH = 236
DEVICE="cuda"

In [4]:
import timm

class bengalimodel(nn.Module):
    def __init__(self, backbone = 'resnet18'):
        super(bengalimodel,self).__init__()
        self.backbone = timm.create_model(backbone, pretrained = False)
        self.l1 = nn.Linear(1000, 168)
        self.l2 = nn.Linear(1000, 11)
        self.l3 = nn.Linear(1000, 7)
        
    def forward(self,x):
        x = self.backbone(x)
        l1 = self.l1(x)
        l2 = self.l2(x)
        l3 = self.l3(x)
        
        return l1,l2,l3

In [5]:
class BengaliDataset:
    def __init__(self, df, img_height, img_width, mean, std):
        
        self.image_ids = df.image_id.values
        self.img_arr = df.iloc[:, 1:].values

        self.aug = albumentations.Compose([
            albumentations.Resize(img_height, img_width, always_apply=True),
            albumentations.Normalize(mean, std, always_apply=True)
        ])


    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, item):
        image = self.img_arr[item, :]
        img_id = self.image_ids[item]
        
        image = image.reshape(137, 236).astype(float)
        image = Image.fromarray(image).convert("RGB")
        image = self.aug(image=np.array(image))["image"]
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        

        return {
            "image": torch.tensor(image, dtype=torch.float),
            "image_id": img_id
        }


In [6]:
def model_predict(model):
    g_pred, v_pred, c_pred = [], [], []
    img_ids_list = [] 
    
    for file_idx in range(4):
        df = pd.read_parquet(f"../input/bengaliai-cv19/test_image_data_{file_idx}.parquet")

        dataset = BengaliDataset(df=df,
                                    img_height=IMG_HEIGHT,
                                    img_width=IMG_WIDTH,
                                    mean=MODEL_MEAN,
                                    std=MODEL_STD)

        data_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size= TEST_BATCH_SIZE,
            shuffle=False,
            num_workers=4
        )

        for bi, d in tqdm(enumerate(data_loader), total = int(len(dataset)/ data_loader.batch_size)):
            image = d["image"]
            img_id = d["image_id"]
            image = image.to(DEVICE, dtype=torch.float)

            g, v, c = model(image)

            for ii, imid in enumerate(img_id):
                g_pred.append(g[ii].cpu().detach().numpy())
                v_pred.append(v[ii].cpu().detach().numpy())
                c_pred.append(c[ii].cpu().detach().numpy())
                img_ids_list.append(imid)
        
    return g_pred, v_pred, c_pred, img_ids_list

In [7]:
model = bengalimodel()
TEST_BATCH_SIZE = 32

final_g_pred = []
final_v_pred = []
final_c_pred = []
final_img_ids = []

for i in range(5):
    model.load_state_dict(torch.load(f"../input/resnet18bengaliai/resnet18_fold{i}.pth"))
    model.to(DEVICE)
    model.eval()
    g_pred, v_pred, c_pred, img_ids_list = model_predict(model)
    
    final_g_pred.append(g_pred)
    final_v_pred.append(v_pred)
    final_c_pred.append(c_pred)
    if i == 0:
        final_img_ids.extend(img_ids_list)

1it [00:00,  1.24it/s]
1it [00:00,  7.86it/s]
1it [00:00,  9.22it/s]
1it [00:00,  8.44it/s]
1it [00:00,  9.23it/s]
1it [00:00,  8.60it/s]
1it [00:00,  9.64it/s]
1it [00:00,  8.33it/s]
1it [00:00,  9.18it/s]
1it [00:00,  8.22it/s]
1it [00:00,  9.17it/s]
1it [00:00,  7.99it/s]
1it [00:00,  9.06it/s]
1it [00:00,  8.20it/s]
1it [00:00,  9.27it/s]
1it [00:00,  8.20it/s]
1it [00:00,  9.07it/s]
1it [00:00,  8.16it/s]
1it [00:00,  8.08it/s]
1it [00:00,  3.95it/s]


In [8]:
final_g = np.argmax(np.mean(np.array(final_g_pred), axis=0), axis=1)
final_v = np.argmax(np.mean(np.array(final_v_pred), axis=0), axis=1)
final_c = np.argmax(np.mean(np.array(final_c_pred), axis=0), axis=1)

In [9]:
predictions = []
for ii, imid in enumerate(final_img_ids):
    predictions.append((f"{imid}_grapheme_root", final_g[ii]))
    predictions.append((f"{imid}_vowel_diacritic", final_v[ii]))
    predictions.append((f"{imid}_consonant_diacritic", final_c[ii]))

In [10]:
sub = pd.DataFrame(predictions, columns=["row_id", "target"])

In [11]:
sub

Unnamed: 0,row_id,target
0,Test_0_grapheme_root,3
1,Test_0_vowel_diacritic,0
2,Test_0_consonant_diacritic,0
3,Test_1_grapheme_root,93
4,Test_1_vowel_diacritic,2
5,Test_1_consonant_diacritic,0
6,Test_2_grapheme_root,19
7,Test_2_vowel_diacritic,0
8,Test_2_consonant_diacritic,0
9,Test_3_grapheme_root,115


In [12]:
sub.to_csv("submission.csv", index=False)