# Imports

In [7]:
import torch
import glob
from Algorithms.Unet3D.unet3D import UNet3D
from Algorithms.Unet3D.mha_loader import MHA_Dataset
from torch.utils.data import DataLoader
from Algorithms.utils import Get_biggest_target_shape
from Algorithms.trainer import UnetTrainer
from Datasets.SegmentationDataset import SegmentationDataset
import torchio as tio
from pathlib import Path

# Loading toothFairy dataset

In [None]:
toothFairy2_image_path = "Data/ToothFairy2/imagesTr"
toothFairy2_labels_path = "Data/ToothFairy2/labelsTr"
target_shape = [304,512,512] # dividable by 16

In [3]:
mha_image_files = glob.glob(f"{toothFairy2_image_path}/*.mha")
mha_labels_files = glob.glob(f"{toothFairy2_labels_path}/*.mha")
print(mha_image_files)
print(mha_labels_files)

['Data/ToothFairy2/imagesTr\\ToothFairy2F_001_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_002_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_003_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_004_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_005_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_006_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_007_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_008_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_009_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_010_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_011_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_012_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_013_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_014_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_015_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_016_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_018_0000.mha', 'Data/ToothFairy2/imagesTr\\ToothFairy2F_020_00

In [None]:
dataset = MHA_Dataset(mha_image_files,mha_labels_files,target_shape)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

print(train_size,val_size)

dataLoader = DataLoader(dataset,batch_size=2,shuffle=True)

In [None]:
# getting target size

target_size = Get_biggest_target_shape(mha_image_files)
print(target_size)

In [10]:
for images,labels in dataLoader:
    print(f"Image size: {images.shape}, label size: {labels.shape}")
    break

Image size: torch.Size([2, 1, 298, 512, 512]), label size: torch.Size([2, 298, 512, 512])


## new DataLoader test 

In [3]:
images = ["Data/ChinaCBCT/img/1000813648_20180116.nii.gz",
          "Data/ChinaCBCT/img/1000889125_20171009.nii.gz",
          "Data/ChinaCBCT/img/1000889125_20171016.nii.gz"
          ]

masks = ["Data/ChinaCBCT/label/1000813648_20180116.nii.gz",
         "Data/ChinaCBCT/label/1000889125_20171009.nii.gz",
         "Data//ChinaCBCT/label/1000889125_20171016.nii.gz"
        ]

target_shape_val = (400,400,280)

transform_val = tio.RandomFlip(axes=(0, 1, 2), flip_probability=0.5)

In [4]:
dataset = SegmentationDataset(images,masks,target_shape = target_shape_val,transform = transform_val)
loader = DataLoader(dataset,batch_size=1,shuffle=False)

In [5]:
image_tensor, mask_tensor = next(iter(loader)) # pobranie batcha
print("Obraz shape:", image_tensor.shape)  # ➜ [1, 1, D, H, W]
print("Maska shape:", mask_tensor.shape)   # ➜ [1, 1, D, H, W]

Obraz shape: torch.Size([1, 1, 400, 400, 280])
Maska shape: torch.Size([1, 1, 400, 400, 280])


## Przetestowanie wytrenowania 3 zdjęć na Unet3D


In [10]:
image_paths = sorted([str(p) for p in Path('Data/ChinaCBCT/img').glob('*.nii.gz')])
masks_paths = sorted([str(p) for p in Path('Data/ChinaCBCT/label').glob('*.nii.gz')])

print(image_paths)
print(masks_paths)

['Data\\ChinaCBCT\\img\\1000813648_20180116.nii.gz', 'Data\\ChinaCBCT\\img\\1000889125_20171009.nii.gz', 'Data\\ChinaCBCT\\img\\1000889125_20171016.nii.gz', 'Data\\ChinaCBCT\\img\\1000889125_20180109.nii.gz', 'Data\\ChinaCBCT\\img\\1000889125_20180521.nii.gz', 'Data\\ChinaCBCT\\img\\1000889125_20181106.nii.gz', 'Data\\ChinaCBCT\\img\\1000889125_20190408.nii.gz', 'Data\\ChinaCBCT\\img\\1000889125_20191101.nii.gz', 'Data\\ChinaCBCT\\img\\1000889125_20200421.nii.gz', 'Data\\ChinaCBCT\\img\\1000915187_20180115.nii.gz', 'Data\\ChinaCBCT\\img\\1000915187_20191217.nii.gz', 'Data\\ChinaCBCT\\img\\1000966359_20180113.nii.gz', 'Data\\ChinaCBCT\\img\\1000971031_20180112.nii.gz', 'Data\\ChinaCBCT\\img\\1000983254_20180109.nii.gz', 'Data\\ChinaCBCT\\img\\1000983254_20180904.nii.gz', 'Data\\ChinaCBCT\\img\\1000995722_20180112.nii.gz', 'Data\\ChinaCBCT\\img\\1001009635_20180116.nii.gz', 'Data\\ChinaCBCT\\img\\1001012179_20180110.nii.gz', 'Data\\ChinaCBCT\\img\\1001012179_20180116.nii.gz', 'Data\\Chin

In [11]:
# dane do trenowania / podział testowy
image_paths = image_paths[:10]
masks_paths = masks_paths[:10]

train_image_paths = image_paths[:8]
train_mask_paths = masks_paths[:8]

val_image_paths = image_paths[8:]
val_mask_paths = masks_paths[8:]

## Parametry sieci
batch_siez = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_epochs = 40
lr = 1e-4 # learning rate
target_shape_val = (400,400,280)


In [12]:
train_transform = tio.Compose([
    tio.RandomFlip(axes=(0, 1, 2)),
    tio.RandomAffine(scales=(0.9, 1.1), degrees=10, translation=5),
    tio.RandomNoise(std=0.01)
])

In [14]:
train_dataset = SegmentationDataset(train_image_paths, train_mask_paths,
                                        target_shape=target_shape_val,
                                        transform=train_transform)

val_dataset = SegmentationDataset(val_image_paths, val_mask_paths,
                                      target_shape=target_shape_val,
                                      transform=None)

In [None]:
model = UNet3D(in_channels=1,out_channels=33)

# 3D Unet

In [13]:
model = UNet3D(in_channels=1,out_channels=1,base_channels=32)

In [14]:
model = model.cuda()

In [None]:
trainer = UnetTrainer(model=model,train_dataset=tra )