# Heterogeneity Image Translation Transformer

### Multiscale Residual Spatiotemporal Vision Transformer (MRSt-ViT | PixFormer)
***

In [1]:
from main import *

hete = Heterogeneity()
hete.__dict__

{'device': device(type='cuda'),
 'verbose': True,
 'folder': 'Fdataset',
 'lr': 0.001,
 'weight_decay': 1e-05,
 'mse_weight': 1.0,
 'ssim_weight': 1.0,
 'train_perc': 0.75,
 'valid_perc': 0.1,
 'batch_size': 32,
 'num_epochs': 100}

In [2]:
hete.check_torch_gpu()


------------------------------------------------------------
----------------------- VERSION INFO -----------------------
Torch version: 2.1.0+cu121
Torch build with CUDA? True
# Device(s) available: 1, Name(s): NVIDIA GeForce RTX 3080
------------------------------------------------------------



In [3]:
train_dataloader, valid_dataloader, test_dataloader = hete.make_dataloaders()

In [4]:
model, train_loss, valid_loss, train_ssim, valid_ssim = hete.trainer()

------------------------------------------------------------
Total number of trainable parameters: 185,746,560


***
# END

In [2]:
import os, time, math
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
from torchvision.ops import SqueezeExcitation
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM

In [3]:
class CustomLoss(nn.Module):
    '''
    Define custom loss function: L = a*MSE + b*(1-SSIM)
    '''
    def __init__(self, mse_weight=1.0, ssim_weight=1.0):
        super(CustomLoss, self).__init__()
        self.device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.mse_weight  = mse_weight                                            # weights for MSE
        self.ssim_weight = ssim_weight                                           # weights for SSIM
        self.mse_loss    = nn.MSELoss()                                          # Mean Squared Error
        self.ssim        = SSIM().to(self.device)                                # Structural Similarity Index Measure
    def forward(self, pred, target):
        mse_loss   = self.mse_loss(pred, target)                                 # mse loss
        ssim_loss  = 1.0 - self.ssim(pred, target)                               # ssim loss
        total_loss = self.mse_weight * mse_loss + self.ssim_weight * ssim_loss   # combined loss
        return total_loss

In [4]:
class PatchEmbedding(nn.Module):
    '''
    Patchify the input image into patches for vision transformer
    '''
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size                           # get original image size
        self.patch_size = patch_size                           # get user-defined patch size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        patches = self.projection(x)                           # convolve image to patch
        patches = rearrange(patches, 'b c h w -> b (h w) c')   # rearrange patches
        return patches
    
class PositionalEncoding(nn.Module):
    '''
    Get the positional codes for each patch of the input image
    '''
    def __init__(self, embed_dim, max_seq_len=512):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()   # get position indices
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        pos_enc  = torch.zeros((1, max_seq_len, embed_dim))            # instantiate empty tensor
        pos_enc[0, :, 0::2] = torch.sin(position * div_term)           # compute positional encoding sine
        pos_enc[0, :, 1::2] = torch.cos(position * div_term)           # compute positional encoding cosine
        self.register_buffer('pos_enc', pos_enc)                       # register buffer for positional encoding
    def forward(self, x):
        return x + self.pos_enc[:, :x.size(1)].detach()                # add positional encoding to patches

In [5]:
class MultiHeadAttention(nn.Module):
    '''
    QKV MultiHead Attention mechanism
    '''
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim  = embed_dim // num_heads                                                         # required: embed_dim/num_heads
        self.query  = nn.Linear(embed_dim, embed_dim)                                                   # querys
        self.key    = nn.Linear(embed_dim, embed_dim)                                                   # keys
        self.value  = nn.Linear(embed_dim, embed_dim)                                                   # values
        self.fc_out = nn.Linear(embed_dim, embed_dim)                                                   # outputs
    def forward(self, query, key, value):
        batch_size = query.shape[0]
        Q = self.query(query).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)   # calculate query
        K = self.key(key).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)       # calculate keys
        V = self.value(value).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)   # calculate values
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)) # get scores
        attention_weights = F.softmax(scores, dim=-1)                                                   # get attention weights
        out = torch.matmul(attention_weights, V)                                                        # calculate output 
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.embed_dim)                 # rearrange output
        out = self.fc_out(out)                                                                          # compute output
        return out

In [6]:
class MLPBlock(nn.Module):
    '''
    Multi-Layer Perceptron block for vision transformer
    '''
    def __init__(self, embed_dim, mlp_hidden_dim):
        super(MLPBlock, self).__init__()
        self.fc1 = nn.Linear(embed_dim, mlp_hidden_dim)
        self.fc2 = nn.Linear(mlp_hidden_dim, embed_dim)
    def forward(self, x):
        x = F.gelu(self.fc1(x))   # activate outputs
        x = self.fc2(x)           # compute outputs
        return x

In [7]:
class TransformerEncoderBlock(nn.Module):
    '''
    Single ViT block with attention
    '''
    def __init__(self, embed_dim, num_heads, mlp_hidden_dim=1024):
        super(TransformerEncoderBlock, self).__init__()
        self.self_attention = MultiHeadAttention(embed_dim, num_heads)   # Attention mechanism
        self.mlp_block = MLPBlock(embed_dim, mlp_hidden_dim)             # MLP block
        self.norm1 = nn.LayerNorm(embed_dim)                             # normalization 1
        self.norm2 = nn.LayerNorm(embed_dim)                             # normalization 2
    def forward(self, x):
        attention_output = self.self_attention(x, x, x)                  # attention values
        x = x + attention_output                                         # update attention
        x = self.norm1(x)                                                # normalize attention
        mlp_output = self.mlp_block(x)                                   # apply MLP block
        x = x + mlp_output                                               # update MLP outputs
        x = self.norm2(x)                                                # normalize outputs
        return x

In [8]:
class ViTencoder(nn.Module):
    '''
    Single ViT block with patch embedding and positional encoding
    '''
    def __init__(self, image_size=256, latent_size=32, in_channels=3, patch_size=16, projection_dim=256, embed_dim=1024, num_heads=16, num_layers=8):
        super(ViTencoder, self).__init__()
        self.patch_embedding     = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)   # patch embedding
        self.positional_encoding = PositionalEncoding(embed_dim)                                    # positional encoding
        self.transformer_blocks  = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads) for _ in range(num_layers)])              # N Transfromer blocks
        self.global_avg_pooling  = nn.AdaptiveAvgPool1d(1)                                          # global average pooling
        self.fc = nn.Linear(embed_dim, projection_dim*latent_size*latent_size)                      # fully connected layer
    def forward(self, x):
        x = self.patch_embedding(x)                                                                 # patch embedding
        x = self.positional_encoding(x)                                                             # positional encoding
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x)                                                                # transformer block(s)
        x = self.global_avg_pooling(x.transpose(1, 2))                                              # global average pooling
        x = x.squeeze(2)                                                                            # squeeze output
        x = self.fc(x)                                                                              # activate outputs
        return x

In [9]:
class MultiScaleResidual(nn.Module):
    '''
    Multiscale Residual concatenation block
    '''
    def __init__(self, image_size=256):
        super(MultiScaleResidual, self).__init__()
        self.image_size = image_size                                     # original image size
    def forward(self, x):
        _, _, h, w = x.shape                                             # get image dimensions
        scale = [self.image_size//h, self.image_size//w]                 # get scale factors
        size  = (self.image_size//scale[0], self.image_size//scale[1])   # get upscaled image size
        x_ups = transforms.Resize(size, antialias=True)(x)               # resize original image to upscaled size
        return torch.cat([x, x_ups], dim=1)                              # concatenate image and upscale image

In [10]:
class PixFormer(nn.Module):
    '''
    PixFormer model: 
    (1) Vision Transformer encoder
    (2) Multiscale Residual Spatiotemporal decoder
    '''
    def __init__(self, projection_dim=128, latent_size=32):
        super(PixFormer, self).__init__()
        self.projection_dim = projection_dim
        self.latent_size    = latent_size
        self.encoder = ViTencoder(latent_size=latent_size, projection_dim=projection_dim)   # Encoder block
        self.layers = nn.Sequential(
            self._conv_block(projection_dim, projection_dim//2),                 # first decoder layer
            self._conv_block(projection_dim//2, projection_dim//4),              # second decoder layer
            self._conv_block(projection_dim//4, projection_dim//8))              # third decoder layer
        self.out = nn.Conv2d(projection_dim//8, 1, kernel_size=3, padding=1)     # output layer

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),      # convolve inputs
            SqueezeExcitation(out_channels, out_channels//4),                    # Squeeze-and-Excite for multichannel feature maps
            nn.InstanceNorm2d(out_channels),                                     # Normalize by instance
            nn.PReLU(),                                                          # Parametric ReLU activation
            MultiScaleResidual(),                                                # Multi-scale residual concatenation
            nn.Upsample(scale_factor=2),                                         # Upsample by 2x
            nn.Conv2d(out_channels*2, out_channels, kernel_size=3, padding=1))   # convolve concatenated features
    def forward(self, x):
        x = self.encoder(x)                                                       # encode inputs: z = Enc(x)
        x = x.view(-1, self.projection_dim, self.latent_size, self.latent_size)   # reshape z: vector -> feature maps
        for layer in self.layers:
            x = layer(x)                                                          # decode inputs: y = Dec(z)
        x_output = self.out(x)                                                    # output layer
        return x_output

In [11]:
class MyDataset(Dataset):
    '''
    Generate a custom dataset from .npz files
    (x) porosity, permeability, timesteps
    (y) pressure, saturation
    '''
    def __init__(self, file_paths):
        self.file_paths = file_paths
    def __len__(self):
        return len(self.file_paths)
    def __getitem__(self, idx):
        data = np.load(self.file_paths[idx])                                              # load .npz file in file_paths
        poro = np.tile(data['poro'], (60, 1, 1, 1))                                       # reshape porosity channel
        perm = np.tile(data['perm'], (60, 1, 1, 1))                                       # reshape permeability channel
        timesteps = np.tile(np.arange(1, 61).reshape(60, 1, 1, 1), (1, 1, 256, 256))      # construct time channel
        pres = data['pres'].reshape(60, 1, 256, 256)                                      # reshape pressure channel
        sat = data['sat'].reshape(60, 1, 256, 256)                                        # reshape saturation channel
        X_data = np.concatenate([poro, perm, timesteps], axis=1).reshape(-1, 3, 256, 256) # Inputs  (X)
        y_data = np.concatenate([pres, sat], axis=1).reshape(-1, 2, 256, 256)             # Outputs (y)
        return torch.Tensor(X_data), torch.Tensor(y_data)                                 # Tensorize data

In [12]:
class MyDataLoader(DataLoader):
    '''
    Generate a custom dataloader for dataset
    (train): x,y at timesteps 0-40
    (valid): x,y at timesteps 40-50
    (test):  x,y at timesteps 50-60
    '''
    def __init__(self, *args, mode:str=None, **kwargs):
        super(MyDataLoader, self).__init__(*args, **kwargs)
        self.mode = mode
    def __iter__(self):
        for batch in super(MyDataLoader, self).__iter__():
            X_data, y_data = batch          # loads a batch of data with shate (b, t, c, h, w)
            if self.mode == 'train':        # _____TRAINING_____
                X_data = X_data[:, :40]     # x at timesteps 0-40
                y_data = y_data[:, :40]     # y at timesteps 0-40
            elif self.mode == 'valid':      # _____VALIDATION_____
                X_data = X_data[:, 40:50]   # x at timesteps 40-50
                y_data = y_data[:, 40:50]   # y at timesteps 40-50
            elif self.mode == 'test':       # ______TESTING______
                X_data = X_data[:, 50:]     # x at timesteps 50-60
                y_data = y_data[:, 50:]     # y at timesteps 50-60
            else:
                raise ValueError('Invalid mode: {} | select between "train", "valid" or "test"'.format(self.mode))
            X_data = X_data.reshape(-1, X_data.size(2), X_data.size(3), X_data.size(4)) # reshape to (b*t, c, h, w)
            y_data = y_data.reshape(-1, y_data.size(2), y_data.size(3), y_data.size(4)) # reshape to (b*t, c, h, w)
            yield X_data, y_data