In [1]:
import platform
import os
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset


import numpy as np
import pandas as pd
import sparse

from efficientnet_pytorch import EfficientNet

from sklearn.model_selection import train_test_split


RUNNING_IN_KAGGLE = 'linux' in platform.platform().lower()
IMAGE_PATH = "../input/osic-pulmonary-fibrosis-progression/" if RUNNING_IN_KAGGLE else '/Users/Macbook/datasets/KaggleOSICPulmonaryFibrosisProgression'
PROCESSED_PATH = 'FIX IT!' if RUNNING_IN_KAGGLE else '/Users/Macbook/datasets/processed-data/'  # TODO: fix this line

dtype = torch.float32
USE_GPU = True
if USE_GPU and torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
device = torch.device(device)

In [2]:
class CTDataset(Dataset):
    _ReturnValue = namedtuple('ReturnValue', ['weeks', 'fvcs', 'features', 'masks', 'images'])
    
    def __init__(
            self, root, csv_path, train=True, test_size=0.25, random_state=42):
        """
        :param dataset:

        :param root:
        :param train:
        :param train_test_split:
        :param random_state:
        """
        assert test_size is not None
        
        self.root = root
        self.train = train
        self.csv_path = csv_path
        self.test_size = test_size
        self.random_state = random_state
        
        if not os.path.exists(self.root):
            raise ValueError('Data is missing')
        
        self._patients = list(sorted(os.listdir(self.root)))
        
        if self.test_size == 0:
            self._train_patients, self._test_patients = self._patients, []
        else:
            self._train_patients, self._test_patients = train_test_split(
                self._patients, test_size=self.test_size, random_state=random_state
            )
        
        self._table_features = dict()
        table_data = pd.read_csv(self.csv_path)
        for patient in self._patients:
            patient_data = table_data[table_data.Patient == patient]
            
            all_weeks = patient_data.Weeks.tolist()
            all_fvcs = patient_data.FVC.tolist()
            
            all_weeks, all_fvcs = zip(*sorted(zip(all_weeks, all_fvcs), key=lambda x: x[0]))
            
            age = sorted(zip(*np.unique(patient_data.Age, return_counts=True)), key=lambda x: x[1])[-1][0]
            sex = sorted(zip(*np.unique(patient_data.Sex, return_counts=True)), key=lambda x: x[1])[-1][0]
            smoking_status = sorted(zip(*np.unique(patient_data.SmokingStatus, return_counts=True)), key=lambda x: x[1])[-1][0]

            sex = [0, 1] if sex == 'Female' else [1, 0]
            smoking_status = (
                [1, 0, 0] if smoking_status == 'Ex-smoker' else
                [0, 1, 0] if smoking_status == 'Never smoked' else
                [0, 0, 1] if smoking_status == 'Currently smokes' else
                [0, 0, 0]
            )
            self._table_features[patient] = (
                all_weeks, all_fvcs, [age] + sex + smoking_status
            )
        
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        patient = self._train_patients[index] if self.train else self._test_patients[index]
        base_path = os.path.join(self.root, patient)

        meta = np.load(os.path.join(base_path, 'meta.npy'), allow_pickle=True).tolist()
        masks = sparse.load_npz(os.path.join(base_path, 'masks.npz'))
        images = np.load(os.path.join(base_path, 'images.npy'))
        
        meta_processed = dict()
        for key, values in meta.items():
            if key in {'SliceLocation', 'InstanceNumber'}:
                continue
            else:
                unique_values, values_cnt = np.unique(values, return_counts=True, axis=0)
                most_frequent = sorted(zip(unique_values, values_cnt), key=lambda x: x[1])[-1][0]
                most_frequent = np.array(most_frequent).reshape(-1)
                if key in {
                    'SliceThickness', 'TableHeight', 'WindowCenter', 'WindowWidth'
                }:
                    meta_processed[key] = most_frequent[0]
                elif key == 'PixelSpacing':
                    if len(most_frequent) == 1:
                        meta_processed['PixelSpacingX'], meta_processed['PixelSpacingY'] = (
                            most_frequent[0], most_frequent[0]
                        )
                    else:
                        meta_processed['PixelSpacingX'], meta_processed['PixelSpacingY'] = (
                            most_frequent[0], most_frequent[1]
                        )
                elif key == 'PatientPosition':
                    pass
                elif key == 'PositionReferenceIndicator':
                    pass
                    
        all_weeks, all_fvcs, features = self._table_features[patient]
        features = [value for key, value in meta_processed.items()] + features
        
        return CTDataset._ReturnValue(weeks=all_weeks, fvcs=all_fvcs, features=features, masks=masks, images=images)

    def __len__(self):
        return len(self._train_patients if self.train else self._test_patients)

    def __repr__(self):
        fmt_str = 'OSIC Pulmonary Fibrosis Progression Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        return fmt_str

In [3]:
train_dataset = CTDataset(
    f'{PROCESSED_PATH}/train',
    f'{IMAGE_PATH}/train.csv',
    train=True, test_size=0.25, random_state=42
)

val_dataset = CTDataset(
    f'{PROCESSED_PATH}/train',
    f'{IMAGE_PATH}/train.csv',
    train=False, test_size=0.25, random_state=42
)

In [4]:
data = train_dataset[0]
len(data.features)

12

In [5]:
len(train_dataset), len(val_dataset)

(132, 44)

In [42]:
# model = EfficientNet.from_pretrained('efficientnet-b0', in_channels=1280)

# img = torch.rand(1, 1280, 1100, 256)
# print(img.shape) 
# features = model.extract_features(img)
# print(features.shape) 

Loaded pretrained weights for efficientnet-b0
torch.Size([1, 1280, 1100, 256])
torch.Size([1, 1280, 35, 8])


In [6]:
def laplace_loss(y_true, y_pred, log_sigma):
    losses = np.sqrt(2) * (y_true - y_pred).abs() / log_sigma.exp() + log_sigma + np.log(2) / 2
    return losses.mean()

In [7]:
class Net2D(nn.Module):
    
    def __init__(self):
        super(Net2D, self).__init__()
        
        self.efficient_net_1 = EfficientNet.from_pretrained('efficientnet-b0', in_channels=1)
        self.efficient_net_2 = EfficientNet.from_pretrained('efficientnet-b0', in_channels=1280)

        self.fc_1 = nn.Linear(2560 + 14, 1000)
        self.fc_2 = nn.Linear(1000, 500)
        self.fc_3 = nn.Linear(500, 5)
    
    def forward(self, X, meta_X):
        """
        X: tensor (s, h, w): s - slices
        meta_X: tensor (n, 12)
        """
        s, h, w = X.shape
        
        X = X.unsqueeze(1) # add chanel axis (s, 1, h, w)
        
        X = self.efficient_net_1.extract_features(X) # shape (s, 1280, 8, 8)
        X = X.view(s, 1280, 64) # shape (s, 1280, 64)
        
        X = X.unsqueeze(0) # shape (1, s, 1280, 64)
        X = X.transpose(1, 2) # shape (1,  1280, s, 64)
        
        X = self.efficient_net_2.extract_features(X) # shape (1, 1280, ?, 2)
        
        X = torch.mean(X, dim=2) # shape (1, 1280, 2)
        X = X.view(1, 2560)
        
        X = torch.cat([X.repeat(meta_X.shape[0], 1), meta_X], dim=1) 
        
        X = F.relu(self.fc_1(X))
        X = F.relu(self.fc_2(X))
        y = self.fc_3(X)
        
        return y

In [8]:
def polynom(coords, coefs):
    # coords shape (n, )
    # coefs shape (4, )
    
    poly_coords = torch.empty((coords.shape[0], 4)).to(device)
    poly_coords[:, 3] = 1
    poly_coords[:, 2] = coords
    poly_coords[:, 1] = coords**2
    poly_coords[:, 0] = coords**3
    return (poly_coords * coefs.unsqueeze(0)).sum(dim=1)

In [9]:
def eval_model(model, loss_func, val_dataset):
    running_loss = 0.0
    model.eval()    
    for i, data in enumerate(val_dataset):
        #prepare lungs
        masks = data.masks 
        images = data.images

        lungs = -1000 * (1.0 - masks) + masks * images
        X = transforms(lungs).to(device)

        # prepare features
        weeks = torch.tensor(data.all_weeks).to(device)
        fvcs = torch.tensor(data.all_fvcs).to(device)
        features = torch.tensor(data.features).to(device)

        num_weeks = len(weeks)
        meta_X = torch.concat([weeks.unsqueeze(1), fvcs.unsqueeze(1), 
                               features.unsqueeze(0).repeat(num_weeks, 1)], dim=1)

        preds = model(X, meta_X) # shape (num_weeks, 5)

        coefs = pred[: 0:4]
        log_sigma = pred[:, 5]

        loss = 0
        for i in range(num_weeks):
            fwc_pred = polynom(weeks, coefs[i])
            loss += loss_func(fwc, fwc_pred, log_sigma[i])

        loss /= num_weeks
        running_loss += loss.item()
    
    return running_loss / len(val_dataset)

In [10]:
def train_model(model, optimizer, loss_func, train_dataset, val_dataset, epochs, scheduler=None):

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    val_loss_min = np.inf
    for epoch in range(epochs):
        running_loss = 0.0
        
        for i, data in enumerate(train_dataset):
            model.train()
            optimizer.zero_grad()
            
            #prepare lungs
            masks = data.masks 
            images = data.images
            
            lungs = -1000 * (1.0 - masks) + masks * images
            X = torch.tensor(lungs, dtype=dtype).to(device)

            # prepare features
            weeks = torch.tensor(data.weeks, dtype=dtype).to(device)
            fvcs = torch.tensor(data.fvcs, dtype=dtype).to(device)
            features = torch.tensor(data.features, dtype=dtype).to(device)
            
            num_weeks = len(weeks)
            meta_X = torch.cat([weeks.unsqueeze(1), fvcs.unsqueeze(1), 
                                   features.unsqueeze(0).repeat(num_weeks, 1)], dim=1)

            preds = model(X, meta_X) # shape (num_weeks, 5)

            
            coefs = preds[:, 0:4]
            log_sigma = preds[:, 4]
            
            loss = 0.0
            for i in range(num_weeks):
                fvcs_pred = polynom(weeks, coefs[i])
                loss += loss_func(fvcs, fvcs_pred, log_sigma[i])
            
            loss /= num_weeks
            loss.backward() 
            optimizer.step()
            running_loss += loss.item()
           
            print("Epoch: {} ".format(epoch + 1),
                  "Iteration: {} ".format(i),
                  "lr: {:.6f} ".format(get_lr(optimizer)),
                  "Loss: {:.6f}.".format(running_loss / (i + 1) ))
            
        running_loss /= len(train_dataset)  
        
#         val_loss = eval_model(model, loss_func, val_dataset)
    
        
#         print("Epoch: {}/{}...".format(epoch + 1, n_epochs),
#               "lr: {:.6f}...".format(get_lr(optimizer)),
#               "Loss: {:.6f}...".format(running_loss.item()),
#               "Val Loss: {:.6f}".format(val_loss))
#         print('------------------------------')

        
#         if val_loss <= val_loss_min:
#             torch.save(model.state_dict(), f'./state_dict{epoch}.pt')
#             print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(val_loss_min, val_loss))
#             val_loss_min = val_loss
        
#         if scheduler is not None:
#             scheduler.step()

In [11]:
model = Net2D().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)

Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0


In [None]:
train_model(model, optimizer, laplace_loss, train_dataset, val_dataset, epochs=1, scheduler=None)

Epoch: 1  Iteration: 8  lr: 0.050000  Loss: 443540.250000.
