In [0]:
# refs 
# https://learnopencv.com/3d-u-net-brats/#aioseo-dataset-preprocessing (WRT BraTS dataset)

# https://nipy.org/nibabel/nifti_images.html

## NIFTI (brain imaging related but not everyone uses it; DICOM may be preferred)
# https://github.com/DataCurationNetwork/data-primers/blob/main/Neuroimaging%20DICOM%20and%20NIfTI%20Data%20Curation%20Primer/neuroimaging-dicom-and-nifti-data-curation-primer.md
# https://discovery.ucl.ac.uk/id/eprint/10146893/1/geometry_medim.pdf 


# test data
# https://www.kaggle.com/datasets/aiocta/brats2023-part-1

In [0]:
# UC path 
# mmt_mlops_demos.cv.data
# /Volumes/mmt_mlops_demos/cv/data/BraTS2021_00495/

In [0]:
## to do -- convert some of the setup as a utils/config file  etc. 

!pip install nibabel -q
!pip install scikit-learn -q
!pip install tqdm -q
!pip install split-folders -q
!pip install torchinfo -q
!pip install segmentation-models-pytorch-3d -q
!pip install livelossplot -q
!pip install torchmetrics -q
!pip install tensorboard -q

!pip install pycocotools
!pip install opencv-python-headless

In [0]:
dbutils.library.restartPython()

In [0]:
import os
import random
import splitfolders
from tqdm import tqdm
import nibabel as nib
import glob
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import shutil
import time
 
from dataclasses import dataclass
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms
from torch.cuda import amp
 
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassAccuracy
 
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
import gc
 
import segmentation_models_pytorch_3d as smp
 
from livelossplot import PlotLosses
from livelossplot.outputs import MatplotlibPlot, ExtremaPrinter


In [0]:
# https://www.kaggle.com/datasets/aiocta/brats2023-part-1

# Do ONCE
# !pip install kaggle -q
# !kaggle datasets download -d aiocta/brats2023-part-1 -p /Volumes/mmt_mlops_demos/cv/data/BraTS2023/

In [0]:
# DO ONCE 
# !sudo apt install unzip
# !unzip /Volumes/mmt_mlops_demos/cv/data/BraTS2023/brats2023-part-1.zip -d /Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023-Glioma/

# !rm -rf <UD Vols path to>/brats2023-part-1.zip

In [0]:
# def seed_everything(SEED):
#    np.random.seed(SEED)
#    torch.manual_seed(SEED)
#    torch.cuda.manual_seed_all(SEED)
#    torch.backends.cudnn.deterministic = True
#    torch.backends.cudnn.benchmark = False
 
 
# def get_default_device():
#    gpu_available = torch.cuda.is_available()
#    return torch.device('cuda' if gpu_available else 'cpu'), gpu_available


In [0]:
# @dataclass(frozen=True)
# class TrainingConfig:
#    BATCH_SIZE:      int = 5
#    EPOCHS:          int = 100
#    LEARNING_RATE: float = 1e-3
#    CHECKPOINT_DIR:  str = os.path.join('model_checkpoint', '3D_UNet_Brats2023')
#    NUM_WORKERS:     int = 4

In [0]:
scaler = MinMaxScaler()
 
DATASET_PATH = '/Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023-Glioma/'
print("Total Files: ", len(os.listdir(DATASET_PATH)))
# Total Files:  625

In [0]:
# Load the NIfTI image
sample_image_flair = nib.load(os.path.join(DATASET_PATH , "BraTS-GLI-00000-000/BraTS-GLI-00000-000-t2f.nii")).get_fdata()
print("Original max value:", sample_image_flair.max()) 
# Original max value: 2934.0
 
# Reshape the 3D image to 2D for scaling
sample_image_flair_flat = sample_image_flair.reshape(-1, 1)

In [0]:
sample_image_flair

In [0]:
sample_image_flair_flat

In [0]:
# Apply scaling
sample_image_flair_scaled = scaler.fit_transform(sample_image_flair_flat)
 
# Reshape it back to the original 3D shape
sample_image_flair_scaled = sample_image_flair_scaled.reshape(sample_image_flair.shape)
 
print("Scaled max value:", sample_image_flair_scaled.max())
print("Shape of scaled Image: ", sample_image_flair_scaled.shape)

# Scaled max value: 1.0
# Shape of scaled Image:  (240, 240, 155)

In [0]:
sample_mask = nib.load(DATASET_PATH + "BraTS-GLI-00000-000/BraTS-GLI-00000-000-seg.nii").get_fdata()
sample_mask = sample_mask.astype(np.uint8)  
 
print("Unique class in the mask", np.unique(sample_mask)) 
print("Shape of sample_mask: ", sample_mask.shape)

# Unique class in the mask [0 1 2 3]
# Shape of sample_mask:  (240, 240, 155) 

In [0]:
sample_image_t1 = nib.load(DATASET_PATH + "BraTS-GLI-00000-000/BraTS-GLI-00000-000-t1n.nii").get_fdata()
# sample_image_t1 = sample_image_t1.astype(np.uint8)  # values between 0 and 255 | NOT NEEDED?


sample_image_t1ce = nib.load(DATASET_PATH + "BraTS-GLI-00000-000/BraTS-GLI-00000-000-t1c.nii").get_fdata()
# sample_image_t1c = sample_image_t1c.astype(np.uint8)  # values between 0 and 255 | NOT NEEDED?


sample_image_t2 = nib.load(DATASET_PATH + "BraTS-GLI-00000-000/BraTS-GLI-00000-000-t2w.nii").get_fdata()
# sample_image_t2 = sample_image_t2.astype(np.uint8)  # values between 0 and 255 |  NOT NEEDED?

In [0]:
import numpy as np

# Define the range
low = 50
high = 90 #141

# Generate a random integer between low (inclusive) and high (exclusive)
rand_int = np.random.randint(low, high)
print(f"Random integer between {low} and {high}: {rand_int}")

In [0]:
# n_slice = random.randint(0, sample_mask.shape[2])  # random slice between 0 - 154
n_slice = np.random.randint(low, high) #77
print("n_slice: ", n_slice)

plt.figure(figsize = (12,8))
 
plt.subplot(231)
plt.imshow(sample_image_flair_scaled[:,:,n_slice], cmap='gray')
plt.title('Image flair')
plt.colorbar()
 
plt.subplot(232)
plt.imshow(sample_image_t1[:,:,n_slice], cmap = "gray")
plt.title("Image t1")
plt.colorbar()

plt.subplot(233)
plt.imshow(sample_image_t1ce[:,:,n_slice], cmap='gray')
plt.title("Image t1ce")
plt.colorbar()

plt.subplot(234)
plt.imshow(sample_image_t2[:,:,n_slice], cmap = 'gray')
plt.title("Image t2")
plt.colorbar()

plt.subplot(235)
plt.imshow(sample_mask[:,:,n_slice])
plt.title("Seg Mask")
plt.colorbar()
 
plt.subplot(236)
plt.imshow(sample_mask[:,:,n_slice], cmap = 'gray')
plt.title('Mask Gray')
plt.colorbar()

plt.show()

In [0]:
combined_x = np.stack(
    [sample_image_flair_scaled, sample_image_t1ce, sample_image_t2], axis=3
)  # along the last channel dimension.
print("Shape of Combined x ", combined_x.shape)

# Shape of Combined x  (240, 240, 155, 3)

In [0]:
## from https://learnopencv.com/3d-u-net-brats/#aioseo-dataset-preprocessing it was determined that the main mask ROIs are within [56:184, 56:184, 13:141] for the dataset -- something to derive from other datasets 

ROI_dims = [[56, 184], [56, 184], [13, 141]]
ROI_dims

In [0]:
# combined_x = combined_x[56:184, 56:184, 13:141]
combined_x = combined_x[ROI_dims[0][0]:ROI_dims[0][1], 
                        ROI_dims[1][0]:ROI_dims[1][1],
                        ROI_dims[2][0]:ROI_dims[2][1]
                        ]
print("Shape after cropping: ", combined_x.shape)
 
# sample_mask_c = sample_mask[56:184,56:184, 13:141]
sample_mask_c = sample_mask[ROI_dims[0][0]:ROI_dims[0][1], 
                            ROI_dims[1][0]:ROI_dims[1][1],
                            ROI_dims[2][0]:ROI_dims[2][1]
                            ]
print("Mask shape after cropping: ", sample_mask_c.shape)
 
#Shape after cropping:  (128, 128, 128, 3)
#Mask shape after cropping:  (128, 128, 128)

In [0]:
len(list(range(56, 185)))

In [0]:
sample_mask[:,:,n_slice]

In [0]:
sample_mask[:,:,n_slice].shape

In [0]:
plt.figure(figsize = (6,4))

plt.subplot(121)
plt.imshow(#combined_x[:,:,n_slice-12],
           combined_x[:,:,n_slice-ROI_dims[2][0]-1],
          #   cmap = 'gray'
          )
plt.title("combined_x")

plt.subplot(122)
plt.imshow(#sample_mask_c[:,:,n_slice-12], 
           sample_mask_c[:,:,n_slice-ROI_dims[2][0]-1]
          #  cmap = 'gray'
          )
plt.title("sample_mask_c")

In [0]:
# https://www.synapse.org/Synapse:syn51156910/wiki/622351
# Task: Tumor Sub-region Segmentation
# The participants are called to address this task by using the provided clinically-acquired training data to develop their method and produce segmentation labels of the different glioma sub-regions. The sub-regions considered for evaluation are the "enhancing tumor" (ET), the "tumor core" (TC), and the "whole tumor" (WT) [see figure below]. The ET is described by areas that show hyper-intensity in T1Gd when compared to T1, but also when compared to “healthy” white matter in T1Gd. The TC describes the bulk of the tumor, which is what is typically resected. The TC entails the ET, as well as the necrotic (NCR) parts of the tumor. The appearance of NCR is typically hypo-intense in T1-Gd when compared to T1. The WT describes the complete extent of the disease, as it entails the TC and the peritumoral edematous/invaded tissue (ED), which is typically depicted by hyper-intense signal in FLAIR.

# The provided segmentation labels have values of:

# 1 for NCR
# 2 for ED
# 3 for ET
# 0 for everything else.

In [0]:
for s in range(sample_mask_c.shape[0]):
  print(s)
  print(sample_mask_c[s,:,n_slice-ROI_dims[2][0]-1])

In [0]:
## one_hot encoding may not be necessary ------------------

In [0]:
sample_mask_cat  = F.one_hot(torch.tensor(sample_mask_c, dtype = torch.long), num_classes = 4) 

In [0]:
sample_mask_cat.shape

In [0]:
# for s in range(sample_mask_cat.shape[0]):
#   print(s)
#   print(sample_mask_cat[:,s,n_slice-ROI_dims[2][0]-1],2)

In [0]:
# t1ce_list = sorted(glob.glob(f"{DATASET_PATH}/*/*t1c.nii"))
# t2_list = sorted(glob.glob(f"{DATASET_PATH}/*/*t2w.nii"))
# flair_list = sorted(glob.glob(f"{DATASET_PATH}/*/*t2f.nii"))
# mask_list = sorted(glob.glob(f"{DATASET_PATH}/*/*seg.nii"))


# print("t1ce list: ", len(t1ce_list))
# print("t2 list: ", len(t2_list))
# print("flair list: ", len(flair_list))
# print("Mask list: ", len(mask_list))

# # t1ce list:  625
# # t2 list:  625
# # flair list:  625
# # Mask list:  625

In [0]:
# import json

# # Save the lists to UC volume
# dbutils.fs.put("/Volumes/mmt_mlops_demos/cv/data/BraTS2023/t1ce_list.json", json.dumps(t1ce_list), overwrite=True)
# dbutils.fs.put("/Volumes/mmt_mlops_demos/cv/data/BraTS2023/t2_list.json", json.dumps(t2_list), overwrite=True)
# dbutils.fs.put("/Volumes/mmt_mlops_demos/cv/data/BraTS2023/flair_list.json", json.dumps(flair_list), overwrite=True)
# dbutils.fs.put("/Volumes/mmt_mlops_demos/cv/data/BraTS2023/mask_list.json", json.dumps(mask_list), overwrite=True)

In [0]:
import json

# Read the JSON files from the UC volume using Unix commands
t1ce_list_json = dbutils.fs.head("/Volumes/mmt_mlops_demos/cv/data/BraTS2023/t1ce_list.json", 1000000)
t2_list_json = dbutils.fs.head("/Volumes/mmt_mlops_demos/cv/data/BraTS2023/t2_list.json", 1000000)
flair_list_json = dbutils.fs.head("/Volumes/mmt_mlops_demos/cv/data/BraTS2023/flair_list.json", 1000000)
mask_list_json = dbutils.fs.head("/Volumes/mmt_mlops_demos/cv/data/BraTS2023/mask_list.json", 1000000)

# Parse the JSON strings into Python lists
t1ce_list = json.loads(t1ce_list_json)
t2_list = json.loads(t2_list_json)
flair_list = json.loads(flair_list_json)
mask_list = json.loads(mask_list_json)

# Print the lengths of the lists
print("t1ce list: ", len(t1ce_list))
print("t2 list: ", len(t2_list))
print("flair list: ", len(flair_list))
print("Mask list: ", len(mask_list))

In [0]:
# to continue with preprocessing for normal pytorch process and then try to convert to coco 

In [0]:
## DATASET Preprocessing test to pytorch dataloader 
# -- we will need to see how to reformat to coco/yolo friendly format 

In [0]:
# '/'.join(f"{DATASET_PATH}".split("/")[:-2])
UCV_folderpath =  "/Volumes/mmt_mlops_demos/cv/data/BraTS2023/"
UCV_subfoldername = "BraTS2023_Preprocessed"
# UCV_subfoldername = "BraTS2023_Preprocessed_v2" ## without one-hot-mask encoding

In [0]:
## Do Once

for idx in tqdm(
    range(len(t2_list)), desc="Preparing to stack, crop and save", unit="file"
):
    temp_image_t1ce = nib.load(t1ce_list[idx]).get_fdata()
    temp_image_t1ce = scaler.fit_transform(
        temp_image_t1ce.reshape(-1, temp_image_t1ce.shape[-1])
    ).reshape(temp_image_t1ce.shape)
 
    temp_image_t2 = nib.load(t2_list[idx]).get_fdata()
    temp_image_t2 = scaler.fit_transform(
        temp_image_t2.reshape(-1, temp_image_t2.shape[-1])
    ).reshape(temp_image_t2.shape)
 
    temp_image_flair = nib.load(flair_list[idx]).get_fdata()
    temp_image_flair = scaler.fit_transform(
        temp_image_flair.reshape(-1, temp_image_flair.shape[-1])
    ).reshape(temp_image_flair.shape)
 
    temp_mask = nib.load(mask_list[idx]).get_fdata()
 
    temp_combined_images = np.stack(
        [temp_image_flair, temp_image_t1ce, temp_image_t2], axis=3
    )
 
    temp_combined_images = temp_combined_images[56:184, 56:184, 13:141]
    temp_mask = temp_mask[56:184, 56:184, 13:141]
 
    val, counts = np.unique(temp_mask, return_counts=True)
 
    # If a volume has less than 1% of mask, we simply ignore to reduce computation
    if (1 - (counts[0] / counts.sum())) > 0.01:
        #         print("Saving Processed Images and Masks")
        
        ## omit the one_hot encoding ?
        if UCV_subfoldername != "BraTS2023_Preprocessed_v2":
            temp_mask = F.one_hot(torch.tensor(temp_mask, dtype=torch.long), num_classes=4)
        
        os.makedirs(f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/images", exist_ok=True)
        os.makedirs(f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/masks", 
                    exist_ok=True)
 
        np.save(
            f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/images/image_"
            + str(idx)
            + ".npy",
            temp_combined_images,
        )
        np.save(
            f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/masks/mask_"
            + str(idx)
            + ".npy",
            temp_mask,
        )
 
    else:
        pass

In [0]:
###

In [0]:
# compare tqdm vs vectorized pandasUDF for processing

In [0]:
# from pyspark.sql.functions import pandas_udf, PandasUDFType
# import pandas as pd
# import numpy as np
# import nibabel as nib
# import os
# import torch
# import torch.nn.functional as F
# from typing import List

# # Define the UDF
# @pandas_udf("string", PandasUDFType.SCALAR)
# def process_images(t1ce_path: pd.Series, 
#                    t2_path: pd.Series, 
#                    flair_path: pd.Series, 
#                    mask_path: pd.Series, 
#                    idx: pd.Series
#                    ) -> pd.Series:
#     results: List[str] = []
#     for i in range(len(t1ce_path)):
#         temp_image_t1ce: np.ndarray = nib.load(t1ce_path[i]).get_fdata()
#         temp_image_t1ce = scaler.fit_transform(
#             temp_image_t1ce.reshape(-1, temp_image_t1ce.shape[-1])
#         ).reshape(temp_image_t1ce.shape)

#         temp_image_t2: np.ndarray = nib.load(t2_path[i]).get_fdata()
#         temp_image_t2 = scaler.fit_transform(
#             temp_image_t2.reshape(-1, temp_image_t2.shape[-1])
#         ).reshape(temp_image_t2.shape)

#         temp_image_flair: np.ndarray = nib.load(flair_path[i]).get_fdata()
#         temp_image_flair = scaler.fit_transform(
#             temp_image_flair.reshape(-1, temp_image_flair.shape[-1])
#         ).reshape(temp_image_flair.shape)

#         temp_mask: np.ndarray = nib.load(mask_path[i]).get_fdata()

#         temp_combined_images: np.ndarray = np.stack(
#             [temp_image_flair, temp_image_t1ce, temp_image_t2], axis=3
#         )

#         temp_combined_images = temp_combined_images[56:184, 56:184, 13:141]
#         temp_mask = temp_mask[56:184, 56:184, 13:141]

#         val, counts = np.unique(temp_mask, return_counts=True)

#         # If a volume has less than 1% of mask, we simply ignore to reduce computation
#         if (1 - (counts[0] / counts.sum())) > 0.01:
#             if UCV_subfoldername != "BraTS2023_Preprocessed_v2":
#                 temp_mask = F.one_hot(torch.tensor(temp_mask, dtype=torch.long), num_classes=4).numpy()
            
#             images_dir: str = f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/images"
#             masks_dir: str = f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/masks"
#             os.makedirs(images_dir, exist_ok=True)
#             os.makedirs(masks_dir, exist_ok=True)

#             np.save(
#                 f"{images_dir}/image_" + str(idx[i]) + ".npy",
#                 temp_combined_images,
#             )
#             np.save(
#                 f"{masks_dir}/mask_" + str(idx[i]) + ".npy",
#                 temp_mask,
#             )
#             results.append("Processed")
#         else:
#             results.append("Skipped")
#     return pd.Series(results)

# # Create a Spark DataFrame
# data = list(zip(t1ce_list, t2_list, flair_list, mask_list, range(len(t1ce_list))))
# columns = ["t1ce_path", "t2_path", "flair_path", "mask_path", "idx"]
# df = spark.createDataFrame(data, columns)

# # Apply the UDF
# df = df.withColumn("result", process_images("t1ce_path", "t2_path", "flair_path", "mask_path", "idx"))

# display(df)

In [0]:
images_folder = f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/images"
print(len(os.listdir(images_folder)))
 
masks_folder = f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/masks"
print(len(os.listdir(masks_folder)))

# Images: 575
# Masks: 575

In [0]:

input_folder = f"{UCV_folderpath}{UCV_subfoldername}/input_data_3channels/"
 
output_folder = f"{UCV_folderpath}{UCV_subfoldername}/input_data_128/"

## do once 
splitfolders.ratio(
    input_folder, output_folder, seed=42, ratio=(0.75, 0.25), group_prefix=None
    # input_folder, output_folder, seed=42, ratio=(0.7, 0.2, 0.1), group_prefix=None
)

In [0]:
# if os.path.exists(input_folder):
#     shutil.rmtree(input_folder)
#     print(f"{input_folder} is removed")
# else:
#     print(f"{input_folder} doesn't exist")

In [0]:
class BraTSDataset(Dataset):
    def __init__(self, img_dir, mask_dir, normalization=True):
        super().__init__()
 
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_list = sorted(
            os.listdir(img_dir)
        )  # Ensure sorting to match images and masks
        self.mask_list = sorted(os.listdir(mask_dir))
        self.normalization = normalization
 
        # If normalization is True, set up a normalization transform
        if self.normalization:
            self.normalizer = transforms.Normalize(
                mean=[0.5], std=[0.5]
            )  # Adjust mean and std based on your data
 
    def load_file(self, filepath):
        return np.load(filepath)
 
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
       image_path = os.path.join(self.img_dir, self.img_list[idx])
       mask_path = os.path.join(self.mask_dir, self.mask_list[idx])
       # Load the image and mask
       image = self.load_file(image_path)
       mask = self.load_file(mask_path)
 
       # Convert to torch tensors and permute axes to C, D, H, W format (needed for 3D models)
       image = torch.from_numpy(image).permute(3, 2, 0, 1)  # Shape: C, D, H, W
    
       mask = torch.from_numpy(mask).permute(3, 2, 0, 1)  # Shape: C, D, H, W
       
       # mask = torch.from_numpy(mask).permute(2, 0, 1)  # Shape: D, H, W
       # mask = torch.from_numpy(mask).permute(2, 0, 1).unsqueeze(0)  # Shape: 1, D, H, W (added channel dimension)
       
       # Normalize the image if normalization is enabled
       if self.normalization:
           image = self.normalizer(image)
       
       return image, mask

In [0]:
train_img_dir = f"{UCV_folderpath}{UCV_subfoldername}/input_data_128/train/images"
train_mask_dir = f"{UCV_folderpath}{UCV_subfoldername}/input_data_128/train/masks"
 
val_img_dir = f"{UCV_folderpath}{UCV_subfoldername}/input_data_128/val/images"
val_mask_dir = f"{UCV_folderpath}{UCV_subfoldername}/input_data_128/val/masks"
 
val_img_list = os.listdir(val_img_dir)
val_mask_list = os.listdir(val_mask_dir)
 
# Initialize datasets with normalization only
train_dataset = BraTSDataset(train_img_dir, train_mask_dir, normalization=True)
val_dataset = BraTSDataset(val_img_dir, val_mask_dir, normalization=True)
 
# Print dataset statistics
print("Total Training Samples: ", len(train_dataset))
print("Total Val Samples: ", len(val_dataset))

#Total Training Samples:  431
#Total Val Samples:  144

In [0]:
# train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=5, shuffle=False, num_workers=4)

# # Sanity Check
# images, masks = next(iter(train_loader))
# print(f"Train Image batch shape: {images.shape}")
# print(f"Train Mask batch shape: {masks.shape}")

# # Train Image batch shape: torch.Size([5, 3, 128, 128, 128])
# # Train Mask batch shape: torch.Size([5, 4, 128, 128, 128])


# # Train Image batch shape: torch.Size([5, 3, 128, 128, 128])
# # Train Mask batch shape: torch.Size([5, 1, 128, 128, 128])

# # Train Image batch shape: torch.Size([5, 3, 128, 128, 128])
# # Train Mask batch shape: torch.Size([5, 128, 128, 128])

In [0]:
## redo 

In [0]:
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=5, shuffle=False, num_workers=4)

# Sanity Check
images, masks = next(iter(train_loader))
print(f"Train Image batch shape: {images.shape}")
print(f"Train Mask batch shape: {masks.shape}")

# # Adjust mask dimensions if necessary
# if masks.dim() == 3:
#     masks = masks.unsqueeze(0)  # Add a channel dimension if missing
# elif masks.dim() == 4:
#     masks = masks.permute(3, 2, 0, 1)  # Shape: C, D, H, W


# Train Image batch shape: torch.Size([5, 3, 128, 128, 128])
# Train Mask batch shape: torch.Size([5, 4, 128, 128, 128])

# Train Image batch shape: torch.Size([5, 3, 128, 128, 128])
# Train Mask batch shape: torch.Size([5, 1, 128, 128, 128])

# Train Image batch shape: torch.Size([5, 3, 128, 128, 128])
# Train Mask batch shape: torch.Size([5, 128, 128, 128])

In [0]:
for i in range(128):
  print(i)
  print(masks[3,0,64,i,:])

In [0]:
# masks.shape 
# torch.Size([5, 4, 128, 128, 128])

# torch.argmax(masks, dim=1).shape
# torch.Size([5, 128, 128, 128])

for i in range(128):
  print(i)
  print(torch.argmax(masks, dim=1)[1,64,i,:])

In [0]:
def visualize_slices(images, masks, num_slices=20):
    batch_size = images.shape[0]
 
    # if torch.Size([5, 4, 128, 128, 128])
    masks = torch.argmax(masks, dim=1)  # along the channel/class dim -- predicted value of each class for that dim (class)
    ## no longer needed when all is already "combined" in the mask
 
    for i in range(min(num_slices, batch_size)):
        fig, ax = plt.subplots(1, 5, figsize=(15, 5))
 
        middle_slice = images.shape[2] // 2
        ax[0].imshow(images[i, 0, middle_slice, :, :], cmap="gray")
        ax[1].imshow(images[i, 1, middle_slice, :, :], cmap="gray")
        ax[2].imshow(images[i, 2, middle_slice, :, :], cmap="gray")
        ax[3].imshow(masks[i, middle_slice, :, :], cmap="viridis")
        ax[4].imshow(masks[i, middle_slice, :, :], cmap="gray")

        # ax[3].imshow(masks[i][middle_slice, :, :], cmap="viridis")
        # ax[4].imshow(masks[i][middle_slice, :, :], cmap="gray")

        ax[0].set_title("T1ce")
        ax[1].set_title("FLAIR")
        ax[2].set_title("T2")
        ax[3].set_title("Seg Mask")
        ax[4].set_title("Mask - Gray")
 
        plt.show()
 
 
visualize_slices(images, masks, num_slices=20)

In [0]:
images, masks = next(iter(train_loader))
visualize_slices(images, masks, num_slices=20)

In [0]:
# def visualize_slices(images, masks, num_slices=20):
#     batch_size = images.shape[0]
 
#     ## if torch.Size([5, 4, 128, 128, 128])
#     masks = torch.argmax(masks, dim=1)  # along the channel/class dim
 
#     for i in range(min(num_slices, batch_size)):
#         fig, ax = plt.subplots(1, 5, figsize=(15, 5))
 
#         middle_slice = images.shape[2] // 2
#         ax[0].imshow(images[i, 0, middle_slice, :, :], cmap="gray")
#         ax[1].imshow(images[i, 1, middle_slice, :, :], cmap="gray")
#         ax[2].imshow(images[i, 2, middle_slice, :, :], cmap="gray")
                
#         ax[3].imshow(masks[i, middle_slice, :, :], cmap="viridis")
#         ax[4].imshow(masks[i, middle_slice, :, :], cmap="gray")
#         # ax[3].imshow(masks[i][middle_slice, :, :], cmap="viridis")
#         # ax[4].imshow(masks[i][middle_slice, :, :], cmap="gray")

        
#         ax[0].set_title("T1ce")
#         ax[1].set_title("FLAIR")
#         ax[2].set_title("T2")
#         ax[3].set_title("Seg Mask")
#         ax[4].set_title("Mask - Gray")
 
#         plt.show()
 
 
# visualize_slices(images, masks, num_slices=20)

In [0]:
middle_slice = images.shape[2] // 2
middle_slice

In [0]:
import matplotlib.pyplot as plt

# Assuming middle_slice is already defined
middle_slice = images.shape[2] // 2

# Convert the mask to a 2D array
mask_2d = masks[3, middle_slice, :, :]

# Plot the 2D mask
plt.imshow(mask_2d, cmap="viridis")
plt.title(f"Mask for middle slice {middle_slice}")
plt.colorbar()
plt.show()

In [0]:
# images, masks = next(iter(train_loader))
# visualize_slices(images, masks, num_slices=20)

In [0]:
def visualize_slices(images, masks, num_slices=20):
    batch_size = images.shape[0]
 
    masks = torch.argmax(masks, dim=1)  # along the channel/class dim
 
    for i in range(min(num_slices, batch_size)):
        fig, ax = plt.subplots(1, 5, figsize=(15, 5))
 
        middle_slice = images.shape[2] // 2
        ax[0].imshow(images[i, 0, middle_slice, :, :], cmap="gray")
        ax[1].imshow(images[i, 1, middle_slice, :, :], cmap="gray")
        ax[2].imshow(images[i, 2, middle_slice, :, :], cmap="gray")
        ax[3].imshow(masks[i, middle_slice, :, :], cmap="viridis")
        ax[4].imshow(masks[i, middle_slice, :, :], cmap="gray")
 
        ax[0].set_title("T1ce")
        ax[1].set_title("FLAIR")
        ax[2].set_title("T2")
        ax[3].set_title("Seg Mask")
        ax[4].set_title("Mask - Gray")
 
        plt.show()
 
 
visualize_slices(images, masks, num_slices=20)

In [0]:
!pip install pycocotools
!pip install opencv-python-headless


In [0]:
images.shape, masks.shape

In [0]:
# import json
# import numpy as np
# from pycocotools import mask as maskUtils

# def convert_to_coco_format(images, masks, categories):
#     coco_format = {
#         "images": [],
#         "annotations": [],
#         "categories": []
#     }

#     annotation_id = 1

#     for i in range(images.shape[0]):
#         image_info = {
#             "id": i,
#             "width": images.shape[3],
#             "height": images.shape[4],
#             "file_name": f"image_{i}.png"
#         }
#         coco_format["images"].append(image_info)

#         for category_id in range(masks.shape[1]):
#             mask = masks[i, category_id, :, :].numpy().astype(np.uint8)
#             mask = np.asfortranarray(mask)
#             rle = maskUtils.encode(mask)
#             if isinstance(rle, list):
#                 rle = rle[0]  # Take the first element if rle is a list
#             rle['counts'] = rle['counts'].decode('utf-8')  # Decode bytes to string
#             area = maskUtils.area(rle)
#             bbox = maskUtils.toBbox(rle)

#             annotation_info = {
#                 "id": annotation_id,
#                 "image_id": i,
#                 "category_id": category_id,
#                 "segmentation": rle,
#                 "area": area.tolist(),
#                 "bbox": bbox.tolist(),
#                 "iscrowd": 0
#             }
#             coco_format["annotations"].append(annotation_info)
#             annotation_id += 1

#     for category_id, category_name in enumerate(categories):
#         category_info = {
#             "id": category_id,
#             "name": category_name
#         }
#         coco_format["categories"].append(category_info)

#     return coco_format

# # Example usage
# # categories = ["category1", "category2", "category3", "category4"]

# # 1 for NCR
# # 2 for ED
# # 3 for ET
# # 0 for everything else.

# categories = ["nonT", "NCR", "ED", "ET"]

# coco_format = convert_to_coco_format(images, masks, categories)

# # Save to JSON file
# # with open("coco_annotations.json", "w") as f:
# #     json.dump(coco_format, f)

In [0]:
# coco_format

In [0]:
coco_format.keys()

In [0]:
# coco_format['images']

In [0]:
len(coco_format)

In [0]:
import json
import os
from PIL import Image
import numpy as np

def create_coco_json(images, masks, categories, output_file):
    coco_data = {
        "images": [],
        "annotations": [],
        "categories": categories
    }
    
    annotation_id = 1
    for image_id, (image_path, mask_path) in enumerate(zip(images, masks)):
        image = Image.open(f"/dbfs{image_path}")
        width, height = image.size
        
        coco_data["images"].append({
            "id": image_id,
            "file_name": os.path.basename(image_path),
            "width": width,
            "height": height
        })
        
        mask = np.array(Image.open(f"/dbfs{mask_path}"))
        for category in categories:
            category_id = category["id"]
            binary_mask = (mask == category_id).astype(np.uint8)
            if binary_mask.sum() == 0:
                continue
            
            # Find bounding box
            pos = np.where(binary_mask)
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            bbox = [xmin, ymin, xmax - xmin, ymax - ymin]
            
            # Create annotation
            coco_data["annotations"].append({
                "id": annotation_id,
                "image_id": image_id,
                "category_id": category_id,
                "bbox": bbox,
                "area": int(binary_mask.sum()),
                "iscrowd": 0
            })
            annotation_id += 1
    
    with open(f"/dbfs{output_file}", "w") as f:
        json.dump(coco_data, f, indent=4)



In [0]:
import json

# Path to the JSON file in DBFS
json_file_path = "/Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023_Preprocessed/coco_annotations.json"

# Read the JSON file
with open(json_file_path, 'r') as f:
    coco_annotations = json.load(f)

# Display the content
print(coco_annotations)

In [0]:
from PIL import Image
import os
import json

def create_coco_json(images, masks, categories, output_file):
    coco_data = {
        "images": [],
        "annotations": [],
        "categories": categories
    }
    
    annotation_id = 1
    for image_id, (image_path, mask_path) in enumerate(zip(images, masks)):
        print(f"Processing image: {image_path}, mask: {mask_path}")
        if not os.path.exists(image_path):
            print(f"Image file not found: {image_path}")
            continue
        if not os.path.exists(mask_path):
            print(f"Mask file not found: {mask_path}")
            continue

        image = Image.open(image_path)
        width, height = image.size

        coco_data["images"].append({
            "id": image_id,
            "file_name": os.path.basename(image_path),
            "width": width,
            "height": height
        })

        mask = Image.open(mask_path)
        # Assuming masks are binary images with 0 and 255 values
        mask_data = mask.getdata()
        mask_data = [1 if pixel > 0 else 0 for pixel in mask_data]

        coco_data["annotations"].append({
            "id": annotation_id,
            "image_id": image_id,
            "category_id": 1,  # Assuming a single category for simplicity
            "segmentation": mask_data,
            "area": sum(mask_data),
            "bbox": [0, 0, width, height],
            "iscrowd": 0
        })
        annotation_id += 1

    with open(output_file, 'w') as f:
        json.dump(coco_data, f)

# Example usage
images = ["/Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023_Preprocessed/input_data_128/train/images/image1.jpg", "/Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023_Preprocessed/input_data_128/train/images/image2.jpg"]
masks = ["/Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023_Preprocessed/input_data_128/train/images/mask1.png", "/Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023_Preprocessed/input_data_128/train/images/mask2.png"]
categories = [{"id": 1, "name": "category1"}, {"id": 2, "name": "category2"}]
output_file = "/Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023_Preprocessed/coco_annotations.json"
create_coco_json(images, masks, categories, output_file)

In [0]:
import json

# Path to the JSON file in DBFS
json_file_path = "/Volumes/mmt_mlops_demos/cv/data/BraTS2023/BraTS2023_Preprocessed/coco_annotations.json"

# Read the JSON file
with open(json_file_path, 'r') as f:
    coco_annotations = json.load(f)

# Display the content
print(coco_annotations)

In [0]:
import os
from PIL import Image
import numpy as np

def create_yolo_files(images, masks, categories, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    category_map = {category["id"]: idx for idx, category in enumerate(categories)}
    
    for image_path, mask_path in zip(images, masks):
        image = Image.open(image_path)
        width, height = image.size
        
        mask = np.array(Image.open(mask_path))
        annotations = []
        
        for category in categories:
            category_id = category["id"]
            binary_mask = (mask == category_id).astype(np.uint8)
            if binary_mask.sum() == 0:
                continue
            
            # Find bounding box
            pos = np.where(binary_mask)
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            
            # Convert to YOLO format
            center_x = (xmin + xmax) / 2 / width
            center_y = (ymin + ymax) / 2 / height
            bbox_width = (xmax - xmin) / width
            bbox_height = (ymax - ymin) / height
            
            annotations.append(f"{category_map[category_id]} {center_x} {center_y} {bbox_width} {bbox_height}")
        
        # Save annotations to file
        annotation_file = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + ".txt")
        with open(annotation_file, "w") as f:
            f.write("\n".join(annotations))

# Example usage
images = ["path/to/image1.jpg", "path/to/image2.jpg"]
masks = ["path/to/mask1.png", "path/to/mask2.png"]
categories = [{"id": 1, "name": "category1"}, {"id": 2, "name": "category2"}]
output_dir = "yolo_annotations"
create_yolo_files(images, masks, categories, output_dir)