In [1]:
#-----------------
# This is just for rendering on the website
import os
import sys
import glob
sys.path.append("..")
#-----------------

from IPPy import operators, utilities, metrics, models
from IPPy.nn import trainer, losses

import torch
from torch import nn

from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np

# Set device
device = utilities.get_device()

# Define model
model = models.UNet(ch_in=1, 
                    ch_out=1,
                    middle_ch=[64, 128, 256],
                    n_layers_per_block=2,
                    down_layers=("ResDownBlock", "ResDownBlock"),
                    up_layers=("ResUpBlock", "ResUpBlock"),
                    final_activation=None).to(device)

# Define dataset class
class MayoDataset(Dataset):
    def __init__(self, data_path, data_shape):
        super().__init__()

        self.data_path = data_path
        self.data_shape = data_shape

        # We expect data_path to be like "./data/Mayo/train" or "./data/Mayo/test"
        self.fname_list = glob.glob(f"{data_path}/*/*.png")

    def __len__(self):
        return len(self.fname_list)
    
    def __getitem__(self, idx):
        # Load the idx's image from fname_list
        img_path = self.fname_list[idx]

         # To load the image as grey-scale
        x = Image.open(img_path).convert("L")

        # Convert to numpy array -> (512, 512)
        x = np.array(x) 

        # Convert to pytorch tensor -> (1, 512, 512) <-> (c, n_x, n_y)
        x = torch.tensor(x).unsqueeze(0)

        # Resize to the required shape
        x = transforms.Resize(self.data_shape)(x) # (1, n_x, n_y)

        # Normalize in [0, 1] range
        x = (x - x.min()) / (x.max() - x.min())

        return x

# --- Load data
train_data = MayoDataset("../data/Mayo/train", data_shape=256)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)

# Define CTProjector operator
K = operators.CTProjector(
    img_shape=(256, 256),
    angles=np.linspace(0, np.pi, 60),
    det_size=512,
    geometry="parallel",
)

# --- Parameters
n_epochs = 0

loss_fn = losses.MixedLoss(
    (nn.MSELoss(), losses.SSIMLoss(), losses.FourierLoss()),
    (1, 0.1, 0.1),)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)

# Cycle over the epochs
for epoch in range(n_epochs):

    # Cycle over the batches with tqdm
    epoch_loss = 0.0
    ssim_loss = 0.0
    for t, x in enumerate(train_loader):
        # Send x and y to device
        x = x.to(device)

        with torch.no_grad():
            # Compute associated y_delta
            y = K(x)
            y_delta = y + utilities.gaussian_noise(y, noise_level=0.01)

            # --- PREPROCESSING
            x_FBP = K.FBP(y_delta)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        x_pred = model(x_FBP)
        loss = loss_fn(x_pred, x)
        loss.backward()
        optimizer.step()

        # update loss
        epoch_loss += loss.item()
        ssim_loss += metrics.SSIM(x_pred.cpu().detach(), x.cpu().detach())

        # Update tqdm bar
        print(
            {
                "Loss": f"{epoch_loss / (t + 1):.4f}",
                "SSIM": f"{ssim_loss / (t + 1):.4f}",
            }
        )

    # Save model every 5 epochs (overwrite)
    if (epoch + 1) % 5 == 0:
        # Save model state
        trainer.save(model, weights_path="../weights/CTUNet")



CUDA not available. CTProjector will use CPU.
Attempting to create ASTRA projector type: 'linear' for 'parallel' geometry...
Successfully created ASTRA projector type: 'linear'
CTProjector initialized. Geometry: parallel. Using GPU: False. FBP Algorithm: FBP
