In [None]:
import sys, importlib
# "../" to go back one director
sys.path.append('../')
import os
import time
import numpy as np
from Modules.Utils.Imports import *
from Modules.Utils.Gradient import Gradient
from Modules.Utils.ModelWrapper import ModelWrapper
from torch.autograd import Variable
import Scripts.FKPP_functions as FKPP

In [None]:
# Set CUDA
device = torch.device(GetLowestGPU(pick_from=[0]))
importlib.reload(FKPP)

# 
def numpy_to_tensor(ndarray):
    arr = torch.tensor(ndarray, dtype=torch.float)
    arr.requires_grad_(True)
    arr = arr.to(device)
    return arr

def to_torch(x):
    return torch.from_numpy(x).float().to(device)
def to_numpy(x):
    return x.detach().cpu().numpy()

In [None]:
# biologically informed neural network
class BINN(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        # surface fitter
        self.surface_fitter = FKPP.SurfaceFitter(K=1.0)
        
        # equation variables
        self.diffusion = FKPP.ScalarDiffusion()
        self.growth = FKPP.LogisticGrowth(K=1.0)
        
        # parameter extrema
        self.D_min = self.diffusion.min
        self.D_max = self.diffusion.max
        self.r_min = self.growth.min
        self.r_max = self.growth.max
        
        # loss weights
        self.IC_weight = 1e0
        self.surface_weight = 1e0
        self.pde_weight = 1e0
        self.pde_IC_weight = 0.1
        self.D_weight = 1e4
        self.r_weight = 1e4
        
        # number of samples for pde loss
        self.num_samples = 1000
        
        # input meshgrid
        x = np.linspace(x_min, x_max, 100)
        t = np.linspace(t_min, t_max, 100)
        X, T = np.meshgrid(x, t, indexing='ij')
        self.inputs_mesh = np.concatenate([X.reshape(-1, 1), 
                                           T.reshape(-1, 1)], axis=1)
        self.inputs_mesh = numpy_to_tensor(self.inputs_mesh)
    
    def forward(self, inputs):
        
        # cache input batch on forward pass
        self.inputs = inputs
        
        return self.surface_fitter(self.inputs)
    
    def surface_loss(self, pred, true):
        
        residual = (pred - true)**2
        
        # add weight to initial condition
        residual = residual*torch.where(self.inputs[:, 1][:, None]==0,
                                self.IC_weight*torch.ones_like(pred), 
                                torch.ones_like(pred))
        
        residual = residual*pred.abs().clamp(min=1e-4)**(-gamma)
        
        return torch.mean(residual)
    
    def pde_loss(self, inputs, outputs, return_mean=True):
        
        # unpack inputs
        x = inputs[:, 0][:,None]
        t = inputs[:, 1][:,None]
        
        # partial derivative computations 
        u = outputs.clone()
        d1 = Gradient(u, inputs, order=1)
        ux = d1[:, 0][:, None]
        ut = d1[:, 1][:, None]
        
        # diffusion
        if self.diffusion.inputs == 1:
            D = self.diffusion(u)
        else:
            D = self.diffusion(u, t)
            
        # growth
        if self.growth.inputs == 1:
            G = self.growth(u)
        else:
            G = self.growth(u, t)
        
        # Fisher-KPP equation
        LHS = ut
        RHS = Gradient(D*ux, inputs)[:, 0][:,None] + G*u
        pde_loss = (LHS-RHS)**2
        
        return pde_loss
    
    def parameter_loss(self, inputs, outputs):
        
        # unpack inputs
        x = inputs[:, 0][:,None]
        t = inputs[:, 1][:,None]
        
        # partial derivative computations 
        u = outputs
        d1 = Gradient(u, inputs, order=1)
        ux = d1[:, 0][:, None]
        ut = d1[:, 1][:, None]
        
        # diffusion
        if self.diffusion.inputs == 1:
            D = self.diffusion(u)
        else:
            D = self.diffusion(u, t)
            
        # growth
        if self.growth.inputs == 1:
            G = self.growth(u)
        else:
            G = self.growth(u, t)
        
        # constraints on learned parameters
        try:
            r = self.growth.r
        except:
            r = G
        self.D_loss = 0
        self.r_loss = 0
        self.s_loss = 0
        self.D_loss = self.D_loss + self.D_weight*torch.where(
            D < self.D_min, (D-self.D_min)**2, torch.zeros_like(D))
        self.D_loss = self.D_loss + self.D_weight*torch.where(
            D > self.D_max, (D-self.D_max)**2, torch.zeros_like(D))
        self.r_loss = self.r_loss + self.r_weight*torch.where(
            r < self.r_min, (r-self.r_min)**2, torch.zeros_like(r))
        self.r_loss = self.r_loss + self.r_weight*torch.where(
            r > self.r_max, (r-self.r_max)**2, torch.zeros_like(r))
        
        return torch.sum(self.D_loss + self.r_loss)
    
    def parameter_loss_PCGrad(self, pred, true):
        
        self.parameter_loss_val = 0
        
        # load cached inputs from forward pass
        inputs = self.inputs
        
        # load cached inputs from forward pass
        inputs = self.inputs
        ######################################################################
        # randomly sample from input domain
        x = torch.rand(self.num_samples, 1, requires_grad=True) 
        x = x*(x_max - x_min) + x_min
        t = torch.rand(self.num_samples, 1, requires_grad=True)
        t = t*(t_max - t_min) + t_min
        inputs_rand = torch.cat([x, t], dim=1).float().to(device)
        ######################################################################
        ######################################################################
        # predict surface fitter at sampled points
        outputs_rand = self.surface_fitter(inputs_rand)
        
        # compute PDE loss at sampled locations
        self.parameter_loss_val = self.pde_weight*self.parameter_loss(inputs_rand, outputs_rand)
        
        return torch.sum(self.parameter_loss_val)
        
    def pde_loss_PCGrad(self, pred, true):
        
        self.pde_loss_val = 0
        
        # load cached inputs from forward pass
        inputs = self.inputs
        
        # load cached inputs from forward pass
        inputs = self.inputs
        ######################################################################
        # randomly sample from input domain
        x = torch.rand(self.num_samples, 1, requires_grad=True) 
        x = x*(x_max - x_min) + x_min
        t = torch.rand(self.num_samples, 1, requires_grad=True)
        t = t*(t_max - t_min) + t_min
        inputs_rand = torch.cat([x, t], dim=1).float().to(device)
        ######################################################################
        ######################################################################
        # predict surface fitter at sampled points
        outputs_rand = self.surface_fitter(inputs_rand)
        
        # compute PDE loss at sampled locations
        self.pde_loss_val = self.pde_weight*self.pde_loss(inputs_rand, outputs_rand)
        
        return torch.sum(self.pde_loss_val)
        
    def pde_IC_loss_PCGrad(self, pred, true):
    
        self.IC_loss_val = 0
        
        # load cached inputs from forward pass
        inputs = self.inputs

        
        ######################################################################
        # randomly sample from input domain for x at initial condition
        x_IC = torch.rand(self.num_samples, 1, requires_grad=True) 
        x_IC = x_IC*(x_max - x_min) + x_min
        # compute PDE loss at IC data
        t_IC = t_min*torch.ones(self.num_samples,1)
        inputs_rand_IC = torch.cat([x_IC, t_IC], dim=1).float().to(device)
        outputs_rand_IC = self.surface_fitter(inputs_rand_IC)
        self.pde_IC_loss_val = self.pde_IC_weight*self.pde_loss(inputs_rand_IC, outputs_rand_IC)
        ######################################################################
        
        return torch.sum(self.pde_IC_loss_val)
        
    def loss(self, pred, true):
        
        self.surface_loss_val = 0
        self.pde_loss_val = 0
        
        # load cached inputs from forward pass
        inputs = self.inputs
        ######################################################################
        # randomly sample from input domain
        x = torch.rand(self.num_samples, 1, requires_grad=True) 
        x = x*(x_max - x_min) + x_min
        t = torch.rand(self.num_samples, 1, requires_grad=True)
        t = t*(t_max - t_min) + t_min
        inputs_rand = torch.cat([x, t], dim=1).float().to(device)
        ######################################################################
        ######################################################################
        # predict surface fitter at sampled points
        outputs_rand = self.surface_fitter(inputs_rand)
        
        # compute surface loss
        self.surface_loss_val = self.surface_weight*self.surface_loss(pred, true)
        
        ######################################################################
        # randomly sample from input domain for x at initial condition
        x_IC = torch.rand(self.num_samples, 1, requires_grad=True) 
        x_IC = x_IC*(x_max - x_min) + x_min
        # compute PDE loss at IC data
        t_IC = t_min*torch.ones(self.num_samples,1)
        inputs_rand_IC = torch.cat([x_IC, t_IC], dim=1).float().to(device)
        outputs_rand_IC = self.surface_fitter(inputs_rand_IC)
        self.pde_loss_val = self.pde_IC_weight*self.pde_loss(inputs_rand_IC, outputs_rand_IC)
        ######################################################################
        ######################################################################
    
        # compute PDE loss at sampled locations
        self.pde_loss_val = self.pde_loss_val + self.pde_weight*self.pde_loss(inputs_rand, outputs_rand)
        
        return torch.sum(self.surface_loss_val) + torch.sum(self.pde_loss_val)

In [None]:
cur_dir = os.getcwd()
path = cur_dir+'/Data'
epochs = int(1e6)
batch_size = 10
rel_save_thresh = 0.01
early_stopping = 2000
use_PCGrad = False

start_patient = 5
N_patients = 5
star_run = 8 #R14 + 1
num_run = 14

file_names = ['']*N_patients


for i_patient in range(start_patient-1,N_patients):
    if i_patient < 9:
        file_names[i_patient] = 'patient'+'0'+str(i_patient+1)+'_N10.npy'
    else:
        file_names[i_patient] = 'patient'+str(i_patient+1)+'_N10.npy'

for i_run in range(star_run-1,num_run):
    for i_patient in range(start_patient-1,N_patients):

        # load data
        file_name = file_names[i_patient]
        print(file_name)
        data = np.load(path +'/'+ file_name, allow_pickle=True).item()

        x = data['x'].copy()
        t = data['t'].copy()
        U = data['U'].copy().T
        gamma = data['gamma']
        shape = U.shape
        D = data['D']
        r = data['r']
        K = 1
        metast_index = data['metast_index']

        # compute extrema
        x_min, x_max = np.min(x), np.max(x)
        t_min, t_max = np.min(t), np.max(t)
        u_min, u_max = np.min(U), np.max(U)

        # convert to 2D
        X, T = np.meshgrid(x, t, indexing='ij')

        # prepare for surface fit
        inputs = np.concatenate([X.reshape(-1)[:, None],
                                 T.reshape(-1)[:, None]], axis=1)
        outputs = U.reshape(-1)[:, None]


        # split into train/val
        N = len(outputs)
        split = int(0.8*N)

        # Shuffle the list 1 to N
        p = np.random.permutation(N)
        x_train = inputs[p[:split]]
        y_train = outputs[p[:split]]
        x_val = inputs[p[split:]]
        y_val = outputs[p[split:]]

        # convert to pytorch
        x_train = numpy_to_tensor(x_train)
        y_train = numpy_to_tensor(y_train)
        x_val = numpy_to_tensor(x_val)
        y_val = numpy_to_tensor(y_val)
        inputs = numpy_to_tensor(inputs)
        outputs = numpy_to_tensor(outputs)

        # Initialize BINN
        binn = BINN()
        binn.to(device)

        weights_dir = cur_dir+'/Weights'
        if not os.path.exists(weights_dir):
                os.makedirs(weights_dir)
        
        # compile 
        parameters = binn.parameters()
        opt = torch.optim.Adam(parameters, lr=1e-3)
        model = ModelWrapper(
            model=binn,
            optimizer=opt,
            loss=binn.loss,
            regularizer=None,
            save_name=weights_dir+'/'+file_name[:-4]+'_R'+str(i_run))

        t0 = time.time()
        # train jointly
        model.fit(
            x=x_train,
            y=y_train,
            batch_size=batch_size,
            epochs=epochs,
            callbacks=None,
            verbose=1,
            validation_data=[x_val, y_val],
            early_stopping=early_stopping,
            rel_save_thresh=rel_save_thresh)
        training_time = time.time() - t0
        #################################################################
        #
        # Plotting
        #
        #################################################################
        plotting_path = cur_dir+'/Plots/'+file_name[:-4]+'/R'+str(i_run)
        if not os.path.exists(plotting_path):
                os.makedirs(plotting_path)

        u_pred = model.predict(inputs.to(device)).cpu().detach().numpy().reshape(-1)

        # evaluate surface fitter on mesh
        x_mesh = np.linspace(x_min, x_max, 100)
        t_mesh = np.linspace(t_min, t_max, 100)
        x_mesh, t_mesh = np.meshgrid(x_mesh, t_mesh, indexing='ij')
        x_mesh, t_mesh = x_mesh.reshape(-1, 1), t_mesh.reshape(-1, 1)
        inputs_mesh = numpy_to_tensor(np.concatenate([x_mesh, t_mesh], 1))
        u_mesh = model.predict(inputs_mesh).cpu().detach().numpy()
        x_mesh = x_mesh.reshape(100,100)
        t_mesh = t_mesh.reshape(100,100)
        u_mesh = u_mesh.reshape(100,100)

        # plot
        fig = plt.figure(figsize=(12,7))
        ax = fig.add_subplot(1, 1, 1, projection='3d')
        ax.plot_surface(x_mesh, t_mesh, u_mesh, alpha=0.8, cmap=cm.coolwarm)
        ax.scatter(X.reshape(-1), T.reshape(-1), U.reshape(-1), color='k', s=5)
        ax.set_xlabel('X')
        ax.set_ylabel('T')
        ax.set_zlabel('U')
        plt.savefig(plotting_path+'/learned_surface.png')
        plt.close(fig)

        # plot in time
        prop_cycle = plt.rcParams['axes.prop_cycle']
        colors = prop_cycle.by_key()['color']
        markers = ['x', 'o', 's', 'd', '^', '1',  'P', 'd', '+', '|']
        plt.figure(figsize=(14.6,7))
        for i in range(len(X.T)):
            plt.plot(x_mesh[:,i], u_mesh[:,int(i/(len(X.T)-1)*99)], '-', c=colors[i])
        for i in range(len(X.T)):
            plt.plot(X[:,i], U[:,i], marker=markers[i], c=colors[i], linestyle='')
        plt.xlabel('X')
        plt.ylabel('U')
        plt.legend(['t0', 't1', 't2', 't3', 't4', 't5', 't6', 't7', 't8', 't9'])
        plt.grid()
        plt.savefig(plotting_path+'/UvsX.png')
        plt.close(fig)

        model.load_best_val()
        # evaluate surface fitter on mesh
        x_mesh = np.linspace(x_min, x_max, 100)
        t_mesh = np.linspace(t_min, t_max, 100)
        x_mesh, t_mesh = np.meshgrid(x_mesh, t_mesh, indexing='ij')
        x_mesh, t_mesh = x_mesh.reshape(-1, 1), t_mesh.reshape(-1, 1)
        inputs_mesh = numpy_to_tensor(np.concatenate([x_mesh, t_mesh], 1))
        u_mesh = model.predict(inputs_mesh)

        # compute derivatives
        d1 = Gradient(u_mesh, inputs_mesh, order=1)
        d1 = [d1[:, i] for i in range(d1.shape[1])]
        d2 = [Gradient(d, inputs_mesh, order=1) for d in d1]
        d2 = [d2[i][:, j] for i in range(len(d1)) for j in range(len(d1))]

        # extract
        u0 = to_numpy(u_mesh).reshape([100, 100])
        ux = to_numpy(d1[0]).reshape([100, 100])
        ut = to_numpy(d1[1]).reshape([100, 100])
        uxx = to_numpy(d2[0]).reshape([100, 100])

        # 3d surface plots
        fig = plt.figure(figsize=(10,7))
        ax = fig.add_subplot(2, 2, 1, projection='3d')
        ax.plot_surface(to_numpy(inputs_mesh[:, 0]).reshape([100, 100]), 
                        to_numpy(inputs_mesh[:, 1]).reshape([100, 100]), 
                        u0, cmap=cm.coolwarm, alpha=0.9)
        plt.title('u')
        ax = fig.add_subplot(2, 2, 2, projection='3d')
        ax.plot_surface(to_numpy(inputs_mesh[:, 0]).reshape([100, 100]), 
                        to_numpy(inputs_mesh[:, 1]).reshape([100, 100]), 
                        ut, cmap=cm.coolwarm, alpha=0.9)
        plt.title('ut')
        ax = fig.add_subplot(2, 2, 3, projection='3d')
        ax.plot_surface(to_numpy(inputs_mesh[:, 0]).reshape([100, 100]), 
                        to_numpy(inputs_mesh[:, 1]).reshape([100, 100]), 
                        ux, cmap=cm.coolwarm, alpha=0.9)
        plt.title('ux')
        ax = fig.add_subplot(2, 2, 4, projection='3d')
        ax.plot_surface(to_numpy(inputs_mesh[:, 0]).reshape([100, 100]), 
                        to_numpy(inputs_mesh[:, 1]).reshape([100, 100]), 
                        uxx, cmap=cm.coolwarm, alpha=0.9)
        plt.title('uxx')
        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
        plt.savefig(plotting_path+'/U_Ut_Ux_Uxx.png')
        plt.close(fig)


        ############################# Train/Val Error/Improvement #############################

        # load training errors
        total_train_losses = model.train_loss_list
        total_val_losses = model.val_loss_list


        # find where errors decreased
        train_idx, train_loss, val_idx, val_loss = [], [], [], []
        best_train, best_val = 1e12, 1e12
        for i in range(len(total_train_losses)-1):
            rel_diff = (best_train - total_train_losses[i])
            rel_diff /= best_train
            if rel_diff > rel_save_thresh:
                best_train = total_train_losses[i]
                train_idx.append(i)
                train_loss.append(best_train)
            rel_diff = (best_val - total_val_losses[i])
            rel_diff /= best_val
            if rel_diff > rel_save_thresh:
                best_val = total_val_losses[i]
                val_idx.append(i)
                val_loss.append(best_val)
        idx = np.argmin(val_loss)

        # plot
        fig = plt.figure(figsize=(15,5))
        ax = fig.add_subplot(1, 2, 1)
        plt.semilogy(total_train_losses, 'b')
        plt.semilogy(total_val_losses, 'r')
        plt.semilogy(val_idx[idx], val_loss[idx], 'ko')
        plt.legend(['train mse', 'val mse', 'best val'])
        plt.xlabel('epochs')
        plt.ylabel('MSE')
        plt.title('Train/Val errors')
        plt.grid()
        ax = fig.add_subplot(1, 2, 2)
        plt.semilogy(train_idx, train_loss, 'b.-')
        plt.semilogy(val_idx, val_loss, 'r.-')
        plt.legend(['train mse', 'val mse'])
        plt.xlabel('epochs')
        plt.ylabel('MSE')
        plt.title('Train/Val improvements')
        plt.grid()
        plt.savefig(plotting_path+'/Train_Val_Err_Improv.png')
        plt.close(fig)    

        ############################# Residuals #############################
        model.load_best_val()

        #
        # Residual plots
        #

        # model prediction
        u_pred = model.predict(inputs.to(device)).cpu().detach().numpy().reshape(-1)
        u_true = U.reshape(-1)
        residuals = u_pred - u_true
        modified_residuals = residuals * np.abs(u_pred).clip(1e-4,np.inf)**(-gamma)

        # plot modified residuals
        fig = plt.figure(figsize=(10,7))
        plt.scatter(u_pred, modified_residuals, color='k', s=10)
        plt.plot([np.min(u_pred), np.max(u_pred)], [0, 0], 'k--')
        plt.xlabel('U')
        plt.ylabel('Residuals')
        plt.savefig(plotting_path+'/Residuals.png')
        plt.close(fig)

        # plot heatmap of residuals
        fig = plt.figure(figsize=(11,7))
        ax = fig.add_subplot(1, 1, 1)
        res = ax.imshow(U-u_pred.reshape(shape), aspect='auto',vmin=-0.5, vmax=0.5,
                        extent=[np.min(t), np.max(t), np.min(x), np.max(x)],)
        ax.set_xlabel('T')
        ax.set_ylabel('X')
        ax.set_title('Surface Residuals')
        fig.colorbar(res)
        plt.savefig(plotting_path+'/Surface_Residuals.png')
        plt.close(fig)

        # evaluate surface fitter on mesh
        x_mesh = np.linspace(np.min(x), np.max(x), 100)
        t_mesh = np.linspace(np.min(t), np.max(t), 100)
        x_mesh, t_mesh = np.meshgrid(x_mesh, t_mesh, indexing='ij')
        x_mesh, t_mesh = x_mesh.reshape(-1, 1), t_mesh.reshape(-1, 1)
        inputs_mesh = numpy_to_tensor(np.concatenate([x_mesh, t_mesh], 1))
        u_mesh = model.predict(inputs_mesh)
        pde_losses = binn.pde_loss(inputs_mesh, u_mesh, return_mean=False)
        u_mesh = u_mesh.cpu().detach().numpy().reshape(100,100)
        pde_losses = pde_losses.cpu().detach().numpy().reshape(100,100)
        x_mesh = x_mesh.reshape(100,100)
        t_mesh = t_mesh.reshape(100,100)

        # plot heatmap of pde losses
        fig = plt.figure(figsize=(11,7))
        ax = fig.add_subplot(1, 1, 1)
        res = ax.imshow(np.sqrt(pde_losses), aspect='auto', vmin=0, vmax=0.5,
                        extent=[np.min(t), np.max(t), np.min(x), np.max(x)],)
        ax.set_xlabel('Time (days)')
        ax.set_ylabel('Position (mm)')
        ax.set_title('PDE Losses, max = {0:1.4e}'.format(np.max(np.sqrt(pde_losses))))
        fig.colorbar(res)
        plt.savefig(plotting_path+'/PDE_Losses.png')
        plt.close(fig)

        #################################################################
        #
        # BINN prediction
        #
        #################################################################

        u0 = u0.reshape(-1)
        ux = ux.reshape(-1)
        ut = ut.reshape(-1)
        uxx = uxx.reshape(-1)

        A = np.concatenate([uxx.reshape(-1, 1), u0.reshape(-1, 1), u0.reshape(-1, 1)**2], axis=1)
        B = ut.reshape(-1, 1)
        theta = np.dot(np.linalg.pinv(A), B)

        D_nn = theta[0]
        r_nn = theta[1]
        K_nn = -r_nn*(1/theta[2])
        D_binn = to_numpy(binn.diffusion.D)
        r_binn = to_numpy(binn.growth.r)
        K_binn = binn.growth.K