In [None]:
!pip install ../input/timm-packages/timm-0.3.4-py3-none-any.whl

In [None]:
import timm
import pandas as pd
import numpy as np
import os
import torch
import cv2

from torch.utils.data import Dataset,DataLoader
from tqdm import tqdm
from torchvision.transforms import Normalize, CenterCrop, Resize, Compose, ToTensor, ColorJitter
from PIL import Image
from scipy.stats import mode

In [None]:
dataset_dir = '/kaggle/input/cassava-leaf-disease-classification/test_images/'
annotations_path = '/kaggle/input/cassava-leaf-disease-classification/train.csv'
exists_label = False

model_path = '../input/efficientnetb1-8-chk0/checkpoint-0.pth.tar'
model_name = 'efficientnet_b1'

In [None]:
CFG = {
    'num_classes': 5, 
    'img_size': 512, 
    'inference_bs': 64, 
    'num_workers': 2, 
    'device': 'cuda:0', 
    'tta_steps': 1
}

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
class ImageInferenceDataset(Dataset):
    def __init__(self, root_dir, images_df, transform, exists_label=False):
        self.images_df = images_df
        self.transform = transform
        self.root_dir = root_dir
        self.exists_label = exists_label
        
    def __len__(self):
        return len(self.images_df)
    
    def __getitem__(self, index):
        img_id = self.images_df.iloc[index]['image_id']
        data = open(os.path.join(self.root_dir, img_id), 'rb')
        img = Image.open(data).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
            
        if self.exists_label:
            label = self.images_df.iloc[index]['label']
            return img, torch.tensor(label, dtype=torch.long)
        else:
            return img, torch.tensor(-1, dtype=torch.long)

In [None]:
def get_inference_transforms():
    return Compose([
        Resize(CFG['img_size'] + 50, interpolation=Image.BICUBIC),
        CenterCrop(CFG['img_size']),
        ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        ToTensor(), 
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
        ])

In [None]:
if exists_label:
    images_df = pd.read_csv(annotations_path)
else:
    images_df = pd.DataFrame()
    images_df['image_id'] = list(os.listdir(dataset_dir))

dataset = ImageInferenceDataset(dataset_dir, images_df, get_inference_transforms(), exists_label=exists_label)

data_loader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=CFG['inference_bs'],
    num_workers=CFG['num_workers'],
    shuffle=False,
    pin_memory=False
)

device = torch.device(CFG['device'])
model = timm.create_model(model_name, num_classes=5, checkpoint_path=model_path).to(device)

all_predictions = []
accuracy = 0.0
count = 0

model.eval()
with torch.no_grad():
    for _ in range(CFG['tta_steps']):
        step_predictions = []
        for batch_idx, (img_data, expected) in enumerate(data_loader):
            imgs = img_data.to(device).float()
            logits = model(imgs)

            predictions = torch.softmax(logits, 1).detach().cpu().numpy()
            labels = np.argmax(predictions, axis=1)

            # Compute accuracy
            expected = expected.detach().cpu().numpy()
            accuracy += np.sum(expected == labels)
            count += expected.shape[0]

            step_predictions += list(labels)
        
        all_predictions += [step_predictions]
        

all_predictions = mode(all_predictions, axis=0)[0][0]

# Display accuracy
print(count)
print(accuracy / count)
        
images_df['label'] = all_predictions
# print(images_df)
        
del model
torch.cuda.empty_cache()

In [None]:
images_df.to_csv('submission.csv', index=False)