In [1]:
#This is the testing code for DenseNet on Kaggle platform
#Group 4, Xiaoyu Wan, Yonghao Duan, Yu Sun
#Code reference1: https://www.kaggle.com/ratthachat/aptos-eye-preprocessing-in-diabetic-retinopathy
#Code reference2: https://www.kaggle.com/leighplt/densenet121-pytorch
#Code reference3: https://www.kaggle.com/abhishek/pytorch-inference-kernel-lazy-tta
#Performance: Private Score: 0.869439; Public Score: 0.684283

import os
import time
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision
from torchvision import transforms
from sklearn.metrics import cohen_kappa_score, accuracy_score
import warnings
import datetime
import matplotlib.pyplot as plt # Plotting
from sklearn.metrics import confusion_matrix
warnings.filterwarnings('ignore')
print("Done module loading")

Done module loading


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(device.type == 'cuda')

testdir = '../input/aptos2019-blindness-detection/test_images/'
testcsvfile='../input/aptos2019-blindness-detection/test.csv'
submitcsv = '../input/aptos2019-blindness-detection/sample_submission.csv'
valid_size = 0.2
batch_size = 16

cpu
False


In [3]:
class RetinopathyDatasetTest(Dataset):
    def __init__(self, csv_file, transform):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img_name = os.path.join(testdir, self.data.loc[idx, 'id_code'] + '.png')
        image = Image.open(img_name)
        image = image.resize((256, 256), resample=Image.BILINEAR)
        #label = torch.tensor(self.data.loc[idx, 'diagnosis'])
        return transforms.ToTensor()(image)

def round_off_preds(preds, coef=[0.5, 1.5, 2.5, 3.5]):
    for i, pred in enumerate(preds):
        if pred < coef[0]:
            preds[i] = 0
        elif pred >= coef[0] and pred < coef[1]:
            preds[i] = 1
        elif pred >= coef[1] and pred < coef[2]:
            preds[i] = 2
        elif pred >= coef[2] and pred < coef[3]:
            preds[i] = 3
        else:
            preds[i] = 4
    return preds

In [4]:
model = torchvision.models.densenet121(pretrained=False)
print("Model downloaded")
model.load_state_dict(torch.load("../input/densenetmodel/DenseNet.2019DataTrain.best_model.pt"))
model = model.to(device)
print("Trained Model loaded")
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = RetinopathyDatasetTest(csv_file=testcsvfile, transform=train_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

Model downloaded
Trained Model loaded


In [5]:
def prediction(best_model, loader):
    preds = np.empty(0)
    for x in loader:
        x = x.to(device)
        output = best_model(x)
        y_pred = output[:, -1].detach().cpu().numpy()
        print(y_pred)
        p = round_off_preds(y_pred)
        preds = np.append(preds, p, axis=0)         #append from batches
    return preds

preds_test = prediction(model, test_loader)
print(preds_test)

[ 1.5679291   2.1672525   2.440809    1.585267    1.8980842   1.5941353
  1.7012388  -0.05385725  2.987105   -0.1563124   2.6148765   2.4223943
  0.7691243   1.7343746   1.4726311  -0.02783853]
[ 1.4253204   2.5060515   1.5359476   0.18685153  2.0553117   2.370138
  1.7646054   3.0120645   2.422573   -0.08158717  1.5133433   0.19960853
  2.5512145   2.0282972  -0.05777397  2.3413625 ]
[ 1.5104376  -0.03500561  1.3320215   1.9110365   2.4197755   1.7345097
 -0.2015163   2.6359766   2.2232237   1.421387    1.5385951   2.5113251
  0.39679253  2.3161101   2.5108314   1.8079921 ]
[ 1.6889223   2.0519338   2.4058876   1.7207547   0.2827479  -0.12635294
  2.3738134   0.2733149   1.1616635   0.02401722  0.8414283   2.5274963
  1.2590729   1.4950396   2.6892745   2.014206  ]
[ 1.2941922   2.5344398   2.8042274   2.8959556   1.6242383   2.045545
  2.0517824   1.3342317   1.8787951  -0.13667578  1.9727867   1.2780293
  1.6448604   2.3796086   0.28641397 -0.06233956]
[0.21484964 2.3927252  0.43403

In [6]:
sample_sub = pd.read_csv(submitcsv)
sample_sub.diagnosis = preds_test
print(preds_test)
sample_sub.diagnosis = sample_sub['diagnosis'].astype(int)
sample_sub.head()

sample_sub.to_csv('submission.csv', index=False)

print("All set")

[2. 2. 2. ... 2. 3. 0.]
All set
