In [4]:
from __future__ import annotations
from collections.abc import Sequence

import torch
import torch.nn as nn

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.utils import deprecated_arg
from collections.abc import Sequence


__all__ = ["ViT"]


class ViT(nn.Module):
    """
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    ViT supports Torchscript but only works for Pytorch after 1.8.
    """

    @deprecated_arg(
        name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead."
    )
    def __init__(
        self,
        in_channels: int,
        img_size: Sequence[int] | int,
        patch_size: Sequence[int] | int,
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_layers: int = 12,
        num_heads: int = 12,
        pos_embed: str = "conv",
        proj_type: str = "conv",
        pos_embed_type: str = "learnable",
        classification: bool = False,
        num_classes: int = 2,
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
        post_activation="Tanh",
        qkv_bias: bool = False,
        save_attn: bool = False,
    ) -> None:
        super().__init__()
        self.patch_embedding = PatchEmbeddingBlock(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            proj_type=proj_type,
            pos_embed_type=pos_embed_type,
            dropout_rate=dropout_rate,
        )
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(
                    hidden_size=hidden_size,
                    mlp_dim=mlp_dim,
                    num_heads=num_heads,
                    dropout_rate=dropout_rate,
                    spatial_dims=spatial_dims,
                    post_activation=post_activation,
                    qkv_bias=qkv_bias,
                    save_attn=save_attn,
                )
                for _ in range(num_layers)
            ]
        )
        self.norm = nn.LayerNorm(hidden_size)
        if classification:
            self.classification_head = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        ct = x[:, 0, :, :, :]  # Extract CT scan
        pet = x[:, 1, :, :, :]  # Extract PET scan

        x_ct = self.patch_embedding(ct)  # Convert CT scan to patch embeddings
        x_pet = self.patch_embedding(pet)  # Convert PET scan to patch embeddings

        x = torch.cat((x_ct, x_pet), dim=1)  # Concatenate CT and PET patch embeddings

        if hasattr(self, "cls_token"):
            cls_token = self.cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_token, x), dim=1)

        hidden_states_out = []
        for blk in self.blocks:
            x = blk(x)
            hidden_states_out.append(x)
        x = self.norm(x)
        if hasattr(self, "classification_head"):
            x = self.classification_head(x[:, 0])
        return x, hidden_states_out


##UNFROZEN DECODER


In [5]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from glob import glob
# import wandb

import monai
from monai.losses import DiceCELoss, DiceFocalLoss, FocalLoss
from monai.inferers import sliding_window_inference
from monai import transforms

from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    MapTransform,
    ScaleIntensityd,
    #AddChanneld,
    SpatialPadd,
    CenterSpatialCropd,
    EnsureChannelFirstd,
    ConcatItemsd,
    AdjustContrastd, 
    Rand3DElasticd,
    HistogramNormalized,
    NormalizeIntensityd,
    Invertd,
    SaveImage,

)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR, UNETR, SegResNet

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)
from monai import data


from monai.utils import first, set_determinism
from sklearn.model_selection import train_test_split
import json


import torch

In [6]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
# 
from unetr import CustomedUNETR
from vit import ViT
from monai.transforms import Compose
#from monai.data import NiftiDataset
import json
from monai.data import Dataset, DataLoader
from torch import device
from monai.transforms import AsDiscrete


In [7]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 
import torch
torch.cuda.device_count()

1

In [8]:
data_dir = '/home/nada.saadi/MIS-FM/hecktor2022_cropped'
json_dir = '/home/nada.saadi/MIS-FM/hecktor2022_cropped/MDA_CTPT_TRAIN.json'

In [9]:
def datafold_read(datalist, basedir, fold=0, key="training"):
    with open(datalist) as f:
        json_data = json.load(f)

    json_data = json_data[key]

    for d in json_data:
        for k in d:
            if isinstance(d[k], list):
                d[k] = [os.path.join(basedir, iv) for iv in d[k]]
            elif isinstance(d[k], str):
                d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]

    tr = []
    val = []
    for d in json_data:
        if "fold" in d and d["fold"] == fold:
            val.append(d)
        else:
            tr.append(d)

    return tr, val

In [10]:
train_files, validation_files = datafold_read(datalist=json_dir, basedir=data_dir, fold=0)
len(train_files), len(validation_files)

(152, 44)

In [11]:
class ClipCT(MapTransform):
    """
    Convert labels to multi channels based on hecktor classes:
    label 1 is the tumor
    label 2 is the lymph node

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            if key == "ct":
                d[key] = torch.clip(d[key], min=-200, max=200)
            # elif key == "pt":
            #     d[key] = torch.clip(d[key], d[key].min(), 5)
        return d

class MulPTFM(MapTransform):
    """
    Mult PT and FM 

    """

    def __call__(self, data):
        d = dict(data)

        fm = d["ct"] > 0
        d["pt"] = d["pt"] * fm
        return d

class SelectClass(MapTransform):
    """
    Select the class for which you want to fine tune the model 

    """
    # def __init__(self, keys, cls=1):
    #     super(self).__init__(keys)
    #     self.cls = cls

    def __call__(self, data):
        d = dict(data)
        d["seg"][d["seg"] == 1] = 0
        # d["seg"][d["seg"] == 2] = 1
        
        return d

In [12]:

from monai.transforms import EnsureTyped
# Path to the JSON file
json_file_path = '/home/nada.saadi/MIS-FM/hecktor2022_cropped/MDA_CTPT_TRAIN.json'

# Load the data from the JSON file
with open(json_file_path, 'r') as file:
    data_json = json.load(file)["training"]

from monai.transforms import (
    Compose, LoadImaged, Orientationd, NormalizeIntensityd, 
     ScaleIntensityd, ConcatItemsd, RandCropByPosNegLabeld, 
    RandFlipd, RandRotate90d, SpatialPadd
)

# Split data into training and validation based on fold
train_data = [entry for entry in data_json if entry['fold'] != 0]  # Using fold 0 for validation
val_data = [entry for entry in data_json if entry['fold'] == 0]

# Create datasets and dataloaders
def create_dataloader(data, transforms, batch_size=2, shuffle=True):
    dataset = Dataset(data=data, transform=transforms)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=8)


num_samples = 4

train_transforms = Compose(
    [
        LoadImaged(keys=["ct", "pt", "seg"], ensure_channel_first=True),
        SpatialPadd(keys=["ct", "pt", "seg"], spatial_size=(200, 200, 310), method='end'),
        Orientationd(keys=["ct", "pt", "seg"], axcodes="PLS"),
        NormalizeIntensityd(keys=["pt"]),
        ClipCT(keys=["ct"]),
        ScaleIntensityd(keys=["ct"], minv=0, maxv=1),
        ConcatItemsd(keys=["pt", "ct"], name="ctpt"),  # Concatenate CT and PET
        RandCropByPosNegLabeld(
            keys=["ctpt", "seg"],
            label_key="seg",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=num_samples,
            image_key="ctpt",
            image_threshold=0,
        ),
        RandFlipd(keys=["ctpt", "seg"], spatial_axis=[0], prob=0.20),
        RandFlipd(keys=["ctpt", "seg"], spatial_axis=[1], prob=0.20),
        RandFlipd(keys=["ctpt", "seg"], spatial_axis=[2], prob=0.20),
        RandRotate90d(keys=["ctpt", "seg"], prob=0.20, max_k=3),
        EnsureTyped(keys=["ctpt", "seg"]),
    
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["ct", "pt", "seg"], ensure_channel_first=True),
        SpatialPadd(keys=["ct", "pt", "seg"], spatial_size=(200, 200, 310), method='end'),
        Orientationd(keys=["ct", "pt", "seg"], axcodes="PLS"),
        NormalizeIntensityd(keys=["pt"]),
        ClipCT(keys=["ct"]),
        ScaleIntensityd(keys=["ct"], minv=0, maxv=1),
        ConcatItemsd(keys=["pt", "ct"], name="ctpt"),# Concatenate CT and PET
        EnsureTyped(keys=["ctpt", "seg"]),
    ]
)

train_loader = create_dataloader(train_data, train_transforms, shuffle=True)
val_loader = create_dataloader(val_data, val_transforms, shuffle=False)

### -------------------

In [13]:
model_dir='/home/nada.saadi/CTPET/hecktor2022_cropped/UNFROZEN-PATCH-MDA'

import torch
import torch.nn as nn
import json
from torch.utils.data import DataLoader


from monai.transforms import Compose, LoadImaged, ScaleIntensityRanged, ConcatItemsd
from monai.data import Dataset
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from collections.abc import Sequence
from monai.transforms import EnsureTyped


from vit import ViT
from unetr import CustomedUNETR



# from monai.networks.blocks.dynunet_block import UnetOutBlock
# from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
# from vit import ViT
# from monai.utils import deprecated_arg, ensure_tuple_rep

# Path to the JSON file
json_file_path = '/home/nada.saadi/MIS-FM/hecktor2022_cropped/MDA_CTPT_TRAIN.json'

# Load the data from the JSON file
with open(json_file_path, 'r') as file:
    data_json = json.load(file)["training"]

from monai.transforms import (
    Compose, LoadImaged, Orientationd, NormalizeIntensityd, 
     ScaleIntensityd, ConcatItemsd, RandCropByPosNegLabeld, 
    RandFlipd, RandRotate90d, SpatialPadd
)

# # Split data into training and validation based on fold
# train_data = [entry for entry in data_json if entry['fold'] != 0]  # Using fold 0 for validation
# val_data = [entry for entry in data_json if entry['fold'] == 0]

# # Create datasets and dataloaders
# def create_dataloader(data, transforms, batch_size=2, shuffle=True):
#     dataset = Dataset(data=data, transform=transforms)
#     return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=8)


num_samples = 4

train_transforms = Compose(
    [
        LoadImaged(keys=["ct", "pt", "seg"], ensure_channel_first=True),
        SpatialPadd(keys=["ct", "pt", "seg"], spatial_size=(200, 200, 310), method='end'),
        Orientationd(keys=["ct", "pt", "seg"], axcodes="PLS"),
        NormalizeIntensityd(keys=["pt"]),
        ClipCT(keys=["ct"]),
        ScaleIntensityd(keys=["ct"], minv=0, maxv=1),
        ConcatItemsd(keys=["pt", "ct"], name="ctpt"),  # Concatenate CT and PET
        RandCropByPosNegLabeld(
            keys=["ctpt", "seg"],
            label_key="seg",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=num_samples,
            image_key="ctpt",
            image_threshold=0,
        ),
        RandFlipd(keys=["ctpt", "seg"], spatial_axis=[0], prob=0.20),
        RandFlipd(keys=["ctpt", "seg"], spatial_axis=[1], prob=0.20),
        RandFlipd(keys=["ctpt", "seg"], spatial_axis=[2], prob=0.20),
        RandRotate90d(keys=["ctpt", "seg"], prob=0.20, max_k=3),
        EnsureTyped(keys=["ctpt", "seg"]),
    
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["ct", "pt", "seg"], ensure_channel_first=True),
        SpatialPadd(keys=["ct", "pt", "seg"], spatial_size=(200, 200, 310), method='end'),
        Orientationd(keys=["ct", "pt", "seg"], axcodes="PLS"),
        NormalizeIntensityd(keys=["pt"]),
        ClipCT(keys=["ct"]),
        ScaleIntensityd(keys=["ct"], minv=0, maxv=1),
        ConcatItemsd(keys=["pt", "ct"], name="ctpt"),# Concatenate CT and PET
        EnsureTyped(keys=["ctpt", "seg"]),
    ]
)


from monai.data import CacheDataset  # Using CacheDataset for efficiency

def create_dataloader(data, transforms, batch_size=2, shuffle=True):
    # Reformat the data to match the expected format
    formatted_data = [{"ct": entry["ct"], "pt": entry["pt"], "seg": entry["seg"]} for entry in data]

    # Create CacheDataset with the reformatted data
    dataset = CacheDataset(data=formatted_data, transform=transforms, cache_rate=1.0)

    train_data = [entry for entry in data_json if entry['fold'] != 0]  # Using fold 0 for validation
    val_data = [entry for entry in data_json if entry['fold'] == 0]

    # Create datasets and dataloaders
    def create_dataloader(data, transforms, batch_size=2, shuffle=True):
        dataset = Dataset(data=data, transform=transforms)
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=8)



    # Create DataLoader
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=8)

# Create dataloaders using the modified function
train_loader = create_dataloader(train_data, train_transforms, shuffle=True)
val_loader = create_dataloader(val_data, val_transforms, shuffle=False)






Loading dataset: 100%|██████████| 152/152 [02:10<00:00,  1.16it/s]
Loading dataset: 100%|██████████| 44/44 [00:36<00:00,  1.20it/s]


In [None]:
# Model, loss, optimizer, and metrics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Instantiate the model
model = UNETR(
    in_channels=1,  # Number of input channels
    out_channels=3,  # Number of output channels
    img_size=(96, 96, 96),  # Size of the input image
    feature_size=48,  # Size of the feature maps
    hidden_size=768,
    num_heads = 12,# Size of the hidden layers in the transformer
    mlp_dim=3072,  # Dimension of the MLP in the transformer
    pos_embed="perceptron",  # Type of positional embedding
    norm_name="instance",  # Type of normalization
    res_block=True,  # Whether to use residual blocks
    dropout_rate=0.0,
    proj_type= "conv",
).to(device)


loss_function = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
dice_metric_batch = DiceMetric(include_background=False, reduction="mean_batch")

# Training loop
max_num_epochs = 530
max_iterations = 18000
eval_num = 100
global_step = 0
dice_val_best = 0.0
global_step_best = 0

for epoch in range(max_num_epochs):
    model.train()
    for batch_data in train_loader:
        # Debugging: Check the type and keys of batch_data
        print("Batch data type:", type(batch_data))
        if isinstance(batch_data, dict):
            print("Batch data keys:", batch_data.keys())

        if isinstance(batch_data, dict) and "ctpt" in batch_data and "seg" in batch_data:
            inputs, targets = batch_data["ctpt"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()
        else:
            print("Invalid batch format received")
    # Validation
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_inputs, val_targets = val_data["ctpt"].to(device), val_data["seg"].to(device)
            val_outputs = model(val_inputs)
            dice_metric(y_pred=val_outputs, y=val_targets)
            dice_metric_batch(y_pred=val_outputs, y=val_targets)

    avg_dice = dice_metric.aggregate().item()
    metric_batch = dice_metric_batch.aggregate()
    tumor_dice, lymph_dice = metric_batch[0].item(), metric_batch[1].item()
    dice_metric.reset()
    dice_metric_batch.reset()

    # Print metrics
    print(f"Epoch {epoch+1}/{max_num_epochs}, Avg Dice: {avg_dice}, Tumor Dice: {tumor_dice}, Lymph Dice: {lymph_dice}")

    # Save model if it's the best so far
    if avg_dice > dice_val_best:
        dice_val_best = avg_dice
        torch.save(model.state_dict(), os.path.join(model_dir, f"/home/nada.saadi/CTPET/hecktor2022_cropped/UNFROZEN-PATCH-MDA/Unfrozen-mda-model{epoch+1}.pth"))

    # Update global step
    global_step += len(train_loader)
    if global_step >= max_iterations:
        break


In [None]:
import inspect
from monai.networks.nets import UNETR

print(inspect.signature(UNETR))


(in_channels: 'int', out_channels: 'int', img_size: 'Sequence[int] | int', feature_size: 'int' = 16, hidden_size: 'int' = 768, mlp_dim: 'int' = 3072, num_heads: 'int' = 12, pos_embed: 'str' = 'conv', proj_type: 'str' = 'conv', norm_name: 'tuple | str' = 'instance', conv_block: 'bool' = True, res_block: 'bool' = True, dropout_rate: 'float' = 0.0, spatial_dims: 'int' = 3, qkv_bias: 'bool' = False, save_attn: 'bool' = False) -> 'None'


In [11]:
train_data.shape()

AttributeError: 'list' object has no attribute 'shape'

In [14]:
a =[1,2,3,4,5,6]
a[::2]

[1, 3, 5]

In [None]:
print(device)


cuda


In [None]:
import monai.networks.nets as nets
print(nets.UNETR)


<class 'monai.networks.nets.unetr.UNETR'>
