In [2]:
import torch
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision import transforms
from torchvision.models import resnet18
from torch import nn

import pytorch_lightning as pl

from pathlib import Path
import pandas as pd
import ast

In [3]:
project_path = Path.cwd().parent
mri_path = project_path/'data/processed/mri/test'
csv_path = project_path/'data/processed/csv/test_mri_patients.csv'

In [20]:
class TestMRIDataset(Dataset):

    def __init__(self, mri_path: Path, csv_path: Path) -> None:
        self.data_dir = mri_path
        self.df = pd.read_csv(csv_path)
        self.mri_png_paths = list(self.data_dir.glob('*.png'))  # list of mri slices as png images
        self.img_size = 128

    def find_mri_png_paths(self, patient_id: str) -> list[Path]:
        """
        For each patient, find the corresponding paths to the MRI .png images
        and return them in a list
        """
        res = []
        for path in self.mri_png_paths:
            if patient_id in path.name:
                res.append(path)
        return res

    def __getitem__(self, patient_id):
        """
        Return all torch Tensors of MRI slices of a patient and his/her condition labels
        """
        png_paths = self.find_mri_png_paths(patient_id)
        data = []
        for path in png_paths:
            img = read_image(str(path))
            img = img.type(torch.FloatTensor) 
            img = transforms.Resize((self.img_size, self.img_size))(img)
            data.append(img)
        data = torch.stack(data)

        label = self.df[self.df['patient']==patient_id]['label'].values[0]
        label = ast.literal_eval(label)
        label = torch.Tensor(label)

        return data, label

    def __len__(self):
        return len(self.mri_png_paths)

test_dataset = TestMRIDataset(mri_path, csv_path)

In [44]:
mri_data, label = test_dataset['6dc8bd6b-e2a8-92bf-613d-8b477eb87d7c']
print(mri_data.shape, '---', label)

torch.Size([115, 1, 128, 128]) --- tensor([1., 0., 0., 0., 0.])


In [39]:
class MRIModel(pl.LightningModule):

    def __init__(self, lr=0.001):
        super().__init__()
        self.lr = lr
        self.resnet = resnet18()
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7),  # change input channel to be 1 instead of 3 
                                      stride=(2, 2), padding=(3, 3), bias=False)
        # add a linear layer at the end for transfer learning
        self.linear = nn.Linear(in_features=self.resnet.fc.out_features,
                                out_features=5)
        self.save_hyperparameters()  # log hyperparameters

    # optionally, define a forward method
    def forward(self, xs):
        y_hats = self.resnet(xs)
        y_hats = self.linear(y_hats)
        return y_hats  # we like to just call the model's forward method
    
    def training_step(self, batch, batch_idx):
        xs, ys = batch
        y_hats = self.forward(xs)
        loss = F.binary_cross_entropy_with_logits(y_hats, ys)
        self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        xs, ys = batch
        y_hats = self.forward(xs)
        loss = F.binary_cross_entropy_with_logits(y_hats, ys)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=True)
    
    # def test_step(self, xs, batch_idx):
    #     y_hats = self.resnet(xs)
    #     y_hats = self.linear(y_hats)
    #     return y_hats

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [41]:
checkpoint = project_path/"weights/epoch=1-step=320.ckpt"
model = MRIModel.load_from_checkpoint(checkpoint)
model.eval()

MRIModel(
  (resnet): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_run

An example testing for a patient

In [45]:
y_hats = model(mri_data)
y_hats = torch.sigmoid(y_hats)
mean_y_hats = torch.mean(y_hats, dim=0)
pred = [1 if i > 0.5 else 0 for i in mean_y_hats]
print(f'predicted = {pred}')
print(f'true labels = {label}')

predicted = [0, 0, 0, 1, 0]
true labels = tensor([1., 0., 0., 0., 0.])
