In [1]:
import os
os.chdir('/work/08649/mmm6558/ls6/Placement-Control-Optim-CO2')

In [2]:
!module purge

In [3]:
!echo $BASH

/bin/bash


In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.io import loadmat

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split

In [5]:
sec2year   = 365.25 * 24 * 60 * 60
psi2pascal = 6894.76
co2_rho    = 686.5266
mega       = 1e6

n_timesteps = 33
nx, ny, nz  = 100, 100, 11

indexMap = loadmat('data_100_100_11/G_cells_indexMap.mat', simplify_cells=True)['gci']
Grid = np.zeros((nx,ny,nz)).flatten(order='F')
Grid[indexMap] = 1
Grid = Grid.reshape(nx,ny,nz, order='F')
Tops = np.load('data_npy_100_100_11/tops_grid.npz')['tops']

In [6]:
def check_torch(verbose:bool=True):
    if torch.cuda.is_available():
        torch_version, cuda_avail = torch.__version__, torch.cuda.is_available()
        count, name = torch.cuda.device_count(), torch.cuda.get_device_name()
        if verbose:
            print('-'*60)
            print('----------------------- VERSION INFO -----------------------')
            print('Torch version: {} | Torch Built with CUDA? {}'.format(torch_version, cuda_avail))
            print('# Device(s) available: {}, Name(s): {}'.format(count, name))
            print('-'*60)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        return device
    else:
        torch_version, cuda_avail = torch.__version__, torch.cuda.is_available()
        if verbose:
            print('-'*60)
            print('----------------------- VERSION INFO -----------------------')
            print('Torch version: {} | Torch Built with CUDA? {}'.format(torch_version, cuda_avail))
            print('-'*60)
        device = torch.device('cpu')
        return device
device = check_torch()

------------------------------------------------------------
----------------------- VERSION INFO -----------------------
Torch version: 2.1.1+cu121 | Torch Built with CUDA? True
# Device(s) available: 3, Name(s): NVIDIA A100-PCIE-40GB
------------------------------------------------------------


In [7]:
class MiONet(nn.Module):
    def __init__(self, hidden_channels=16, output_channels=32):
        super(MiONet, self).__init__()
        self.hidden = hidden_channels
        self.output = output_channels

        self.conv1 = nn.Conv3d(2, self.hidden, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(self.hidden, self.output, kernel_size=3, padding=1)
        self.norm1 = nn.BatchNorm3d(self.hidden)
        self.norm2 = nn.BatchNorm3d(self.output)
        self.pool  = nn.MaxPool3d(2)
        self.gelu  = nn.GELU()

        self.linW1 = nn.Linear(5, self.hidden)
        self.linW2 = nn.Linear(self.hidden, self.output)
        self.bnW1  = nn.BatchNorm1d(self.hidden//8)
        self.bnW2  = nn.BatchNorm1d(self.output//16)

        self.lstmC1 = nn.LSTM(5, self.hidden, num_layers=1, batch_first=True)
        self.lstmC2 = nn.LSTM(self.hidden, self.output, num_layers=1, batch_first=True)

        self.lmstT1 = nn.LSTM(1, self.hidden, num_layers=1, batch_first=True)
        self.lmstT2 = nn.LSTM(self.hidden, self.output, num_layers=1, batch_first=True)

        self.linY1 = nn.Linear(1250, 10000)
        self.linY2 = nn.Linear(10000, 29128)

    def forward(self, x):
        xm, xg, xw, xc, xt = x

        zm = self.gelu(self.pool(self.norm1(self.conv1(xm))))
        zm = self.gelu(self.pool(self.norm2(self.conv2(zm))))
        zm = zm.view(zm.shape[0], self.output, -1)

        zg = self.gelu(self.pool(self.norm1(self.conv1(xg))))
        zg = self.gelu(self.pool(self.norm2(self.conv2(zg))))
        zg = zg.view(zg.shape[0], self.output, -1)

        zw = self.gelu(self.bnW1(self.linW1(xw)))
        zw = self.gelu(self.bnW2(self.linW2(zw)))

        zc, _ = self.lstmC1(xc)
        zc, _ = self.lstmC2(zc)

        zt, _ = self.lmstT1(xt)
        zt, _ = self.lmstT2(zt)

        mg = torch.einsum('bcp,bcp->bcp', zm, zg)
        wc = torch.einsum('blc,btc->btlc', zw, zc)
        branch = torch.einsum('bcp,btlc->btpl', mg, wc)
        merge  = torch.einsum('btpl,btc->btlp', branch, zt)

        yy = self.gelu(self.linY1(merge))
        yy = self.linY2(yy)

        return yy

In [8]:
class CustomDataset(Dataset):
    def __init__(self, data_folder:str='data_npy_100_100_11'):
        self.data_folder = data_folder
        
        self.x_folder = os.path.join(data_folder, 'inputs_rock_rates_locs_time')
        self.y_folder = os.path.join(data_folder, 'outputs_masked_pressure_saturation')

        self.x_file_list = os.listdir(self.x_folder)
        self.y_file_list = os.listdir(self.y_folder)

    def __len__(self):
        return len(self.x_file_list)
    
    def __getitem__(self, idx):
        x  = np.load(os.path.join(self.x_folder, self.x_file_list[idx]))
        y  = np.load(os.path.join(self.y_folder, self.y_file_list[idx]))

        xg = np.concatenate([np.expand_dims(Tops, 0), 
                             np.expand_dims(Grid, 0)], 
                             axis=0)

        xm = np.concatenate([np.expand_dims(x['poro'],0), 
                             np.expand_dims(x['perm'],0)], 
                             axis=0)
        
        xw = x['locs']
        xc = np.concatenate([np.zeros((1,xw.shape[-1])), x['ctrl']], axis=0)
        xt = np.expand_dims(np.insert(x['time'], 0, 0), -1)
        yp = y['pressure']
        ys = y['saturation']
        yy = np.concatenate([np.expand_dims(yp,1), np.expand_dims(ys,1)], axis=1)

        xm = torch.tensor(xm, dtype=torch.float32, device=device)
        xg = torch.tensor(xg, dtype=torch.float32, device=device)
        xw = torch.tensor(xw, dtype=torch.float32, device=device)
        xc = torch.tensor(xc, dtype=torch.float32, device=device)
        xt = torch.tensor(xt, dtype=torch.float32, device=device)
        yy = torch.tensor(yy, dtype=torch.float32, device=device)

        return (xm, xg, xw, xc, xt), yy

In [9]:
history = pd.read_csv('MiONet_losses.csv')

plt.figure(figsize=(12,4))
plt.plot(history.index, history['train'], ls='-', label='Train')
plt.plot(history.index, history['valid'], ls='-', label='Valid')
plt.legend()
plt.grid(True, which='both')
plt.tight_layout()
plt.savefig('MiONet_losses.png')
plt.show()

In [10]:
dataset = CustomDataset()
trainset, testset  = random_split(dataset,  [1172, 100])
trainset, validset = random_split(trainset, [972,  200])

trainloader = DataLoader(trainset, batch_size=16, shuffle=True)
validloader = DataLoader(validset, batch_size=16, shuffle=False)
testloader  = DataLoader(testset, batch_size=16, shuffle=False)

In [14]:
!module purge

In [15]:
!echo $BASH

/bin/bash


In [16]:
model = MiONet().to(device)
model.load_state_dict(torch.load('MiONet.pth'))

<All keys matched successfully>

In [17]:
x, y = trainset[114]
xm, xg, xw, xc, xt = x
xm = xm.cpu().detach().numpy()
xg = xg.cpu().detach().numpy()
xw = xw.cpu().detach().numpy()
xc = xc.cpu().detach().numpy()
xt = xt.cpu().detach().numpy()/sec2year
y = y.cpu().detach().numpy()
print('xm', xm.shape)
print('xg', xg.shape)
print('xw', xw.shape)
print('xc', xc.shape)
print('xt', xt.shape)
print('yy', y.shape)

xm (2, 100, 100, 11)
xg (2, 100, 100, 11)
xw (2, 5)
xc (34, 5)
xt (34, 1)
yy (34, 2, 29128)


In [18]:
fig, axs = plt.subplots(2, 11, figsize=(12,3.5), sharex=True, sharey=True)
for j in  range(11):
    ax1, ax2 = axs[0,j], axs[1,j]
    im1 = ax1.imshow(xm[0,:,:,j], cmap='jet')
    im2 = ax2.imshow(xm[1,:,:,j], cmap='jet')
    plt.colorbar(im1, ax=ax1, pad=0.04, fraction=0.046)
    plt.colorbar(im2, ax=ax2, pad=0.04, fraction=0.046)
plt.tight_layout()
plt.savefig('xm.png')
plt.show()
plt.close()

In [19]:
fig, axs = plt.subplots(2, 11, figsize=(12,3.5), sharex=True, sharey=True)
for j in  range(11):
    ax1, ax2 = axs[0,j], axs[1,j]
    im1 = ax1.imshow(xg[0,:,:,j], cmap='jet')
    im2 = ax2.imshow(xg[1,:,:,j], cmap='jet')
    plt.colorbar(im1, ax=ax1, pad=0.04, fraction=0.046)
    plt.colorbar(im2, ax=ax2, pad=0.04, fraction=0.046)
plt.tight_layout()
plt.savefig('xg.png')
plt.show()
plt.close()

In [20]:
plt.figure(figsize=(8,5))
plt.imshow(xm[0,:,:,5], cmap='jet')
for i in range(5):
    plt.scatter(xw[0,i], xw[1,i], marker='v', color='k', s=100)
plt.tight_layout()
plt.savefig('xw.png')
plt.show()
plt.close()

In [21]:
plt.figure(figsize=(6,8))
plt.imshow(xc, cmap='inferno', vmin=0, vmax=1)
plt.colorbar(pad=0.04, fraction=0.046)
plt.yticks(range(34), labels=range(34))
plt.xticks(range(5), labels=['W%d'%i for i in range(1,6)])
plt.tight_layout()
plt.savefig('xc.png')
plt.show()
plt.close()

In [22]:
colors = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple']
plt.figure(figsize=(10,5))
for i in range(5):
    for t in range(33):
        plt.hlines(xc[t,i], xt[t], xt[t+1], color=colors[i], ls='-')
        plt.vlines(xt[t], xc[t-1,i], xc[t,i], color='k', ls='--', alpha=0.25)
plt.xlim(0,10)
plt.xticks(np.arange(0,10.5,step=0.5))
plt.xlabel('Timestep')
plt.ylabel('Injection Rate')
plt.xlim()
plt.grid(True, which='both')
plt.tight_layout()
plt.savefig('xc2.png')
plt.show()
plt.close()

In [23]:
Grid_ext = np.repeat(np.expand_dims(Grid, 0), 34, axis=0)
print(Grid_ext.shape)

(34, 100, 100, 11)


In [24]:
yp_g = np.zeros((34,100,100,11)).reshape(34,-1,order='F')
ys_g = np.zeros((34,100,100,11)).reshape(34,-1,order='F')
for t in range(34):
    yp_g[t,indexMap] = y[t,0]
    ys_g[t,indexMap] = y[t,1]
yp_g = yp_g.reshape((34,100,100,11), order='F')
ys_g = ys_g.reshape((34,100,100,11), order='F')

yp = np.ma.masked_where(Grid_ext==0, yp_g)
ys = np.ma.masked_where(Grid_ext==0, ys_g)

print(yp.shape, ys.shape)

(34, 100, 100, 11) (34, 100, 100, 11)


In [56]:
fig, axs = plt.subplots(5, 17, figsize=(20,10), sharex=True, sharey=True)
for i in range(5):
    for j in range(17):
        ax = axs[i,j]
        ax.imshow(yp[j*2,:,:,i+5], cmap='jet')
plt.tight_layout()
plt.savefig('yp.png')
plt.show()
plt.close()

In [57]:
fig, axs = plt.subplots(5, 17, figsize=(20,10), sharex=True, sharey=True)
for i in range(5):
    for j in range(17):
        ax = axs[i,j]
        ax.imshow(ys[j*2,:,:,i+5], cmap='jet')
plt.tight_layout()
plt.savefig('ys.png')
plt.show()
plt.close()

In [45]:
xm, xg, xw, xc, xt = x

c, h, w, d = xm.shape
xm = xm.view(1, c, h, w, d)
xg = xg.view(1, c, h, w, d)
xw = xw.view(1, 2, 5)
xc = xc.view(1, 34, 5)
xt = xt.view(1, 34, 1)

In [50]:
y_pred = model([xm,xg,xw,xc,xt])
yp_pred = y_pred[0,:,0].cpu().detach().numpy()
ys_pred = y_pred[0,:,1].cpu().detach().numpy()
print(yp_pred.shape, ys_pred.shape)

(34, 29128) (34, 29128)


In [51]:
yp_pred_g = np.zeros((34,100,100,11)).reshape(34,-1,order='F')
ys_pred_g = np.zeros((34,100,100,11)).reshape(34,-1,order='F')
for t in range(34):
    yp_pred_g[t,indexMap] = yp_pred[t]
    ys_pred_g[t,indexMap] = ys_pred[t]
yp_pred_g = yp_g.reshape((34,100,100,11), order='F')
ys_pred_g = ys_g.reshape((34,100,100,11), order='F')

yp_pred = np.ma.masked_where(Grid_ext==0, yp_pred_g)
ys_pred = np.ma.masked_where(Grid_ext==0, ys_pred_g)

print(yp.shape, ys.shape)

(34, 100, 100, 11) (34, 100, 100, 11)


In [54]:
fig, axs = plt.subplots(5, 17, figsize=(20,10), sharex=True, sharey=True)
for i in range(5):
    for j in range(17):
        ax = axs[i,j]
        ax.imshow(yp_pred[j*2,:,:,i+5], cmap='jet')
plt.tight_layout()
plt.savefig('yp_pred.png')
plt.show()
plt.close()

In [55]:
fig, axs = plt.subplots(5, 17, figsize=(20,10), sharex=True, sharey=True)
for i in range(5):
    for j in range(17):
        ax = axs[i,j]
        ax.imshow(ys_pred[j*2,:,:,i+5], cmap='jet')
plt.tight_layout()
plt.savefig('ys_pred.png')
plt.show()
plt.close()

***
# END