In [75]:
import os, time, math
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.ops import SqueezeExcitation
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM

In [70]:
def check_torch_gpu():
    torch_version, cuda_avail = torch.__version__, torch.cuda.is_available()
    count, name = torch.cuda.device_count(), torch.cuda.get_device_name()
    #py_version, conda_env_name = sys.version, sys.executable.split('\\')[-2]
    print('\n'+'-'*60)
    print('----------------------- VERSION INFO -----------------------')
    #print('Conda Environment: {} | Python version: {}'.format(conda_env_name, py_version))
    print('Torch version: {}'.format(torch_version))
    print('Torch build with CUDA? {}'.format(cuda_avail))
    print('# Device(s) available: {}, Name(s): {}'.format(count, name))
    print('-'*60+'\n')
    device = torch.device('cuda' if cuda_avail else 'cpu')
    return device

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [71]:
device = check_torch_gpu()


------------------------------------------------------------
----------------------- VERSION INFO -----------------------
Torch version: 2.1.0+cu121
Torch build with CUDA? True
# Device(s) available: 1, Name(s): NVIDIA GeForce RTX 3080
------------------------------------------------------------



In [None]:
class CustomLoss(nn.Module):
    def __init__(self, mse_weight=1.0, ssim_weight=1.0, todevice=True):
        super(CustomLoss, self).__init__()
        self.mse_weight = mse_weight
        self.ssim_weight = ssim_weight
        self.mse_loss = nn.MSELoss()
        if todevice:
          device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
          self.ssim = SSIM().to(device)
        else:
          self.ssim = SSIM()
    def forward(self, pred, target):
        mse_loss = self.mse_loss(pred, target)
        ssim_loss = 1.0 - self.ssim(pred, target)
        total_loss = self.mse_weight * mse_loss + self.ssim_weight * ssim_loss
        return total_loss

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        patches = self.projection(x)
        patches = rearrange(patches, 'b c h w -> b (h w) c')
        return patches

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_seq_len=512):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        pos_enc  = torch.zeros((1, max_seq_len, embed_dim))
        pos_enc[0, :, 0::2] = torch.sin(position * div_term)
        pos_enc[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        return x + self.pos_enc[:, :x.size(1)].detach()

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim  = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key   = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        batch_size = query.shape[0]
        Q = self.query(query).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = self.key(key).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = self.value(value).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        attention_weights = F.softmax(scores, dim=-1)
        out = torch.matmul(attention_weights, V)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.embed_dim)
        out = self.fc_out(out)
        return out

In [5]:
class MLPBlock(nn.Module):
    def __init__(self, embed_dim, mlp_hidden_dim):
        super(MLPBlock, self).__init__()
        self.fc1 = nn.Linear(embed_dim, mlp_hidden_dim)
        self.fc2 = nn.Linear(mlp_hidden_dim, embed_dim)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_hidden_dim=1024):
        super(TransformerEncoderBlock, self).__init__()
        self.self_attention = MultiHeadAttention(embed_dim, num_heads)
        self.mlp_block = MLPBlock(embed_dim, mlp_hidden_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attention_output = self.self_attention(x, x, x)
        x = x + attention_output
        x = self.norm1(x)
        mlp_output = self.mlp_block(x)
        x = x + mlp_output
        x = self.norm2(x)
        return x

In [7]:
class ViTencoder(nn.Module):
    def __init__(self, image_size=256, latent_size=32, in_channels=3, patch_size=16, projection_dim=256, embed_dim=1024, num_heads=16, num_layers=8):
        super(ViTencoder, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim)
        self.transformer_blocks = nn.ModuleList([TransformerEncoderBlock(embed_dim, num_heads) for _ in range(num_layers)])
        self.global_avg_pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(embed_dim, projection_dim*latent_size*latent_size)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.positional_encoding(x)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x)
        x = self.global_avg_pooling(x.transpose(1, 2))
        x = x.squeeze(2)
        x = self.fc(x)
        return x

In [64]:
class MultiScaleResidual(nn.Module):
    def __init__(self, image_size=256):
        super(MultiScaleResidual, self).__init__()
        self.image_size = image_size

    def forward(self, x):
        _, _, h, w = x.shape
        scale = [self.image_size//h, self.image_size//w]
        size  = (self.image_size//scale[0], self.image_size//scale[1])
        x_ups = transforms.Resize(size, antialias=True)(x)
        return torch.cat([x, x_ups], dim=1)

In [62]:
class PixFormer(nn.Module):
    def __init__(self, projection_dim=128, latent_size=32):
        super(PixFormer, self).__init__()
        self.projection_dim = projection_dim
        self.latent_size    = latent_size
        
        self.encoder = ViTencoder(latent_size=latent_size, projection_dim=projection_dim)

        self.layers = nn.Sequential(
            self._conv_block(projection_dim, projection_dim//2),
            self._conv_block(projection_dim//2, projection_dim//4),
            self._conv_block(projection_dim//4, projection_dim//8),
        )
        self.out = nn.Conv2d(projection_dim//8, 1, kernel_size=3, padding=1)

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            SqueezeExcitation(out_channels, out_channels//4),
            nn.InstanceNorm2d(out_channels),
            nn.PReLU(),
            MultiScaleResidual(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(out_channels*2, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, self.projection_dim, self.latent_size, self.latent_size)
        for layer in self.layers:
            x = layer(x)
        x_output = self.out(x)
        return x_output

In [74]:
model = PixFormer()
input_tensor = torch.rand((32, 3, 256, 256))
output = model(input_tensor)
print('# Parameters: {:,}'.format(count_params(model)))
print('Inputs: {} | Outputs: {}'.format(input_tensor.shape, output.shape))

# Parameters: 185,746,560
Inputs: torch.Size([32, 3, 256, 256]) | Outputs: torch.Size([32, 1, 256, 256])


***
# END