In [None]:
!pip install monai pydicom



In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# mount the drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from glob import glob
import matplotlib.pyplot as plt
import monai
from monai.config import KeysCollection
from monai.data import (
    Dataset,
    PersistentDataset,
    DataLoader
)
from monai.transforms.compose import MapTransform
import monai.transforms as transforms
from monai.transforms import (
    Compose,
    PadListDataCollate,
    Spacing,
    Pad, Resize,
    adaptor,
    LoadImaged,
    EnsureChannelFirstd,
    SpatialCropd,
    SpatialPadd,
    ScaleIntensityd,
    RandRotate90d,
    RandRotated,
    RandAxisFlipd,
    Resized,
    NormalizeIntensityd,
    ToTensord,
)
from monai.utils import set_determinism
import numpy as np
import pandas as pd
from pathlib import Path
import pydicom
import random
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import GroupKFold, train_test_split
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from tqdm.notebook import tqdm as tqdm_nb

pd.set_option("display.max_colwidth", -1)
pd.set_option('display.max_rows', None)

  pd.set_option("display.max_colwidth", -1)


In [None]:
class Config:
  SEED = 42
  BASE_PATH = "/content/drive/MyDrive/rsna_data"
  TEST_PATH = "/content/drive/MyDrive/png_test_lyd"
  IMAGE_SIZE = [256, 256]
  PARAMETER = 'g0.5'
  BATCH_SIZE = 1
  EPOCHS = 5
  TARGET_COLS = [
      "bowel_injury", "extravasation_injury",
      "kidney_healthy", "kidney_low", "kidney_high",
      "liver_healthy", "liver_low", "liver_high",
      "spleen_healthy", "spleen_low", "spleen_high",
  ]

In [None]:
config = Config()

In [None]:
random_seed = config.SEED
np.random.seed(random_seed)
set_determinism(random_seed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
BASE_PATH = config.BASE_PATH
TEST_PATH = config.TEST_PATH
IMAGE_SIZE = config.IMAGE_SIZE
PARAMETER = config.PARAMETER
width, height = config.IMAGE_SIZE

In [None]:
image_level_labels_path = "image_level_labels.csv"
train_csv_path = "train.csv"
test_csv_path = "sample_submission.csv"

In [None]:
train_csv = pd.read_csv(f'{BASE_PATH}/{train_csv_path}')
test_csv = pd.read_csv(f'{BASE_PATH}/{test_csv_path}')

In [None]:
# collect the files for train
train_img_paths = glob(f'{TEST_PATH}/*/*/*{IMAGE_SIZE}*{PARAMETER}/*.png')
print(f'Total number of images {len(train_img_paths)}')

Total number of images 12512


In [None]:
dataframe = pd.DataFrame(train_img_paths, columns=["image_path"])
# dataframe = dataframe.drop_duplicates()
dataframe['patient_id'] = dataframe.image_path.map(lambda x: x.split('/')[-4]).astype(int)
dataframe['series_id'] = dataframe.image_path.map(lambda x: x.split('/')[-3]).astype(int)
dataframe['width'] = width
dataframe['height'] = height
dataframe = dataframe.merge(train_csv, on='patient_id', how='inner')

In [None]:
dataframe.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 12512 entries, 0 to 12511
Data columns (total 19 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   image_path             12512 non-null  object
 1   patient_id             12512 non-null  int64 
 2   series_id              12512 non-null  int64 
 3   width                  12512 non-null  int64 
 4   height                 12512 non-null  int64 
 5   bowel_healthy          12512 non-null  int64 
 6   bowel_injury           12512 non-null  int64 
 7   extravasation_healthy  12512 non-null  int64 
 8   extravasation_injury   12512 non-null  int64 
 9   kidney_healthy         12512 non-null  int64 
 10  kidney_low             12512 non-null  int64 
 11  kidney_high            12512 non-null  int64 
 12  liver_healthy          12512 non-null  int64 
 13  liver_low              12512 non-null  int64 
 14  liver_high             12512 non-null  int64 
 15  spleen_healthy     

In [None]:
dataframe.head(5)

Unnamed: 0,image_path,patient_id,series_id,width,height,bowel_healthy,bowel_injury,extravasation_healthy,extravasation_injury,kidney_healthy,kidney_low,kidney_high,liver_healthy,liver_low,liver_high,spleen_healthy,spleen_low,spleen_high,any_injury
0,/content/drive/MyDrive/png_test_lyd/10004/21057/img_225x225_d1_g0.5/image_22.png,10004,21057,256,256,1,0,0,1,0,1,0,1,0,0,0,0,1,1
1,/content/drive/MyDrive/png_test_lyd/10004/21057/img_225x225_d1_g0.5/image_23.png,10004,21057,256,256,1,0,0,1,0,1,0,1,0,0,0,0,1,1
2,/content/drive/MyDrive/png_test_lyd/10004/21057/img_225x225_d1_g0.5/image_24.png,10004,21057,256,256,1,0,0,1,0,1,0,1,0,0,0,0,1,1
3,/content/drive/MyDrive/png_test_lyd/10004/21057/img_225x225_d1_g0.5/image_25.png,10004,21057,256,256,1,0,0,1,0,1,0,1,0,0,0,0,1,1
4,/content/drive/MyDrive/png_test_lyd/10004/21057/img_225x225_d1_g0.5/image_26.png,10004,21057,256,256,1,0,0,1,0,1,0,1,0,0,0,0,1,1


In [None]:
# Function to handle the split for each group
def split_group(group, test_size=0.2):
    if len(group) == 1:
        return (group, pd.DataFrame()) if np.random.rand() < test_size else (pd.DataFrame(), group)
    else:
        return train_test_split(group, test_size=test_size, random_state=42)

# Initialize the train and validation datasets
train_data = pd.DataFrame()
val_data = pd.DataFrame()

# Iterate through the groups and split them, handling single-sample groups
for _, group in dataframe.groupby(config.TARGET_COLS):
    train_group, val_group = split_group(group)
    train_data = pd.concat([train_data, train_group], ignore_index=True)
    val_data = pd.concat([val_data, val_group], ignore_index=True)

train_data.shape, val_data.shape

((10009, 19), (2503, 19))

In [None]:

train_data.columns

Index(['image_path', 'patient_id', 'series_id', 'width', 'height',
       'bowel_healthy', 'bowel_injury', 'extravasation_healthy',
       'extravasation_injury', 'kidney_healthy', 'kidney_low', 'kidney_high',
       'liver_healthy', 'liver_low', 'liver_high', 'spleen_healthy',
       'spleen_low', 'spleen_high', 'any_injury'],
      dtype='object')

In [None]:
# get image_paths and labels
print("[INFO] Building the dataset...")
train_paths = train_data.image_path.values; train_labels = train_data[config.TARGET_COLS].values.astype(np.float32)
valid_paths = val_data.image_path.values; valid_labels = val_data[config.TARGET_COLS].values.astype(np.float32)
train_labels = torch.tensor(train_labels, dtype=torch.float)
train_labels = torch.unsqueeze(train_labels, dim=1)
valid_labels = torch.tensor(valid_labels, dtype=torch.float)
valid_labels = torch.unsqueeze(valid_labels, dim=1)

[INFO] Building the dataset...


In [None]:
# train and valid dataset
train_dicts = [
    {"image" : img_name, "label" : lbl_name} for img_name, lbl_name in zip(train_paths, train_labels)
]
valid_dicts = [
    {"image" : img_name, "label" : lbl_name} for img_name, lbl_name in zip(valid_paths, valid_labels)
]

In [None]:
class RemoveAlphaChannel(transforms.Transform):
  def __init__(self):
    super().__init__()

  def __call__(self, img):
    # check if the input image has 4 channels
    if img['image'].shape[-1] == 4:
      # remove the alpha channel by selecting the first three channel
      img['image'] = img['image'][..., :3]
    return img

def transform():
  transform = Compose(
      [
          LoadImaged(keys='image', dtype=np.float64),
          RemoveAlphaChannel(),
          EnsureChannelFirstd(keys='image'),
          Resized(keys="image", mode="nearest", spatial_size=(224, 224)),
      ]
  )
  return transform

In [None]:
train_ds = Dataset(data = train_dicts, transform = transform())
valid_ds = Dataset(data = valid_dicts, transform = transform())

train_loader = DataLoader(train_ds, batch_size = config.BATCH_SIZE,
                          num_workers = 0, pin_memory = True)
valid_loader = DataLoader(valid_ds, batch_size = config.BATCH_SIZE,
                          num_workers = 0, pin_memory = True)



In [None]:
first_batch = monai.utils.misc.first(train_loader)
inputs = first_batch['image']

In [None]:
inputs.shape

torch.Size([1, 3, 224, 224])

In [None]:
for label in first_batch['label']:
  print(label.shape)
  print(label[..., 2:5])

torch.Size([1, 11])
tensor([[1., 0., 0.]])


In [None]:
for image in inputs:
  print(image.shape)
  plt.imshow(image[0])

  plt.show()

In [None]:
dir(models)

['AlexNet',
 'AlexNet_Weights',
 'ConvNeXt',
 'ConvNeXt_Base_Weights',
 'ConvNeXt_Large_Weights',
 'ConvNeXt_Small_Weights',
 'ConvNeXt_Tiny_Weights',
 'DenseNet',
 'DenseNet121_Weights',
 'DenseNet161_Weights',
 'DenseNet169_Weights',
 'DenseNet201_Weights',
 'EfficientNet',
 'EfficientNet_B0_Weights',
 'EfficientNet_B1_Weights',
 'EfficientNet_B2_Weights',
 'EfficientNet_B3_Weights',
 'EfficientNet_B4_Weights',
 'EfficientNet_B5_Weights',
 'EfficientNet_B6_Weights',
 'EfficientNet_B7_Weights',
 'EfficientNet_V2_L_Weights',
 'EfficientNet_V2_M_Weights',
 'EfficientNet_V2_S_Weights',
 'GoogLeNet',
 'GoogLeNetOutputs',
 'GoogLeNet_Weights',
 'Inception3',
 'InceptionOutputs',
 'Inception_V3_Weights',
 'MNASNet',
 'MNASNet0_5_Weights',
 'MNASNet0_75_Weights',
 'MNASNet1_0_Weights',
 'MNASNet1_3_Weights',
 'MaxVit',
 'MaxVit_T_Weights',
 'MobileNetV2',
 'MobileNetV3',
 'MobileNet_V2_Weights',
 'MobileNet_V3_Large_Weights',
 'MobileNet_V3_Small_Weights',
 'RegNet',
 'RegNet_X_16GF_Weights'

[Input size for EfficientNet](https://discuss.pytorch.org/t/input-size-for-efficientnet-versions-from-torchvision-models/140525)

In [None]:
# model = models.efficientnet_b0(pretrained=True)
# print(model)

In [None]:
def build_model(pretrained=True, fine_tune=True, num_classes=3):
    if pretrained:
        print('[INFO]: Loading pre-trained weights')
    else:
        print('[INFO]: Not loading pre-trained weights')
    model = models.efficientnet_b5(pretrained=pretrained)
    if fine_tune:
        print('[INFO]: Fine-tuning all layers...')
        for params in model.parameters():
            params.requires_grad = True
    elif not fine_tune:
        print('[INFO]: Freezing hidden layers...')
        for params in model.parameters():
            params.requires_grad = False
    # Change the final classification head.
    model.classifier[1] = nn.Linear(in_features=2048, out_features=num_classes)
    return model

In [None]:
# class CustomMultiHeadModel(nn.Module):
#   def __init__(self, num_classes=[1, 1, 3, 3, 3]):
#     super(CustomMultiHeadModel, self).__init__()

#     # define the backbone
#     self.backbone = models.efficientnet_b5(pretrained=True)

#     # gloabl average pooling (GAP) layer
#     self.gap = nn.AdaptiveAvgPool2d(1)

#     # num_features = self.backbone.classifier.in_features # ?

#     # define 'necks' for each head
#     self.neck = nn.Linear(num_features, 32)
#     self.activation = nn.ReLU()

#     # define heads
#     self.head_bowel = nn.Linear(32, num_classes[0])
#     self.head_extra = nn.Linear(32, num_classes[1])
#     self.head_liver = nn.Linear(32, num_classes[2])
#     self.head_kidney = nn.Linear(32, num_classes[3])
#     self.head_spleen = nn.Linear(32, num_classes[4])

#   def forward(self, x):
#     x = self.backbone(x)
#     x = self.gap(x)
#     x = x.view(x.size(0), -1)
#     x = self.neck(x)
#     x = self.activation(x)

#     out_bowel = self.head_bowel(x)
#     out_extra = self.head_extra(x)
#     out_liver = self.head_liver(x)
#     out_kidney = self.head_kidney(x)
#     out_spleen = self.head_spleen(x)

#     outputs = [out_bowel, out_extra, out_liver, out_kidney, out_spleen]
#     print(outputs.shape)
#     return outputs

In [None]:
# model = CustomMultiHeadModel()

In [None]:
model = build_model()

[INFO]: Loading pre-trained weights




[INFO]: Fine-tuning all layers...


In [None]:
#model.save_pretrained("./your_file_name")

In [None]:
learning_rate = 1e-5

# load the model

model = model.to(device)

# loss function

loss_function = torch.nn.CrossEntropyLoss()

# load optimizer

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [None]:
best_metric, best_metric_epoch = -1, -1
tr_metric_values, vl_metric_values = list(), list()
tr_epoch_loss_values, vl_epoch_loss_values = list(), list()
state = None

for epoch in tqdm_nb(range(1, config.EPOCHS + 1)):
  model.train()
  for batch in train_loader:
    inputs, labels = batch['image'].to(device), batch['label'].to(device)
    # print(inputs.shape)
    # print(labels.shape)
    labels = labels[..., 2:5]
    optimizer.zero_grad()
    outputs = model(inputs)
    # print(labels.shape)
    # print(outputs.shape)
    loss = loss_function(outputs, labels[0])
    loss.backward()
    optimizer.step()

  tr_pred, tr_true, vl_pred, vl_true = list(), list(), list(), list()
  tr_loss, tr_num, vl_loss, vl_num = 0, 0, 0, 0

  model.eval()
  with torch.no_grad():
    for batch in train_loader:
      inputs, labels = batch['image'].to(device), batch['label'].to(device)
      outputs = model(inputs)
      labels = labels[..., 2:5]
      loss = loss_function(outputs, labels[0])

      tr_loss += loss.item()
      tr_num += 1

      probs = F.softmax(outputs, dim=1)
      argmax_class = torch.argmax(probs, dim=1)
      pred = argmax_class.cpu().detach().numpy()
      labels = labels.squeeze(1).cpu().detach().numpy()
      tr_pred.extend(pred)
      tr_true.extend(labels)

    for batch in valid_loader:
      inputs, labels = batch['image'].to(device), batch['label'].to(device)
      outputs = model(inputs)
      labels = labels[..., 2:5]
      loss = loss_function(outputs, labels[0])

      vl_loss += loss.item()
      vl_num += 1

      probs = F.softmax(outputs, dim=1)
      argmax_class = torch.argmax(probs, dim=1)
      pred = argmax_class.cpu().detach().numpy()
      labels = labels.squeeze(1).cpu().detach().numpy()
      vl_pred.extend(pred)
      vl_true.extend(labels)

    tr_acc = balanced_accuracy_score(tr_true, tr_pred)
    tr_loss = tr_loss / tr_num
    tr_metric_values.append(tr_acc)
    tr_epoch_loss_values.append(tr_loss)

    vl_acc = balanced_accuracy_score(vl_true, vl_pred)
    vl_loss = vl_loss / vl_num
    vl_metric_values.append(vl_acc)
    vl_epoch_loss_values.append(vl_loss)

    if vl_acc > best_metric:
      best_metric = vl_Acc
      best_metric_epoch = epoch
      state = model.state_dict()
      torch.save(state, 'best_metric_model.pth')

values = {
    'tr_metric' : tr_metric_values,
    'vl_metric' : vl_metric_values,
    'tr_loss' : tr_epoch_loss_values,
    'vl_loss' : vl_epoch_loss_values
}

  0%|          | 0/5 [00:00<?, ?it/s]

ValueError: ignored