In [None]:
KAGGLE = False

In [None]:
%%time
if KAGGLE:
    !pip install ../input/efficientnet-pytorch/EfficientNet-PyTorch-1.0 -f ./ --no-index
else:
    %pip install efficientnet_pytorch

In [None]:
#!g1.1
import os
import gc
import sys
import json
import time
import cv2
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms
from torch.utils.data.sampler import SequentialSampler
from efficientnet_pytorch import model as enet
DEVICE = torch.device('cuda')

In [None]:
#!g1.1
TEST = True
VER = 'v0'
if KAGGLE:
    DATA_PATH = '../input/plant-pathology-2021-fgvc8'
    MDLS_PATH = f'../input/plant-models-{VER}'
else:
    DATA_PATH = '/home/jupyter/mnt/datasets/PLANT_DATASET'
    MDLS_PATH = f'./models_{VER}'
TH = None #[.3, .35, .35, .35, .35, .35]
TTAS = [0]
FOLDS = [0]
IMGS_PATH = f'{DATA_PATH}/test_images' if TEST else f'{DATA_PATH}/train_images'

start_time = time.time()

In [None]:
#!g1.1
with open(f'{MDLS_PATH}/params.json') as file:
    params = json.load(file)
LABELS_ = params['labels_']
LABELS = params['labels']
WORKERS = 2 if KAGGLE else params['workers']
print('loaded params:', params)

if TH: 
    ths = {'0': TH[0], '1': TH[1], 
           '2': TH[2], '3': TH[3], 
           '4': TH[4], '5': TH[5]}
else:
    with open(f'{MDLS_PATH}/ths.json') as file:
        ths = json.load(file)
print('thresholds:', ths)

df_sub = pd.DataFrame(os.listdir(IMGS_PATH)) if TEST else pd.DataFrame(os.listdir(IMGS_PATH)[:100])
df_sub.columns = ['image']
df_sub['labels'] = 'healthy'
display(df_sub.head())

In [None]:
#!g1.1
def flip(img, axis=0):
    if axis == 1:
        return img[::-1, :, ]
    elif axis == 2:
        return img[:, ::-1, ]
    elif axis == 3:
        return img[::-1, ::-1, ]
    else:
        return img

class PlantDataset(data.Dataset):
    
    def __init__(self, df, size, labels, transform=None, tta=0):
        self.df = df.reset_index(drop=True)
        self.size = size
        self.labels = labels
        self.transform = transform
        self.tta = tta
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_name = row.image
        img_path = f'{IMGS_PATH}/{img_name}'
        img = cv2.imread(img_path)
        if not np.any(img):
            print('no img file read:', img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.size, self.size))
        img = img.astype(np.float32) / 255
        if self.transform is not None:
            img = self.transform(image=img)['image']
        if self.labels:
            img = img.transpose(2, 0, 1)
            label = np.zeros(len(self.labels)).astype(np.float32)
            for lbl in row.labels.split():
                label[self.labels[lbl]] = 1
            return torch.tensor(img), torch.tensor(label)
        else:
            img = flip(img, axis=self.tta)
            img = img.transpose(2, 0, 1)
            return torch.tensor(img.copy())

class EffNet(nn.Module):
    
    def __init__(self, params, out_dim):
        super(EffNet, self).__init__()
        self.enet = enet.EfficientNet.from_name(params['backbone'])
        nc = self.enet._fc.in_features
        self.enet._fc = nn.Identity()
        self.myfc = nn.Sequential(
            nn.Dropout(params['dropout']),
            nn.Linear(nc, int(nc / 4)),
            nn.Dropout(params['dropout']),
            nn.Linear(int(nc / 4), out_dim)
        )
        
    def extract(self, x):
        return self.enet(x)
    
    def forward(self, x):
        x = self.extract(x)
        x = self.myfc(x)
        return x
    
class ResNext(nn.Module):
    
    def __init__(self, params, out_dim):
        super(ResNext, self).__init__()
        self.rsnxt = torchvision.models.resnext50_32x4d(pretrained=False)
        nc = self.rsnxt.fc.in_features
        self.rsnxt.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(nc, int(nc / 4)),
            nn.ReLU(),
            nn.Dropout(params['dropout']),
            nn.Linear(int(nc / 4), out_dim)
        )
        self.rsnxt = nn.DataParallel(self.rsnxt)
        
    def forward(self, x):
        x = self.rsnxt(x)
        return x

In [None]:
#!g1.1
models = []
for n_fold in FOLDS:
    if params['backbone'] == 'resnext':
        model = ResNext(params=params, out_dim=len(LABELS_)) 
    else:
        model = EffNet(params=params, out_dim=len(LABELS_)) 
    path = '{}/model_best_{}.pth'.format(MDLS_PATH, n_fold)
    state_dict = torch.load(path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    model.float()
    model.eval()
    model.cuda(DEVICE)
    models.append(model)
    print('loaded:', path)
del state_dict, model
gc.collect();

In [None]:
#!g1.1
datasets, loaders = [], []
for tta in TTAS:
    dataset = PlantDataset(
        df=df_sub,
        size=params['img_size'],
        labels=None,
        transform=None,
        tta=tta)
    datasets.append(dataset)
    loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=params['batch_size'], 
        sampler=SequentialSampler(dataset), 
        num_workers=WORKERS)
    loaders.append(loader)

def get_labels(row, labels, ths):
    try:
        row = [i for i, x in enumerate(row) if x > ths[str(i)]]
        row = [labels[str(i)] for i in row]
        row = 'healthy' if ('healthy' in row or len(row) == 0) else ' '.join(row)
    except:
        print(row)
    return row

logits = []
with torch.no_grad():
    for i, model in enumerate(models):
        for j, loader in enumerate(loaders):
            logits_tta = []
            for img_data in loader:
                img_data = img_data.to(DEVICE)
                preds = np.squeeze(model(img_data).sigmoid().cpu().numpy())
                logits_tta.append(preds)
            print('model {} | loader {} -> done'.format(i, j))
            logits.append(logits_tta)
logits = np.mean(logits, axis=0)
logits = np.squeeze(np.vstack(logits))
df_sub['labels'] = [get_labels(x, LABELS, ths) for x in list(logits)]

elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')

In [None]:
print('value counts:')
print(df_sub.labels.value_counts())
df_sub.head()

In [None]:
df_sub.to_csv('submission.csv', index=False)