In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import re

import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl
from utilities import *



import imageio
from io import BytesIO
from IPython.display import Image as DisplayImage

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#%%
"""
Adapted from Zongyi Li TODO: include referene in the README
This file is the Fourier Neural Operator for 3D problem takes the 2D spatial + 1D temporal equation directly as a 3D problem
"""

import torch.nn.functional as F
from utilities import *
from timeit import default_timer

torch.manual_seed(0)
np.random.seed(0)
#%%
################################################################
# 3d fourier layers
################################################################

class SpectralConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        super(SpectralConv3d, self).__init__()

        """
        3D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.modes3 = modes3

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul3d(self, input, weights):
        # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
        return torch.einsum("bixyz,ioxyz->boxyz", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1])

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        #Return to physical space
        x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x

class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP, self).__init__()
        self.mlp1 = nn.Conv3d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv3d(mid_channels, out_channels, 1)

    def forward(self, x):
        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)
        return x

class FNO3d(nn.Module):
    def __init__(self, modes1, modes2, modes3, width):
        super(FNO3d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: all parameters  + encoded spatial-temporal locations (x, y, t)
        input shape: (batchsize, x=32, y=32, t=61, c=6)
        output: the solution of the 61 timesteps
        output shape: (batchsize, x=32, y=32, t=61, c=1)
        """

        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.padding = 1 # pad the domain if input is non-periodic

        self.p = nn.Linear(6, self.width)# input channel is 6: Por, Perm, Pressure + x, y, time encodings
        self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.mlp0 = MLP(self.width, self.width, self.width)
        self.mlp1 = MLP(self.width, self.width, self.width)
        self.mlp2 = MLP(self.width, self.width, self.width)
        self.mlp3 = MLP(self.width, self.width, self.width)
        self.w0 = nn.Conv3d(self.width, self.width, 1)
        self.w1 = nn.Conv3d(self.width, self.width, 1)
        self.w2 = nn.Conv3d(self.width, self.width, 1)
        self.w3 = nn.Conv3d(self.width, self.width, 1)
        self.q = MLP(self.width, 1, self.width * 4) # output channel is 1: u(x, y)

    def forward(self, x):
        #grid = self.get_grid(x.shape, x.device)
        #print(f'grid shape: {grid.shape}')
        #print(f'x shape: {x.shape}')
        #x = torch.cat((x, grid), dim=-1)
        #print(f'x shape after cat: {x.shape}')
        x = self.p(x)
        x = x.permute(0, 4, 1, 2, 3)
        x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic

        x1 = self.conv0(x)
        x1 = self.mlp0(x1)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x1 = self.mlp1(x1)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x1 = self.mlp2(x1)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x1 = self.mlp3(x1)
        x2 = self.w3(x)
        x = x1 + x2

        x = x[..., :-self.padding]
        x = self.q(x)
        x = x.permute(0, 2, 3, 4, 1) # pad the domain if input is non-periodic
        return x
        
    def get_grid(self, shape, device):
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        return torch.cat((gridx, gridy, gridz), dim=-1).to(device)
        
      


In [3]:
import torch
import matplotlib.pyplot as plt
import imageio
from io import BytesIO
from IPython.display import Image as DisplayImage

def plot_comparison(true, predicted, ax, title):
    vmin, vmax = true.min(), true.max()
    im = ax.imshow(true, cmap='jet', vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax, orientation='vertical')
    ax.set_title(title)
    ax.axis('off')

def create_memory_image(true, predicted, time_step):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 6))
    plot_comparison(true, ax1, f'True at time step {time_step}')
    plot_comparison(predicted, ax2, f'Predicted at time step {time_step}')
    diff = np.abs(true - predicted)
    im = ax3.imshow(diff, cmap='jet')
    divider = make_axes_locatable(ax3)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax, orientation='vertical')
    ax3.set_title(f'Absolute Error - Time step {time_step}')
    ax3.axis('off')
    plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    return buf

def create_animation(data, predictions, index, device, model, output_path):
    image_buffers = []
    with torch.no_grad():
        x_single = data[index][0].unsqueeze(0).to(device)
        y_single = data[index][1].unsqueeze(0).to(device)
        out_single = model(x_single)
        out_single = y_normalizer.decode(out_single)
        y_single = y_normalizer.decode(y_single)
        true_data = y_single.view(1, 61, 32, 32, 1).cpu().numpy()
        predicted_data = out_single.view(1, 61, 32, 32, 1).cpu().numpy()

        for t in range(61):
            buf = create_memory_image(
                true_data[0, t, :, :, 0],
                predicted_data[0, t, :, :, 0],
                t
            )
            image_buffers.append(buf)

    images = [imageio.imread(buf.getvalue()) for buf in image_buffers]
    buf_result = BytesIO()
    imageio.mimsave(buf_result, images, format='GIF', duration=0.5)

    with open(output_path, 'wb') as f:
        f.write(buf_result.getvalue())

    buf_result.seek(0)
    for buf in image_buffers:
        buf.close()

    return buf_result


# Example usage for CO_2
folder = "/scratch/smrserraoseabr/Projects/FluvialCO2/results32/"
input_vars = ['Por', 'Perm', 'gas_rate']
output_vars = ['CO_2']
device = 'cpu'
num_files = 1000
traintest_split = 0.8
dataset = ReadXarrayDataset(
    folder=folder,
    input_vars=input_vars,
    output_vars=output_vars,
    num_files=num_files,
    traintest_split=traintest_split,
)

train_a = dataset.train_data_input.to(device)
train_u = dataset.train_data_output.to(device)
test_a = dataset.test_data_input.to(device)
test_u = dataset.test_data_output.to(device)

a_normalizer = UnitGaussianNormalizer(train_a)
train_a = a_normalizer.encode(train_a)
test_a = a_normalizer.encode(test_a)

y_normalizer = UnitGaussianNormalizer(train_u)
train_u = y_normalizer.encode(train_u)
test_u = y_normalizer.encode(test_u)

modes = 12
width = 128
path_model = '/scratch/smrserraoseabr/Projects/NO-DA/runs/ns_fourier_3d_N800.0_ep300_m12_w128'

model = FNO3d(modes, modes, modes, width).to(device)
model.load_state_dict(torch.load(path_model))
model.eval()

pred = torch.zeros(test_u.shape)
test_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False
)

num_samples = 10
fig, axes = plt.subplots(nrows=num_samples, ncols=3, figsize=(14, 6 * num_samples))
vmin, vmax = 0.0, 1.0
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

with torch.no_grad():
    for index, (x, y) in enumerate(test_loader):
        if index >= num_samples:
            break

        x = x.to(device)
        y = y.to(device)

        out = model(x)
        out = y_normalizer.decode(out)
        y = y_normalizer.decode(y)

        test_y_shape = (1, 61, 32, 32, 1)
        predicted_y_shape = (1, 61, 32, 32, 1)

        test_y = y.view(test_y_shape).cpu().numpy()
        predicted_y = out.view(predicted_y_shape).cpu().numpy()

        img1 = axes[index, 0].imshow(test_y[0, -1, :, :, 0], cmap='jet', norm=norm)
        img2 = axes[index, 1].imshow(predicted_y[0, -1, :, :, 0], cmap='jet', norm=norm)
        img3 = axes[index, 2].imshow(np.abs(test_y[0, -1, :, :, 0] - predicted_y[0, -1, :, :, 0]), cmap='jet')

        axes[index, 0].set_title("True CO2 - Last Time Step")
        axes[index, 1].set_title("Predicted CO2 - Last Time Step")
        axes[index, 2].set_title("Absolute Error - Last Time Step")

        axes[index, 0].axis('off')
        axes[index, 1].axis('off')
        axes[index, 2].axis('off')

        plt.colorbar(img1, ax=axes[index, 0], fraction=0.046, pad=0.04)
        plt.colorbar(img2, ax=axes[index, 1], fraction=0.046, pad=0.04)
        plt.colorbar(img3, ax=axes[index, 2], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig("/scratch/smrserraoseabr/Projects/NO-DA/plots/ns_fourier_3d_N800.0_ep300_m12_w128_predictions.png")
plt.close()

create_animation(dataset.test_data, pred, 0, device, model, "output_CO2.gif")

# Example usage for Pressure
folder = "/scratch/smrserraoseabr/Projects/FluvialCO2/results32/"
input_vars = ['Por', 'Perm', 'gas_rate']
output_vars = ['Pressure']
device = 'cpu'
num_files = 1000
traintest_split = 0.8
dataset = ReadXarrayDataset(
    folder=folder,
    input_vars=input_vars,
    output_vars=output_vars,
    num_files=num_files,
    traintest_split=traintest_split,
)

train_a = dataset.train_data_input.to(device)
train_u = dataset.train_data_output.to(device)
test_a = dataset.test_data_input.to(device)
test_u = dataset.test_data_output.to(device)

a_normalizer = UnitGaussianNormalizer(train_a)
train_a = a_normalizer.encode(train_a)
test_a = a_normalizer.encode(test_a)

y_normalizer = UnitGaussianNormalizer(train_u)
train_u = y_normalizer.encode(train_u)
test_u = y_normalizer.encode(test_u)

modes = 12
width = 128
path_model = '/scratch/smrserraoseabr/Projects/NO-DA/runs/ns_fourier_3d_N800.0_ep300_m12_w128_INPUT_Por_Perm_gas_rate_OUTPUT_Pressure/model/ns_fourier_3d_N800.0_ep300_m12_w128_pressure'

model = torch.load(path_model)
model.eval()

pred = torch.zeros(test_u.shape)
index = 0
test_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(test_a, test_u),
    batch_size=1,
    shuffle=False
)

with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        if i == index:
            x_single, y_single = x, y
            break

    x_single = x_single.to(device)
    y_single = y_single.to(device)
    out_single = model(x_single)
    out_single = y_normalizer.decode(out_single)
    y_single = y_normalizer.decode(y_single)

    true_data = y_single.view(1, 61, 32, 32, 1).cpu().numpy()
    predicted_data = out_single.view(1, 61, 32, 32, 1).cpu().numpy()

    for t in range(61):
        pred[0, t, :, :, 0] = predicted_data[0, t, :, :, 0]

    fig, axes = plt.subplots(nrows=num_samples, ncols=3, figsize=(14, 6 * num_samples))
    vmin, vmax = 200.0, 550.0
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

    for index, (true, predicted) in enumerate(zip(true_data, predicted_data)):
        if index >= num_samples:
            break

        img1 = axes[index, 0].imshow(true[-1, :, :, 0], cmap='jet', norm=norm)
        img2 = axes[index, 1].imshow(predicted[-1, :, :, 0], cmap='jet', norm=norm)
        img3 = axes[index, 2].imshow(np.abs(true[-1, :, :, 0] - predicted[-1, :, :, 0]), cmap='jet')

        axes[index, 0].set_title("True Pressure - Last Time Step")
        axes[index, 1].set_title("Predicted Pressure - Last Time Step")
        axes[index, 2].set_title("Absolute Error - Last Time Step")

        axes[index, 0].axis('off')
        axes[index, 1].axis('off')
        axes[index, 2].axis('off')

        plt.colorbar(img1, ax=axes[index, 0], fraction=0.046, pad=0.04)
        plt.colorbar(img2, ax=axes[index, 1], fraction=0.046, pad=0.04)
        plt.colorbar(img3, ax=axes[index, 2], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.savefig("/scratch/smrserraoseabr/Projects/NO-DA/plots/ns_fourier_3d_N800.0_ep300_m12_w128_predictions.png")
    plt.close()

    create_animation(dataset.test_data, pred, 0, device, model, "output_pressure.gif")


TypeError: __init__() got an unexpected keyword argument 'traintest_split'

In [22]:

def plot_comparison(true, predicted, ax, title):
    vmin, vmax = true.min(), true.max()
    im = ax.imshow(true, cmap='jet', vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax, orientation='vertical')
    ax.set_title(title)
    ax.axis('off')


def create_memory_image(true, predicted, time_step):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 6))
    plot_comparison(true, predicted, ax1, f'True CO2 at time step {time_step}')
    plot_comparison(predicted, ax2, f'Predicted CO2 at time step {time_step}')
    diff = np.abs(true - predicted)
    im = ax3.imshow(diff, cmap='jet')
    divider = make_axes_locatable(ax3)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax, orientation='vertical')
    ax3.set_title(f'Absolute Error - Time step {time_step}')
    ax3.axis('off')
    plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    return buf


def create_animation(test_data, predicted_data, index, device, model, output_path):
    image_buffers = []
    with torch.no_grad():
        x_single = test_data[index][0].unsqueeze(0).to(device)
        y_single = test_data[index][1].unsqueeze(0).to(device)
        out_single = model(x_single)
        out_single = y_normalizer.decode(out_single)
        y_single = y_normalizer.decode(y_single)
        true_data = y_single.view(1, 61, 32, 32, 1).cpu().numpy()
        predicted_data = out_single.view(1, 61, 32, 32, 1).cpu().numpy()

        for t in range(61):
            buf = create_memory_image(
                true_data[0, t, :, :, 0],
                predicted_data[0, t, :, :, 0],
                t
            )
            image_buffers.append(buf)

    images = [imageio.imread(buf.getvalue()) for buf in image_buffers]
    buf_result = BytesIO()
    imageio.mimsave(buf_result, images, format='GIF', duration=0.5)

    with open(output_path, 'wb') as f:
        f.write(buf_result.getvalue())

    buf_result.seek(0)
    for buf in image_buffers:
        buf.close()

    return buf_result



folder = "/scratch/smrserraoseabr/Projects/FluvialCO2/results32/"
input_vars = ['Por', 'Perm', 'gas_rate']
output_vars = ['CO_2']
device = 'cpu'
num_files = 1000
traintest_split = 0.8
dataset = ReadXarrayDataset(
    folder=folder,
    input_vars=input_vars,
    output_vars=output_vars,
    num_files=num_files,
    traintest_split=traintest_split,
)

train_a = dataset.train_data_input.to(device)
train_u = dataset.train_data_output.to(device)
test_a = dataset.test_data_input.to(device)
test_u = dataset.test_data_output.to(device)

a_normalizer = UnitGaussianNormalizer(train_a)
train_a = a_normalizer.encode(train_a)
test_a = a_normalizer.encode(test_a)

y_normalizer = UnitGaussianNormalizer(train_u)
train_u = y_normalizer.encode(train_u)
test_u = y_normalizer.encode(test_u)

modes = 12
width = 128
path_model = '/scratch/smrserraoseabr/Projects/NO-DA/runs/ns_fourier_3d_N800.0_ep300_m12_w128'

model = FNO3d(modes, modes, modes, width).to(device)
model.load_state_dict(torch.load(path_model))
model.eval()

pred = torch.zeros(test_u.shape)
test_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False
)

num_samples = 10
fig, axes = plt.subplots(nrows=num_samples, ncols=3, figsize=(14, 6 * num_samples))
vmin, vmax = 0.0, 1.0
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

with torch.no_grad():
    for index, (x, y) in enumerate(test_loader):
        if index >= num_samples:
            break

        x = x.to(device)
        y = y.to(device)

        out = model(x)
        out = y_normalizer.decode(out)
        y = y_normalizer.decode(y)

        test_y_shape = (1, 61, 32, 32, 1)
        predicted_y_shape = (1, 61, 32, 32, 1)

        test_y = y.view(test_y_shape).cpu().numpy()
        predicted_y = out.view(predicted_y_shape).cpu().numpy()

        img1 = axes[index, 0].imshow(test_y[0, -1, :, :, 0], cmap='jet', norm=norm)
        img2 = axes[index, 1].imshow(predicted_y[0, -1, :, :, 0], cmap='jet', norm=norm)
        img3 = axes[index, 2].imshow(np.abs(test_y[0, -1, :, :, 0] - predicted_y[0, -1, :, :, 0]), cmap='jet')

        axes[index, 0].set_title("True CO2 - Last Time Step")
        axes[index, 1].set_title("Predicted CO2 - Last Time Step")
        axes[index, 2].set_title("Absolute Error - Last Time Step")

        axes[index, 0].axis('off')
        axes[index, 1].axis('off')
        axes[index, 2].axis('off')

        plt.colorbar(img1, ax=axes[index, 0], fraction=0.046, pad=0.04)
        plt.colorbar(img2, ax=axes[index, 1], fraction=0.046, pad=0.04)
        plt.colorbar(img3, ax=axes[index, 2], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig("/scratch/smrserraoseabr/Projects/NO-DA/plots/ns_fourier_3d_N800.0_ep300_m12_w128_predictions.png")
plt.close()

create_animation(dataset.test_data, pred, 0, device, model, "output_CO2.gif")



In [16]:
import torch
import matplotlib.pyplot as plt
import imageio
from io import BytesIO
from IPython.display import Image as DisplayImage


def create_memory_image(true_pressure, predicted_pressure, time_step):
    buf = BytesIO()
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    ax1.imshow(true_pressure, cmap='jet')
    ax2.imshow(predicted_pressure, cmap='jet')
    ax1.set_title(f'True Pressure at time step {time_step}')
    ax2.set_title(f'Predicted Pressure at time step {time_step}')
    for ax in (ax1, ax2):
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    return buf


def create_animation(data, predictions, index, device, model, output_path):
    image_buffers = []

    x_single = data[index][0].unsqueeze(0).to(device)
    y_single = data[index][1].unsqueeze(0).to(device)

    with torch.no_grad():
        x_single = x_single.to(device)
        y_single = y_single.to(device)
        out_single = model(x_single)
        out_single = y_normalizer.decode(out_single)
        y_single = y_normalizer.decode(y_single)

        true_data = y_single.view(1, 61, 32, 32, 1).cpu().numpy()
        predicted_data = out_single.view(1, 61, 32, 32, 1).cpu().numpy()

        for t in range(61):
            buf = create_memory_image(
                true_data[0, t, :, :, 0],
                predicted_data[0, t, :, :, 0],
                t
            )
            image_buffers.append(buf)

    images = [imageio.imread(buf.getvalue()) for buf in image_buffers]
    buf_result = BytesIO()
    imageio.mimsave(buf_result, images, format='GIF', duration=0.5)

    with open(output_path, 'wb') as f:
        f.write(buf_result.getvalue())

    buf_result.seek(0)
    for buf in image_buffers:
        buf.close()

    return buf_result


def main():
    folder = "/scratch/smrserraoseabr/Projects/FluvialCO2/results32/"
    input_vars = ['Por', 'Perm', 'gas_rate']
    output_vars = ['Pressure']
    device = 'cpu'
    num_files = 1000
    traintest_split = 0.8
    dataset = ReadXarrayDataset(
        folder=folder,
        input_vars=input_vars,
        output_vars=output_vars,
        num_files=num_files,
        traintest_split=traintest_split,
    )

    train_a = dataset.train_data_input.to(device)
    train_u = dataset.train_data_output.to(device)
    test_a = dataset.test_data_input.to(device)
    test_u = dataset.test_data_output.to(device)

    a_normalizer = UnitGaussianNormalizer(train_a)
    train_a = a_normalizer.encode(train_a)
    test_a = a_normalizer.encode(test_a)

    y_normalizer = UnitGaussianNormalizer(train_u)
    train_u = y_normalizer.encode(train_u)
    test_u = y_normalizer.encode(test_u)

    modes = 12
    width = 128
    path_model = '/scratch/smrserraoseabr/Projects/NO-DA/runs/ns_fourier_3d_N800.0_ep300_m12_w128_INPUT_Por_Perm_gas_rate_OUTPUT_Pressure/model/ns_fourier_3d_N800.0_ep300_m12_w128_pressure'

    model = torch.load(path_model)
    model.eval()

    pred = torch.zeros(test_u.shape)
    index = 0
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(test_a, test_u),
        batch_size=1,
        shuffle=False
    )

    with torch.no_grad():
        for i, (x, y) in enumerate(test_loader):
            if i == index:
                x_single, y_single = x, y
                break

        x_single = x_single.to(device)
        y_single = y_single.to(device)
        out_single = model(x_single)
        out_single = y_normalizer.decode(out_single)
        y_single = y_normalizer.decode(y_single)

        true_data = y_single.view(1, 61, 32, 32, 1).cpu().numpy()
        predicted_data = out_single.view(1, 61, 32, 32, 1).cpu().numpy()

        for t in range(61):
            pred[0, t, :, :, 0] = predicted_data[0, t, :, :, 0]

        fig, axes = plt.subplots(nrows=num_samples, ncols=3, figsize=(14, 6 * num_samples))
        vmin, vmax = 200.0, 550.0
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

        for index, (true, predicted) in enumerate(zip(true_data, predicted_data)):
            if index >= num_samples:
                break

            img1 = axes[index, 0].imshow(true[-1, :, :, 0], cmap='jet', norm=norm)
            img2 = axes[index, 1].imshow(predicted[-1, :, :, 0], cmap='jet', norm=norm)
            img3 = axes[index, 2].imshow(np.abs(true[-1, :, :, 0] - predicted[-1, :, :, 0]), cmap='jet')

            axes[index, 0].set_title("True Pressure - Last Time Step")
            axes[index, 1].set_title("Predicted Pressure - Last Time Step")
            axes[index, 2].set_title("Absolute Error - Last Time Step")

            axes[index, 0].axis('off')
            axes[index, 1].axis('off')
            axes[index, 2].axis('off')

            plt.colorbar(img1, ax=axes[index, 0], fraction=0.046, pad=0.04)
            plt.colorbar(img2, ax=axes[index, 1], fraction=0.046, pad=0.04)
            plt.colorbar(img3, ax=axes[index, 2], fraction=0.046, pad=0.04)

        plt.tight_layout()
        plt.savefig("/scratch/smrserraoseabr/Projects/NO-DA/plots/ns_fourier_3d_N800.0_ep300_m12_w128_predictions.png")
        plt.close()

        create_animation(dataset.test_data, pred, 0, device, model, "/path/to/output.gif")


if __name__ == "__main__":
    main()




In [23]:
# CO2
device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
folder = "/scratch/smrserraoseabr/Projects/FluvialCO2/results32/"
num_files= 1000
traintest_split = 0.8

input_vars_CO2 = ['Por', 'Perm', 'gas_rate']
output_vars_CO2 = ['CO_2']

path_model_CO2 = '/scratch/smrserraoseabr/Projects/NO-DA/runs/ns_fourier_3d_N800.0_ep500_m12_w128_INPUT_Por_Perm_gas_rate_OUTPUT_CO_2/model/ns_fourier_3d_N800.0_ep300_m12_w128'

def load_model(path_model):
    model = torch.load(path_model)
    model.eval()
    return model

model_CO2 = load_model(path_model_CO2)

 #change the names to reflect CO2
reshaped_decoded_train_u_CO2, reshaped_decoded_train_predicted_CO2, reshaped_decoded_test_u_CO2, reshaped_decoded_test_predicted_CO2 = \
    load_and_normalize_data(input_vars_CO2, output_vars_CO2, model_CO2, device='cpu', folder=folder, num_files = 100, traintest_split = 0.8)





RuntimeError: shape '[1, 61, 32, 32, 1]' is invalid for input of size 4997120

In [None]:
num_samples = 10
path_to_save = '/scratch/smrserraoseabr/Projects/NO-DA/runs/ns_fourier_3d_N800.0_ep500_m12_w128_INPUT_Por_Perm_gas_rate_OUTPUT_CO_2/'

evaluate_and_plot_results(reshaped_decoded_test_u_CO2, reshaped_decoded_test_predicted_CO2, variable_name = 'CO_2', num_samples, path_to_save)

In [14]:
train_a_CO2.shape

torch.Size([800, 61, 32, 32, 6])

In [None]:
# Perform evaluation and plot results
evaluate_and_plot_results(model_CO2, test_a_CO2, test_u_CO2, y_normalizer_CO2, num_samples, path_to_save)

# Create a GIF
gif_path = f'/scratch/smrserraoseabr/Projects/NO-DA/runs/ns_fourier_3d_N800.0_ep500_m12_w128_INPUT_Por_Perm_gas_rate_OUTPUT_CO_2/images/ns_fourier_3d_N800.0_ep300_m12_w128_INPUT_Por_Perm_gas_rate_OUTPUT_CO_2_model_{num_samples}.gif'
create_gif(model_CO2, test_a_CO2, test_u_CO2, y_normalizer_CO2, num_samples, gif_path, device)
create_gif(model_CO2, test_loader,  y_normalizer, path_to_save, device, index)


In [None]:
# Pressure
input_vars_pressure = ['Por', 'Perm', 'gas_rate']
output_vars_pressure = ['pressure']
path_model_pressure = '/scratch/smrserraoseabr/Projects/NO-DA/runs/ns_fourier_3d_N800.0_ep300_m12_w128_INPUT_Por_Perm_gas_rate_OUTPUT_Pressure/model/ns_fourier_3d_N800.0_ep300_m12_w128_pressure'

train_a_pressure, train_u_pressure, test_a_pressure, test_u_pressure, a_normalizer_pressure, y_normalizer_pressure = load_and_normalize_data(input_vars_pressure, output_vars_pressure)
model_pressure = load_model(path_model_pressure)
evaluate_and_plot_results(model_pressure, test_a_pressure, test_u_pressure, y_normalizer_pressure, 'pressure')