In [None]:
path_root = '/home/oab18/Projects/'
path_data = '/home/oab18/Desktop/HCP Young Adult Database/' # "/home/dp4018/data/ultrasound-data/Ultrasound-MRI-sagittal/"

In [None]:
import os
import sys 

sys.path.append(path_root+'/AWLoss/awloss/')
sys.path.append(path_root+'/AWLoss/examples')
sys.path.append(path_root+'/AWLoss/')

%load_ext autoreload
import torch
from torchvision.utils import make_grid
from torchvision.transforms import Compose, Resize, Lambda, Normalize
from torch.utils.data import DataLoader, Subset
from monai.networks.nets import UNet
from sklearn.impute import SimpleImputer, KNNImputer
from torch.nn.functional import interpolate
from awloss import AWLoss

from train_utils import *

%autoreload 2
from networks import *
from datasets import MaskedUltrasoundDataset2D
from landscape import *


import matplotlib.pyplot as plt
import matplotlib.colors as clt
import progressbar
import random
import numpy as np


## CUDA Setup

In [None]:
# Set seed, clear cache and enable anomaly detection (for debugging)
set_seed(42)
torch.cuda.empty_cache()
torch.autograd.set_detect_anomaly(True)                     
device=set_device("cuda", 0)            

# MRI Dataset 

In [None]:
# path = os.path.abspath("/media/dekape/HDD/Ultrasound-MRI-sagittal/")
size = 320
path = os.path.abspath(path_data)
train_transform = Compose([
                    Resize(size),
                    Lambda(lambda x: x / x.abs().max()),
                    Lambda(lambda x: clip_outliers(x, "outer")),
                    Lambda(lambda x: scale2range(x, [0., 1.])),
                    ])

mask = create_mask((size,size), (0,3), (0,1))

ds = MaskedUltrasoundDataset2D(path, 
                                    mode="mri",
                                    transform=train_transform,
                                    mask=mask,
                                    maxsamples=None)
print(ds, "\n")
print(ds.__len__())

print(ds.info(nsamples=30))


## Data split

In [None]:
valid_ratio = 0.2
i = int(len(ds)*valid_ratio)

idxs = np.arange(0, len(ds), 1)
np.random.shuffle(idxs)

train_idxs, valid_idxs = idxs[:-i], idxs[-i:]
trainds, validds = Subset(ds, train_idxs), Subset(ds, valid_idxs)

print(len(trainds), len(validds))

In [None]:
data = []
for i in range(min(len(trainds), 10)):
    data += list((trainds[i][1].flatten().detach().cpu().numpy()))
plt.title("Train Dataset Value")
plt.hist(data)
plt.show()

# Model 

In [None]:
def make_model(nc=64):
    set_seed(42)
    channels = (16, 32, 64)
    model =  UNet(
    spatial_dims=2,
    in_channels=nc,
    out_channels=nc,
    channels=channels,
    strides=tuple([2 for i in range(len(channels))]), 
    num_res_units=3,
    act="mish")
    model = nn.DataParallel(model) 
    return model.to(device)

# Train Function

In [None]:
def train_model(model, optimizer, loss, train_loader, valid_loader=None, nepochs=150, 
                log_frequency=10, sample_input=None, sample_target=None, device="cpu", 
                exp_name="", save=True, scheduler=None):
                
    print("\n\nTraining started ...")
    try:
        all_train_losses, all_valid_losses = [], []
        with progressbar.ProgressBar(max_value=nepochs) as bar:    
            for epoch in range(nepochs):
                # Train and validate epoch
                train_loss = train(model, train_loader, optimizer, loss, scheduler, device)
                all_train_losses.append(train_loss.item())
                if valid_loader:
                    valid_loss = validate(model, valid_loader, loss, device)
                    all_valid_losses.append(valid_loss.item())
                
                bar.update(epoch)
                
                # Logging
                log = {"epoch": epoch, "train_loss": train_loss.item()}
                if valid_loader:
                    log.update({"valid_loss": valid_loss.item()})
                    
                if (epoch % log_frequency == 0 or epoch==nepochs-1):
                    print("\n", log)

                    if valid_loader:
                        model.eval()
                        X, target = next(iter(valid_loader))
                        X, target = X[:train_loader.batch_size], target[:train_loader.batch_size]
                        recon = torch.sigmoid(model(X))

                        fig, axs = plt.subplots(4, 1, figsize=(10*train_loader.batch_size, 15))
                        axs[0].imshow(make_grid(X, pad_value=0, padding=2, vmin=0, vmax=1).cpu().data[0], cmap='Greys_r')
                        axs[1].imshow(make_grid(recon, pad_value=0, padding=2, vmin=0, vmax=1).cpu().data[0], cmap='Greys_r')
                        axs[2].imshow(make_grid(target, pad_value=0, padding=2, vmin=0, vmax=1).cpu().data[0], cmap='Greys_r')
                        try:
                            loss(X, target)
                            v = loss.filters
                        except:
                            v = torch.zeros_like(X)
                        axs[3].imshow(make_grid(v, pad_value=0, padding=2, vmin=-0.1, vmax=0.1).cpu().data[0], cmap='seismic')
                        plt.show()

                    if sample_input is not None:
                        idx = int(sample_input.shape[0]/2)
                        samples = {"Input idx %g"%idx: sample_input[idx]}

                        # Model forward pass
                        model.eval()
                        X = sample_input.unsqueeze(0).to(device)
                        recon = torch.sigmoid(model(X))[0]
                        samples.update({"Reconstruction idx %g"%idx: recon[idx].cpu().detach().numpy()})

                        # If testing sample provided
                        if sample_target is not None:
                            samples.update({"Target idx %g"%idx: sample_target[idx]})

                            # Loss evaluation and filters
                            f = loss(recon.unsqueeze(0).to(device), sample_target.unsqueeze(0).to(device))
                            try:
                                v, T = loss.filters[0], loss.T
                            except:
                                try:
                                    loss_list = [str(l) for l in loss.losses]
                                    awloss = loss.losses[loss_list.index("AWLoss()")]
                                    v, T = awloss.filters[0], awloss.T
                                except:
                                    v, T = torch.tensor([0.]), torch.tensor([0.])
                            print(" argidx T, v: ",torch.argmin(torch.abs(T)).item(), torch.argmax(torch.abs(v)).item())
                
                    samples_fig = plot_samples(samples)
                    losses_fig = plot_losses(losses={"train": all_train_losses, "valid":all_valid_losses},
                                filters={"Weiner Filter": v.flatten().cpu().detach().numpy(), "Penalty": T.flatten().cpu().detach().numpy()})
        raise(KeyboardInterrupt)
    except KeyboardInterrupt:                                   
        if save:
            objs = { "mask": train_loader.dataset.dataset.mask,
                "train_loader":train_loader,
                "valid_loader":valid_loader,
                "x_sample": x_sample,
                "y_sample": y_sample,
                "recon": recon,

                "model": model,
                "optim": optimizer,
                "loss": loss,
                "train_losses": all_train_losses,
                "vald_losses": all_valid_losses,
                "penalty": T,
                }

            summary = { "data_mode": train_loader.dataset.dataset.mode,
                        "interpolation_model": "UNet",
                        "loss": str(loss),
                        "img_size": x_sample.numpy().shape,
                        "device":device,
                        "nepochs": nepochs,
                        "current_epoch":epoch, 
                        "learning_rate":optimizer.defaults["lr"],
                        "batch_size":train_loader.batch_size,
                        "ntrain": len(train_loader.dataset),}
            try:
                summary["nvalid"]= len(valid_loader.dataset)
            except:
                summary["nvalid"]= 0
            try:
                summary["aw_filter_dim"] = loss.filter_dim,
                summary["aw_epsilon"] =  loss.epsilon,
                summary[ "aw_std"] = loss.std
            except:
                summary["aw_filter_dim"] = None
                summary["aw_epsilon"] =  None
                summary[ "aw_std"] = None

            figs = {"losses":losses_fig, "samples":samples_fig}
            save_exp(objs=objs, figs=figs, summary=summary, overwrite=False)            
    return None
            

# Training Setup

In [None]:
# Set training static parameters and hyperparameters
nepochs=300                        
learning_rate=1e-2
batch_size=32                                        

# Dataloader
train_loader = DataLoader(trainds,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4)
valid_loader = DataLoader(validds,
                        batch_size=1000,
                        shuffle=False,
                        num_workers=4)


# Sample for visualisation
x_sample, y_sample = validds[0]
f = (nepochs//2) + 1

In [None]:
train_batch = next(iter(train_loader))
fig, axs = plt.subplots(2, 1, figsize=(4*batch_size,7))
axs[0].imshow(make_grid(train_batch[0], pad_value=0, padding=3).data[0], cmap='Greys_r')
axs[0].set_title("Masked Input")
axs[1].imshow(make_grid(train_batch[1], pad_value=0, padding=3).data[0], cmap='Greys_r')
axs[1].set_title("Target")
plt.show()

# Train AWLoss

In [None]:
def laplacian2D(mesh):
    alpha, beta = -0.2, 1.5
    xx, yy = mesh[:,:,0], mesh[:,:,1]
    x = torch.sqrt(xx**2 + yy**2) 
    T = 1 - torch.exp(-torch.abs(x) ** alpha) ** beta
    T = scale2range(T, [0.05, 1.])
    return T

In [None]:
awmodel = make_model(nc=x_sample.shape[0])
awoptim = torch.optim.Adam(awmodel.parameters(), lr=learning_rate)

awloss     = AWLoss(filter_dim=2, method="fft", reduction="mean", store_filters="unorm", 
                    epsilon=250., filter_scale=2, penalty_function=laplacian2D)


train_model(awmodel, awoptim, awloss, train_loader, valid_loader=valid_loader, nepochs=nepochs, log_frequency=100, 
            sample_input=x_sample, sample_target=y_sample, device=device, save=True)


