In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import os


In [2]:
class NiftiDataset(Dataset):
    def __init__(self, source_dir, target_dir, transforms):
        """
        create a dataset in PyTorch for reading NIfTI files
        Args:
            source_dir (str): path to source images
            target_dir (str): path to target images
            transform (callable) transform to apply to both source and target images
        """
        self.source_dir = source_dir
        self.target_dir = target_dir
        self.transform = transforms
        self.source_images = NiftiDataset.get_data_paths(source_dir)
        self.target_images = NiftiDataset.change_path_for_targets(self.source_images)
        
    def get_data_paths(directory):
        paths = []
        for i in os.listdir(directory):
            paths.append(os.path.join(directory, i))
        return paths
    
    def change_path_for_targets(source_images):
        target_images = []
        for i in source_images:
            mapped_string = i.replace("t1", "t2")
            mapped_string = mapped_string.replace("T1_fcm","T2_reg_fcm" )
            target_images.append(mapped_string)
        return target_images
            
    def __len__(self):
        return len(self.source_dir)
    
    def __getitem__(self, idx):
        source_image = nib.load(self.source_images[idx]).get_fdata()
        target_image = nib.load(self.target_images[idx]).get_fdata()
        source_patch, target_patch = self.transform(source_image, target_image)

        return np.transpose(source_patch, (1, 0, 2)), np.transpose(target_patch, (1, 0, 2))
    
class RandomCrop3D:
    def __init__(self, output_size):
        self.output_size = output_size
        
    def generate_patch(self, source, target):
        if isinstance(self.output_size, int):
            crop_start_x = np.random.randint(0, source.shape[0] - self.output_size + 1)
            crop_start_y = np.random.randint(0, source.shape[2] - self.output_size + 1)
            crop_end_x = crop_start_x + self.output_size
            crop_end_y = crop_start_y + self.output_size

            return (source[crop_start_x:crop_end_x, :, crop_start_y: crop_end_y],
                   target[crop_start_x:crop_end_x, :, crop_start_y: crop_end_y])
        
        if isinstance(self.output_size, tuple):
            crop_start_x = np.random.randint(0, source.shape[0] - self.output_size[0] + 1)
            crop_start_y = np.random.randint(0, source.shape[2] - self.output_size[1] + 1)
            crop_end_x = crop_start_x + self.output_size[0]
            crop_end_y = crop_start_y + self.output_size[1]

            return (source[crop_start_x:crop_end_x, :, crop_start_y: crop_end_y],
                   target[crop_start_x:crop_end_x, :, crop_start_y: crop_end_y])

        

    

In [14]:
class ConvNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = torch.nn.Sequential(
                        torch.nn.Conv3d(120, 240, 3, 1, 1),
                        torch.nn.BatchNorm3d(240),
                        torch.nn.ReLU(),
                        torch.nn.Conv3d(240, 240, 3, 1, 1),
                        torch.nn.BatchNorm3d(240),
                        torch.nn.ReLU(),
                        torch.nn.Conv3d(240, 120, 3, 1, 1),
                                         )
    
    def forward(self, X):
        return self.layers(X)
    



In [15]:
rand = RandomCrop3D(65)
df = NiftiDataset("../small_data/small/t1/", "../small_data/small/t2/", rand.generate_patch)


In [16]:
train_data, test_data = torch.utils.data.random_split(df, [17, 6])

In [17]:
device = torch.device("cuda")
model = ConvNet()
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr = 5e-5)
criterion = torch.nn.MSELoss()
train_loader = DataLoader(train_data, batch_size=48)
test_loader = DataLoader(test_data, batch_size=6)


In [18]:
for t in range(1, 100):
    model.train()
    for i, (src, tgt) in enumerate(train_loader):
        src.unsqueeze_(-1)
        tgt.unsqueeze_(-1)
        src, tgt = src.to(device).float(), tgt.to(device).float()
        pred = model(src)
        optimizer.zero_grad()
        loss = criterion(pred, tgt)
        loss.backward()
        optimizer.step()
        
    model.eval()
    with torch.no_grad():
        for (src, tgt) in test_loader:
            src.unsqueeze_(-1)
            tgt.unsqueeze_(-1)
            src, tgt = src.to(device).float(), tgt.to(device).float()
            pred = model(src)
            test_loss = criterion(pred, tgt)
    if t % 10 == 0:
        print(f'Epoch: {t}, Train Loss: {loss}, Test Loss: {test_loss}')

Epoch: 10, Train Loss: 0.535723865032196, Test Loss: 0.6659237146377563
Epoch: 20, Train Loss: 0.3057956099510193, Test Loss: 0.525753378868103
Epoch: 30, Train Loss: 0.2300391048192978, Test Loss: 0.40821942687034607
Epoch: 40, Train Loss: 0.2193366140127182, Test Loss: 0.2777760922908783
Epoch: 50, Train Loss: 0.2065686583518982, Test Loss: 0.22226104140281677
Epoch: 60, Train Loss: 0.2040197104215622, Test Loss: 0.18528293073177338
Epoch: 70, Train Loss: 0.1952318698167801, Test Loss: 0.18756912648677826
Epoch: 80, Train Loss: 0.184870645403862, Test Loss: 0.17772535979747772
Epoch: 90, Train Loss: 0.17650218307971954, Test Loss: 0.17603114247322083
