In [2]:
import numpy as np
import torch
import os
import nibabel as nib

# Define parameters
data_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Images"  # Update with your actual path
file_list_path = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/train.txt"  # Text file containing filenames without extension

#
# Read file names
with open(file_list_path, "r") as f:
    file_names = [line.strip() for line in f.readlines()]

# Initialize running mean tensor
running_mean = np.zeros((4, 240, 240, 155), dtype=np.float32)

# Iterate through each file and update mean incrementally
for i, file_name in enumerate(file_names):
    file_path = os.path.join(data_dir, file_name + ".nii.gz")  # Assuming .nii.gz format
    nii_img = nib.load(file_path)  # Load the NIfTI file
    sample = np.array(nii_img.get_fdata(), dtype=np.float32)  # Convert to NumPy array

    # Ensure the shape is correct and transpose to (4, 240, 240, 155)
    if sample.shape != (240, 240, 155, 4):
        raise ValueError(f"Unexpected shape {sample.shape} for file {file_name}")
    
    sample = np.transpose(sample, (3, 0, 1, 2))  # Move channels to first dimension

    # Incremental mean update
    running_mean += (sample - running_mean) / (i + 1)

# Convert final mean to PyTorch tensor
mean_tensor = torch.tensor(running_mean)

# Save the mean tensor
torch.save(mean_tensor, "mean_tensor.pt")


print("Mean tensor shape:", mean_tensor.shape)  # Should be (4, 240, 240, 155)

Mean tensor shape: torch.Size([4, 240, 240, 155])


In [5]:
import numpy as np
now_lr = round(0.0002 * np.power(1 - np.float32(999)/np.float32(10000), 0.9), 8) 
print(now_lr)

0.00018192


In [3]:
import torch

checkpoint = torch.load("/mnt/disk1/hjlee/orhun/repo/mmFormer/pt_model/model_last.pth")
new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}

for name, param in new_state_dict.items():
    print(name)


flair_pos
t1ce_pos
t1_pos
t2_pos
flair_encoder.e1_c1.weight
flair_encoder.e1_c1.bias
flair_encoder.e1_c2.conv.weight
flair_encoder.e1_c2.conv.bias
flair_encoder.e1_c3.conv.weight
flair_encoder.e1_c3.conv.bias
flair_encoder.e2_c1.conv.weight
flair_encoder.e2_c1.conv.bias
flair_encoder.e2_c2.conv.weight
flair_encoder.e2_c2.conv.bias
flair_encoder.e2_c3.conv.weight
flair_encoder.e2_c3.conv.bias
flair_encoder.e3_c1.conv.weight
flair_encoder.e3_c1.conv.bias
flair_encoder.e3_c2.conv.weight
flair_encoder.e3_c2.conv.bias
flair_encoder.e3_c3.conv.weight
flair_encoder.e3_c3.conv.bias
flair_encoder.e4_c1.conv.weight
flair_encoder.e4_c1.conv.bias
flair_encoder.e4_c2.conv.weight
flair_encoder.e4_c2.conv.bias
flair_encoder.e4_c3.conv.weight
flair_encoder.e4_c3.conv.bias
flair_encoder.e5_c1.conv.weight
flair_encoder.e5_c1.conv.bias
flair_encoder.e5_c2.conv.weight
flair_encoder.e5_c2.conv.bias
flair_encoder.e5_c3.conv.weight
flair_encoder.e5_c3.conv.bias
t1ce_encoder.e1_c1.weight
t1ce_encoder.e1_c1.bi

In [4]:
import torch

checkpoint = torch.load("/mnt/disk1/hjlee/orhun/repo/mmFormer/pt_model/model_last.pth")
new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}
new_state_dict = {k: v for k, v in new_state_dict.items() if not k.startswith("decoder_fuse")}

for name, param in new_state_dict.items():
    print(name)

flair_pos
t1ce_pos
t1_pos
t2_pos
flair_encoder.e1_c1.weight
flair_encoder.e1_c1.bias
flair_encoder.e1_c2.conv.weight
flair_encoder.e1_c2.conv.bias
flair_encoder.e1_c3.conv.weight
flair_encoder.e1_c3.conv.bias
flair_encoder.e2_c1.conv.weight
flair_encoder.e2_c1.conv.bias
flair_encoder.e2_c2.conv.weight
flair_encoder.e2_c2.conv.bias
flair_encoder.e2_c3.conv.weight
flair_encoder.e2_c3.conv.bias
flair_encoder.e3_c1.conv.weight
flair_encoder.e3_c1.conv.bias
flair_encoder.e3_c2.conv.weight
flair_encoder.e3_c2.conv.bias
flair_encoder.e3_c3.conv.weight
flair_encoder.e3_c3.conv.bias
flair_encoder.e4_c1.conv.weight
flair_encoder.e4_c1.conv.bias
flair_encoder.e4_c2.conv.weight
flair_encoder.e4_c2.conv.bias
flair_encoder.e4_c3.conv.weight
flair_encoder.e4_c3.conv.bias
flair_encoder.e5_c1.conv.weight
flair_encoder.e5_c1.conv.bias
flair_encoder.e5_c2.conv.weight
flair_encoder.e5_c2.conv.bias
flair_encoder.e5_c3.conv.weight
flair_encoder.e5_c3.conv.bias
t1ce_encoder.e1_c1.weight
t1ce_encoder.e1_c1.bi

In [10]:
checkpoint = torch.load("/mnt/disk1/hjlee/orhun/repo/mmFormer/pt_model/model_last.pth")
new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}

for name, param in new_state_dict.items():
    print(name)

flair_pos
t1ce_pos
t1_pos
t2_pos
flair_encoder.e1_c1.weight
flair_encoder.e1_c1.bias
flair_encoder.e1_c2.conv.weight
flair_encoder.e1_c2.conv.bias
flair_encoder.e1_c3.conv.weight
flair_encoder.e1_c3.conv.bias
flair_encoder.e2_c1.conv.weight
flair_encoder.e2_c1.conv.bias
flair_encoder.e2_c2.conv.weight
flair_encoder.e2_c2.conv.bias
flair_encoder.e2_c3.conv.weight
flair_encoder.e2_c3.conv.bias
flair_encoder.e3_c1.conv.weight
flair_encoder.e3_c1.conv.bias
flair_encoder.e3_c2.conv.weight
flair_encoder.e3_c2.conv.bias
flair_encoder.e3_c3.conv.weight
flair_encoder.e3_c3.conv.bias
flair_encoder.e4_c1.conv.weight
flair_encoder.e4_c1.conv.bias
flair_encoder.e4_c2.conv.weight
flair_encoder.e4_c2.conv.bias
flair_encoder.e4_c3.conv.weight
flair_encoder.e4_c3.conv.bias
flair_encoder.e5_c1.conv.weight
flair_encoder.e5_c1.conv.bias
flair_encoder.e5_c2.conv.weight
flair_encoder.e5_c2.conv.bias
flair_encoder.e5_c3.conv.weight
flair_encoder.e5_c3.conv.bias
t1ce_encoder.e1_c1.weight
t1ce_encoder.e1_c1.bi

In [7]:
import torch
checkpoint = torch.load("/mnt/disk1/hjlee/orhun/repo/mmFormer/pt_model/model_last.pth")

for name, param in checkpoint["state_dict"].items():
    print(name)

module.flair_pos
module.t1ce_pos
module.t1_pos
module.t2_pos
module.flair_encoder.e1_c1.weight
module.flair_encoder.e1_c1.bias
module.flair_encoder.e1_c2.conv.weight
module.flair_encoder.e1_c2.conv.bias
module.flair_encoder.e1_c3.conv.weight
module.flair_encoder.e1_c3.conv.bias
module.flair_encoder.e2_c1.conv.weight
module.flair_encoder.e2_c1.conv.bias
module.flair_encoder.e2_c2.conv.weight
module.flair_encoder.e2_c2.conv.bias
module.flair_encoder.e2_c3.conv.weight
module.flair_encoder.e2_c3.conv.bias
module.flair_encoder.e3_c1.conv.weight
module.flair_encoder.e3_c1.conv.bias
module.flair_encoder.e3_c2.conv.weight
module.flair_encoder.e3_c2.conv.bias
module.flair_encoder.e3_c3.conv.weight
module.flair_encoder.e3_c3.conv.bias
module.flair_encoder.e4_c1.conv.weight
module.flair_encoder.e4_c1.conv.bias
module.flair_encoder.e4_c2.conv.weight
module.flair_encoder.e4_c2.conv.bias
module.flair_encoder.e4_c3.conv.weight
module.flair_encoder.e4_c3.conv.bias
module.flair_encoder.e5_c1.conv.weigh

In [6]:
from nets.mmformer import Model, LR_Scheduler, mmformer_mask, get_mmformer_loaders, softmax_weighted_loss_monai, softmax_weighted_loss

model = Model(num_cls=4)
for name, param in model.state_dict().items():
    print(name)

flair_pos
t1ce_pos
t1_pos
t2_pos
flair_encoder.e1_c1.weight
flair_encoder.e1_c1.bias
flair_encoder.e1_c2.conv.weight
flair_encoder.e1_c2.conv.bias
flair_encoder.e1_c3.conv.weight
flair_encoder.e1_c3.conv.bias
flair_encoder.e2_c1.conv.weight
flair_encoder.e2_c1.conv.bias
flair_encoder.e2_c2.conv.weight
flair_encoder.e2_c2.conv.bias
flair_encoder.e2_c3.conv.weight
flair_encoder.e2_c3.conv.bias
flair_encoder.e3_c1.conv.weight
flair_encoder.e3_c1.conv.bias
flair_encoder.e3_c2.conv.weight
flair_encoder.e3_c2.conv.bias
flair_encoder.e3_c3.conv.weight
flair_encoder.e3_c3.conv.bias
flair_encoder.e4_c1.conv.weight
flair_encoder.e4_c1.conv.bias
flair_encoder.e4_c2.conv.weight
flair_encoder.e4_c2.conv.bias
flair_encoder.e4_c3.conv.weight
flair_encoder.e4_c3.conv.bias
flair_encoder.e5_c1.conv.weight
flair_encoder.e5_c1.conv.bias
flair_encoder.e5_c2.conv.weight
flair_encoder.e5_c2.conv.bias
flair_encoder.e5_c3.conv.weight
flair_encoder.e5_c3.conv.bias
t1ce_encoder.e1_c1.weight
t1ce_encoder.e1_c1.bi

In [None]:
from monai.networks.nets import SwinUNETR
import torch

swinunetr_model = SwinUNETR(
    img_size=(128, 128, 128),
    in_channels=4, 
    out_channels=3,
    feature_size=48,
    use_checkpoint=True)

print(swinunetr_model.swinViT)



SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv3d(4, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers1): ModuleList(
    (0): BasicLayer(
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=48, out_features=144, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=48, out_features=48, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
          (mlp): MLPBlock(
            (linear1): Linear(in_features=48, out_features=192, bias=True)
            (linear2): Linear(in_features=192, out_features=48, bias=True)
            (fn): GELU(approximate='n

In [1]:
import torch
checkpoint_path = "/data/hjlee/orhun/thesis/models/finetune_mm48_swinunetr21base_nodecoderloaded_discrete04/finetune_mm48_swinunetr21base_nodecoderloaded_discrete04_checkpoint_Epoch_180.pt"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda:2'))["model_state_dict"]

In [10]:

def sup_128(xmin, xmax):
    if xmax - xmin < 128:
        print('#' * 100)
        ecart = int((128 - (xmax - xmin)) / 2)
        xmax = xmax + ecart + 1
        xmin = xmin - ecart
    if xmin < 0:
        xmax -= xmin
        xmin = 0
    return xmin, xmax

def crop(vol):
    assert len(vol.shape) == 4  # Expecting (H, W, D, C)
    vol_max = np.amax(vol, axis=-1)  # Max projection over channels to find non-zero region
    
    x_nonzeros, y_nonzeros, z_nonzeros = np.where(vol_max != 0)

    x_min, x_max = np.amin(x_nonzeros), np.amax(x_nonzeros)
    y_min, y_max = np.amin(y_nonzeros), np.amax(y_nonzeros)
    z_min, z_max = np.amin(z_nonzeros), np.amax(z_nonzeros)

    x_min, x_max = sup_128(x_min, x_max)
    y_min, y_max = sup_128(y_min, y_max)
    z_min, z_max = sup_128(z_min, z_max)

    return x_min, x_max, y_min, y_max, z_min, z_max

def normalize(vol):
    mask = vol.sum(axis=(0, 1, 2)) > 0  # Sum over spatial dimensions, keep only valid channels
    for k in range(vol.shape[-1]):  # Iterate over channels
        if mask[k]:
            x = vol[..., k]
            y = x[x > 0]
            if y.size > 0:
                vol[..., k] = (x - y.mean()) / (y.std() + 1e-8)  # Normalize with numerical stability
    return vol

def process_and_save(image_path, label_path, output_img_folder, output_label_folder):
    img_nii = nib.load(image_path)
    label_nii = nib.load(label_path)
    
    img_data = img_nii.get_fdata()
    label_data = label_nii.get_fdata()
    
    x_min, x_max, y_min, y_max, z_min, z_max = crop(img_data)  # Extract coordinates from image data
    
    cropped_img = img_data[x_min:x_max, y_min:y_max, z_min:z_max, :]
    cropped_label = label_data[x_min:x_max, y_min:y_max, z_min:z_max]  # Apply same crop to label data
    
    normalized_img = normalize(cropped_img)  # Apply normalization
    
    cropped_img_nii = nib.Nifti1Image(normalized_img, img_nii.affine, img_nii.header)
    cropped_label_nii = nib.Nifti1Image(cropped_label, label_nii.affine, label_nii.header)
    
    os.makedirs(output_img_folder, exist_ok=True)
    os.makedirs(output_label_folder, exist_ok=True)
    
    img_filename = os.path.basename(image_path)
    label_filename = os.path.basename(label_path)
    
    nib.save(cropped_img_nii, os.path.join(output_img_folder, img_filename))
    nib.save(cropped_label_nii, os.path.join(output_label_folder, label_filename))
    
    print(f"Processed and saved: {img_filename}, {label_filename}")

In [3]:
import nibabel as nib

# Load the NIfTI file
nii_file = nib.load("/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Images/Brats18_2013_0_1.nii.gz")

# Get the shape of the image data
image_data = nii_file.get_fdata()
print("Shape of the image:", image_data.shape)


Shape of the image: (128, 128, 128, 4)


In [11]:
image_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Images"
label_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels"
output_img_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C_N/Images"
output_label_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C_N/Labels"

for img_file in os.listdir(image_dir):
    if img_file.endswith(".nii.gz"):
        img_path = os.path.join(image_dir, img_file)
        label_path = os.path.join(label_dir, img_file.replace(".nii.gz", "_label.nii.gz"))
        
        if os.path.exists(label_path):
            process_and_save(img_path, label_path, output_img_dir, output_label_dir)
        else:
            print(f"Label file not found for {img_file}")

Processed and saved: Brats18_TCIA08_105_1.nii.gz, Brats18_TCIA08_105_1_label.nii.gz
Processed and saved: Brats18_CBICA_AAG_1.nii.gz, Brats18_CBICA_AAG_1_label.nii.gz
Processed and saved: Brats18_TCIA13_642_1.nii.gz, Brats18_TCIA13_642_1_label.nii.gz
Processed and saved: Brats18_TCIA10_640_1.nii.gz, Brats18_TCIA10_640_1_label.nii.gz
Processed and saved: Brats18_CBICA_ANG_1.nii.gz, Brats18_CBICA_ANG_1_label.nii.gz
Processed and saved: Brats18_TCIA06_211_1.nii.gz, Brats18_TCIA06_211_1_label.nii.gz
Processed and saved: Brats18_CBICA_AQT_1.nii.gz, Brats18_CBICA_AQT_1_label.nii.gz
####################################################################################################
Processed and saved: Brats18_CBICA_AUN_1.nii.gz, Brats18_CBICA_AUN_1_label.nii.gz
Processed and saved: Brats18_TCIA09_254_1.nii.gz, Brats18_TCIA09_254_1_label.nii.gz
Processed and saved: Brats18_TCIA05_478_1.nii.gz, Brats18_TCIA05_478_1_label.nii.gz
Processed and saved: Brats18_TCIA02_608_1.nii.gz, Brats18_TCIA02_60

In [1]:
from monai.networks.nets import UNet

_model = UNet(
    spatial_dims=3,  # 3D UNet
    in_channels=4,   # BraTS18 has 4 modalities (T1, T1ce, T2, FLAIR)
    out_channels=3,  # 3 output channels for WT, TC, and ET
    channels=(32, 64, 128, 256, 512),  # Number of filters at each level
    strides=(2, 2, 2, 2),  # Strides for downsampling
    num_res_units=2,  # Residual units
    norm='instance',  # Instance normalization
)

print(_model)


UNet(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (unit1): Convolution(
          (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(
          (conv): Se

In [7]:
import torch
import itertools
original_tensor = torch.randn(1, 4, 384, 8, 8)
modified_tensors = []
# Random tensor of shape [384, 8, 8] (same for all replacements)
random_tensor = torch.randn(384, 8, 8)

retained_channels_dict = {}


for r in range(1, 4):  # Only consider 1, 2, or 3 channels
    for channels in itertools.combinations(range(4), r):  # Generate all valid subsets
        print(channels)

        modified_tensor = original_tensor.clone()  # Clone original tensor
        for ch in channels:
            print(ch)
            modified_tensor[0, ch] = random_tensor  # Set selected channels

        modified_tensors.append((channels, modified_tensor))  # Store result

(0,)
0
(1,)
1
(2,)
2
(3,)
3
(0, 1)
0
1
(0, 2)
0
2
(0, 3)
0
3
(1, 2)
1
2
(1, 3)
1
3
(2, 3)
2
3
(0, 1, 2)
0
1
2
(0, 1, 3)
0
1
3
(0, 2, 3)
0
2
3
(1, 2, 3)
1
2
3


In [1]:
# get the mean of each modality features
import torch
from glob import glob
import os
import config
from monai.data import ImageDataset,DataLoader
import utils.utils as utils
from nets.multimodal_swinunetr import Multimodal_SwinUNETR
from nets.mm_ldmv2 import mm_ldmv2
from monai.transforms import (Lambda,
                               Compose, EnsureChannelFirst,
                                 RandSpatialCrop, RandRotate90, 
                                 NormalizeIntensity, RandAdjustContrast,
                                   RandZoom, RandFlip, RandGaussianNoise,
                                     RandGaussianSmooth, RandAdjustContrast,
                                     ConvertToMultiChannelBasedOnBratsClasses,
                                     RandScaleIntensity, RandShiftIntensity
                                )
from monai.transforms import Activations, AsDiscrete, Compose
from tqdm import tqdm

Training_config = config.Training_config()
Database_config = config.Database_config
chosen_ds = Training_config.dataset_to_train[0]
channel_indices = []
modalities_to_train = Training_config.modalities_to_train
for _m in modalities_to_train:
    channel_indices.append(Database_config.channels[chosen_ds].index(_m))

img_path=Database_config.img_path[chosen_ds]
seg_path=Database_config.seg_path[chosen_ds]

images= sorted(glob(os.path.join(img_path, "*.*")))
segs = sorted(glob(os.path.join(seg_path,"*.*"))) 

train_images, train_segs, val_images, val_segs = utils.separate_paths(images,
                                                                    segs,
                                                                    Database_config.split_path[chosen_ds]["train"],
                                                                    Database_config.split_path[chosen_ds]["val"])


def select_channels(x):
    if x.ndim == 4:
        return x[..., channel_indices]
    else:
        return x
imtrans = Compose(
        [   Lambda(select_channels),
            EnsureChannelFirst(strict_check=True),
            NormalizeIntensity(nonzero=True,channel_wise=True),
            #RandFlip(prob=0.5,spatial_axis=[0, 1, 2]),
            #RandGaussianNoise(prob=0.15, mean=0.0, std= 0.33),
            #RandGaussianSmooth(prob=0.15, sigma_x=(0.5, 1.5),sigma_y=(0.5, 1.5),sigma_z=(0.5, 1.5)),
            #RandAdjustContrast(prob=0.15, gamma=(0.7, 1.3)),
            #RandScaleIntensity(factors=0.1, prob=1.0),
            #RandShiftIntensity(offsets=0.1, prob=1.0),
            RandSpatialCrop((128, 128, 128), random_size=False)            
        ])
labeltrans = Compose(
        [   Lambda(select_channels),
            EnsureChannelFirst(strict_check=True),
            #ConvertToMultiChannelBasedOnBratsClassesCustom(),
            #RandFlip(prob=0.5,spatial_axis=[0, 1, 2]),
            RandSpatialCrop((128, 128, 128), random_size=False),
            RandRotate90(prob=0.1, spatial_axes=(0, 2))
        ])
_ds = ImageDataset(train_images, train_segs, transform=imtrans, seg_transform=labeltrans)
_loader = DataLoader(_ds, batch_size=1, drop_last=True, shuffle=True, num_workers=2, pin_memory=0)

device="cuda:2"
#load_model_path = "/data/hjlee/orhun/thesis/models/mm12_sd_ds_dicece/mm12_sd_ds_dicece_checkpoint_Epoch_50.pt"
load_model_path = "/data/hjlee/orhun/thesis/models/mm12_sd_ds_dicece_rd_EXPtpconv/mm12_sd_ds_dicece_rd_EXPtpconv_BEST_ET.pth"

swinunetr = Multimodal_SwinUNETR(
                img_size=(128, 128, 128),
                in_channels=1, 
                out_channels=3,
                feature_size=12,
                deep_supervision=True,
                sep_dec=True,
                tp_conv=True,
                dec_upsample = False
                )

#checkpoint = torch.load(load_model_path, map_location=torch.device(device))["model_state_dict"]
checkpoint = torch.load(load_model_path, map_location=torch.device(device))

swinunetr.load_state_dict(checkpoint)
swinunetr.to(device)
swinunetr.eval()
swinunetr.is_training = False


mean_tensor = torch.zeros((1, 768, 4, 4, 4)).to(device)
step = 0
for batch in tqdm(_loader):
    input_data = batch[0].to(device)
    with torch.no_grad():
        c_hidden_states_out_m1 = swinunetr.swinViT_1(input_data[:,0:1,:,:], normalize=True)[4]
        c_hidden_states_out_m2 = swinunetr.swinViT_2(input_data[:,1:2,:,:], normalize=True)[4]
        c_hidden_states_out_m3 = swinunetr.swinViT_3(input_data[:,2:3,:,:], normalize=True)[4]
        c_hidden_states_out_m4 = swinunetr.swinViT_4(input_data[:,3:4,:,:], normalize=True)[4]
        c_dec4_m1 = swinunetr.encoder10_1(c_hidden_states_out_m1)
        c_dec4_m2 = swinunetr.encoder10_2(c_hidden_states_out_m2)
        c_dec4_m3 = swinunetr.encoder10_3(c_hidden_states_out_m3)
        c_dec4_m4 = swinunetr.encoder10_4(c_hidden_states_out_m4)

        complete_modality_features  = torch.cat((c_dec4_m1, c_dec4_m2, c_dec4_m3, c_dec4_m4), dim=1) # ([B, 4 * fs*16, 8, 8])
    print(complete_modality_features.shape)
    mean_tensor += (complete_modality_features - mean_tensor) / (step + 1)
    step +=1

print(mean_tensor.shape)

torch.save(mean_tensor, "mean_features_mm12_sd_ds_separate_768_4_4_4.pt")

  0%|          | 1/228 [00:03<14:56,  3.95s/it]

torch.Size([1, 768, 4, 4, 4])


  1%|          | 2/228 [00:04<06:43,  1.79s/it]

torch.Size([1, 768, 4, 4, 4])


  1%|▏         | 3/228 [00:04<04:52,  1.30s/it]

torch.Size([1, 768, 4, 4, 4])


  2%|▏         | 4/228 [00:05<03:50,  1.03s/it]

torch.Size([1, 768, 4, 4, 4])


  2%|▏         | 5/228 [00:07<04:50,  1.30s/it]

torch.Size([1, 768, 4, 4, 4])


  3%|▎         | 6/228 [00:08<04:07,  1.11s/it]

torch.Size([1, 768, 4, 4, 4])


  3%|▎         | 7/228 [00:10<06:02,  1.64s/it]

torch.Size([1, 768, 4, 4, 4])


  4%|▎         | 8/228 [00:11<05:00,  1.37s/it]

torch.Size([1, 768, 4, 4, 4])


  4%|▍         | 9/228 [00:14<06:26,  1.77s/it]

torch.Size([1, 768, 4, 4, 4])


  4%|▍         | 10/228 [00:14<04:50,  1.33s/it]

torch.Size([1, 768, 4, 4, 4])


  5%|▍         | 11/228 [00:17<06:27,  1.79s/it]

torch.Size([1, 768, 4, 4, 4])


  5%|▌         | 12/228 [00:18<05:22,  1.49s/it]

torch.Size([1, 768, 4, 4, 4])


  6%|▌         | 13/228 [00:20<06:27,  1.80s/it]

torch.Size([1, 768, 4, 4, 4])


  6%|▌         | 14/228 [00:21<05:02,  1.41s/it]

torch.Size([1, 768, 4, 4, 4])


  7%|▋         | 15/228 [00:23<06:12,  1.75s/it]

torch.Size([1, 768, 4, 4, 4])


  7%|▋         | 16/228 [00:24<04:47,  1.36s/it]

torch.Size([1, 768, 4, 4, 4])


  7%|▋         | 17/228 [00:26<06:05,  1.73s/it]

torch.Size([1, 768, 4, 4, 4])


  8%|▊         | 18/228 [00:27<04:31,  1.29s/it]

torch.Size([1, 768, 4, 4, 4])


  8%|▊         | 19/228 [00:29<05:49,  1.67s/it]

torch.Size([1, 768, 4, 4, 4])


  9%|▉         | 20/228 [00:29<04:21,  1.25s/it]

torch.Size([1, 768, 4, 4, 4])


  9%|▉         | 21/228 [00:33<06:30,  1.89s/it]

torch.Size([1, 768, 4, 4, 4])


 10%|▉         | 22/228 [00:33<04:48,  1.40s/it]

torch.Size([1, 768, 4, 4, 4])


 10%|█         | 23/228 [00:36<06:13,  1.82s/it]

torch.Size([1, 768, 4, 4, 4])


 11%|█         | 24/228 [00:36<04:36,  1.36s/it]

torch.Size([1, 768, 4, 4, 4])


 11%|█         | 25/228 [00:39<05:35,  1.65s/it]

torch.Size([1, 768, 4, 4, 4])


 11%|█▏        | 26/228 [00:39<04:10,  1.24s/it]

torch.Size([1, 768, 4, 4, 4])


 12%|█▏        | 27/228 [00:41<05:35,  1.67s/it]

torch.Size([1, 768, 4, 4, 4])


 12%|█▏        | 28/228 [00:42<04:09,  1.25s/it]

torch.Size([1, 768, 4, 4, 4])


 13%|█▎        | 29/228 [00:45<06:22,  1.92s/it]

torch.Size([1, 768, 4, 4, 4])


 13%|█▎        | 30/228 [00:45<04:42,  1.43s/it]

torch.Size([1, 768, 4, 4, 4])


 14%|█▎        | 31/228 [00:48<05:16,  1.61s/it]

torch.Size([1, 768, 4, 4, 4])


 14%|█▍        | 32/228 [00:48<03:57,  1.21s/it]

torch.Size([1, 768, 4, 4, 4])


 14%|█▍        | 33/228 [00:51<05:46,  1.78s/it]

torch.Size([1, 768, 4, 4, 4])


 15%|█▍        | 34/228 [00:51<04:16,  1.32s/it]

torch.Size([1, 768, 4, 4, 4])


 15%|█▌        | 35/228 [00:54<05:57,  1.85s/it]

torch.Size([1, 768, 4, 4, 4])


 16%|█▌        | 36/228 [00:55<04:24,  1.38s/it]

torch.Size([1, 768, 4, 4, 4])


 16%|█▌        | 37/228 [00:57<05:45,  1.81s/it]

torch.Size([1, 768, 4, 4, 4])


 17%|█▋        | 38/228 [00:58<04:16,  1.35s/it]

torch.Size([1, 768, 4, 4, 4])


 17%|█▋        | 39/228 [00:59<04:39,  1.48s/it]

torch.Size([1, 768, 4, 4, 4])


 18%|█▊        | 40/228 [01:00<03:30,  1.12s/it]

torch.Size([1, 768, 4, 4, 4])


 18%|█▊        | 41/228 [01:02<04:30,  1.44s/it]

torch.Size([1, 768, 4, 4, 4])


 18%|█▊        | 42/228 [01:02<03:23,  1.09s/it]

torch.Size([1, 768, 4, 4, 4])


 19%|█▉        | 43/228 [01:05<05:01,  1.63s/it]

torch.Size([1, 768, 4, 4, 4])


 19%|█▉        | 44/228 [01:05<03:44,  1.22s/it]

torch.Size([1, 768, 4, 4, 4])


 20%|█▉        | 45/228 [01:08<05:01,  1.65s/it]

torch.Size([1, 768, 4, 4, 4])


 20%|██        | 46/228 [01:08<03:46,  1.24s/it]

torch.Size([1, 768, 4, 4, 4])


 21%|██        | 47/228 [01:11<04:58,  1.65s/it]

torch.Size([1, 768, 4, 4, 4])


 21%|██        | 48/228 [01:11<03:46,  1.26s/it]

torch.Size([1, 768, 4, 4, 4])


 21%|██▏       | 49/228 [01:14<05:12,  1.74s/it]

torch.Size([1, 768, 4, 4, 4])


 22%|██▏       | 50/228 [01:14<04:00,  1.35s/it]

torch.Size([1, 768, 4, 4, 4])


 22%|██▏       | 51/228 [01:17<04:35,  1.56s/it]

torch.Size([1, 768, 4, 4, 4])


 23%|██▎       | 52/228 [01:18<04:13,  1.44s/it]

torch.Size([1, 768, 4, 4, 4])


 23%|██▎       | 53/228 [01:19<04:04,  1.40s/it]

torch.Size([1, 768, 4, 4, 4])


 24%|██▎       | 54/228 [01:20<04:04,  1.41s/it]

torch.Size([1, 768, 4, 4, 4])


 24%|██▍       | 55/228 [01:21<03:41,  1.28s/it]

torch.Size([1, 768, 4, 4, 4])


 25%|██▍       | 56/228 [01:23<04:17,  1.49s/it]

torch.Size([1, 768, 4, 4, 4])


 25%|██▌       | 57/228 [01:24<03:17,  1.15s/it]

torch.Size([1, 768, 4, 4, 4])


 25%|██▌       | 58/228 [01:25<03:42,  1.31s/it]

torch.Size([1, 768, 4, 4, 4])


 26%|██▌       | 59/228 [01:26<03:00,  1.07s/it]

torch.Size([1, 768, 4, 4, 4])


 26%|██▋       | 60/228 [01:28<03:35,  1.28s/it]

torch.Size([1, 768, 4, 4, 4])


 27%|██▋       | 61/228 [01:28<02:43,  1.02it/s]

torch.Size([1, 768, 4, 4, 4])


 27%|██▋       | 62/228 [01:30<03:45,  1.36s/it]

torch.Size([1, 768, 4, 4, 4])


 28%|██▊       | 63/228 [01:31<02:57,  1.08s/it]

torch.Size([1, 768, 4, 4, 4])


 28%|██▊       | 64/228 [01:33<03:59,  1.46s/it]

torch.Size([1, 768, 4, 4, 4])


 29%|██▊       | 65/228 [01:33<03:00,  1.11s/it]

torch.Size([1, 768, 4, 4, 4])


 29%|██▉       | 66/228 [01:35<03:49,  1.42s/it]

torch.Size([1, 768, 4, 4, 4])


 29%|██▉       | 67/228 [01:36<03:03,  1.14s/it]

torch.Size([1, 768, 4, 4, 4])


 30%|██▉       | 68/228 [01:38<03:33,  1.34s/it]

torch.Size([1, 768, 4, 4, 4])


 30%|███       | 69/228 [01:39<03:12,  1.21s/it]

torch.Size([1, 768, 4, 4, 4])


 31%|███       | 70/228 [01:41<03:55,  1.49s/it]

torch.Size([1, 768, 4, 4, 4])


 31%|███       | 71/228 [01:41<03:07,  1.20s/it]

torch.Size([1, 768, 4, 4, 4])


 32%|███▏      | 72/228 [01:44<04:06,  1.58s/it]

torch.Size([1, 768, 4, 4, 4])


 32%|███▏      | 73/228 [01:44<03:04,  1.19s/it]

torch.Size([1, 768, 4, 4, 4])


 32%|███▏      | 74/228 [01:47<04:02,  1.57s/it]

torch.Size([1, 768, 4, 4, 4])


 33%|███▎      | 75/228 [01:47<03:08,  1.23s/it]

torch.Size([1, 768, 4, 4, 4])


 33%|███▎      | 76/228 [01:50<04:12,  1.66s/it]

torch.Size([1, 768, 4, 4, 4])


 34%|███▍      | 77/228 [01:50<03:08,  1.25s/it]

torch.Size([1, 768, 4, 4, 4])


 34%|███▍      | 78/228 [01:52<03:52,  1.55s/it]

torch.Size([1, 768, 4, 4, 4])


 35%|███▍      | 79/228 [01:53<03:07,  1.26s/it]

torch.Size([1, 768, 4, 4, 4])


 35%|███▌      | 80/228 [01:55<04:11,  1.70s/it]

torch.Size([1, 768, 4, 4, 4])


 36%|███▌      | 81/228 [01:56<03:09,  1.29s/it]

torch.Size([1, 768, 4, 4, 4])


 36%|███▌      | 82/228 [01:58<03:50,  1.58s/it]

torch.Size([1, 768, 4, 4, 4])


 36%|███▋      | 83/228 [01:58<02:54,  1.20s/it]

torch.Size([1, 768, 4, 4, 4])


 37%|███▋      | 84/228 [02:01<03:53,  1.62s/it]

torch.Size([1, 768, 4, 4, 4])


 37%|███▋      | 85/228 [02:01<02:56,  1.24s/it]

torch.Size([1, 768, 4, 4, 4])


 38%|███▊      | 86/228 [02:04<03:47,  1.60s/it]

torch.Size([1, 768, 4, 4, 4])


 38%|███▊      | 87/228 [02:04<02:55,  1.24s/it]

torch.Size([1, 768, 4, 4, 4])


 39%|███▊      | 88/228 [02:07<03:47,  1.62s/it]

torch.Size([1, 768, 4, 4, 4])


 39%|███▉      | 89/228 [02:07<02:49,  1.22s/it]

torch.Size([1, 768, 4, 4, 4])


 39%|███▉      | 90/228 [02:09<03:37,  1.58s/it]

torch.Size([1, 768, 4, 4, 4])


 40%|███▉      | 91/228 [02:10<02:44,  1.20s/it]

torch.Size([1, 768, 4, 4, 4])


 40%|████      | 92/228 [02:12<03:29,  1.54s/it]

torch.Size([1, 768, 4, 4, 4])


 41%|████      | 93/228 [02:13<02:46,  1.23s/it]

torch.Size([1, 768, 4, 4, 4])


 41%|████      | 94/228 [02:15<03:36,  1.61s/it]

torch.Size([1, 768, 4, 4, 4])


 42%|████▏     | 95/228 [02:16<02:53,  1.30s/it]

torch.Size([1, 768, 4, 4, 4])


 42%|████▏     | 96/228 [02:18<03:34,  1.62s/it]

torch.Size([1, 768, 4, 4, 4])


 43%|████▎     | 97/228 [02:19<02:50,  1.30s/it]

torch.Size([1, 768, 4, 4, 4])


 43%|████▎     | 98/228 [02:21<03:36,  1.67s/it]

torch.Size([1, 768, 4, 4, 4])


 43%|████▎     | 99/228 [02:22<02:54,  1.35s/it]

torch.Size([1, 768, 4, 4, 4])


 44%|████▍     | 100/228 [02:24<03:25,  1.61s/it]

torch.Size([1, 768, 4, 4, 4])


 44%|████▍     | 101/228 [02:25<02:51,  1.35s/it]

torch.Size([1, 768, 4, 4, 4])


 45%|████▍     | 102/228 [02:26<03:05,  1.47s/it]

torch.Size([1, 768, 4, 4, 4])


 45%|████▌     | 103/228 [02:28<02:53,  1.39s/it]

torch.Size([1, 768, 4, 4, 4])


 46%|████▌     | 104/228 [02:29<02:50,  1.38s/it]

torch.Size([1, 768, 4, 4, 4])


 46%|████▌     | 105/228 [02:30<02:53,  1.41s/it]

torch.Size([1, 768, 4, 4, 4])


 46%|████▋     | 106/228 [02:32<02:59,  1.47s/it]

torch.Size([1, 768, 4, 4, 4])


 47%|████▋     | 107/228 [02:33<02:33,  1.27s/it]

torch.Size([1, 768, 4, 4, 4])


 47%|████▋     | 108/228 [02:35<03:14,  1.62s/it]

torch.Size([1, 768, 4, 4, 4])


 48%|████▊     | 109/228 [02:36<02:24,  1.21s/it]

torch.Size([1, 768, 4, 4, 4])


 48%|████▊     | 110/228 [02:38<03:18,  1.68s/it]

torch.Size([1, 768, 4, 4, 4])


 49%|████▊     | 111/228 [02:39<02:27,  1.26s/it]

torch.Size([1, 768, 4, 4, 4])


 49%|████▉     | 112/228 [02:42<03:24,  1.77s/it]

torch.Size([1, 768, 4, 4, 4])


 50%|████▉     | 113/228 [02:42<02:31,  1.32s/it]

torch.Size([1, 768, 4, 4, 4])


 50%|█████     | 114/228 [02:45<03:26,  1.81s/it]

torch.Size([1, 768, 4, 4, 4])


 50%|█████     | 115/228 [02:45<02:32,  1.35s/it]

torch.Size([1, 768, 4, 4, 4])


 51%|█████     | 116/228 [02:47<03:04,  1.65s/it]

torch.Size([1, 768, 4, 4, 4])


 51%|█████▏    | 117/228 [02:48<02:16,  1.23s/it]

torch.Size([1, 768, 4, 4, 4])


 52%|█████▏    | 118/228 [02:50<03:08,  1.71s/it]

torch.Size([1, 768, 4, 4, 4])


 52%|█████▏    | 119/228 [02:51<02:19,  1.28s/it]

torch.Size([1, 768, 4, 4, 4])


 53%|█████▎    | 120/228 [02:53<02:35,  1.44s/it]

torch.Size([1, 768, 4, 4, 4])


 53%|█████▎    | 121/228 [02:53<01:56,  1.09s/it]

torch.Size([1, 768, 4, 4, 4])


 54%|█████▎    | 122/228 [02:55<02:27,  1.39s/it]

torch.Size([1, 768, 4, 4, 4])


 54%|█████▍    | 123/228 [02:55<01:50,  1.05s/it]

torch.Size([1, 768, 4, 4, 4])


 54%|█████▍    | 124/228 [02:57<02:21,  1.36s/it]

torch.Size([1, 768, 4, 4, 4])


 55%|█████▍    | 125/228 [02:58<01:46,  1.03s/it]

torch.Size([1, 768, 4, 4, 4])


 55%|█████▌    | 126/228 [03:00<02:17,  1.35s/it]

torch.Size([1, 768, 4, 4, 4])


 56%|█████▌    | 127/228 [03:00<01:43,  1.03s/it]

torch.Size([1, 768, 4, 4, 4])


 56%|█████▌    | 128/228 [03:02<02:16,  1.37s/it]

torch.Size([1, 768, 4, 4, 4])


 57%|█████▋    | 129/228 [03:02<01:42,  1.04s/it]

torch.Size([1, 768, 4, 4, 4])


 57%|█████▋    | 130/228 [03:05<02:16,  1.39s/it]

torch.Size([1, 768, 4, 4, 4])


 57%|█████▋    | 131/228 [03:05<01:45,  1.08s/it]

torch.Size([1, 768, 4, 4, 4])


 58%|█████▊    | 132/228 [03:08<02:31,  1.57s/it]

torch.Size([1, 768, 4, 4, 4])


 58%|█████▊    | 133/228 [03:08<01:59,  1.26s/it]

torch.Size([1, 768, 4, 4, 4])


 59%|█████▉    | 134/228 [03:11<02:51,  1.82s/it]

torch.Size([1, 768, 4, 4, 4])


 59%|█████▉    | 135/228 [03:12<02:15,  1.46s/it]

torch.Size([1, 768, 4, 4, 4])


 60%|█████▉    | 136/228 [03:14<02:43,  1.78s/it]

torch.Size([1, 768, 4, 4, 4])


 60%|██████    | 137/228 [03:15<02:17,  1.51s/it]

torch.Size([1, 768, 4, 4, 4])


 61%|██████    | 138/228 [03:18<02:42,  1.80s/it]

torch.Size([1, 768, 4, 4, 4])


 61%|██████    | 139/228 [03:19<02:13,  1.50s/it]

torch.Size([1, 768, 4, 4, 4])


 61%|██████▏   | 140/228 [03:21<02:40,  1.83s/it]

torch.Size([1, 768, 4, 4, 4])


 62%|██████▏   | 141/228 [03:22<02:17,  1.58s/it]

torch.Size([1, 768, 4, 4, 4])


 62%|██████▏   | 142/228 [03:25<02:39,  1.85s/it]

torch.Size([1, 768, 4, 4, 4])


 63%|██████▎   | 143/228 [03:26<02:21,  1.67s/it]

torch.Size([1, 768, 4, 4, 4])


 63%|██████▎   | 144/228 [03:28<02:33,  1.82s/it]

torch.Size([1, 768, 4, 4, 4])


 64%|██████▎   | 145/228 [03:30<02:36,  1.88s/it]

torch.Size([1, 768, 4, 4, 4])


 64%|██████▍   | 146/228 [03:32<02:31,  1.85s/it]

torch.Size([1, 768, 4, 4, 4])


 64%|██████▍   | 147/228 [03:34<02:39,  1.98s/it]

torch.Size([1, 768, 4, 4, 4])


 65%|██████▍   | 148/228 [03:35<02:20,  1.76s/it]

torch.Size([1, 768, 4, 4, 4])


 65%|██████▌   | 149/228 [03:38<02:39,  2.01s/it]

torch.Size([1, 768, 4, 4, 4])


 66%|██████▌   | 150/228 [03:39<02:18,  1.78s/it]

torch.Size([1, 768, 4, 4, 4])


 66%|██████▌   | 151/228 [03:42<02:37,  2.05s/it]

torch.Size([1, 768, 4, 4, 4])


 67%|██████▋   | 152/228 [03:43<02:15,  1.79s/it]

torch.Size([1, 768, 4, 4, 4])


 67%|██████▋   | 153/228 [03:45<02:24,  1.93s/it]

torch.Size([1, 768, 4, 4, 4])


 68%|██████▊   | 154/228 [03:46<02:02,  1.65s/it]

torch.Size([1, 768, 4, 4, 4])


 68%|██████▊   | 155/228 [03:49<02:32,  2.09s/it]

torch.Size([1, 768, 4, 4, 4])


 68%|██████▊   | 156/228 [03:50<01:51,  1.55s/it]

torch.Size([1, 768, 4, 4, 4])


 69%|██████▉   | 157/228 [03:53<02:34,  2.17s/it]

torch.Size([1, 768, 4, 4, 4])


 69%|██████▉   | 158/228 [03:54<01:54,  1.64s/it]

torch.Size([1, 768, 4, 4, 4])


 70%|██████▉   | 159/228 [03:57<02:28,  2.15s/it]

torch.Size([1, 768, 4, 4, 4])


 70%|███████   | 160/228 [03:57<01:48,  1.59s/it]

torch.Size([1, 768, 4, 4, 4])


 71%|███████   | 161/228 [04:01<02:18,  2.06s/it]

torch.Size([1, 768, 4, 4, 4])


 71%|███████   | 162/228 [04:01<01:42,  1.56s/it]

torch.Size([1, 768, 4, 4, 4])


 71%|███████▏  | 163/228 [04:04<02:02,  1.88s/it]

torch.Size([1, 768, 4, 4, 4])


 72%|███████▏  | 164/228 [04:04<01:35,  1.50s/it]

torch.Size([1, 768, 4, 4, 4])


 72%|███████▏  | 165/228 [04:08<02:12,  2.11s/it]

torch.Size([1, 768, 4, 4, 4])


 73%|███████▎  | 166/228 [04:08<01:39,  1.60s/it]

torch.Size([1, 768, 4, 4, 4])


 73%|███████▎  | 167/228 [04:11<02:03,  2.02s/it]

torch.Size([1, 768, 4, 4, 4])


 74%|███████▎  | 168/228 [04:12<01:37,  1.63s/it]

torch.Size([1, 768, 4, 4, 4])


 74%|███████▍  | 169/228 [04:14<01:47,  1.82s/it]

torch.Size([1, 768, 4, 4, 4])


 75%|███████▍  | 170/228 [04:15<01:33,  1.62s/it]

torch.Size([1, 768, 4, 4, 4])


 75%|███████▌  | 171/228 [04:17<01:42,  1.79s/it]

torch.Size([1, 768, 4, 4, 4])


 75%|███████▌  | 172/228 [04:19<01:34,  1.69s/it]

torch.Size([1, 768, 4, 4, 4])


 76%|███████▌  | 173/228 [04:22<01:48,  1.97s/it]

torch.Size([1, 768, 4, 4, 4])


 76%|███████▋  | 174/228 [04:23<01:44,  1.94s/it]

torch.Size([1, 768, 4, 4, 4])


 77%|███████▋  | 175/228 [04:26<01:57,  2.22s/it]

torch.Size([1, 768, 4, 4, 4])


 77%|███████▋  | 176/228 [04:27<01:37,  1.87s/it]

torch.Size([1, 768, 4, 4, 4])


 78%|███████▊  | 177/228 [04:30<01:51,  2.18s/it]

torch.Size([1, 768, 4, 4, 4])


 78%|███████▊  | 178/228 [04:31<01:23,  1.66s/it]

torch.Size([1, 768, 4, 4, 4])


 79%|███████▊  | 179/228 [04:34<01:42,  2.09s/it]

torch.Size([1, 768, 4, 4, 4])


 79%|███████▉  | 180/228 [04:34<01:14,  1.56s/it]

torch.Size([1, 768, 4, 4, 4])


 79%|███████▉  | 181/228 [04:37<01:34,  2.01s/it]

torch.Size([1, 768, 4, 4, 4])


 80%|███████▉  | 182/228 [04:37<01:08,  1.50s/it]

torch.Size([1, 768, 4, 4, 4])


 80%|████████  | 183/228 [04:41<01:28,  1.97s/it]

torch.Size([1, 768, 4, 4, 4])


 81%|████████  | 184/228 [04:41<01:08,  1.56s/it]

torch.Size([1, 768, 4, 4, 4])


 81%|████████  | 185/228 [04:44<01:18,  1.83s/it]

torch.Size([1, 768, 4, 4, 4])


 82%|████████▏ | 186/228 [04:44<01:03,  1.52s/it]

torch.Size([1, 768, 4, 4, 4])


 82%|████████▏ | 187/228 [04:47<01:11,  1.74s/it]

torch.Size([1, 768, 4, 4, 4])


 82%|████████▏ | 188/228 [04:48<01:04,  1.62s/it]

torch.Size([1, 768, 4, 4, 4])


 83%|████████▎ | 189/228 [04:50<01:06,  1.69s/it]

torch.Size([1, 768, 4, 4, 4])


 83%|████████▎ | 190/228 [04:51<00:58,  1.55s/it]

torch.Size([1, 768, 4, 4, 4])


 84%|████████▍ | 191/228 [04:52<00:55,  1.51s/it]

torch.Size([1, 768, 4, 4, 4])


 84%|████████▍ | 192/228 [04:53<00:46,  1.30s/it]

torch.Size([1, 768, 4, 4, 4])


 85%|████████▍ | 193/228 [04:55<00:51,  1.48s/it]

torch.Size([1, 768, 4, 4, 4])


 85%|████████▌ | 194/228 [04:56<00:39,  1.17s/it]

torch.Size([1, 768, 4, 4, 4])


 86%|████████▌ | 195/228 [04:58<00:47,  1.45s/it]

torch.Size([1, 768, 4, 4, 4])


 86%|████████▌ | 196/228 [04:58<00:35,  1.10s/it]

torch.Size([1, 768, 4, 4, 4])


 86%|████████▋ | 197/228 [05:00<00:39,  1.28s/it]

torch.Size([1, 768, 4, 4, 4])


 87%|████████▋ | 198/228 [05:00<00:30,  1.02s/it]

torch.Size([1, 768, 4, 4, 4])


 87%|████████▋ | 199/228 [05:04<00:51,  1.77s/it]

torch.Size([1, 768, 4, 4, 4])


 88%|████████▊ | 200/228 [05:04<00:36,  1.31s/it]

torch.Size([1, 768, 4, 4, 4])


 88%|████████▊ | 201/228 [05:06<00:38,  1.44s/it]

torch.Size([1, 768, 4, 4, 4])


 89%|████████▊ | 202/228 [05:06<00:28,  1.09s/it]

torch.Size([1, 768, 4, 4, 4])


 89%|████████▉ | 203/228 [05:08<00:33,  1.33s/it]

torch.Size([1, 768, 4, 4, 4])


 89%|████████▉ | 204/228 [05:08<00:24,  1.01s/it]

torch.Size([1, 768, 4, 4, 4])


 90%|████████▉ | 205/228 [05:10<00:26,  1.17s/it]

torch.Size([1, 768, 4, 4, 4])


 90%|█████████ | 206/228 [05:10<00:19,  1.11it/s]

torch.Size([1, 768, 4, 4, 4])


 91%|█████████ | 207/228 [05:11<00:23,  1.10s/it]

torch.Size([1, 768, 4, 4, 4])


 91%|█████████ | 208/228 [05:12<00:17,  1.17it/s]

torch.Size([1, 768, 4, 4, 4])


 92%|█████████▏| 209/228 [05:14<00:24,  1.30s/it]

torch.Size([1, 768, 4, 4, 4])


 92%|█████████▏| 210/228 [05:14<00:17,  1.01it/s]

torch.Size([1, 768, 4, 4, 4])


 93%|█████████▎| 211/228 [05:16<00:21,  1.29s/it]

torch.Size([1, 768, 4, 4, 4])


 93%|█████████▎| 212/228 [05:17<00:15,  1.02it/s]

torch.Size([1, 768, 4, 4, 4])


 93%|█████████▎| 213/228 [05:18<00:17,  1.20s/it]

torch.Size([1, 768, 4, 4, 4])


 94%|█████████▍| 214/228 [05:19<00:12,  1.09it/s]

torch.Size([1, 768, 4, 4, 4])


 94%|█████████▍| 215/228 [05:20<00:14,  1.14s/it]

torch.Size([1, 768, 4, 4, 4])


 95%|█████████▍| 216/228 [05:20<00:10,  1.14it/s]

torch.Size([1, 768, 4, 4, 4])


 95%|█████████▌| 217/228 [05:22<00:12,  1.15s/it]

torch.Size([1, 768, 4, 4, 4])


 96%|█████████▌| 218/228 [05:22<00:08,  1.13it/s]

torch.Size([1, 768, 4, 4, 4])


 96%|█████████▌| 219/228 [05:24<00:10,  1.13s/it]

torch.Size([1, 768, 4, 4, 4])


 96%|█████████▋| 220/228 [05:24<00:06,  1.14it/s]

torch.Size([1, 768, 4, 4, 4])


 97%|█████████▋| 221/228 [05:26<00:06,  1.03it/s]

torch.Size([1, 768, 4, 4, 4])


 97%|█████████▋| 222/228 [05:26<00:04,  1.32it/s]

torch.Size([1, 768, 4, 4, 4])


 98%|█████████▊| 223/228 [05:28<00:05,  1.03s/it]

torch.Size([1, 768, 4, 4, 4])


 98%|█████████▊| 224/228 [05:28<00:03,  1.25it/s]

torch.Size([1, 768, 4, 4, 4])


 99%|█████████▊| 225/228 [05:29<00:03,  1.04s/it]

torch.Size([1, 768, 4, 4, 4])


 99%|█████████▉| 226/228 [05:30<00:01,  1.24it/s]

torch.Size([1, 768, 4, 4, 4])


100%|█████████▉| 227/228 [05:31<00:00,  1.21it/s]

torch.Size([1, 768, 4, 4, 4])


100%|██████████| 228/228 [05:31<00:00,  1.45s/it]

torch.Size([1, 768, 4, 4, 4])
torch.Size([1, 768, 4, 4, 4])





In [1]:
import torch

mean_feature = torch.load("/mnt/disk1/hjlee/orhun/repo/thesis/mean_features_1536_4_4_4.pt")

In [3]:
mean_feature.shape

torch.Size([1, 1536, 4, 4, 4])

In [11]:

import nibabel as nib

# Load the .nii.gz file
nii_image = nib.load("/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_2013_0_1.nii.gz")

# Get the data as a NumPy array
image_data = nii_image.get_fdata()
image_data.shape

(4, 192, 192, 192)

In [2]:
import os
import nibabel as nib
import numpy as np
# Set the directory path
directory = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images"


# Initialize an empty list to store shapes
shapes = []

# Iterate through all files in the directory
for root, _, files in os.walk(directory):
    for file in files:
        if file.endswith(".nii.gz"):
            file_path = os.path.join(root, file)
            nii_img = nib.load(file_path)  # Load the NIfTI file
            shape = nii_img.shape
            shapes.append(shape)  # Store shape
            print(f"File: {file_path}, Shape: {shape}")

# Compute maximum size for each dimension
if shapes:
    max_shape = np.max(np.array([list(shape) for shape in shapes]), axis=0)
    print(f"\nMaximum size for each dimension: {max_shape}")
else:
    print("No NIfTI files found in the directory.")

File: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images/Brats18_TCIA08_105_1.nii.gz, Shape: (133, 159, 136, 4)
File: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images/Brats18_CBICA_AAG_1.nii.gz, Shape: (133, 186, 139, 4)
File: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images/Brats18_TCIA13_642_1.nii.gz, Shape: (136, 182, 140, 4)
File: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images/Brats18_TCIA10_640_1.nii.gz, Shape: (143, 175, 132, 4)
File: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images/Brats18_CBICA_ANG_1.nii.gz, Shape: (132, 175, 139, 4)
File: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images/Brats18_TCIA06_211_1.nii.gz, Shape: (129, 169, 134, 4)
File: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images/Brats18_CBICA_AQT_1.nii.gz, Shape: (141, 161, 134, 4)
File: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/Images/Brats18_CB

In [9]:
from monai.transforms import (Lambda,
                               Compose, EnsureChannelFirst,
                                 RandSpatialCrop, RandRotate90, 
                                 NormalizeIntensity, RandAdjustContrast,
                                   RandZoom, RandFlip, RandGaussianNoise,
                                     RandGaussianSmooth, RandAdjustContrast,
                                     ConvertToMultiChannelBasedOnBratsClasses,
                                     RandScaleIntensity, RandShiftIntensity,
                                     SpatialCrop, SpatialPad
                                )
class ConvertToMultiChannelBasedOnBratsClassesCustom(ConvertToMultiChannelBasedOnBratsClasses):

    def __call__(self, img):
        if img.ndim == 4 and img.shape[0] == 1:
            img = img.squeeze(0)
        #result = [ img == 3, (img == 1) | (img == 3), (img == 1) | (img == 3) | (img == 2)]
        result = [(img == 1) | (img == 3), (img == 1) | (img == 3) | (img == 2), img == 3]
        
        # order: ET, TC, WT
        return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)



import torch
import os
import glob
import nibabel as nib
import numpy as np
from monai.transforms import SpatialPad

# Define input and output directories
input_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_C/Training/"  # Adjust this
output_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N"

# Ensure output directories exist
os.makedirs(os.path.join(output_dir, "Images"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "Labels"), exist_ok=True)

# Get all image and label files
image_files = glob.glob(os.path.join(input_dir, "Images", "*.nii.gz"))
label_files = glob.glob(os.path.join(input_dir, "Labels", "*.nii.gz"))

# Define the target shape
target_shape = (192, 192, 192)

# Padding transform
pad_transform = SpatialPad(spatial_size=target_shape, method="symmetric")
normalize= NormalizeIntensity(nonzero=True,channel_wise=True)
change_label= ConvertToMultiChannelBasedOnBratsClassesCustom()


def process_and_save_nii_images(file_path, output_folder):
    """Load, transform, and save a NIfTI file with proper shape and padding."""
    # Load NIfTI image
    img_nib = nib.load(file_path)
    img_data = img_nib.get_fdata()  # Get numpy array

    # Ensure modality channel is first: (X, Y, Z, 4) -> (4, X, Y, Z)
    img_data = np.moveaxis(img_data, -1, 0)  # Move last axis to first

    # Apply padding
    norm_img = normalize(img_data)
    padded_img = pad_transform(norm_img)
    
    # Convert back to NIfTI image
    new_nib = nib.Nifti1Image(padded_img, img_nib.affine, img_nib.header)

    # Save file
    output_path = os.path.join(output_folder, os.path.basename(file_path))
    nib.save(new_nib, output_path)
    print(f"Saved: {output_path}")

def process_and_save_nii_labels(file_path, output_folder):
    """Load, transform, and save a NIfTI file with proper shape and padding."""
    # Load NIfTI image
    img_nib = nib.load(file_path)
    img_data = img_nib.get_fdata()  # Get numpy array

    # Ensure modality channel is first: (X, Y, Z, 4) -> (4, X, Y, Z)
    #img_data = np.moveaxis(img_data, -1, 0)  # Move last axis to first
    img_data = img_data[None, ...]  # Shape becomes (1, X, Y, Z)
    # Apply padding
    norm_img = normalize(img_data)
    padded_img = pad_transform(norm_img)
    final = change_label(padded_img)
    # Convert back to NIfTI image
    new_nib = nib.Nifti1Image(final, img_nib.affine, img_nib.header)

    # Save file
    output_path = os.path.join(output_folder, os.path.basename(file_path))
    nib.save(new_nib, output_path)
    print(f"Saved: {output_path}")


# Process Images
for file in image_files:
    process_and_save_nii_images(file, os.path.join(output_dir, "Images"))

# Process Labels
for file in label_files:
    process_and_save_nii_labels(file, os.path.join(output_dir, "Labels"))

Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_TCIA08_105_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_CBICA_AAG_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_TCIA13_642_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_TCIA10_640_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_CBICA_ANG_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_TCIA06_211_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_CBICA_AQT_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_CBICA_AUN_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_TCIA09_254_1.nii.gz
Saved: /data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N/Images/Brats18_TCIA05_478_1.nii.gz
Saved: /data/hjlee/orhun

In [2]:
import os
import glob
import nibabel as nib
import numpy as np
from monai.transforms import SpatialPad

# Define input and output directories
input_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/"  # Adjust this
output_dir = "/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed_P_N"

# Ensure output directories exist
os.makedirs(os.path.join(output_dir, "Images"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "Labels"), exist_ok=True)

# Get all image and label files
image_files = glob.glob(os.path.join(input_dir, "Images", "*.nii.gz"))
label_files = glob.glob(os.path.join(input_dir, "Labels", "*.nii.gz"))

print(label_files)

['/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_2013_8_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_CBICA_ASG_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_TCIA10_393_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_TCIA02_179_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_CBICA_AAL_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_CBICA_AZD_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_CBICA_AOZ_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_TCIA10_241_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/Labels/Brats18_TCIA02_608_1_label.nii.gz', '/data/hjlee/orhun/data/BRATS18/BRATS18_Preprocessed/Training/

In [1]:
import itertools
from monai.metrics import DiceMetric

# Step 1: Generate all 4-bit binary combinations excluding 0000 and 1111
binary_combinations = [''.join(map(str, bits)) for bits in itertools.product([0, 1], repeat=4)]
binary_combinations = [comb for comb in binary_combinations if comb != '0000' and comb != '1111']

# Step 2: Create a dictionary with binary combinations as keys and Dice metrics as values
dice_metric = DiceMetric(include_background=True, reduction="mean_batch", get_not_nans=False)
dice_dict = {comb: dice_metric for comb in binary_combinations}

In [2]:
print(dice_dict)

{'0001': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '0010': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '0011': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '0100': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '0101': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '0110': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '0111': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '1000': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '1001': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '1010': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '1011': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '1100': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '1101': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>, '1110': <monai.metrics.meandice.DiceMetric object at 0x74afd015e890>}


In [6]:
binary_key = ['0'] * 4
print(binary_key)
print(str(binary_key))
print("".join(binary_key))

['0', '0', '0', '0']
['0', '0', '0', '0']
0000


In [3]:
from monai.metrics import DiceMetric
import itertools

binary_combinations = [''.join(map(str, bits)) for bits in itertools.product([0, 1], repeat=4)]
binary_combinations = [comb for comb in binary_combinations if comb != '0000' and comb != '1111']

dice_dict_m = {comb: DiceMetric(include_background=True, reduction="mean_batch", get_not_nans=False) for comb in binary_combinations}

binary_key = ['0'] * 4
index_list= [0]
for index in index_list:
    binary_key[index] = '1'
    binary_key = ''.join(binary_key) 

In [6]:
for sc_key in dice_dict_m:
    print(sc_key)
    if sc_key == "".join(binary_key):
        print("True")
    else:
        print("False")

0001
False
0010
False
0011
False
0100
False
0101
False
0110
False
0111
False
1000
True
1001
False
1010
False
1011
False
1100
False
1101
False
1110
False


In [None]:
import itertools
import random

# Define the elements
elements = {0, 1, 2, 3}

# Generate all valid subsets (excluding empty and full set)
valid_subsets = [list(subset) for i in range(1, len(elements))
                 for subset in itertools.combinations(elements, i)]

# Pick one random subset
random_scenario = random.choice(valid_subsets)

print(random_scenario)





[1, 2]


In [4]:
import itertools
for r in range(1, 2):       
    for channels in itertools.combinations(range(1), r):
        print(channels)

(0,)
