In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from monai.losses import DiceCELoss
from torch.optim import Adam
# from dataset import VSDataset
from model import DynUNet

  check_for_updates()


In [None]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

from utils.utils import load_nifti_as_dhw, dicom_load

class VSDataset(Dataset):
    def __init__(self, csv_path, data_dir, transform=None, target_slices=None):
        self.data = pd.read_csv(csv_path)
        self.data_dir = data_dir
        self.transform = transform
        self.target_slices = target_slices
        self.image_filenames = self.data['image_path'].tolist()
        self.mask_filenames = self.data['SegmentationPath'].tolist()

    def transform_volume(self, image_volume, mask_volume):
        image_volume = image_volume.transpose(1, 2, 0)  
        mask_volume = mask_volume.transpose(1, 2, 0)   
        # print('before:-', image_volume.shape, mask_volume.shape)
        transformed = self.transform(
            image=image_volume,
            mask=mask_volume
        )
        images = transformed['image']
        masks = transformed['mask']
        masks= masks.permute(2, 0, 1)
        # print('after:- ',images.shape, masks.shape)
        return images, masks.float()

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.data_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.data_dir, self.mask_filenames[idx])

        mask = load_nifti_as_dhw(mask_path)
        image = dicom_load(image_path, mask.shape)
        # print(image.shape, mask.shape)

        transformed_image_volume, transformed_mask_volume = self.transform_volume(image, mask)
        
        # # Change to (C, D, H, W)
        image_tensor = transformed_image_volume.unsqueeze(0)  
        mask_tensor = transformed_mask_volume.unsqueeze(0)   
        # print(image_tensor.shape, mask_tensor.shape)

        # Now handle expansion or removal
        current_slices = image_tensor.shape[1]

        if self.target_slices:
            if current_slices < self.target_slices:
                # --- Duplicate labeled slices ---
                # required_slices = self.target_slices - current_slices
                labeled_slices = []
                for i in range(current_slices):
                    unique_vals = torch.unique(mask_tensor[0, i, :, :])
                    if 0 in unique_vals and 1 in unique_vals:
                        labeled_slices.append(i)

                if len(labeled_slices) == 0:
                    raise ValueError("No labeled slices (with both 0 and 1) found. Cannot duplicate.")

                while current_slices < self.target_slices:
                    for i in labeled_slices:
                        if current_slices < self.target_slices:
                            image_tensor = torch.cat((image_tensor, image_tensor[:, i:i+1, :, :]), dim=1)
                            mask_tensor = torch.cat((mask_tensor, mask_tensor[:, i:i+1, :, :]), dim=1)
                            current_slices += 1

            elif current_slices > self.target_slices:
                # --- Remove unlabeled slices ---
                unlabeled_slices = []
                for i in range(current_slices):
                    unique_vals = torch.unique(mask_tensor[0, i, :, :])
                    if torch.all(unique_vals == 0):
                        unlabeled_slices.append(i)

                if len(unlabeled_slices) == 0:
                    raise ValueError("No unlabeled slices found for removal.")

                slices_to_keep = list(range(current_slices))

                # Remove unlabeled slices until reaching target
                for i in unlabeled_slices:
                    if len(slices_to_keep) > self.target_slices:
                        slices_to_keep.remove(i)

                # After removal, crop
                image_tensor = image_tensor[:, slices_to_keep, :, :]
                mask_tensor = mask_tensor[:, slices_to_keep, :, :]

                # Final safety check
                if image_tensor.shape[1] != self.target_slices:
                    raise ValueError(f"After removal, slice count {image_tensor.shape[1]} not equal to target {self.target_slices}.")

        return image_tensor, mask_tensor
        # return transformed_image_volume, transformed_mask_volume

In [11]:
batch_size = 1
num_workers = 0
pin_memory = False
LEARNING_RATE = 1e-4
num_epochs = 4
image_size = 128

In [12]:
slice_transform = A.Compose([
    A.Resize(image_size, image_size),
    A.Rotate(limit=35, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.Normalize(mean=0.0, std=1.0, max_pixel_value=255.0),
    ToTensorV2()
])

In [13]:
# masks_path = 'D:\\3dVS1\\sample_data\\Masks'
# images_path = 'D:\\3dVS1\\sample_data\\Image'
data_dir= r'D:\VSdata'
# csv_path = r'D:\VSdata\vs_paths_cleaned.csv' 
csv_path= r"C:\Users\Acer\Desktop\vs_paths.csv"

In [14]:
dataset = VSDataset(
    csv_path= csv_path,
    data_dir=data_dir,
    transform=slice_transform,  
    target_slices=128            
)

In [15]:
x,y = dataset[3]
x.shape, y.shape

(torch.Size([1, 128, 128, 128]), torch.Size([1, 128, 128, 128]))

In [16]:
train_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

In [9]:
model= DynUNet(spatial_dims=3,
    in_channels=1,
    out_channels=1,
    kernel_size=[3, 3, 3, 3, 3, 3],
    strides=[1, 2, 2, 2, 2, 2],
    upsample_kernel_size=[2, 2, 2, 2, 2],
    res_block=True,
)

In [10]:
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = Adam(model.parameters(), lr=1e-4)

In [11]:
device= torch.device('cpu')

In [12]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # optimizer.zero_grad()

        # outputs = model(inputs)
        # outputs = torch.sigmoid(outputs)
        # loss = loss_function(outputs, targets)
        # loss.backward()
        # optimizer.step()

        # running_loss += loss.item()
        
        # if batch_idx % 10 == 0:
        #     print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item()}")

    # avg_loss = running_loss / len(train_loader)
    # print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss}")

print("Training complete!")

Reading image: D:\VSdata\Vestibular-Schwannoma-MC-RC\VS-MC-RC-048\01-27-1995-NA-t4of5  MRI Head-13489\15.000000-T2cor iams-56664
Reading mask: D:\VSdata\VS-MC-RC segmentations-NIfTI-Release 2023/VS-MC-RC-048/1995-01-27/seg_T2.nii.gz
[INFO] image_stack shape already matches target shape.
[INFO] Final image_stack shape: (55, 640, 640)
(55, 640, 640) (55, 640, 640)
torch.Size([640, 128, 128]) torch.Size([128, 128, 640])
Reading image: D:\VSdata\Vestibular-Schwannoma-MC-RC\VS-MC-RC-015\08-15-1990-NA-t2of3  External Images for PACS-98115\301.000000-T2I.A.C.      T2 TSE-97536
Reading mask: D:\VSdata\VS-MC-RC segmentations-NIfTI-Release 2023/VS-MC-RC-015/1990-08-15/seg_T2.nii.gz
[INFO] image_stack shape already matches target shape.
[INFO] Final image_stack shape: (30, 256, 256)
(30, 256, 256) (30, 256, 256)
torch.Size([256, 128, 128]) torch.Size([128, 128, 256])
Reading image: D:\VSdata\Vestibular-Schwannoma-MC-RC\VS-MC-RC-052\09-26-1998-NA-t3of3  MRI IAMS-98521\5.000000-T2FIESTA-C-94297
Rea

error: OpenCV(4.11.0) D:\a\opencv-python\opencv-python\opencv\modules\core\src\matrix_transform.cpp:784: error: (-215:Assertion failed) _src.dims() <= 2 in function 'cv::flip'


In [None]:
csv_path = r'D:\VSdata\vs_paths.csv' 
import pandas as pd
df= pd.read_csv(csv_path)
df[df['patient_id']=='VS-MC-RC-036']

Unnamed: 0,patient_id,series_instance_uid,study_date,SegmentationPath,image_path,imagenifti_path
182,VS-MC-RC-036,1.3.6.1.4.1.14519.5.2.1.2356822478735907152433...,07-21-1995,VS-MC-RC segmentations-NIfTI-Release 2023/VS-M...,manifest-1742405880893\Vestibular-Schwannoma-M...,image-nifti\VS-MC-RC-036\07-21-1995\image.nii.gz
183,VS-MC-RC-036,1.3.6.1.4.1.14519.5.2.1.1160169313945234564236...,07-09-1999,VS-MC-RC segmentations-NIfTI-Release 2023/VS-M...,manifest-1742405880893\Vestibular-Schwannoma-M...,image-nifti\VS-MC-RC-036\07-09-1999\image.nii.gz


In [None]:
from utils.utils import load_nifti_as_dhw
from utils.utils import dicom_load

image= r'D:\VSdata\image-nifti\VS-MC-RC-028\10-12-1992\image.nii.gz'
mask = r'D:\VSdata\VS-MC-RC segmentations-NIfTI-Release 2023/VS-MC-RC-028/1992-10-12/seg_T2.nii.gz'
path= r'D:\VSdata\Vestibular-Schwannoma-MC-RC\VS-MC-RC-028\10-12-1992-NA-t1of7  External Images for PACS-81342\3.000000-T2t2 tra 3d-ciss-71390'

mask= load_nifti_as_dhw(mask)
image= dicom_load(path, mask.shape)

print(image.shape)
print(mask.shape)

[INFO] Transposing image_stack from (40, 512, 384) to (40, 384, 512) using order (0, 2, 1)
[INFO] Final image_stack shape: (40, 384, 512)
(40, 384, 512)
(40, 384, 512)


In [None]:
image= r'D:\VSdata\image-nifti\VS-MC-RC-028\10-12-1992\image.nii.gz'
image= load_nifti_as_dhw(image)
print(image.shape)
print(type(image))

(40, 384, 512)
<class 'numpy.ndarray'>


In [None]:
import pandas as pd

csv_path = r'D:\VSdata\vs_paths.csv'
df = pd.read_csv(csv_path)

# Remove 'manifest-<digits>\\' from the beginning of the image_path
df['image_path'] = df['image_path'].str.replace(r'^manifest-\d+\\', '', regex=True)

# Save the cleaned dataframe back to CSV (overwrite original)
df.to_csv(r'D:\VSdata\vs_paths_cleaned.csv', index=False)
