In [1]:
import numpy as np
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import config
import cv2

import glob
import os
import re

import joblib
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import pandas as pd
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F

import torch.nn as nn
import monai

In [2]:
NUM_IMAGES_3D = 64
TRAINING_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
IMAGE_SIZE = 224
N_EPOCHS = 15
do_valid = True
n_workers = 1

In [3]:
df = pd.read_csv("train_labels.csv")
df[:20] , len(df)

(    BraTS21ID  MGMT_value
 0           0           1
 1           2           1
 2           3           0
 3           5           1
 4           6           1
 5           8           1
 6           9           0
 7          11           1
 8          12           1
 9          14           1
 10         17           0
 11         18           0
 12         19           0
 13         20           1
 14         21           0
 15         22           0
 16         24           0
 17         25           1
 18         26           1
 19         28           1,
 585)

In [4]:
def extract_cropped_image_size(path):
    """
    reading dicom files and returning the resolution after cropping the files using `crop_img` function 
    resolution : number of pixels in cropped dicom file
    """
    
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    cropped_data = crop_img(data)
    resolution = cropped_data.shape[0]*cropped_data.shape[1]  
    return resolution

In [5]:
def crop_img(img):
    
    """
    removing zero valued pixels in dicom slice , if the dicom file is all zeros the fucntion returns an empty list
    
    
    """
    
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    c1, c2 = False, False
    try:
        rmin, rmax = np.where(rows)[0][[0, -1]]        # np.where(rows) : gettin indices of True values (not zero pixels in dicom)
    except:                                            # np.where(rows)[0][0,-1] getting the first and the last indices of the non zero pixeles in dicom file  (rmin , rmax)
        rmin, rmax = 0, img.shape[0]                   # remove all zeros slices           
        c1 = True

    try:
        cmin, cmax = np.where(cols)[0][[0, -1]]
    except:
        cmin, cmax = 0, img.shape[1]
        c2 = True
    bb = (rmin, rmax, cmin, cmax)
    
    if c1 and c2:
        return img[0:0, 0:0]                           # remove all zeros slices
    else:
        return img[bb[0] : bb[1], bb[2] : bb[3]]


In [6]:
def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array

    if rotate > 0:
        rot_choices = [
            0,
            cv2.ROTATE_90_CLOCKWISE,
            cv2.ROTATE_90_COUNTERCLOCKWISE,
            cv2.ROTATE_180,
        ]
        data = cv2.rotate(data, rot_choices[rotate])

    data = cv2.resize(data, (img_size, img_size))
    data = data - np.min(data)
    if np.min(data) < np.max(data):
        data = data / np.max(data)
    return data

In [7]:
class BrainRSNADataset(Dataset):
    def __init__(
        self, data, transform=None, target="MGMT_value", mri_type="T1w", is_train=True, ds_type="forgot", do_load=True
    ):
        self.target = target
        self.data = data
        self.type = mri_type

        self.transform = transform
        self.is_train = is_train
        self.folder = "train" if self.is_train else "test"
        self.do_load = do_load
        self.ds_type = ds_type
        self.img_indexes = self._prepare_biggest_images()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.loc[index]
        case_id = int(row.BraTS21ID)
        target = int(row[self.target])
        _3d_images = self.load_dicom_images_3d(case_id)
        _3d_images = torch.tensor(_3d_images).float()
        if self.is_train:
            return {"image": _3d_images, "target": target, "case_id": case_id}
        else:
            return {"image": _3d_images, "case_id": case_id}

    def _prepare_biggest_images(self):
        """
        getting the biggest dicom file from patient scans after cropping zero valued pixels
        
        """
        
        
        big_image_indexes = {}
        if (f"big_image_indexes_{self.ds_type}_{self.type}.pkl" in os.listdir("indices/"))\
            and (self.do_load) :
            print("Loading the best images indexes for all the cases...")
            big_image_indexes = joblib.load(f"indices/big_image_indexes_{self.ds_type}_{self.type}.pkl")
            return big_image_indexes
        else:
            
            print("Caulculating the best scans for every case...")
            for row in tqdm(self.data.iterrows(), total=len(self.data)):
                case_id = str(int(row[1].BraTS21ID)).zfill(5)
                path = f"train/{case_id}/{self.type}/*.dcm"
                files = sorted(
                    glob.glob(path),
                    key=lambda var: [
                        int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
                    ],
                )
                resolutions = [extract_cropped_image_size(f) for f in files]
                middle = np.array(resolutions).argmax()
                big_image_indexes[case_id] = middle

            joblib.dump(big_image_indexes, f"indices/big_image_indexes_{self.ds_type}_{self.type}.pkl")
            return big_image_indexes



    def load_dicom_images_3d(
        self,
        case_id,
        num_imgs=NUM_IMAGES_3D,
        img_size=IMAGE_SIZE,
        rotate=0,
    ):
        case_id = str(case_id).zfill(5)

        path = f"{self.folder}/{case_id}/{self.type}/*.dcm"
        files = sorted(
            glob.glob(path),
            key=lambda var: [
                int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
            ],
        )

    
        middle = self.img_indexes[case_id]   # largest resolution index of cropped dicom files  (largest cropped dicom)

        # # middle = len(files) // 2
        num_imgs2 = num_imgs // 2
        p1 = max(0, middle - num_imgs2)    # if the largest resultion dicom index less than the half of image depth start from 0
        p2 = min(len(files), middle + num_imgs2)  # either you take all files of only half the depth after the largest
        image_stack = [load_dicom_image(f, rotate=rotate, voi_lut=True) for f in files[p1:p2]]  #stacking images after one another
        
        img3d = np.stack(image_stack).T
        if img3d.shape[-1] < num_imgs:   # in case all the dicom files are less than the preset `num_imgs` 
            n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
            img3d = np.concatenate((img3d, n_zero), axis=-1)

        return np.expand_dims(img3d, 0)

In [8]:
random.seed(42)

df = pd.read_csv("train_labels.csv")

train_idx = random.sample(list(range(len(df))), k = 420)

train_df = df.iloc[train_idx].reset_index(drop=False)

val_idx = [x for x in list(range(len(df))) if x not in train_idx]

val_df = df.iloc[val_idx].reset_index(drop=False)

device = torch.device("cuda")

print("train_{}")
train_dataset = BrainRSNADataset(data=train_df, mri_type="FLAIR",ds_type="train")

valid_dataset = BrainRSNADataset(data=val_df, mri_type="FLAIR" ,ds_type="val")


train_dl = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=TRAINING_BATCH_SIZE,
    shuffle=True,
    num_workers=n_workers,
    drop_last=True,
    pin_memory=True,
)


validation_dl = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=n_workers,
    pin_memory=True,
)

train_{}
Loading the best images indexes for all the cases...
Loading the best images indexes for all the cases...


In [None]:

net = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)

In [None]:
net.features[0]  = nn.Conv2d(1,64 ,(3,3) , padding=(1,1))
net.classifier[6] = nn.Linear(4096 , 256)

In [9]:
batch = next(iter(train_dl))
batch["image"].shape

torch.Size([4, 1, 224, 224, 64])

In [None]:
net(batch["image"][:1])

# LSTM

In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        # -> x needs to be: (batch_size, seq, input_size)
        
        # or:
        #self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        
        
        
        
        # Set initial hidden states (and cell states for LSTM)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        
        # x: (n, 28, 28), h0: (2, n, 128)
        
        # Forward propagate RNN
        #out, _ = self.rnn(x, h0)  
        # or:
        out, _ = self.lstm(x, (h0,c0))  
        
        # out: tensor of shape (batch_size, seq_length, hidden_size)
        # out: (n, 28, 128)
        
        # Decode the hidden state of the last time step
        out = out[:, -1, :]
        # out: (n, 128)
         
        out = self.fc(out)
        # out: (n, 10)
        return out

In [None]:
lstm = RNN(1000 , 1000 , 16 , 2)
lstm.to('cuda')

In [None]:
lstm(seq.reshape(1,64,1000).to(device))

In [None]:
batch["image"].shape

In [None]:
seq = torch.zeros(64,1000)
for i in range(64):
    seq[i,:] = net(batch["image"][:1,:,:,:,i])
    

In [None]:
seq.to('cuda')

# Model 

In [None]:
class MODEL(nn.Module):
      def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        
        # CNN Encoder
        cnn = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
        cnn.features[0]  = nn.Conv2d(1,64 ,(3,3) , padding=(1,1))
        cnn.classifier[6] = nn.Linear(4096 , 256)
        
                
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        # -> x needs to be: (batch_size, seq, input_size)
        
        # or:
        #self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        
        # CNN Encoding
        seq = torch.zeros([TRAINING_BATCH_SIZE,NUM_IMAGES_3D,self.input_size])
        for i in range(TRAINING_BATCH_SIZE+1):
            for j in range(NUM_IMAGES_3D+1):
                seq[i,j,:] = self.cnn(x["image"][:j,:,:,:,i])

        
        
        # Set initial hidden states (and cell states for LSTM)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        
        # x: (n, 28, 28), h0: (2, n, 128)
        
        # Forward propagate RNN
        #out, _ = self.rnn(x, h0)  
        # or:
        out, _ = self.lstm(x, (h0,c0))  
        
        # out: tensor of shape (batch_size, seq_length, hidden_size)
        # out: (n, 28, 128)
        
        # Decode the hidden state of the last time step
        out = out[:, -1, :]
        # out: (n, 128)
         
        out = self.fc(out)
        # out: (n, 10)
        return out

In [None]:
se = torch.zeros_like(seq)

In [None]:
se.sum()

In [None]:
seq.shape

In [None]:
net.to(device)

In [11]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

Using cache found in /home/ahmed/.cache/torch/hub/pytorch_vision_v0.10.0


In [None]:
model

In [12]:
model.conv1 = nn.Conv2d(1,64,(7,7),stride=(2,2),padding=(3,3) , bias=False)
model.fc = nn.Linear(512,256)

In [None]:
model.to(device)

In [None]:
torch.cuda.empty_cache()

In [13]:
seq = torch.zeros([TRAINING_BATCH_SIZE,NUM_IMAGES_3D,256])
for i in range(TRAINING_BATCH_SIZE):
    for j in range(NUM_IMAGES_3D):
        seq[i,j,:] = model(batch["image"][i,:,:,:,j].unsqueeze(dim=0))


In [16]:
seq.unsqueeze(dim=0).shape

torch.Size([1, 4, 64, 256])

In [15]:
seq.shape

torch.Size([4, 64, 256])

In [None]:
loss(lstm(seq.reshape(1,64,1000).to(device)), torch.tensor([1],device=device))

In [None]:
torch.tensor(1)

In [None]:
def train_monai(model , lstm , optimizer , train_loader , val_loader , loss_fucntion , device = device , epochs = 20):
    val_interval = 2
    best_metric = -1
    #epoch_loss_values = list()
   # metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].to(device), batch_data["target"].to(device)
            optimizer.zero_grad()
            outputs_cnn = model(inputs)
            outputs = lstm(outputs_cnn(TRAINING_BATCH_SIZE,NUM_IMAGES_3D,outputs_cnn.shape[1]))
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        #epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.0
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                #metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), f"{best_metric}_best_metric_model_classification3d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_accuracy", metric, epoch + 1)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()

In [None]:

x[0] = torch.tensor([1])