In [8]:
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
from tqdm.notebook import tqdm
import math
import numpy as np
from skimage.metrics import structural_similarity as ssim
from torch.utils.data import DataLoader
import time
import wandb
import os
import json
from torch.utils.data import Subset, ConcatDataset, DataLoader
from torch.utils.data import Dataset
import pickle
import cv2


ERROR! Session/line number was not unique in database. History logging moved to new session 314


In [9]:
class MiniBatchStd(torch.nn.Module):
    def __init__(self):
        super(MiniBatchStd, self).__init__()

    def forward(self, x):
        std = torch.std(x, dim=1)
        mu = std.mean()
        rep = mu.repeat(x.shape[0], 1, x.shape[2], x.shape[3])

        # minibatch_std = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])

        return torch.cat([x, rep], dim=1)


class PixelWiseNormalization(torch.nn.Module):
    def __init__(self):
        super(PixelWiseNormalization, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)


class WeightedConv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super(WeightedConv2d, self).__init__()

        self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding)
        self.scale = (2 / (kernel_size**2 * in_channels))**.5
        self.bias = self.conv.bias
        self.conv.bias = None

        torch.nn.init.normal_(self.conv.weight, mean=0, std=1)
        torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)


class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, apply_pixelnorm=False):
        super(ConvBlock, self).__init__()
        self.apply_pixelnorm = apply_pixelnorm
        self.conv1 = WeightedConv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1)
        self.leaky_relu = torch.nn.LeakyReLU(0.2)
        self.conv2 = WeightedConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
        self.pixelnorm = PixelWiseNormalization()

    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky_relu(x)
        x = self.pixelnorm(x) if self.apply_pixelnorm else x

        x = self.conv2(x)
        x = self.leaky_relu(x)
        x = self.pixelnorm(x) if self.apply_pixelnorm else x

        return x

In [10]:
def normalized_mean_squared_error(fake_batch, real_batch):
    nom = torch.linalg.matrix_norm(fake_batch - real_batch[:, None, :, :]) ** 2
    denom = torch.linalg.matrix_norm(real_batch) ** 2
    o = nom.squeeze()/denom
    return o


class Generator(torch.nn.Module):
    def __init__(self, layers):
        super(Generator, self).__init__()
        self.togray_layers = torch.nn.ModuleList([
            WeightedConv2d(in_channels=layers[0], out_channels=1, kernel_size=1),
            WeightedConv2d(in_channels=layers[1], out_channels=1, kernel_size=1),
            WeightedConv2d(in_channels=layers[2], out_channels=1, kernel_size=1),
            WeightedConv2d(in_channels=layers[3], out_channels=1, kernel_size=1),
            WeightedConv2d(in_channels=layers[4], out_channels=1, kernel_size=1),
            WeightedConv2d(in_channels=layers[5], out_channels=1, kernel_size=1),
        ])

        self.layers = torch.nn.ModuleList([
            torch.nn.Sequential(*[
                PixelWiseNormalization(),
                torch.nn.ConvTranspose2d(in_channels=layers[0], out_channels=layers[0], kernel_size=4, stride=1, padding=0),
                torch.nn.LeakyReLU(0.2),
                WeightedConv2d(in_channels=layers[0], out_channels=layers[0], kernel_size=3, padding=1),
                torch.nn.LeakyReLU(0.2),
                PixelWiseNormalization()
            ]),
            ConvBlock(layers[0], layers[1], apply_pixelnorm=True),
            ConvBlock(layers[1], layers[2], apply_pixelnorm=True),
            ConvBlock(layers[2], layers[3], apply_pixelnorm=True),
            ConvBlock(layers[3], layers[4], apply_pixelnorm=True),
            ConvBlock(layers[4], layers[5], apply_pixelnorm=True)
        ])

    def forward(self, x, surrogate_window, step, alpha):
        x = torch.concatenate([surrogate_window[:, :, None, None], x], dim=1)

        for i in range(step + 1):
            # Don't upsample the first layer
            x_upscaled = F.interpolate(x, scale_factor=2, mode='nearest') if i != 0 else x

            x = self.layers[i](x_upscaled)

        final_out = self.togray_layers[step](x)

        # Fade-in except on step 0
        if step != 0:
            final_upscaled = self.togray_layers[step - 1](x_upscaled)
            o = final_out * alpha + final_upscaled * (1 - alpha)
            return torch.tanh(o)

        return torch.tanh(final_out)


class Discriminator(torch.nn.Module):
    def __init__(self, layers):
        super(Discriminator, self).__init__()

        self.fromgray_layers = torch.nn.ModuleList([
            WeightedConv2d(in_channels=1, out_channels=layers[0], kernel_size=1),
            WeightedConv2d(in_channels=1, out_channels=layers[1], kernel_size=1),
            WeightedConv2d(in_channels=1, out_channels=layers[2], kernel_size=1),
            WeightedConv2d(in_channels=1, out_channels=layers[3], kernel_size=1),
            WeightedConv2d(in_channels=1, out_channels=layers[4], kernel_size=1),
            WeightedConv2d(in_channels=1, out_channels=layers[5], kernel_size=1)
        ])
        self.act = torch.nn.LeakyReLU(0.2)

        self.layers = torch.nn.ModuleList([
            ConvBlock(layers[0], layers[1]),
            ConvBlock(layers[1], layers[2]),
            ConvBlock(layers[2], layers[3]),
            ConvBlock(layers[3], layers[4]),
            ConvBlock(layers[4], layers[5]),
            torch.nn.Sequential(*[
                MiniBatchStd(),
                WeightedConv2d(in_channels=layers[5]+1, out_channels=layers[5], kernel_size=3, padding=1),
                torch.nn.LeakyReLU(0.2),
                WeightedConv2d(in_channels=layers[5], out_channels=layers[5], kernel_size=4, padding=0),
                torch.nn.LeakyReLU(0.2),
                torch.nn.Flatten(),
                torch.nn.Linear(layers[5], out_features=1)
            ])
        ])

    def forward(self, input, step, alpha):
        x = self.fromgray_layers[len(self.fromgray_layers) - step - 1](input)
        x = self.act(x)

        for i in range(len(self.layers) - step - 1, len(self.layers)):
            x = self.layers[i](x)

            # Don't pool for the last layer
            x = F.avg_pool2d(x, kernel_size=2) if i != len(self.layers) - 1 else x

            # Fade-in in first layer except for step 0
            if i == len(self.layers) - step - 1 and step != 0:
                x_hat = F.avg_pool2d(input, kernel_size=2)
                x_hat = self.fromgray_layers[len(self.fromgray_layers) - step](x_hat)
                x = x_hat * (1 - alpha) + x * alpha

        return x


class ConditionalProGAN(torch.nn.Module):
    def __init__(self,
                 device,
                 desired_resolution,
                 G_lr,
                 D_lr,
                 n_critic,
                 n_epochs,
                 D_layers,
                 G_layers):
        super(ConditionalProGAN, self).__init__()
        self.desired_resolution = desired_resolution
        self.total_steps = 1 + math.log2(desired_resolution / 4)
        if n_epochs % self.total_steps != 0:
            raise Exception("Total number of epochs should be divisible by the total number of steps")
        
        self.n_epochs = n_epochs
        self.noise_vector_length = G_layers[0] - 17
        self.device = device
        self.D = Discriminator(D_layers).to(device)
        self.G = Generator(G_layers).to(device)
        self.n_critic = n_critic
        self.G_optimizer = torch.optim.Adam(self.G.parameters(), betas=(0, 0.99), lr=G_lr, eps=1e-8)
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), betas=(0, 0.99), lr=D_lr, eps=1e-8)

        self.G_scheduler = lr_scheduler.LinearLR(self.G_optimizer, start_factor=1, end_factor=.01, total_iters=n_epochs)
        self.D_scheduler = lr_scheduler.LinearLR(self.D_optimizer, start_factor=1, end_factor=.01, total_iters=n_epochs)

    def train_single_epoch(self, dataloader, current_epoch, gp_lambda, step, alpha, epochs_in_curr_step, dataset_length):
        self.D.train()
        self.G.train()

        running_D_loss, running_G_loss = 0, 0
        for data in tqdm(dataloader, desc=f"Epoch {current_epoch + 1}, step {step}, alpha {round(alpha, 2)}: ", total=len(dataloader)):
            mri_batch = data["mr"].to(self.device)
            coil_batch = data["coil"].to(self.device)

            noise_batch = torch.randn(mri_batch.shape[0], self.noise_vector_length, 1, 1, device=self.device)
            fake = self.G(noise_batch, coil_batch, step, alpha)
            real_input = torch.nn.functional.adaptive_avg_pool2d(mri_batch, (4 * 2 ** step, 4 * 2 ** step))
            d_fake = self.D(fake.detach(), step, alpha)
            d_real = self.D(real_input[:, None], step, alpha)

            gp = self.compute_gradient_penalty(real_input[:, None], fake, step, alpha)
            d_loss = (
                    -(torch.mean(d_real) - torch.mean(d_fake))
                    + gp_lambda * gp
                    + (0.001 * torch.mean(d_real ** 2))
            )

            self.D_optimizer.zero_grad()
            d_loss.backward()
            self.D_optimizer.step()

            g_fake = self.D(fake, step, alpha)
            g_loss = -torch.mean(g_fake)

            self.G_optimizer.zero_grad()
            g_loss.backward()
            self.G_optimizer.step()
  
            alpha += mri_batch.shape[0] / ((epochs_in_curr_step*.5) * dataset_length)
            alpha = min(alpha, 1.0)

            running_D_loss += d_loss.cpu().item()
            running_G_loss += g_loss.cpu().item()

        self.G_scheduler.step()
        self.D_scheduler.step()

        return running_D_loss, running_G_loss, alpha

    def evaluate(self, dataloader, step, alpha):
        self.D.eval()
        self.G.eval()

        fake_to_return = []
        real_to_return = []
        all_nmse = []
        all_ssim = []
        for data in dataloader:
            mr_batch = data["mr"].to(self.device)
            coil_batch = data["coil"].to(self.device)

            noise_batch = torch.randn(mr_batch.shape[0], self.noise_vector_length, 1, 1, device=self.device)
            fake = self.G(noise_batch, coil_batch, step, alpha)
            fake_upscaled = F.interpolate(fake, scale_factor=2**(self.total_steps - step - 1), mode='nearest')
            nmse = normalized_mean_squared_error(fake_upscaled, mr_batch).detach().cpu().numpy().flatten()
            
            for i in range(fake_upscaled.shape[0]):
                img1 = fake_upscaled[i].detach().cpu().numpy().squeeze()
                img2 = mr_batch[i, None, :, :].detach().cpu().numpy().squeeze()
                curr_ssim = ssim(img1, img2, win_size=11, data_range=2)
                all_ssim.append(curr_ssim)
            all_nmse.extend(nmse)
            fake_to_return.extend(fake_upscaled.detach().cpu().numpy())
            real_to_return.extend(mr_batch.detach().cpu().numpy())

        return np.array(fake_to_return), np.array(real_to_return), np.array(all_nmse), np.array(all_ssim)

    def compute_gradient_penalty(self, real, fake, step, alpha):
        epsilon = torch.rand((real.shape[0], 1, 1, 1), device=self.device)
        x_hat = (epsilon * real + (1-epsilon)*fake.detach()).requires_grad_(True)

        score = self.D(x_hat, step, alpha)
        gradient = torch.autograd.grad(
            inputs=x_hat,
            outputs=score,
            grad_outputs=torch.ones_like(score),
            create_graph=True,
            retain_graph=True
        )[0]
        gradient = gradient.view(gradient.shape[0], -1)
        gradient_norm = gradient.norm(2, dim=1)
        gradient_penalty = torch.mean((gradient_norm-1)**2)
        return gradient_penalty

    @staticmethod
    def _get_alpha_linear(curr_epoch, epochs_per_step, quickness):
        alpha = quickness * (curr_epoch % epochs_per_step) / epochs_per_step

        return min(alpha, 1)

    @staticmethod
    def _get_step_linear(n_epochs, total_steps, curr_epoch):
        epochs_per_step = n_epochs // total_steps
        step = int(curr_epoch / (epochs_per_step))

        return min(step, int(total_steps - 1))

    @staticmethod
    def _get_step_root(n_epochs, n_steps, curr_epoch):
        return math.floor(n_steps / math.sqrt(n_epochs) * math.sqrt(curr_epoch))

    def _get_milestones(self, n_epochs, n_steps):
        milestones = []
        for i in range(1, int(n_steps)):
            milestone = math.ceil(n_epochs * i ** 2 / n_steps ** 2)
            milestones.append(milestone)

        return milestones

    @staticmethod
    def _get_alpha_root(n_epochs, n_steps, curr_epoch, curr_step, quickness=2):
        # curr_step = _get_step_root(n_epochs, n_steps, curr_epoch)

        start_step = math.ceil(n_epochs * curr_step ** 2 / n_steps ** 2)
        end_step = math.ceil(n_epochs * (curr_step + 1) ** 2 / n_steps ** 2)

        dx = (end_step - start_step) / quickness
        dy = 1

        alpha = dy / dx * (curr_epoch - start_step)

        return min(alpha, 1)

In [11]:
class DatasetSplitter:
    def __init__(self, dataset, train_fraction, val_fraction, test_fraction):
        self.dataset = dataset
        self.splits = self.dataset.splits
        self.patterns = list(self.splits.keys())

        self.train_subsets = {}
        self.val_subsets = {}
        self.test_subsets = {}
        for p in self.patterns:
            start_idx = self.splits[p]["start"]
            end_idx = self.splits[p]["end"]
            idxs = torch.arange(start_idx, end_idx + 1)

            # No Breath holds in training data because there is so little training data
            if not p.endswith("BH"):
                train_split = int(len(idxs) * train_fraction)
                train_idxs = idxs[:train_split]

                val_split = int(len(idxs) * (train_fraction + val_fraction))
                val_idxs = idxs[train_split:val_split]

                test_idxs = idxs[val_split:]

                self.train_subsets[p] = Subset(dataset, train_idxs)
                self.val_subsets[p] = Subset(dataset, val_idxs)
                self.test_subsets[p] = Subset(dataset, test_idxs)
            else:
                val_split = int(len(idxs) * val_fraction/(val_fraction + test_fraction))
                val_idxs = idxs[:val_split]

                test_idxs = idxs[val_split:]
                self.val_subsets[p] = Subset(dataset, val_idxs)
                self.test_subsets[p] = Subset(dataset, test_idxs)

        self.concatenated_train = ConcatDataset(self.train_subsets.values())


    def get_train_dataset(self):
        return self.concatenated_train

In [12]:
class CustomDataset(Dataset):
    def __init__(self, root_path, patient, coil_normalizer=(0, 1), heat_normalizer=(0, 1), us_normalizer=(0, 1)):
        with open(os.path.join(root_path, patient, "settings.json")) as file:
            self.settings = json.load(file)
            self.TR = self.settings["MRI"]["TR"]
            mri_freq = 1/self.TR
            surrogate_freq = 50
            self.signals_between_mrs = int(surrogate_freq//mri_freq)

        with open(os.path.join(root_path, patient, "mr.pickle"), 'rb') as file:
            self.mr = pickle.load(file)["images"]
            self.mr = np.clip(self.mr, a_min=0, a_max=255).astype(np.uint8)
            self.mr = cv2.addWeighted(self.mr, 1.7, np.zeros(self.mr.shape, self.mr.dtype), 0, 0)
            self.mr = torch.from_numpy(self.mr).float()
            self.mr = self.mr * 2 / 255 - 1
            self.mr = self.mr[:, :128, 32:-32]

        with open(os.path.join(root_path, patient, "surrogates.pickle"), 'rb') as file:
            surrogates = pickle.load(file)
            self.us = np.float32(surrogates["us"])

            self.heat = torch.tensor(np.float32(surrogates["heat"]))
            self.heat = (self.heat - heat_normalizer[0]) / heat_normalizer[1]

            self.coil = torch.tensor(np.float32(surrogates["coil"]))
            self.coil = (self.coil - coil_normalizer[0]) / coil_normalizer[1]

        with open(os.path.join(root_path, patient, "us_wave_detrended.pickle"), "rb") as file:
            self.us_wave = pickle.load(file)
            self.us_wave = torch.tensor(np.float32(self.us_wave))
            self.us_wave = (self.us_wave - us_normalizer[0]) / us_normalizer[1]

        with open(os.path.join(root_path, patient, "mr2us_new.pickle"), 'rb') as file:
            self.mr2us = pickle.load(file)["mr2us"]

        with open(os.path.join(root_path, patient, "mr_wave.pickle"), 'rb') as file:
            self.mr_wave = torch.Tensor(pickle.load(file)["mri_waveform"])

        with open(os.path.join(root_path, patient, "splits.pickle"), 'rb') as file:
            self.splits = pickle.load(file)

    def set_normalizers(self, heat, us, coil):
        self.heat_normalizer = heat
        self.us_normalizer = us
        self.coil_normalizer = coil

    def visualize(self):
        for img in self.mr:
            cv2.imshow("Frame", (img.numpy() + 1)/2)
            cv2.waitKey(30)

    def __getitem__(self, idx):
        mr = self.mr[idx]
        mr2us = self.mr2us[idx]
        heat = self.heat[mr2us-self.signals_between_mrs+1:mr2us+1]
        coil = self.coil[mr2us-self.signals_between_mrs+1:mr2us+1]
        us_wave = self.us_wave[mr2us-self.signals_between_mrs+1:mr2us+1]
        mr_wave = self.mr_wave[idx]

        return {"mr": mr, "heat": heat, "mr_wave": mr_wave, "us_wave": us_wave, "coil": coil}

    def __len__(self):
        return self.mr.shape[0]

<h1> Training </h1>

In [13]:
def get_mean_std(train_dataset):
    loader = DataLoader(train_dataset, batch_size=len(train_dataset))
    data = next(iter(loader))
    
    heat = data["heat"].mean(), data["heat"].std()
    coil = data["coil"].mean(), data["coil"].std()
    us = data["us_wave"].mean(), data["us_wave"].std()

    return heat, coil, us

In [14]:
def eval(cProGAN, val_patterns, val_loaders, step, alpha):
    vids = []
    ssim_res, nmse_res = {}, {}
    n = 0
    mean_ssim, mean_nmse = 0, 0
    for pattern, val_loader in zip(val_patterns, val_loaders):
        fake_imgs, real_imgs, nmse, ssim = cProGAN.evaluate(val_loader, step, alpha)
        ssim_res[pattern] = ssim.mean()
        nmse_res[pattern] = nmse.mean()
        n += len(ssim)
        mean_ssim += ssim.mean() * len(ssim)
        mean_nmse += nmse.mean() * len(nmse)
        vid = np.concatenate([fake_imgs, real_imgs[:, None, :, :]], axis=3)
        vid = np.uint8((vid + 1) / 2 * 255).repeat(3, axis=1)
        vids.append(wandb.Video(vid, fps=5, caption=pattern))

    ssim_res["All"] = mean_ssim / n
    nmse_res["All"] = mean_nmse / n

    return vids, ssim_res, nmse_res


In [15]:
def train(subject):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    config = dict(
        batch_size=6,
        n_epochs=60,
        desired_resolution=128,
        G_learning_rate=0.001,
        D_learning_rate=0.001,
        GP_lambda=10,
        n_critic=1,
        patient=subject,
        surrogates="All",
        D_layers=[8, 16, 32, 64, 128, 256],
        G_layers=[256, 128, 64, 32, 16, 8]
    )

    data_root = os.path.join("F:", os.sep, "Formatted_datasets")
    dataset = CustomDataset(data_root, config["patient"])
    splitter = DatasetSplitter(dataset, .8, .1, .1)
    train_dataset = splitter.get_train_dataset()
    heat_normalizer, coil_normalizer, us_normalizer = get_mean_std(train_dataset)

    dataset = CustomDataset(data_root, config["patient"], coil_normalizer, heat_normalizer, us_normalizer)
    splitter = DatasetSplitter(dataset, .8, .1, .1)
    train_dataset = splitter.get_train_dataset()

    val_patterns = ["Regular Breathing", "Shallow Breathing", "Deep Breathing", "Deep BH", "Half Exhale BH", "Full Exhale BH"]
    val_loaders = []
    for pattern in val_patterns:
        loader = DataLoader(splitter.val_subsets[pattern], batch_size=10, shuffle=False, pin_memory=True)
        val_loaders.append(loader)

    cProGAN = ConditionalProGAN(
        device=device,
        desired_resolution=config["desired_resolution"],
        G_lr=config["G_learning_rate"],
        D_lr=config["D_learning_rate"],
        n_critic=config["n_critic"],
        n_epochs=config["n_epochs"],
        D_layers=config["D_layers"],
        G_layers=config["G_layers"],
    )
    
    prog_epochs = [0, 0, 0, 10, 20, 30]
    batch_sizes = [0, 0, 0, 8, 8, 4]
    top_ssim = 0
    best_epoch = 0
    for step, n_epochs in enumerate(prog_epochs):
        alpha = 0
        train_dataloader = DataLoader(train_dataset, batch_size=batch_sizes[step], shuffle=True, pin_memory=True) if n_epochs != 0 else None

        for i in range(n_epochs):
            start_time = time.time()
            D_loss, G_loss, alpha = cProGAN.train_single_epoch(train_dataloader, sum(prog_epochs[:step])+i, config["GP_lambda"], step, alpha, n_epochs, len(train_dataset))
            end_time = time.time()
            vids, ssim, nmse = eval(cProGAN, val_patterns, val_loaders, step, alpha)

            if step == cProGAN.total_steps - 1 and alpha == 1 and ssim["All"] > top_ssim:
                top_ssim = ssim["All"]
                torch.save(cProGAN.G.state_dict(), f"C:\\dev\\depth-tests\\GAN\\best_models\\{run.name}.pth")

In [16]:

subjects = ["D1", "D2", "D3", "E1", "E2", "E3", "F1", "F3", "F4", "G2", "G3", "G4"]
for s in subjects:
    train(s)

KeyboardInterrupt: 