In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from functools import partial
import os

from mkrsna.rsna.loaders.loaders import RSNAData
from mkrsna.rsna.model import RSNAModel
from mkrsna.torch.collate import mixed_collate_imgs_fn_with_pad_value

from albumentations import (
    Compose, Normalize,
    ImageOnlyTransform
)
from albumentations.pytorch import ToTensorV2


In [2]:
try:
    from kaggle_secrets import UserSecretsClient
    IS_KAGGLE = True
except ImportError:
    IS_KAGGLE = False

In [3]:
# Please download the dataset manually from kaggle
DATASET_URI = "kaggle/rsna-breast-cancer-detection"
DATASET_PATH_SUFFIX = ""
DATASET_DIR = os.path.split(DATASET_URI)[1]
DATASETS_LOCAL_REPO = "/kaggle/input/" if IS_KAGGLE else os.path.expanduser("~/rsna-breast")
DATASET_PATH_START = os.path.join(DATASETS_LOCAL_REPO, DATASET_DIR)
DATASET_PATH = os.path.join(DATASET_PATH_START, DATASET_PATH_SUFFIX)

In [4]:
class Config:
    checkpoint_path = "checkpoints/ckpt_epoch05_trainf065.ckpt"
    test_imgs_path = f"{DATASET_PATH}/test_images"
    test_csv_path = f"{DATASET_PATH}/test.csv"
    test_bs = 8
    dataloader_workers_count = 4

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
str(device)

'cuda:0'

In [6]:
mixed_collate_imgs_fn = partial(mixed_collate_imgs_fn_with_pad_value, pad_value=0)

class ExpandTo3Channels(ImageOnlyTransform):
    def __init__(self, always_apply=True, p=1.0):
        super().__init__(always_apply=always_apply, p=p)

    def apply(self, img, **params):
        return img.reshape(img.shape[0],img.shape[1],1)*np.ones(3).reshape(1,1,3)

    def get_transform_init_args_names(self):
        return () 

valid_augments = Compose([
    Normalize(mean=(0.5), std=(0.5), max_pixel_value=1.0, p=1.0),
    ExpandTo3Channels(p=1.0),
    ToTensorV2(p=1.0)
])

In [7]:
test_df = pd.read_csv(Config.test_csv_path)
test_df['img_name'] = test_df['patient_id'].astype(str) + "/" + test_df['image_id'].astype(str) + ".png"

test_dataset = RSNAData(
    df = test_df,
    img_folder = Config.test_imgs_path,
    has_patient_folder_sturcture = True,
    resize_longer_axis_to=1024,
    pre_resize_for_countours_aspect=0.1,
    extension="dcm",
    is_test=True,
    transform = valid_augments
)

test_loader = DataLoader( 
    test_dataset,
    batch_size=Config.test_bs,
    shuffle=False,
    num_workers=0,
    collate_fn=mixed_collate_imgs_fn,
    pin_memory=True
)

In [8]:
eval_model = RSNAModel.load_from_checkpoint(checkpoint_path = Config.checkpoint_path, pretrained=False)
scripted_model = eval_model.to_torchscript()
scripted_model = eval_model.to(device)
scripted_model.eval()
scripted_model.freeze()

In [12]:

predictions = []
indices = []

with torch.no_grad():
    for batch_imgs, batch_indices in test_loader:
        cancer_predictions = scripted_model.predict(batch_imgs.to(device))
        predictions.extend(cancer_predictions.cpu().detach().numpy().flatten().tolist())
        indices.extend(batch_indices.cpu().detach().numpy().tolist())

summary_df = pd.DataFrame({"prediction_id": test_df["prediction_id"].values[indices], "cancer": predictions})
submission_series = summary_df.groupby(["prediction_id"])["cancer"].max()
submission_series.to_csv("submission.csv")

  img = torch.tensor(img, dtype=torch.float)
