In [None]:
test = 'Test1A'

import torch
import pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import math
import DRLPDE.neuralnets as neuralnets
import importlib

mpl.rcParams['figure.dpi'] = 300
#plt.rcParams['text.usetex'] = True
#plt.rcParams['font.family'] = 'serif'
#plt.rcParams['font.serif'] = ['Computer Modern']
#plt.rcParams['font.size'] = 8

problem = importlib.import_module(".example1", package='examples')

input_dim = [problem.x_dim, problem.t_dim, problem.hyper_dim]
output_dim = problem.output_dim
boundingbox = problem.boundingbox

with open('experiments/' + test + '_parameters.pickle', 'rb') as f:
    parameters = pickle.load(f)

if parameters['neuralnetwork'] == 'FeedForward':
    MyNeuralNetwork = neuralnets.FeedForwardNN
elif parameters['neuralnetwork'] == 'Incompressible':
    MyNeuralNetwork = neuralnets.IncompressibleNN

nn_size = parameters['nn_size']

model = MyNeuralNetwork(input_dim, output_dim, **nn_size)
model.load_state_dict(torch.load("savedmodels/" + test + ".pt"))

true_fun = problem.true_fun
polar_eq = problem.polar_eq

theta = torch.linspace(0, 2*math.pi, 120)
r = polar_eq(theta)
polar_x = r*torch.cos(theta)
polar_y = r*torch.sin(theta)

num_x = 51
num_y = 61

X = torch.cartesian_prod(torch.linspace(boundingbox[0][0], boundingbox[0][1], num_x),
                         torch.linspace(boundingbox[1][0], boundingbox[1][1], num_y))

Y = model(X).detach().reshape(num_x, num_y).numpy()
Ytrue = true_fun(X).detach().reshape(num_x, num_y).numpy()

Xplot = X[:,0].detach().reshape(num_x, num_y).numpy()
Yplot = X[:,1].detach().reshape(num_x, num_y).numpy()

levels = np.linspace(-15, 15, 31)

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[12, 4])

plt.tight_layout(pad=5.0)

ax[0].set_title('Neural Network')

contour0 = ax[0].contourf( Xplot, Yplot, Y, levels=levels, cmap=plt.cm.viridis)
ax[0].plot(polar_x, polar_y)
colorbar0_param = fig.add_axes(
    [ax[0].get_position().x1 + 0.01,
     ax[0].get_position().y0,
     0.01,
     ax[0].get_position().height])
colorbar0 = plt.colorbar(contour0, cax = colorbar0_param)

ax[1].set_title('True function')
contour1 = ax[1].contourf( Xplot, Yplot, Ytrue, levels=levels, cmap=plt.cm.viridis)
ax[1].plot(polar_x, polar_y)
colorbar1_param = fig.add_axes(
    [ax[1].get_position().x1 + 0.01,
     ax[1].get_position().y0,
     0.01,
     ax[1].get_position().height])
colorbar1 = plt.colorbar(contour1, cax = colorbar1_param)

ax[2].set_title('NN - True')
contour2 = ax[2].contourf( Xplot, Yplot, Y - Ytrue, cmap=plt.cm.coolwarm)
ax[2].plot(polar_x, polar_y)
colorbar2_param = fig.add_axes(
    [ax[2].get_position().x1 + 0.01,
     ax[2].get_position().y0,
     0.01,
     ax[2].get_position().height])
colorbar2 = plt.colorbar(contour2, cax = colorbar2_param)
