In [None]:
import os, time, math
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

def check_torch_gpu():
    torch_version, cuda_avail = torch.__version__, torch.cuda.is_available()
    count, name = torch.cuda.device_count(), torch.cuda.get_device_name()
    #py_version, conda_env_name = sys.version, sys.executable.split('\\')[-2]
    print('\n'+'-'*60)
    print('----------------------- VERSION INFO -----------------------')
    #print('Conda Environment: {} | Python version: {}'.format(conda_env_name, py_version))
    print('Torch version: {}'.format(torch_version))
    print('Torch build with CUDA? {}'.format(cuda_avail))
    print('# Device(s) available: {}, Name(s): {}'.format(count, name))
    print('-'*60+'\n')
    device = torch.device('cuda' if cuda_avail else 'cpu')
    return device
device = check_torch_gpu()

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        patches = self.projection(x)
        patches = rearrange(patches, 'b c h w -> b (h w) c')
        return patches

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_seq_len=512):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        pos_enc  = torch.zeros((1, max_seq_len, embed_dim))
        pos_enc[0, :, 0::2] = torch.sin(position * div_term)
        pos_enc[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        return x + self.pos_enc[:, :x.size(1)].detach()

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim  = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key   = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        batch_size = query.shape[0]
        Q = self.query(query).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = self.key(key).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = self.value(value).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        attention_weights = F.softmax(scores, dim=-1)
        out = torch.matmul(attention_weights, V)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.embed_dim)
        out = self.fc_out(out)
        return out

In [None]:
class MLPBlock(nn.Module):
    def __init__(self, embed_dim, mlp_hidden_dim):
        super(MLPBlock, self).__init__()
        self.fc1 = nn.Linear(embed_dim, mlp_hidden_dim)
        self.fc2 = nn.Linear(mlp_hidden_dim, embed_dim)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_hidden_dim=1024):
        super(TransformerEncoderBlock, self).__init__()
        self.self_attention = MultiHeadAttention(embed_dim, num_heads)
        self.mlp_block = MLPBlock(embed_dim, mlp_hidden_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attention_output = self.self_attention(x, x, x)
        x = x + attention_output
        x = self.norm1(x)
        mlp_output = self.mlp_block(x)
        x = x + mlp_output
        x = self.norm2(x)
        return x

In [None]:
class ViTencoder(nn.Module):
    def __init__(self, image_size=256, in_channels=3, patch_size=16, num_classes=32*32*128, embed_dim=1024, num_heads=16, num_layers=8):
        super(ViTencoder, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim)
        self.transformer_blocks = nn.ModuleList([TransformerEncoderBlock(embed_dim, num_heads) for _ in range(num_layers)])
        self.global_avg_pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.positional_encoding(x)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x)
        x = self.global_avg_pooling(x.transpose(1, 2))
        x = x.squeeze(2)
        x = self.fc(x)
        return x

In [None]:
class pixFormer(nn.Module):
    def __init__(self):
        super(pixFormer, self).__init__()
        self.Tencoder = ViTencoder()

    def forward(self, x):
        x = self.Tencoder(x)
        x = x.view(-1, 256, 32, 32)
        return x

In [None]:
model = pixFormer()
input_tensor = torch.rand((32, 3, 256, 256))
output = model(input_tensor)

In [None]:
print('Inputs: {} | Outputs: {}'.format(input_tensor.shape, output.shape))

In [None]:
262144/32/32

In [None]:
porosity = torch.zeros((10,256,256))
for i in range(10):
    k = i+1
    porosity[i] = torch.Tensor(np.load('Fdataset/sample_{}.npz'.format(k))['poro'])
print('Porosity:', porosity.shape)

sample_number = 3
sample_poro = porosity[sample_number].unsqueeze(0).unsqueeze(0)
print('Sample poro:', sample_poro.shape)

poro_0 = sample_poro
poro_1 = transforms.v2.Resize(size=(128,128), antialias=True)(sample_poro)
poro_2 = transforms.v2.Resize(size=(64,64), antialias=True)(sample_poro)
poro_3 = transforms.v2.Resize(size=(32,32), antialias=True)(sample_poro)

plt.figure(figsize=(15,4))
for i in range(4):
    data = eval('poro_{}'.format(i)).squeeze()
    plt.subplot(1,4,i+1)
    plt.imshow(data, cmap='jet')
    plt.title('upscale_{} | {}x{}'.format(i, data.shape[0], data.shape[1]))
    plt.colorbar(pad=0.04, fraction=0.046)
    plt.xticks([]); plt.yticks([])
plt.suptitle('Realization {}'.format(sample_number), weight='bold')
plt.tight_layout(); plt.show()

In [None]:
import torch
import torch.nn as nn

class ImageTransformer(nn.Module):
    def __init__(self):
        super(ImageTransformer, self).__init__()
               
        self.encoder_layer1 = nn.TransformerEncoderLayer(d_model=256, nhead=8, activation='gelu', batch_first=True)

        self.transformer_encoder1 = nn.TransformerEncoder(self.encoder_layer1, num_layers=4)

    def forward(self, x):
        x = self.transformer_encoder1(x)
        return x

model = ImageTransformer()

input_tensor = porosity[sample_number].unsqueeze(0)
print("Input shape:", input_tensor.shape)

output_tensor = model(input_tensor).detach().numpy()
print("Output shape:", output_tensor.shape)

plt.figure(figsize=(15,4))
plt.subplot(121); plt.imshow(input_tensor.T, cmap='jet'); plt.colorbar(pad=0.04, fraction=0.046)
plt.subplot(122); plt.imshow(output_tensor.T, cmap='jet'); plt.colorbar(pad=0.04, fraction=0.046)
plt.tight_layout(); plt.show()

In [None]:
fig, axs = plt.subplots(3, 10, figsize=(15, 5))
for i in range(10):
    axs[0, i].imshow(inputs[i].detach().numpy(), cmap='jet')
    axs[1, i].imshow(latent[i].detach().numpy(), cmap='jet')
    axs[2, i].imshow(outputs[i].detach().numpy(), cmap='jet')
    for j in range(3):
        axs[j,i].set(xticks=[], yticks=[])
plt.tight_layout(); plt.show()