In [None]:
from utils import *
from HeteTrans import *

hete = Heterogeneity()
hete.check_torch_gpu()

***
# END

In [None]:
import os, sys, glob, math, re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyvista as pv
from time import time

from cv2 import resize
from scipy.stats import zscore
from scipy.io import loadmat, savemat
from numpy.matlib import repmat
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary
from torchviz import make_dot
import torchio as tio

In [None]:
class NormalizeTransform:
    def __call__(self, sample):
        x, y = sample
        x_normalized = self.normalize_data(x)
        y_normalized = self.normalize_data(y)
        return x_normalized, y_normalized

    def normalize_data(self, data):
        # Assuming data is a PyTorch tensor
        scaler = MinMaxScaler()
        data_np = data.numpy()
        data_normalized_np = scaler.fit_transform(data_np.reshape(-1, data_np.shape[-1])).reshape(data_np.shape)
        return torch.Tensor(data_normalized_np)

In [None]:
class MyDataset(Dataset):
    def __init__(self, data_folder, transform=None):
        self.y_folder = data_folder
        self.x_folder = os.path.join(data_folder, 'X_data')
        self.y_list = [file for file in os.listdir(self.y_folder) if file.endswith('.mat')]
        self.x_list = os.listdir(self.x_folder)
        self.transform = transform

    def __len__(self):
        return len(self.x_list)

    def __getitem__(self, idx):
        x_name = self.x_list[idx]
        y_name = self.y_list[idx]

        x_path = os.path.join(self.x_folder, x_name)
        y_path = os.path.join(self.y_folder, y_name)

        x_data = np.load(x_path)
        poro = np.expand_dims(x_data[0], 0)
        perm = np.expand_dims(np.log10(x_data[1]), 0)
        t = ((torch.ones((256, 256, 61)) * torch.arange(61)).T).unsqueeze(0)
        x = torch.Tensor(np.concatenate([poro, perm], 0)).unsqueeze(1).repeat(1, 61, 1, 1)
        x = torch.cat([x, t], dim=0)

        y = torch.zeros((2, 61, 256, 256))
        y_data = loadmat(y_path, simplify_cells=True)
        for timestep in range(61):
            y[0, timestep] = torch.Tensor(y_data['PRESSURE'])
            y[1, timestep] = torch.Tensor(y_data['SGAS'] * y_data['YMF_3'])

        sample = (x, y)

        if self.transform:
            sample = self.transform(sample)

        return sample

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.batchnorm  = nn.BatchNorm2d(embed_dim)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.projection(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        return x

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        return self.attention(x, x, x)[0]

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim, num_heads, num_layers):
        super(VisionTransformer, self).__init__()
        self.patch_embedding    = PatchEmbedding(in_channels, patch_size, embed_dim)
        self.transformer_layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) for _ in range(num_layers) ])
        self.fc                 = nn.Linear(embed_dim, in_channels)
        self.batchnorm          = nn.BatchNorm1d(in_channels)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x.flatten(2).permute(2, 0, 1)
        for layer in self.transformer_layers:
            x = layer(x)
        x = x.permute(1, 2, 0).view(x.size(1), -1, x.size(0))
        x = self.fc(x)
        x = self.batchnorm(x)
        x = F.gelu(x)
        return x

In [None]:
transform  = NormalizeTransform()
dataset    = MyDataset(data_folder='h2dataf', transform=transform)
dataloader = DataLoader(dataset, batch_size=50, shuffle=True)

model      = VisionTransformer(in_channels=3, patch_size=16, embed_dim=256, num_heads=8, num_layers=4)
criterion  = nn.MSELoss()
optimizer  = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs   = 10
train_tsteps = 40 
for epoch in range(num_epochs):
    for i, (xbatch, ybatch) in enumerate(dataloader):
        # Flatten xbatch and ybatch for the Vision Transformer
        xbatch = xbatch[:, :, :train_tsteps, :, :]
        ybatch = ybatch[:, :, :train_tsteps, :, :]

        xbatch = xbatch.reshape(-1, 3, 256, 256)
        ybatch = ybatch.reshape(-1, 2, train_tsteps, 256, 256)

        outputs = model(xbatch)
        outputs = outputs.reshape(xbatch.size(0), 2, 40, 256, 256)

        loss = criterion(outputs, ybatch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print('Epoch: [{}/{}] | Batch: [{}/{}] | Loss: {}'.format(epoch+1, num_epochs, i+1, len(dataloader), loss.item()))

In [None]:
k = 0
plt.figure(figsize=(20,6))
for i in range(3):
    for j in range(12):
        plt.subplot(3, 12, k+1)
        im = plt.imshow(xbatch[55, i, j*5], cmap='jet')
        plt.colorbar(im, pad=0.04, fraction=0.046)
        plt.xticks([]); plt.yticks([])
        k += 1
plt.tight_layout()
plt.show()