In [9]:
import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import check_torch
device = check_torch()

------------------------------------------------------------
----------------------- VERSION INFO -----------------------
Torch version: 2.3.1+cu121 | Torch Built with CUDA? True
# Device(s) available: 1, Name(s): NVIDIA GeForce RTX 3080
------------------------------------------------------------


In [10]:
sec2year   = 365.25 * 24 * 60 * 60
psi2pascal = 6894.76
co2_rho    = 686.5266
mega       = 1e6

n_timesteps = 33
nx, ny, nz, nz_short  = 100, 100, 11, 5

indexMap = loadmat('data_100_100_11/G_cells_indexMap.mat', simplify_cells=True)['gci']
Grid = np.zeros((nx,ny,nz)).flatten(order='F')
Grid[indexMap] = 1
Grid = Grid.reshape(nx,ny,nz, order='F')
Tops = np.load('data_npy_100_100_11/tops_grid.npz')['tops']
print('Grid: {} | Tops: {}'.format(Grid.shape, Tops.shape))

Grid_short = Grid[:,:,5:10]
Grid_ext = np.repeat(np.expand_dims(Grid, 0), 33, axis=0)
Grid_short_ext = np.repeat(np.expand_dims(Grid_short, 0), 33, axis=0)
print('Grid_ext: {} | Grid_short_ext: {}'.format(Grid_ext.shape, Grid_short_ext.shape))

Grid: (100, 100, 11) | Tops: (100, 100, 11)
Grid_ext: (33, 100, 100, 11) | Grid_short_ext: (33, 100, 100, 5)


In [147]:
train_idx = np.random.choice(range(1272), size=5, replace=False)
test_idx  = np.setdiff1d(range(1272), train_idx)

xm = np.zeros((len(train_idx), 3, 100,100,5))
xw = np.zeros((len(train_idx), 2, 5))
xc = np.zeros((len(train_idx), n_timesteps, 5))
xt = np.zeros((len(train_idx), n_timesteps, 1))
yy = np.zeros((len(train_idx), 33, 2, 100,100,5))

def apply_mask(x, imap=indexMap, mask_value=0.0):
    xx = mask_value*np.ones((nx,ny,nz)).flatten(order='F')
    xx[imap] = x.flatten(order='F')[imap]
    xx = xx.reshape((nx,ny,nz), order='F')
    return xx

for i in range(len(train_idx)):
    m = np.load('data_npy_100_100_11/inputs_rock_rates_locs_time/x_{}.npz'.format(train_idx[i]))
    p = np.expand_dims(apply_mask(m['poro']), 0)[...,5:10] / 0.3
    k = np.expand_dims(apply_mask(m['perm']), 0)[...,5:10] / 3.3
    t = np.expand_dims(apply_mask(Tops), 0)[...,5:10]      / Tops.max()
    xm[i] = np.concatenate([t, p, k], 0)

    xw[i] = m['locs']
    xc[i] = m['ctrl']
    xt[i] = m['time']

    dd = np.load('data_npy_100_100_11/outputs_pressure_saturation/y_{}.npz'.format(train_idx[i]))
    prm = dd['pressure'][...,5:10]
    sam = dd['saturation'][...,5:10]
    yy[i,:,0] = np.expand_dims(prm, 0)
    yy[i,:,1] = np.expand_dims(sam, 0)

inj_locs  = np.zeros((len(train_idx), 1, 100,100,5))
inj_rates = np.zeros((len(train_idx), 1, 100,100,5))
inj_times = np.zeros((len(train_idx), 1, 100,100,5))
for i in range(len(train_idx)):
    inj_locs[i, 0, xw[i][0,:].astype(int), xw[i][1,:].astype(int), :] = 1
    inj_rates[i] = np.expand_dims(np.repeat(np.expand_dims(np.concatenate([np.zeros((1,100)),
                                      np.repeat(np.repeat(xc[1],20,axis=-1),3,axis=0)],
                                      axis=0), -1), 5, axis=-1),0)
    inj_times[i] = np.repeat(np.expand_dims(np.expand_dims(np.concatenate([np.zeros((1,100)),
                                      np.repeat(np.repeat(xt[0],3,axis=0),100,axis=1)],axis=0),0),-1), 5, axis=-1)

xx = np.concatenate([xm, inj_locs, inj_rates, inj_times], 1)

print('xx', xx.shape)
print('yy', yy.shape)

xx (5, 6, 100, 100, 5)
yy (5, 33, 2, 100, 100, 5)


In [175]:
class SeparableConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(SeparableConv3d, self).__init__()

        self.depthwise = nn.Conv3d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=bias
        )

        self.pointwise = nn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=bias
        )

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)

        return x

In [176]:
class SqueezeExcite3d(nn.Module):
    def __init__(self, channels, ratio=4):
        super(SqueezeExcite3d, self).__init__()
        self.ratio = ratio
        self.squeeze = nn.AdaptiveAvgPool3d(1)
        self.excite1 = nn.Linear(channels, channels//ratio)
        self.excite2 = nn.Linear(channels//ratio, channels)

    def forward(self, x):
        b, c, h, w, d = x.size()
        se_tensor = self.squeeze(x).view(b,c)
        se_tensor = F.relu(self.excite1(se_tensor))
        se_tensor = torch.sigmoid(self.excite2(se_tensor)).view(b,c,1,1,1)
        scaled_inputs = x * se_tensor.expand_as(x)
        return x + scaled_inputs

In [177]:
class SpatialEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels:list, return_hidden_states:bool=False,
                 kernel_size=3, stride=1, padding=1, dilation=1, bias=True, 
                 pool_size=(2,2,1), dropout_rate=0.1):
        super(SpatialEncoder, self).__init__()
        assert len(hidden_channels) == 3, 'Hidden channels must be a list of 3 integers'
        c1, c2, c3 = hidden_channels
        self.return_hidden_states = return_hidden_states
        self.conv1 = SeparableConv3d(in_channels, c1, kernel_size, stride, padding, dilation, bias)
        self.sae1  = SqueezeExcite3d(c1)
        self.norm1 = nn.GroupNorm(c1, c1)

        self.conv2 = SeparableConv3d(c1, c2, kernel_size, stride, padding, dilation, bias)
        self.sae2  = SqueezeExcite3d(c2)
        self.norm2 = nn.GroupNorm(c2, c2)

        self.conv3 = SeparableConv3d(c2, c3, kernel_size, stride, padding, dilation, bias)
        self.sae3  = SqueezeExcite3d(c3)
        self.norm3 = nn.GroupNorm(c3, c3)

        self.pool = nn.MaxPool3d(pool_size)
        self.gelu = nn.GELU()
        self.drop = nn.Dropout3d(dropout_rate)

    def forward(self, x):
        x = self.sae1(self.conv1(x))
        x1 = x
        x = self.drop(self.pool(self.gelu(self.norm1(x))))
        x = self.sae2(self.conv2(x))
        x2 = x
        x = self.drop(self.pool(self.gelu(self.norm2(x))))
        x = self.sae3(self.conv3(x))
        x3 = x
        x = self.drop(self.pool(self.gelu(self.norm3(x))))

        if self.return_hidden_states:
            return x, (x1,x2,x3)
        else:
            return x

In [192]:
temp_m = torch.tensor(np.expand_dims(xm[0], 0), dtype=torch.float32)
print('temp_m', temp_m.shape)

spatial_encoder = SpatialEncoder(3, [16,64,256], return_hidden_states=True)

zm, hm = spatial_encoder(temp_m)
print('zm', zm.shape)

temp_m torch.Size([1, 3, 100, 100, 5])
zm torch.Size([1, 256, 12, 12, 5])


In [185]:
print('xw', xw.shape)
print('xc', xc.shape)
print('xt', xt.shape)
print('yy', yy.shape)

xw (5, 2, 5)
xc (5, 33, 5)
xt (5, 33, 1)
yy (5, 33, 2, 100, 100, 5)


In [205]:
class LiftingLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(LiftingLayer, self).__init__()
        self.fc   = nn.Linear(in_features, out_features)
        self.norm = nn.LayerNorm(out_features)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.fc(x)
        x = self.gelu(self.norm(x))
        return x

In [214]:
class DenseBlock(nn.Module):
    def __init__(self, in_features, out_features, num_layers):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList([LiftingLayer(in_features, out_features) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [233]:
temp_w = torch.tensor(np.expand_dims(xw[0].flatten(), 0), dtype=torch.float32)
temp_c = torch.tensor(np.expand_dims(xc[0].flatten(), 0), dtype=torch.float32)
temp_t = torch.tensor(np.expand_dims(xt[0].flatten(), 0), dtype=torch.float32)

lift_w = LiftingLayer(2*5, 1024)
lift_c = LiftingLayer(33*5, 1024)
lift_t = LiftingLayer(33, 1024)

branch_w = DenseBlock(1024, 1024, 5)
branch_c = DenseBlock(1024, 1024, 5)
trunk_t  = DenseBlock(1024, 1024, 5)

print('temp_w', temp_w.shape)
zw = lift_w(temp_w)
zw = branch_w(zw)
print('zw', zw.shape)

print('temp_c', temp_c.shape)
zc = lift_c(temp_c)
zc = branch_c(zc)
print('zc', zc.shape)

print('temp_t', temp_t.shape)
zt = lift_t(temp_t)
zt = trunk_t(zt)
print('zt', zt.shape)

temp_w torch.Size([1, 10])
zw torch.Size([1, 1024])
temp_c torch.Size([1, 165])
zc torch.Size([1, 1024])
temp_t torch.Size([1, 33])
zt torch.Size([1, 1024])


In [235]:
class ConvLSTM3DCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTM3DCell, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        self.conv = nn.Conv3d(in_channels=self.input_channels + self.hidden_channels,
                              out_channels=4 * self.hidden_channels,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1)
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

In [238]:
class ConvLSTM3DCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTM3DCell, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        self.conv = nn.Conv3d(in_channels=self.input_channels + self.hidden_channels,
                              out_channels=4 * self.hidden_channels,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1)
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        depth, height, width = image_size
        return (torch.zeros(batch_size, self.hidden_channels, depth, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_channels, depth, height, width, device=self.conv.weight.device))

In [308]:
class ConvLSTM3DCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTM3DCell, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        self.conv = nn.Conv3d(in_channels=self.input_channels + self.hidden_channels,
                              out_channels=4 * self.hidden_channels,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1)
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        depth, height, width = image_size
        return (torch.zeros(batch_size, self.hidden_channels, depth, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_channels, depth, height, width, device=self.conv.weight.device))


class ConvLSTM3D(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, num_layers, batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM3D, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_channels = self.input_channels if i == 0 else self.hidden_channels[i - 1]
            cell_list.append(ConvLSTM3DCell(input_channels=cur_input_channels,
                                            hidden_channels=self.hidden_channels[i],
                                            kernel_size=self.kernel_size,
                                            bias=self.bias))
        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, x, hidden_state=None):
        if not self.batch_first:
            x = x.permute(1, 0, 2, 3, 4, 5)

        b, _, _, d, h, w = x.size()
        if hidden_state is None:
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(d, h, w))

        layer_output_list = []
        last_state_list = []

        seq_len = x.size(1)
        cur_layer_input = x

        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](x=cur_layer_input[:, t, :, :, :, :],
                                                 h=h, c=c)
                output_inner.append(h)
            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        if self.num_layers == 1:
            layer_output_list = layer_output_list[0]
            last_state_list = last_state_list[0]
            
        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states


In [349]:
model = ConvLSTM3D(input_channels=256, hidden_channels=[64, 16, 4], kernel_size=3, num_layers=3, 
                   batch_first=True, bias=True, return_all_layers=True)

input_tensor = torch.rand(1, 33, 256, 12, 12, 5)

zd, hd = model(input_tensor)