## Libraries

In [None]:
import albumentations as A
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata
import torchvision.models as models
import yaml

from pathlib import Path
from typing import Tuple, Dict, Union, Optional, List

from catalyst.utils import get_device
from fastprogress import progress_bar
from sklearn.model_selection import KFold, train_test_split
from torch.nn.parameter import Parameter

## Settings

In [None]:
bin_name = "resnet34_best.pth"

## Config

In [None]:
conf_string = '''
dataset:
  test:
    affine: False
    morphology: False

data:
  test_df_path: ../input/bengaliai-cv19/train.csv
  test_parquet_path:
    - ../input/bengaliai-cv19/test_image_data_0.parquet
    - ../input/bengaliai-cv19/test_image_data_1.parquet
    - ../input/bengaliai-cv19/test_image_data_2.parquet
    - ../input/bengaliai-cv19/test_image_data_3.parquet
  sample_submission_path: ../input/bengaliai-cv19/sample_submission.csv

bin:
  - ../input/bengali-resnet34-init/fold0.pth
  - ../input/bengali-resnet34-init/fold1.pth
  - ../input/bengali-resnet34-init/fold2.pth
  - ../input/bengali-resnet34-init/fold3.pth
  - ../input/bengali-resnet34-init/fold4.pth

model:
  model_name: resnet34
  pretrained: True
  num_classes: 186
  head: custom
  in_channels: 3

test:
  batch_size: 128

transforms:
  test:
    HorizontalFlip: False
    VerticalFlip: False
    Noise: False
    Contrast: False
    Rotate: False
    RandomScale: False
    Cutout:
      num_holes: 0
  mean: [0.485, 0.456, 0.406]
  std: [0.229, 0.224, 0.225]

num_workers: 2
seed: 1213
img_size: 128
'''

In [None]:
config = dict(yaml.load(conf_string, Loader=yaml.SafeLoader))

## Data and utilities preparation

## transforms

In [None]:
def get_transforms(config: dict, phase: str = "train"):
    assert phase in ["train", "valid", "test"]
    if phase == "train":
        cfg = config["transforms"]["train"]
    elif phase == "valid":
        cfg = config["transforms"]["val"]
    elif phase == "test":
        cfg = config["transforms"]["test"]
    list_transforms = []
    if cfg["HorizontalFlip"]:
        list_transforms.append(A.HorizontalFrip())
    if cfg["VerticalFlip"]:
        list_transforms.append(A.VerticalFlip())
    if cfg["Rotate"]:
        list_transforms.append(A.Rotate(limit=15))
    if cfg["RandomScale"]:
        list_transforms.append(A.RandomScale())
    if cfg["Noise"]:
        list_transforms.append(
            A.OneOf(
                [A.GaussNoise(), A.IAAAdditiveGaussianNoise()], p=0.5))
    if cfg["Contrast"]:
        list_transforms.append(
            A.OneOf(
                [A.RandomContrast(0.5),
                 A.RandomGamma(),
                 A.RandomBrightness()],
                p=0.5))
    if cfg["Cutout"]["num_holes"] > 0:
        list_transforms.append(A.Cutout(**config.Cutout))

    list_transforms.append(
        A.Normalize(
            mean=config["transforms"]["mean"], std=config["transforms"]["std"], p=1))

    return A.Compose(list_transforms, p=1.0)

## Data Loading

In [None]:
df = pd.read_csv(config["data"]["test_df_path"])
transforms_dict = {"test": get_transforms(config, "test")}
cls_levels = {
    "grapheme": 168,
    "vowel": 11,
    "consonant": 7
}

## Dataset and DataLoader

In [None]:
class BaseTestDataset(torchdata.Dataset):
    def __init__(self, df: pd.DataFrame, transforms, size: Tuple[int, int]):
        self.images = df.iloc[:, 1:].values.reshape(-1, 137, 236)
        self.size = size
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if image.ndim == 2:
            image = np.moveaxis(np.stack([image, image, image]), 0, -1)
        image = cv2.resize(image, self.size)
        if self.transforms is not None:
            image = self.transforms(image=image)["image"]
        image = cv2.resize(image, self.size)
        if image.shape[2] == 3:
            image = np.moveaxis(image, -1, 0)
        return image
    
    
def get_base_test_loader(df: pd.DataFrame,
                         size: Tuple[int, int] = (128, 128),
                         batch_size=256,
                         num_workers=2,
                         transforms=None):
    dataset = BaseTestDataset(df, transforms, size)
    return torchdata.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False)

## Model

In [None]:
def gem(x: torch.Tensor, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p),
                        (x.size(-2), x.size(-1))).pow(1. / p)


def mish(input):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    See additional documentation for mish class.
    '''
    return input * torch.tanh(F.softplus(input))


class Mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    Examples:
        >>> m = Mish()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''

    def __init__(self):
        '''
        Init method.
        '''
        super().__init__()

    def forward(self, input):
        '''
        Forward pass of the function.
        '''
        return mish(input)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps).squeeze(-1).squeeze(-1)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(
            self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class Resnet(nn.Module):
    def __init__(self,
                 model_name: str,
                 num_classes: int,
                 pretrained=False,
                 head="linear",
                 in_channels=3):
        super().__init__()
        self.num_classes = num_classes
        self.base = getattr(models, model_name)(pretrained=pretrained)
        self.head = head
        assert in_channels in [1, 3]
        assert head in ["linear", "custom"]
        if in_channels == 1:
            if pretrained:
                weight = self.base.conv1.weight
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
                self.base.conv1.weight = nn.Parameter(
                    data=torch.mean(weight, dim=1, keepdim=True),
                    requires_grad=True)
            else:
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if head == "linear":
            n_in_features = self.base.fc.in_features
            self.base.fc = nn.Linear(n_in_features, self.num_classes)
        elif head == "custom":
            n_in_features = self.base.fc.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 7))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.head == "linear":
            return self.base(x)
        elif self.head == "custom":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        else:
            raise NotImplementedError


def get_model(config: dict):
    params = config["model"]
    if "resnet" in params["model_name"]:
        return Resnet(**params)
    else:
        raise NotImplementedError

## Inference Utilities

In [None]:
def load_model(config: dict, bin_path: Union[str, Path]):
    model = get_model(config)
    state_dict = torch.load(bin_path, map_location=get_device())
    if "model_state_dict" in state_dict.keys():
        model.load_state_dict(state_dict["model_state_dict"])
    else:
        model.load_state_dict(state_dict)
    return model

In [None]:
def inference_loop(model: nn.Module,
                   loader: torchdata.DataLoader,
                   cls_levels: dict,
                   loss_fn: Optional[nn.Module] = None,
                   requires_soft=False):
    n_grapheme = cls_levels["grapheme"]
    n_vowel = cls_levels["vowel"]
    n_consonant = cls_levels["consonant"]

    dataset_length = len(loader.dataset)
    prediction = np.zeros((dataset_length, 3), dtype=np.uint8)
    if requires_soft:
        soft_prediction = np.zeros(
            (dataset_length, n_grapheme + n_vowel, n_consonant),
            dtype=np.uint8)

    batch_size = loader.batch_size
    device = get_device()

    avg_loss = 0.
    model.eval()

    targets: Optional[torch.Tensor] = None

    for i, batch in enumerate(progress_bar(loader, leave=False)):
        with torch.no_grad():
            if isinstance(batch, dict):
                images = batch["images"].to(device)
                targets = batch["targets"].to(device)
            else:
                images = batch.to(device)
                targets = None
            pred = model(images).detach()
            if loss_fn is not None and targets is not None:
                avg_loss += loss_fn(
                    pred, batch["targets"].to(device)).item() / len(loader)
            head = 0
            tail = n_grapheme
            pred_grapheme = torch.argmax(
                pred[:, head:tail], dim=1).cpu().numpy()

            head = tail
            tail = head + n_vowel
            pred_vowel = torch.argmax(pred[:, head:tail], dim=1).cpu().numpy()

            head = tail
            tail = head + n_consonant
            pred_consonant = torch.argmax(
                pred[:, head:tail], dim=1).cpu().numpy()

            prediction[i * batch_size:(i + 1) * batch_size, 0] = pred_grapheme
            prediction[i * batch_size:(i + 1) * batch_size, 1] = pred_vowel
            prediction[i * batch_size:(i + 1) * batch_size, 2] = pred_consonant

            if requires_soft:
                head = 0
                tail = n_grapheme
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

                head = tail
                tail = head + n_vowel
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

                head = tail
                tail = head + n_consonant
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

    return_dict = {"prediction": prediction, "loss": avg_loss}
    if requires_soft:
        return_dict["soft_prediction"] = soft_prediction

    return return_dict

## Inference

In [None]:
components = ['consonant_diacritic', 'grapheme_root', 'vowel_diacritic']
target_grapheme = []
target_vowel = []
target_consonant = []
row_id = [] # row_id place holder

n_grapheme = cls_levels["grapheme"]
n_vowel = cls_levels["vowel"]
n_consonant = cls_levels["consonant"]

parquets = config["data"]["test_parquet_path"]
for path in parquets:
    df = pd.read_parquet(path)
    row_id.extend(df["image_id"].values)
    
    loader = get_base_test_loader(
        df,
        size=(config["img_size"], config["img_size"]),
        batch_size=config["test"]["batch_size"],
        num_workers=config["num_workers"],
        transforms=transforms_dict["test"])
    binaries = config["bin"]
    result_array = np.zeros((len(df), 186))
    for binary in binaries:
        model = load_model(config, binary)
        prediction = inference_loop(
            model,
            loader,
            cls_levels,
            loss_fn=None,
            requires_soft=True)
        result_array += prediction["soft_prediction"] / len(binaries)
    
    head = 0
    tail = n_grapheme
    grapheme_preds = np.argmax(result_array[:, head:tail], axis=1)
    
    head = tail
    tail = n_vowel
    vowel_preds = np.argmax(result_array[:, head:tail], axis=1)
    
    head = tail
    tail = n_consonant
    consonant_preds = np.argmax(result_array[:, head:tail], axis=1)
    
    target_grapheme.extend(grapheme_preds)
    target_vowel.extend(vowel_preds)
    target_consonant.extend(consonant_preds)
    
prediction_df = pd.DataFrame({
    "image_id": row_id,
    "grapheme_root": target_grapheme,
    "vowel_diacritic": target_vowel,
    "consonant_diacritic": target_consonant
})

## Make Submission

In [None]:
name = []
result = []

for i, row in prediction_df:
    for target_name in components:
        name.append(row.image_id + "_" + target_name)
        result.append(row[target_name])
        
submission = pd.DataFrame({
    "row_id": name,
    "target": result
})
submission.to_csv("submission.csv", index=False)
submission.head()