In [None]:
import sys
sys.path.append('../')
import warnings
from torch.utils.data import DataLoader, Dataset
from src.pl_module import MelanomaModel
import pandas as pd
import torch
import torch.nn as nn
from typing import Tuple
import albumentations as A
from tqdm.auto import tqdm
import skimage.io
import numpy as np
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")

In [None]:
def load_model(model_name: str, model_type: str, weights: str):
    print('Loading {}'.format(model_name))
    model = MelanomaModel.net_mapping(model_name, model_type)
    model.load_state_dict(
        torch.load(weights)
    )
    model.eval()
    model.cuda()
    print("Loaded model {} from checkpoint {}".format(model_name, weights))
    return model

class MelanomaDataset(Dataset):
    def __init__(self, image_folder, df, transform=None):
        super().__init__()
        self.image_folder = image_folder
        self.df = df
        self.transform = transform

    def __len__(self) -> int:
        return self.df.shape[0]

    def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]:
        row = self.df.iloc[index]
        img_id = row.image_name
        img_path = f"{self.image_folder}/{img_id}.jpg"
        image = skimage.io.imread(img_path)
        if self.transform is not None:
            image = self.transform(image=image)['image']
        image = image.transpose(2, 0, 1)
        image = torch.from_numpy(image)
        return{'features': image, 'img_id': img_id}


def get_valid_transforms():
    return A.Compose(
        [
            A.Normalize()
        ],
        p=1.0)

In [None]:
data = pd.read_csv('../data/test.csv')
data.head()

In [None]:
model_name_list = [
    'resnest50d', 
    'resnest269e', 
    'resnest101e', 
    #'seresnext101_32x4d', 
    'tf_efficientnet_b3_ns', 
    'tf_efficientnet_b7_ns', 
    'tf_efficientnet_b5_ns']
model_type_list = ['SingleHeadMax'] * len(model_name_list)
weights_list = [
    '../weights/train_384_balancedW_resnest50d_fold0_heavyaugs_averaged_best_weights.pth',
    '../weights/07.09_train_384_balancedW_resnest269e_heavyaugs_averaged_best_weights.pth',
    '../weights/03.09_train_384_balancedW_resnest101e_fold0_heavyaugs_averaged_best_weights.pth',
    #'../weights/06.18_train_384_balancedW_seresnext101_32x4d_fold0_heavyaugs_averaged_best_weights.pth',
    '../weights/06.10_train_384_balancedW_b3_fold0_heavyaugs_averaged_best_weights.pth',
    '../weights/05.23_train_384_balancedW_b7_fold0_heavyaugs_averaged_best_weights.pth',
    '../weights/03.18_train_384_balancedW_b5_fold0_heavyaugs_averaged_best_weights.pth'
]
models = [load_model(model_name, model_type, weights) for model_name, model_type, weights in 
          zip(model_name_list, model_type_list, weights_list)]

In [None]:
dataset = MelanomaDataset('../data/jpeg-melanoma-384x384/test/', data, get_valid_transforms())
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
mean_cls_1_list = []
for batch in tqdm(dataloader, total=len(dataloader)):
    with torch.no_grad():
        preds = [nn.Sigmoid()(model(batch['features'].cuda())) for model in models]
        preds = torch.stack(preds)
        mean_cls_1 = preds[..., 0].cpu().numpy().mean(axis=0)
        mean_cls_1_list.extend(mean_cls_1)

In [None]:
data['target'] = mean_cls_1_list
data.to_csv('../data/labeled_test.csv', index=False)