<a href="https://colab.research.google.com/github/ayushmaanFCB/3D-Medical-Image-Reconstruction-for-visualizing-Internal-Body-Structure/blob/main/UNet_3D_Vision_Transformer_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tcia_utils pydicom monai pytorch_msssim

In [2]:
import requests
import pandas as pd
from tcia_utils import nbia

In [4]:
# get list of available collections as JSON
print(nbia.getCollections())

[{'Collection': '4D-Lung'}, {'Collection': 'ACRIN-6698'}, {'Collection': 'ACRIN-Contralateral-Breast-MR'}, {'Collection': 'ACRIN-FLT-Breast'}, {'Collection': 'ACRIN-NSCLC-FDG-PET'}, {'Collection': 'Adrenal-ACC-Ki67-Seg'}, {'Collection': 'Advanced-MRI-Breast-Lesions'}, {'Collection': 'Anti-PD-1_Lung'}, {'Collection': 'B-mode-and-CEUS-Liver'}, {'Collection': 'BREAST-DIAGNOSIS'}, {'Collection': 'Breast-Cancer-Screening-DBT'}, {'Collection': 'Breast-MRI-NACT-Pilot'}, {'Collection': 'C4KC-KiTS'}, {'Collection': 'CBIS-DDSM'}, {'Collection': 'CC-Radiomics-Phantom'}, {'Collection': 'CC-Radiomics-Phantom-2'}, {'Collection': 'CC-Radiomics-Phantom-3'}, {'Collection': 'CC-Tumor-Heterogeneity'}, {'Collection': 'CMB-AML'}, {'Collection': 'CMB-CRC'}, {'Collection': 'CMB-GEC'}, {'Collection': 'CMB-LCA'}, {'Collection': 'CMB-MEL'}, {'Collection': 'CMB-MML'}, {'Collection': 'CMB-PCA'}, {'Collection': 'CMMD'}, {'Collection': 'COVID-19-AR'}, {'Collection': 'COVID-19-NY-SBU'}, {'Collection': 'CPTAC-CCRCC'}

In [3]:
data = nbia.getSeries(collection = "Soft-tissue-Sarcoma")
data[0]

{'SeriesInstanceUID': '1.3.6.1.4.1.14519.5.2.1.5168.1900.104193299251798317056218297018',
 'StudyInstanceUID': '1.3.6.1.4.1.14519.5.2.1.5168.1900.154535988064062152660648619556',
 'Modality': 'MR',
 'SeriesDate': '2003-12-12 00:00:00.0',
 'SeriesDescription': '2. AXIAL T1 BOTH LEGS - RESEARCH',
 'BodyPartExamined': 'EXTREMITY',
 'SeriesNumber': 2,
 'Collection': 'Soft-tissue-Sarcoma',
 'PatientID': 'STS_010',
 'Manufacturer': 'GE MEDICAL SYSTEMS',
 'ManufacturerModelName': 'GENESIS_SIGNA',
 'SoftwareVersions': '09',
 'ImageCount': 48,
 'TimeStamp': '2015-05-27 17:12:21.0',
 'LicenseName': 'Creative Commons Attribution 3.0 Unported License',
 'LicenseURI': 'http://creativecommons.org/licenses/by/3.0/',
 'CollectionURI': 'https://doi.org/10.7937/K9/TCIA.2015.7GO2GSKS',
 'FileSize': 25273786,
 'DateReleased': '2015-05-27 17:12:21.0',
 'StudyDesc': 'MRI LT LEG +C',
 'StudyDate': '2003-12-12 00:00:00.0'}

In [5]:
nbia.downloadSeries(data, number = 10)

In [None]:
import pydicom
import os
import shutil

def organize_dicom_by_patient(base_dir, output_dir):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Iterate through all the SeriesInstanceUID folders
    for series_folder in os.listdir(base_dir):
        series_path = os.path.join(base_dir, series_folder)

        if os.path.isdir(series_path):
            # Get all DICOM files in the folder
            dicom_files = [f for f in os.listdir(series_path) if f.endswith(".dcm")]

            for dicom_file in dicom_files:
                dicom_path = os.path.join(series_path, dicom_file)
                dicom_data = pydicom.dcmread(dicom_path)

                # Extract Patient ID and Series Description from DICOM file
                patient_id = dicom_data.PatientID
                series_desc = dicom_data.SeriesDescription

                # Create output folder based on Patient ID
                patient_folder = os.path.join(output_dir, f"patient_{patient_id}")
                os.makedirs(patient_folder, exist_ok=True)

                # Copy or move the DICOM file to the patient's folder
                shutil.copy(dicom_path, os.path.join(patient_folder, dicom_file))
                print(f"Copied {dicom_file} to {patient_folder}")

# Set base_dir to where your downloaded DICOM series are
base_dir = "/content/tciaDownload"  # The folder containing SeriesInstanceUID subfolders
output_dir = "./organized_data"

organize_dicom_by_patient(base_dir, output_dir)

In [None]:
import cv2, numpy as np

def dicom_to_png(input_dir, output_dir):
    # Ensure output directories exist
    os.makedirs(output_dir, exist_ok=True)

    # Loop over patients
    for patient_folder in os.listdir(input_dir):
        patient_path = os.path.join(input_dir, patient_folder)

        # Create output directory for the patient
        patient_output_dir = os.path.join(output_dir, patient_folder)
        os.makedirs(patient_output_dir, exist_ok=True)

        for dicom_file in os.listdir(patient_path):
            if dicom_file.endswith(".dcm"):
                dicom_path = os.path.join(patient_path, dicom_file)
                dicom_data = pydicom.dcmread(dicom_path)
                pixel_array = dicom_data.pixel_array

                # Normalize image to 0-255 for PNG saving
                img = cv2.normalize(pixel_array, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

                # Create the PNG filename
                png_filename = dicom_file.replace(".dcm", ".png")

                # Save as PNG
                cv2.imwrite(os.path.join(patient_output_dir, png_filename), img)
                print(f"Saved {png_filename} in {patient_output_dir}")

# Paths to organized DICOM data and desired PNG output
organized_dicom_dir = "/content/organized_data"
png_output_dir = "./organized_data_PNG"

dicom_to_png(organized_dicom_dir, png_output_dir)

In [9]:
import torch
import torch.nn as nn
from monai.networks.nets import UNet
from torchvision.models import vit_b_16
import torch.nn.functional as F

class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()
        self.unet = UNet(
            spatial_dims=3,
            in_channels=in_channels,
            out_channels=out_channels,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2)  # Downsampling
        )

    def forward(self, x):
        return self.unet(x)

class TransformerModule(nn.Module):
    def __init__(self):
        super(TransformerModule, self).__init__()
        self.transformer = vit_b_16(pretrained=True)  # Vision Transformer

    def forward(self, x):
        batch_size, channels, depth, height, width = x.shape

        # Loop through each slice in the depth dimension and resize height and width to 224x224
        resized_slices = []
        for i in range(depth):
            slice_2d = x[:, :, i, :, :]  # Extract each 2D slice (N, C, H, W)

            # Duplicate the grayscale channel to form 3 channels (N, 3, H, W)
            slice_2d = slice_2d.repeat(1, 3, 1, 1)

            # Resize to (224x224) as required by Vision Transformer
            resized_slice = F.interpolate(slice_2d, size=(224, 224), mode='bilinear', align_corners=False)
            resized_slices.append(resized_slice)

        # Stack the resized slices back into a tensor of shape (N, C, D, 224, 224)
        resized_3d = torch.stack(resized_slices, dim=2)

        # Flatten depth to treat each slice as a separate image for the Transformer
        x = resized_3d.permute(0, 2, 1, 3, 4).reshape(batch_size * depth, 3, 224, 224)

        # Apply Vision Transformer
        x = self.transformer(x)

        # Reshape the output back to original batch and depth dimensions
        x = x.view(batch_size, depth, -1)
        return x

class Hybrid3DModel(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Hybrid3DModel, self).__init__()
        self.unet3d = UNet3D(in_channels, out_channels)
        self.transformer = TransformerModule()
        self.conv_final = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # 3D CNN (U-Net) for volumetric feature extraction
        unet_output = self.unet3d(x)

        # Transformer for capturing long-range dependencies
        transformer_output = self.transformer(x)

        # Combine 3D CNN output with Transformer output
        combined = unet_output + transformer_output

        # Final convolution for output prediction
        output = self.conv_final(combined.unsqueeze(1))
        return output

# Instantiate the hybrid model
model = Hybrid3DModel(in_channels=1, out_channels=1)

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:02<00:00, 140MB/s]


In [10]:
from pytorch_msssim import ssim
import torch.optim as optim

# Loss functions
mse_loss = nn.MSELoss()  # Mean Squared Error for pixel-wise comparison

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4)  # AdamW optimizer

In [24]:
from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np
import torch

class MedicalImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.patient_dirs = sorted(os.listdir(image_dir))  # List of patient folders

    def __len__(self):
        return len(self.patient_dirs)

    def __getitem__(self, idx):
        patient_folder = os.path.join(self.image_dir, self.patient_dirs[idx])
        slice_files = sorted(os.listdir(patient_folder))  # List of slice images

        slices = []
        for slice_file in slice_files:
            img_path = os.path.join(patient_folder, slice_file)
            image = Image.open(img_path).convert("L")  # Grayscale
            image = np.array(image, dtype=np.float32)
            slices.append(image)

        # Stack slices to create a 3D volume
        volume = np.stack(slices, axis=0)  # Shape: [depth, height, width]
        volume = torch.FloatTensor(volume).unsqueeze(0)  # Add channel dimension

        if self.transform:
            volume = self.transform(volume)

        return volume

In [25]:
from torch.utils.data import DataLoader

# Create the dataset
train_dataset = MedicalImageDataset(image_dir='/content/organized_data_PNG/train')
test_dataset = MedicalImageDataset(image_dir='/content/organized_data_PNG/test/')

# Adjust batch size as needed (start with batch_size=1)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [27]:
def train_model(model, train_loader, optimizer, epochs=10):
    model.train()  # Set the model to training mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs in train_loader:
            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs)  # Forward pass

            # Calculate the loss
            loss_mse = mse_loss(outputs, inputs)  # Assuming reconstruction task
            loss_ssim = 1 - ssim(outputs, inputs, data_range=outputs.max() - outputs.min())
            total_loss = loss_mse + 0.5 * loss_ssim

            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()

        # Print the average loss per epoch
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

# Train the model
train_model(model, train_loader, optimizer, epochs=10)

RuntimeError: The size of tensor a (512) must match the size of tensor b (1000) at non-singleton dimension 4