In [None]:
import os
import platform
from collections import namedtuple
import time

from tqdm.notebook import tqdm
import tabulate

import pandas as pd
import numpy as np
import sparse

from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import _LRScheduler
from my_efficientnet_pytorch_3d import EfficientNet3D

import torchio

from torch.utils.tensorboard import SummaryWriter

from utils import CTDataset

In [3]:
import os
import platform
import pandas as pd
import numpy as np
from typing import Dict

In [4]:
from ..src.utils import segmentate_patient, resample, CTDataset

In [None]:
from ..src.model_utils import OSICNet

In [None]:
dtype = torch.float32
USE_GPU = True
if USE_GPU and torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
device = torch.device(device)

In [None]:
MODEL_PATH = "../models/model.npz"

In [5]:
IMAGE_PATH = "../input/osic-pulmonary-fibrosis-progression/" if 'linux' in platform.platform().lower() else 'data/'

In [None]:
PROCESSED_PATH = os.path.join(IMAGE_PATH, 'processed_data')

In [None]:
test_patients = sorted(os.listdir(os.path.join(IMAGE_PATH, 'test')))

In [None]:
mode = 'test'

for patient_n in tqdm.tqdm(range(len(test_patients))):
    patient = test_patients[patient_n]
    
    all_images, _, _, all_masks, meta_data = segmentate_patient(mode, patient_n, IMAGE_PATH, perform_hack=False)
    SliceThickness, PixelSpacing = meta_data['SliceThickness'][0], meta_data['PixelSpacing'][0]
    assert len(PixelSpacing) == 2
    
    new_spacing = np.array([SliceThickness] + list(PixelSpacing))
    # noinspection PyBroadException
    try:
        ordering, case = np.argsort([float(_) for _ in meta_data['SliceLocation']]), 0
    except Exception:
        # noinspection PyBroadException
        try:
            ordering, case = np.argsort([float(_) for _ in meta_data['InstanceNumber']]), 1
        except Exception:
            ordering, case = np.arange(len(all_images)), 2
    
    all_images, all_masks = np.array(all_images)[ordering], np.array(all_masks)[ordering]
    for key, values in meta_data.items():
        meta_data[key] = np.array(values)[ordering].tolist()
    
    if len(all_images) != 196:
        all_images, _ = resample(
            all_images, [196, *all_images.shape[1:]], SliceThickness, PixelSpacing
        )
        all_masks, new_spacing = resample(
            all_masks, [196, *all_images.shape[1:]], SliceThickness, PixelSpacing
        )
        all_masks = all_masks > 0

        meta_data['SliceThickness'] = [new_spacing[0] for _ in meta_data['SliceThickness']]
        meta_data['PixelSpacing'] = [[new_spacing[0], new_spacing[1]] for _ in meta_data['PixelSpacing']]

    base_path = os.path.join(PROCESSED_PATH, mode, test_patients[patient_n])
    os.makedirs(base_path, exist_ok=True)
    
    if all_images.shape[1] == 512:
        all_masks = all_masks[:, ::2, ::2]
        all_images = all_images[:, ::2, ::2]
        meta_data['PixelSpacing'] = [[new_spacing[0] * 2, new_spacing[1] * 2] for _ in meta_data['PixelSpacing']] 
    if all_images.shape[1] == 632:
        all_masks = all_masks[:, 60:-60:2, 60:-60:2]
        all_images = all_images[:, 60:-60:2, 60:-60:2]
        meta_data['PixelSpacing'] = [[new_spacing[0] * 2, new_spacing[1] * 2] for _ in meta_data['PixelSpacing']]
    if all_images.shape[1] == 768:
        all_masks = all_masks[:, ::3, ::3]
        all_images = all_images[:, ::3, ::3]
        meta_data['PixelSpacing'] = [[new_spacing[0] * 3, new_spacing[1] * 3] for _ in meta_data['PixelSpacing']]
    if all_images.shape[1] == 1302:
        all_masks = all_masks[:, 11:-11:5, 11:-11:5]
        all_images = all_images[:, 11:-11:5, 11:-11:5]
        meta_data['PixelSpacing'] = [[new_spacing[0] * 5, new_spacing[1] * 5] for _ in meta_data['PixelSpacing']]

    np.save(os.path.join(base_path, 'meta.npy'), meta_data)
    np.save(os.path.join(base_path, 'images.npy'), all_images)
    sparse.save_npz(os.path.join(base_path, 'masks.npz'), sparse.COO(np.array(all_masks)))

In [None]:
test_dataset = CTDataset(
    f'{PROCESSED_PATH}/test',
    f'{IMAGE_PATH}/test.csv',
    train=False,
    transform=None,
    test_size=0,
    padding_mode=None, 
    random_state=42,
    pad_global=False,
)

test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=4)

In [None]:
model = OSICNet(
    dtype=dtype, device=device, use_poly=False, efficient_net_model_number=0, hidden_size=256, dropout_rate=0.5
)

In [None]:
model.load_state_dict(torch.load(MODEL_PATH))

In [None]:
answer = []

model.eval()
for cur_iter, data in enumerate(test_dataset):
    FVC_true = data[2]
    
    with torch.no_grad():
        weeks = list(range(-12, 133 + 1))
        data[0] = torch.tensor([data[0][0].item()] * len(weeks), dtype=torch.float32)
        data[1] = torch.tensor(weeks, dtype=torch.float32)
        data[2] = torch.tensor([data[2][0].item()] * len(weeks), dtype=torch.float32)
        
        all_preds = model(data)

    FVC_low, FVC_preds, FVC_high = all_preds[0]    
    sigmas = torch.clamp_min(FVC_high - FVC_low, 1e-7)
        
    for idx, week in enumerate(range(-12, 133 + 1)):
        tmp_id = test_dataset_test_patients[cur_iter] + '_' + str(week)
        FVC = FVC_preds[idx].item()
        Confidence = sigmas[idx].item()
        answer.append([tmp_id, FVC, Confidence])

In [None]:
result = pd.DataFrame(answer, columns=['Patient_Week', 'FVC', 'Confidence'])

In [None]:
result.to_csv("submission.csv", index=False)