In [1]:
from torch import nn
import torch
from torchvision import models
from typing import Union
import cv2
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from inference import InferenceDs, InferenceModel, Predictor
import numpy as np

# assemble network
WEIGHTS_PATH = "../input/leafdiseaseclassificationmodelweights/weights_fold0.pt"
classifier = models.resnext50_32x4d()

# dims for the base model
num_ftrs = classifier.fc.out_features
h1 = 512 
h2 = int(h1/2)

base_model = nn.Sequential(
    nn.BatchNorm1d(num_ftrs),
    nn.ReLU(inplace=True),
    nn.Dropout(0.25),
    nn.Linear(num_ftrs, h1),
    nn.BatchNorm1d(h1),
    nn.ReLU(inplace=True),
    nn.Dropout(0.25),
    nn.Linear(h1, h2),
    nn.BatchNorm1d(h2),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(h2, 5)
)

# init model
net = InferenceModel(classifier=classifier, base=base_model)
net.load_state_dict(torch.load(WEIGHTS_PATH))

<All keys matched successfully>

In [2]:
test_df = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
test_dir = "../input/cassava-leaf-disease-classification/test_images"
test_df["filePath"] = [os.path.join(test_dir,test_df["image_id"][n]) for n in range(len(test_df))]


train_df = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv")

weights = {}
for i in range(5):
    weights[i] = 1 - (len(train_df.loc[train_df.label == i]) / len(train_df))
weights = list(weights.values())

In [3]:
test_augs = A.Compose([
    A.RandomResizedCrop(224, 224),
    A.Transpose(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.HueSaturationValue(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
    A.Normalize(max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
],p=1.0)

device  = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
test_dl = torch.utils.data.DataLoader(InferenceDs(test_df, test_augs), shuffle=False, batch_size=1)

In [4]:
ids = None
test_preds = []

predictor = Predictor(net, device)

with torch.no_grad():
    for i in tqdm(range(5)):
        preds, ids = predictor.predict(test_dl)
        test_preds.append(preds * weights)

test_preds = np.mean(test_preds, axis=0)
test_preds = np.argmax(test_preds)

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [5]:
submit = pd.DataFrame()
submit["image_id"] = ids
submit["label"] = test_preds
submit.to_csv("submission.csv", index=False)