In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
import os

### data functions

In [None]:
def gen_cube_data(L, N_ax, bc_prop=0.1):
    # step size
    d = L / N_ax

    # generate xyz of cube (NOTE assumes bottom corner of cube sits at origin and extends into positive xyz) (NOTE not midpoints - includes boundaries)
    x = np.linspace(0, L, N_ax)
    y = np.linspace(0, L, N_ax)
    z = np.linspace(0, L, N_ax)

    # stack xyz into one array
    cube = np.stack(np.meshgrid(x, y, z, indexing='ij'), axis=-1).reshape(-1, 3) 

    # split collocation and boundary data
    coll_data = cube[np.all((cube != 0) & (cube != L), axis=1)]
    bc_data = cube[np.any((cube == 0) | (cube == L), axis=1)]

    # over/under sample coll data to fit desired bc_prop
    coll_target_N = int(bc_data.shape[0] / bc_prop) - bc_data.shape[0]
    if (coll_data.shape[0] < coll_target_N):
        # oversample
        remaining_N = coll_target_N - coll_data.shape[0]
        np.random.shuffle(coll_data)
        oversampled_coll = coll_data[np.random.choice([i for i, _ in enumerate(coll_data)], size=remaining_N)]  # TODO perturb these? might be fine since they get perturbed in training anyways
        coll_data = np.concatenate([coll_data, oversampled_coll])
    elif (coll_data.shape[0] > coll_target_N):
        # undersample
        np.random.shuffle(coll_data)
        coll_data = coll_data[:coll_target_N]

    # bc solution values
    bc_u = np.array([0.] * bc_data.shape[0]).reshape(-1, 1)

    print(f"initial full cube shape: {cube.shape}")
    print(f"step size: {d}")
    print(f"collocation data shape: {coll_data.shape}")
    print(f"boundary condition data shape: {bc_data.shape}")
    print(f"total N after over/under sample: {coll_data.shape[0] + bc_data.shape[0]}")

    return coll_data, bc_data, bc_u, d

def perturb(vals, minimum, maximum, perturb_delta):
    # add noise to vals
    noise = torch.randn_like(vals) * perturb_delta
    new_vals = vals + noise
    
    # if perturbed vals fall outside of vals domain, move them back in
    new_vals.data[new_vals < minimum] = minimum - new_vals.data[new_vals < minimum]  
    new_vals.data[new_vals > maximum] = 2 * maximum - new_vals.data[new_vals > maximum]  

    return new_vals

def perturb_data(data, perturb_delta=0.01):
    new_data = torch.ones_like(data)
    minimums = [torch.min(data[:, i]).item() for i in range(data.shape[1])]
    maximums = [torch.max(data[:, i]).item() for i in range(data.shape[1])]
    perturb_deltas = [(torch.max(data[:, i]) - torch.min(data[:, i])).item() * perturb_delta for i in range(data.shape[1])]

    for i in range(data.shape[1]):
        new_data[:, i] = perturb(data[:, i], minimums[i], maximums[i], perturb_delta=perturb_deltas[i])

    return new_data

def visualize(xyz_data):
    # 3D xyz plot
    xyzd = xyz_data
    fig = plt.figure(figsize=(8, 6), constrained_layout=True)
    ax = fig.add_subplot(projection='3d')
    scatter_ax = ax.scatter(xyzd[:, 0], xyzd[:, 1], xyzd[:, 2], c=[1. for _ in range(xyzd.shape[0])], cmap="Reds")
    colorbar = fig.colorbar(scatter_ax, shrink=0.55, aspect=6, pad=0.1)
    colorbar.remove()
    ax.set_xlabel("x", weight="bold")
    ax.set_ylabel("y", weight="bold")
    ax.set_zlabel("z", weight="bold")
    ax.zaxis.labelpad = 5
    plt.show()

### generate data

In [None]:
# set gpu
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# generate grid of data
L = 1
N_ax = 25
xmin = 0
xmax = 1
ymin = 0
ymax = 1
coll_data, bc_data, bc_u, d = gen_cube_data(L, N_ax)
visualize(coll_data)

# separate train data
train_N_ax = 10
train_coll_data, train_bc_data, train_bc_u, train_d = gen_cube_data(L, train_N_ax)
visualize(train_coll_data)
visualize(train_bc_data)

# cast arrays to tensors
coll_data = torch.tensor(coll_data, requires_grad=True).to(device)
bc_data = torch.tensor(bc_data, requires_grad=True).to(device)
bc_u = torch.tensor(bc_u, requires_grad=True).to(device)

train_coll_data = torch.tensor(train_coll_data, requires_grad=True).to(device)
train_bc_data = torch.tensor(train_bc_data, requires_grad=True).to(device)
train_bc_u = torch.tensor(train_bc_u, requires_grad=True).to(device)

### loss class definition

In [None]:
class PINNLoss(nn.Module):
    def __init__(self):
        super(PINNLoss, self).__init__()
        self.loss = None

        # NOTE if you change loss terms returned, must change this and return+calculations in calc_loss()
        self.return_names = ["loss_ovr", "loss_ovr_no_reg", "loss_f", "loss_bc", "u_trivial_penalty"]
    
    def MSE_f(self, f_pred):
        """
        Returns MSE loss on f predictions. 

        Inputs:
        f_pred: batch of differential equation predictions
        """

        return torch.mean(f_pred**2)

    def MSE_bc(self, u_target, u_pred):
        """
        Returns MSE loss on b.c. predictions.

        Inputs:
        u_target: actual value of u
        u_pred: batch of predicted u values at b.c. points
        """

        return torch.mean((u_target - u_pred)**2)

    def calc_u_trivial_penalty(self, u_pred):
        """
        Returns a term that penalizes trivial solutions.

        Inputs:
        u_pred: batch of wavefunction predictions
        """

        return (1 / torch.mean(u_pred**2 + 1e-6))

    def calc_loss(self, u_coll_pred, f_pred, u_bc_target, u_bc_pred):
        """
        Calculate loss terms for predictions. Returns a tuple of losses. 
        Names and order of returned terms should match self.return_names.

        Inputs:
        u_coll_pred: batch of wavefunction predictions 
        f_pred: batch of differential equation predictions (only generated from collocation data)
        u_bc_target: actual u values for b.c. data
        u_bc_pred: batch of b.c. wavefunction predictions
        """

        # differential equation MSE loss (collocation)
        loss_f = self.MSE_f(f_pred)

        # trivial solution penalty (collocation)
        u_trivial_penalty = self.calc_u_trivial_penalty(u_coll_pred)

        # b.c. MSE loss
        loss_bc = self.MSE_bc(u_bc_target, u_bc_pred)

        # overall loss
        loss_ovr = loss_f + loss_bc + u_trivial_penalty 

        # loss w/o regularization terms
        loss_ovr_no_reg = loss_f + loss_bc

        return loss_ovr, loss_ovr_no_reg, loss_f, loss_bc, u_trivial_penalty

### PINN class definition

In [None]:
class SinActivation(torch.nn.Module):
    @staticmethod
    def forward(input):
        """
        Passes input tensor through sin function. Used in neural network class.

        Inputs:
        input: tensor batch of layer outputs
        """

        return torch.sin(input)

class PINN(nn.Module):
    def __init__(self, input_size, device, E=None, num_hidden_layers=4, optimizer_lr=0.0001, optimizer_betas=(0.999, 0.9999)):
        super(PINN, self).__init__()

        # constants
        self.planck_term_constant = 0.5  # hbar**2/2*m in Coulomb units

        # torch device
        self.device = device

        # E initialization
        self.E = E 

        # network
        self.activation = SinActivation()
        self.dense0 = nn.Linear(input_size + 1, 64, dtype=torch.double)  # NOTE +1 to input shape for implicit E value
        self.hidden_layers = nn.ModuleList([nn.Linear(64, 64, dtype=torch.double) for _ in range(num_hidden_layers)])  # NOTE new
        self.dense_out = nn.Linear(64, 1, dtype=torch.double)

        # optimizer + loss
        self.optimizer = optim.Adam(self.parameters(), lr=optimizer_lr, betas=optimizer_betas)
        self.loss_criteria = PINNLoss()
        self.loss = None

    def forward_net(self, data):
        """
        Forward pass through full network. Returns network output.

        Inputs:
        data: tensor with correct input shape
        """

        out = self.dense0(data)
        out = self.activation(out)
        for hidden_layer in self.hidden_layers:
            out = hidden_layer(out)
            out = self.activation(out)
        out = self.dense_out(out)
        
        return out

    def forward(self, x, y, z):
        """
        Full PINN forward pass that includes E concat, network forward pass, network output b.c. scaling, 
        gradient calculations, and diff. eqn. (f) calculation.
        Returns final output prediction, u, and f.

        Inputs:
        x: tensor of x with shape (N, 1)
        y: tensor of y with shape (N, 1)
        z: tensor of z with shape (N, 1)
        """

        # create E input vector
        E_batch = torch.tensor([[self.E]] * x.shape[0]).to(self.device)

        # concat inputs for network (need them separate initially for grad calc later)
        nn_inp = torch.cat([x, y, z, E_batch], dim=1)

        # pass through nn
        u = self.forward_net(nn_inp)

        # calc gradients
        u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), retain_graph=True, create_graph=True)[0]

        u_y = torch.autograd.grad(u, y, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
        u_yy = torch.autograd.grad(u_y, y, grad_outputs=torch.ones_like(u_y), retain_graph=True, create_graph=True)[0]

        u_z = torch.autograd.grad(u, z, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
        u_zz = torch.autograd.grad(u_z, z, grad_outputs=torch.ones_like(u_z), retain_graph=True, create_graph=True)[0]
        
        # differential equation
        f = self.planck_term_constant * (u_xx + u_yy + u_zz) + (self.E * u)
        
        return u, f

    def forward_inference(self, x, y, z):
        """
        Inference forward pass. Doesn't include gradient + f calculation.

        Inputs:
        x: tensor of x with shape (N, 1)
        y: tensor of y with shape (N, 1)
        z: tensor of z with shape (N, 1)
        """

        # create E input vector
        E_batch = torch.tensor([[self.E]] * x.shape[0]).to(self.device)

        # concat inputs for network 
        nn_inp = torch.cat([x, y, z, E_batch], dim=1)

        # pass through nn
        u = self.forward_net(nn_inp)

        return u

    def backward(self, u_coll_pred, f_pred, u_bc_target, u_bc_pred):
        """
        Network backward pass. Returns tuple of loss values.
        
        Inputs:
        u_coll_pred: batch of collocation wavefunction predictions 
        f_pred: batch of differential equation predictions
        u_bc_target: actual u values for b.c. data
        u_bc_pred: batch of b.c. wavefunction predictions
        """

        # calculate losses
        losses = self.loss_criteria.calc_loss(u_coll_pred, f_pred, u_bc_target, u_bc_pred)
        self.loss = losses[0]  # NOTE overall loss should always be first value returned in losses tuple

        # backprop + update params
        self.loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        return losses 

### plotting functions

In [None]:
def display_plot(x_vals, y_vals, x_label="x", y_label="y", line_color="b", log_scale=False, data_label=None, title=None):
    """
    Generate plot for x_vals and y_vals.

    Inputs:
    pretty self-explanatory
    """

    if data_label:
        plt.plot(x_vals, y_vals, color=line_color, label=data_label)
        plt.legend()
    else:
        plt.plot(x_vals, y_vals, color=line_color)

    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)

    if log_scale:
        plt.yscale("log")

    plt.show()

### train loop function

In [None]:
def run_train_loop(coll_data, 
                   bc_data,
                   bc_u,
                   device, 
                   E_range,
                   num_train_steps_per_E,
                   save_denom=1000,
                   loss_criteria_idx=0,
                   num_hidden_layers=4):
    start_time = time.time()
    print(f"training network on {len(E_range)} values of E with {num_train_steps_per_E} train steps per E...")
    epoch = 0
    epoch_print_denom = 1000

    # initialize network
    lr = 8e-3
    betas = (0.999, 0.9999)
    pinn = PINN(coll_data.shape[1], device, optimizer_lr=lr, optimizer_betas=betas, num_hidden_layers=num_hidden_layers)
    pinn.to(device)
    
    # initialize copy of current network (used for tracking best model across training)
    current_best_pinn = PINN(coll_data.shape[1], device, E=pinn.E, optimizer_lr=lr, optimizer_betas=betas, num_hidden_layers=num_hidden_layers).to(device)
    current_best_pinn.load_state_dict(pinn.state_dict())

    # best model save variables
    min_save_epoch = 100
    prev_pinns = []  # list to store best models 
    default_min_loss = 1000
    min_loss = default_min_loss
    current_best_epoch = -1

    E_predictions = []  # track E by training step (epoch)
    train_loss_vals = []  # track all loss values across training

    for E_idx, E in enumerate(E_range):
        pinn.E = E
        
        for train_step in range(num_train_steps_per_E):

            # perturb data
            coll_perturbed = perturb_data(coll_data).to(device)
            #coll_perturbed = coll_data

            # collocation forward pass
            u_coll_pred, f_pred = pinn.forward(coll_perturbed[:, 0].reshape(-1, 1), 
                                               coll_perturbed[:, 1].reshape(-1, 1),
                                               coll_perturbed[:, 2].reshape(-1, 1))

            # b.c. forward pass
            u_bc_pred = pinn.forward_inference(bc_data[:, 0].reshape(-1, 1),
                                               bc_data[:, 1].reshape(-1, 1),
                                               bc_data[:, 2].reshape(-1, 1))

            # backward pass
            losses = pinn.backward(u_coll_pred, f_pred, bc_u, u_bc_pred)
            train_loss_vals.append([loss_val.item() for loss_val in losses])

            # track E
            E_predictions.append(E)
            
            ####
            # check if new minimum overall loss attained across this window of epochs
            if (epoch >= min_save_epoch) and (losses[loss_criteria_idx] < min_loss):
                # update current best loss and epoch
                min_loss = losses[loss_criteria_idx]
                current_best_epoch = epoch
                
                # copy current model state 
                current_best_pinn = PINN(coll_data.shape[1], device, E=pinn.E, optimizer_lr=lr, optimizer_betas=betas, num_hidden_layers=num_hidden_layers).to(device)
                current_best_pinn.load_state_dict(pinn.state_dict())

            # save the best (min loss) model across this window of epochs
            if (epoch >= min_save_epoch) and ((epoch + 1) % save_denom == 0):
                min_loss = default_min_loss  # reset min loss to default (ideally next best model beats quickly) 

                # copy this epoch window's best model state and save (model, epoch) tuple to list
                pinn_copy = PINN(coll_data.shape[1], device, E=current_best_pinn.E, optimizer_lr=lr, optimizer_betas=betas, num_hidden_layers=num_hidden_layers).to(device)
                pinn_copy.load_state_dict(current_best_pinn.state_dict())
                prev_pinns.append((pinn_copy, current_best_epoch))

                # NOTE I don't reset current_best_pinn here. so if model in next window doesn't beat default min loss, best model from previous window is saved
            ####

            if ((epoch == 0) or ((epoch + 1) % epoch_print_denom == 0)):
                print("\n" + "=" * 20 + f" EPOCH {epoch} " + "=" * 20)
                print()
                print("train losses:")
                for return_name, val in zip(pinn.loss_criteria.return_names, train_loss_vals[epoch]):
                    print(f"{return_name} = {val:.3e}")
                print()
                print(f"current E = {pinn.E.item()}")
                print(f"elapsed model training time: {(time.time() - start_time) / 60 :.2f} minutes")

            epoch += 1

    # plot losses
    epoch_list = [e for e in range(epoch)]
    for loss_idx, loss_name in enumerate(pinn.loss_criteria.return_names):
        display_plot(epoch_list,
                     torch.tensor(train_loss_vals)[:, loss_idx].detach().numpy(),
                     x_label="epoch",
                     y_label=loss_name,
                     log_scale=True)

    # plot E
    display_plot(epoch_list, E_predictions, x_label="epoch", y_label=r"$\hat{E}$")

    return prev_pinns, train_loss_vals, epoch

### train model(s)

In [None]:
def calc_E(nx, ny, nz):
    # Coulomb units: hbar**2/2m * pi**2/L**2 == 0.5 * pi**2/L**2
    return (0.5 * np.pi**2 / L**2) * (nx**2 + ny**2 + nz**2)

nx_ny_nz = [(1, 1, 1), (2, 1, 1), (3, 1, 1), (2, 2, 2)]
E_range = np.array([calc_E(nx, ny, nz) for nx, ny, nz in nx_ny_nz])  # Coulomb units
E_delta = 0.1
E_range = np.array([[E - E_delta, E, E + E_delta] for E in E_range]).flatten()

In [None]:
E_range

In [None]:
num_train_steps_per_E = 30000
loss_criteria_idx = 1
saved_pinns, train_loss_vals, num_epochs = run_train_loop(train_coll_data, 
                                                          train_bc_data,
                                                          train_bc_u,
                                                          device, 
                                                          E_range, 
                                                          num_train_steps_per_E, 
                                                          loss_criteria_idx=loss_criteria_idx, 
                                                          num_hidden_layers=4)

### get best model(s) and evaluate predictions

In [None]:
def get_best_pinns(min_loss_windows, pinns, train_loss_vals, loss_criteria_idx):
    """
    Plots model with minimum loss (based on loss_criteria_idx) in each window of min_loss_windows.
    Returns list of dictionaries containing the best models + relevant info.

    Inputs:
    min_loss_windows: list of tuples where each tuple is a range of epoch values to look for model with minimum loss
    pinns: list of all saved torch models across training
    train_loss_vals: list of tuples containing loss values for each training epoch
    loss_criteria_idx: index of loss term to use for evaluating best model to save (see PINNLoss)
    """

    ovr_losses = torch.tensor(train_loss_vals)[:, loss_criteria_idx].detach().numpy()
    eigen_pinns = []  # list to save best models to (assuming each model saved below corresponds to unique eigenvalue)

    for (start, stop) in min_loss_windows:
        # find the epoch where the minimum overall loss is achieved
        epoch_idx = np.argmin(ovr_losses[start:stop]) + start

        if (len(pinns) == 0):
            print("No models saved during training.")
        else:
            # get the matching model (this should always hit i think as long as there is a match)
            p_list = [(net, epoch_num) for (net, epoch_num) in pinns if epoch_num == epoch_idx]

            if (len(p_list) == 0):
                print(f"No matching saved models found for epoch window [{start}, {stop}] with minimum epoch = {epoch_idx}. Check self-defined minimum loss windows.")
            else:
                # matching model found
                p = p_list[0]
                
                # save this predicted eigen solution
                eigen_pinns.append({"model": p[0], "epoch": epoch_idx, "losses": train_loss_vals[epoch_idx]})
    
    return eigen_pinns

In [None]:
loss_windows = [[100 + (i*num_train_steps_per_E * 3), (i*num_train_steps_per_E + num_train_steps_per_E) * 3] for i in range(4)]  # NOTE hardcoding
best_pinns = get_best_pinns(loss_windows, saved_pinns, train_loss_vals, loss_criteria_idx)
best_pinns

In [None]:
def calc_analytical(x, y, z, nx, ny, nz, L):
    # cube analytical solution
    return np.sqrt(8 / L**3) * np.sin(nx * np.pi * x / L) * np.sin(ny * np.pi * y / L) * np.sin(nz * np.pi * z / L)

def calc_pred_norm_const(approx_domain_pred_vec, d):
    # 3D prediction normalization constant
    return 1. / torch.sqrt(d**3 * torch.sum(approx_domain_pred_vec**2))

def make_3D_plot(x_vals, 
                 y_vals, 
                 z_vals, 
                 cmap_vals, 
                 cbar_label="",
                 xlabel="",
                 ylabel="",
                 zlabel="",
                 ):
    plt.close()
    fig = plt.figure(figsize=(8, 6), constrained_layout=True)
    ax = fig.add_subplot(projection='3d')
    scatter_ax = ax.scatter(x_vals, y_vals, z_vals, c=cmap_vals, cmap="Reds")
    colorbar = fig.colorbar(scatter_ax, shrink=0.55, aspect=6, pad=0.1)
    colorbar.set_label(cbar_label, labelpad=30)
    colorbar.ax.yaxis.label.set_rotation(0)
    ax.set_xlabel(xlabel, weight="bold")
    ax.set_ylabel(ylabel, weight="bold")
    ax.set_zlabel(zlabel, weight="bold")
    ax.zaxis.labelpad = 10
    ax.set_box_aspect(aspect=None, zoom=0.9)  
    plt.show()

def eval_pinn(pinn_dict, coll_data, bc_data, nx, ny, nz):
    print("=" * 100)
    print(f"PINN saved on epoch {pinn_dict['epoch']}")
    print(f"E = {pinn_dict['model'].E}")
    
    for loss_val, loss_name in zip(pinn_dict["losses"], pinn_dict["model"].loss_criteria.return_names):
        print(f"{loss_name} = {loss_val}")

    # combine bc+coll data for inference
    data = torch.concat([coll_data, bc_data])

    # make predictions and normalize
    pinn = pinn_dict["model"]
    u_pred = pinn.forward_inference(data[:, 0].reshape(-1, 1), data[:, 1].reshape(-1, 1), data[:, 2].reshape(-1, 1))
    pred_norm_const = calc_pred_norm_const(u_pred, d)  
    u_pred = u_pred * pred_norm_const
    u_pred = u_pred.flatten().detach().cpu().numpy()
    u_pred_prob_densities = u_pred**2
    x = data[:, 0].detach().cpu().numpy()
    y = data[:, 1].detach().cpu().numpy()
    z = data[:, 2].detach().cpu().numpy()

    # analytical solution (NOTE based on input nx ny nz, could display different degenerate solution)
    u_analytical = calc_analytical(x, y, z, nx, ny, nz, L)
    u_analytical_prob_densities = u_analytical**2

    # sign flip
    flip_val = np.round(np.mean(u_analytical / (u_pred + 1e-6)))

    if int(flip_val) < 0:
        flipped_sign = True
        u_pred = -u_pred
    else:
        flipped_sign = False

    # MSE calc
    mse = np.mean((u_analytical - u_pred)**2)

    print(f"flip val = {flip_val}")
    print(f"flipped sign = {flipped_sign}")
    print(f"MSE = {mse}")

    # subset predictions to points with prob density values greater than this
    prob_thresh = 0.3
    pred_prob_thresh_idx = np.where(u_pred_prob_densities > prob_thresh)[0]
    x_pred_sub = x[pred_prob_thresh_idx]
    y_pred_sub = y[pred_prob_thresh_idx]
    z_pred_sub = z[pred_prob_thresh_idx]
    u_pred_prob_densities_sub = u_pred_prob_densities[pred_prob_thresh_idx]
    
    # analytical subset
    prob_thresh = 0.3
    analytical_prob_thresh_idx = np.where(u_analytical_prob_densities > prob_thresh)[0]
    x_analytical_sub = x[analytical_prob_thresh_idx]
    y_analytical_sub = y[analytical_prob_thresh_idx]
    z_analytical_sub = z[analytical_prob_thresh_idx]
    u_analytical_prob_densities_sub = u_analytical_prob_densities[analytical_prob_thresh_idx]

    # predicted u
    make_3D_plot(x_pred_sub, y_pred_sub, z_pred_sub, u_pred_prob_densities_sub, cbar_label=r"$|\hat{\psi}(x,y,z)|^2$", xlabel="x", ylabel="y", zlabel="z")

    # analytical u
    make_3D_plot(x_analytical_sub, y_analytical_sub, z_analytical_sub, u_analytical_prob_densities_sub, cbar_label=r"$|\psi(x,y,z)|^2$", xlabel="x", ylabel="y", zlabel="z")

In [None]:
for pinn_dict, nx_ny_nz_tuple in zip(best_pinns, nx_ny_nz):
    eval_pinn(pinn_dict, coll_data, bc_data, nx_ny_nz_tuple[0], nx_ny_nz_tuple[1], nx_ny_nz_tuple[2])

### messy hack to generate output plots of other degenerate solutions

In [None]:
def output(pinn_dict, coll_data, bc_data, nx, ny, nz):
    # combine bc+coll data for inference
    data = torch.concat([coll_data, bc_data])

    x = data[:, 0].detach().cpu().numpy()
    y = data[:, 1].detach().cpu().numpy()
    z = data[:, 2].detach().cpu().numpy()

    # analytical solution
    u_analytical = calc_analytical(x, y, z, nx, ny, nz, L)
    u_analytical_prob_densities = u_analytical**2

    # analytical subset
    prob_thresh = 0.3
    analytical_prob_thresh_idx = np.where(u_analytical_prob_densities > prob_thresh)[0]
    x_analytical_sub = x[analytical_prob_thresh_idx]
    y_analytical_sub = y[analytical_prob_thresh_idx]
    z_analytical_sub = z[analytical_prob_thresh_idx]
    u_analytical_prob_densities_sub = u_analytical_prob_densities[analytical_prob_thresh_idx]

    # analytical u
    make_3D_plot(x_analytical_sub, y_analytical_sub, z_analytical_sub, u_analytical_prob_densities_sub, cbar_label=r"$|\psi(x,y,z)|^2$", xlabel="x", ylabel="y", zlabel="z")

In [None]:
output(None, coll_data, bc_data, 1, 2, 1)

In [None]:
output(None, coll_data, bc_data, 1, 3, 1)