## Vision Convolutional Transformer

In [1]:
import torch
import torch.nn.functional as F
from torchvision.datasets import OxfordIIITPet
import matplotlib.pyplot as plt
from random import random
from torchvision.transforms import Resize, ToTensor
from torchvision.transforms.functional import to_pil_image

from torch import nn
from einops.layers.torch import Rearrange
from torch import Tensor

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from einops import repeat, rearrange

import numpy as np

In [2]:
class TIFFImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        """
        Args:
            folder_path (str): Path to the folder containing TIFF images.
            transform (callable, optional): A function/transform to apply to the images.
        """
        self.folder_path = folder_path
        self.image_paths = [
            os.path.join(folder_path, fname)
            for fname in os.listdir(folder_path)
            if fname.lower().endswith(('.tif', '.tiff'))  # Include TIFF files
        ]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the TIFF image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # Convert to RGB
        
        if self.transform:
            image = self.transform(image)
        
        return image

# Path to the folder containing TIFF images
folder_path = "/Users/nelsonfarrell/Documents/Northeastern/7180/projects/spectral_ratio/data/folder_3/processed_folder_3/high_quality"

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 144x144
    transforms.ToTensor(),          # Convert image to PyTorch tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Create the dataset
tiff_dataset = TIFFImageDataset(folder_path=folder_path, transform=transform)

# Wrap the dataset in a DataLoader for batch processing
dataloader = DataLoader(tiff_dataset, batch_size = 16, shuffle=True)

In [3]:
for batch_idx, images in enumerate(dataloader):
    print(f"Batch Shape: {images.shape}")  # (batch_size, channels, height, width)

    # Access the first image in the batch and print its shape
    single_image = images[0]  # Access the first image in the batch
    print(f"Shape of one image: {single_image.shape}")  # (channels, height, width)
    break

Batch Shape: torch.Size([16, 3, 224, 224])
Shape of one image: torch.Size([3, 224, 224])


In [4]:
single_image_batch = single_image.unsqueeze(0)

In [5]:
single_image_batch.shape

torch.Size([1, 3, 224, 224])

In [6]:
class MyPatchEmbedding(nn.Module):
    """ 
    Creates a non-overlapping patch embedding

    Dims in: 1, 3, 144, 144
    Dims out: 1, 2304, 128
        * 144 / 3 = 48   -- The number of windows that fit the in image.
        * 48 * 48 = 2304 -- The number of positional embeddings.
        * 128            -- The size of each embedding.
    """
    def __init__(self, in_channels = 3, patch_size = 4, stride = 4, emb_size = 128):
        super().__init__()

        # 
        self.conv = nn.Conv2d(
            in_channels = in_channels,
            out_channels = emb_size,  # Embedding dims
            kernel_size = patch_size,
            stride = stride
        )
        
        # Rearrange functions flattens the patches, but maintains the embedding.
        self.rearrange = Rearrange('b c h w -> b (h w) c')

    def forward(self, x):
        """ 
        Applies conv2d and generates patch embeddings
        """
        # Apply convolution to generate patch embeddings
        patches = self.conv(x)

        # Rearrange to sequence format
        patches = self.rearrange(patches)
        return patches


# Initialize the patch embedding layer
patch_embedding = MyPatchEmbedding(in_channels = 3, patch_size = 4, stride = 4, emb_size = 128)

# Forward pass
patches = patch_embedding(single_image_batch)
print("Patch Embeddings Shape:", patches.shape)

Patch Embeddings Shape: torch.Size([1, 3136, 128])


In [7]:
def get_positional_embeddings(num_patches, embedding_size):
    """ 
    Generates a positional encoding
    """
    result = torch.ones(num_patches, embedding_size)
    for i in range(num_patches):
        for j in range(embedding_size):
            result[i][j] = np.sin(i / (10000 ** (j / embedding_size))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / embedding_size)))
    return result

In [8]:
class Attention(nn.Module):
    """ 
    Performs Multihead attention
    """
    def __init__(self, emb_dim, n_heads, dropout):
        super().__init__()
        self.n_heads = n_heads
        self.att = torch.nn.MultiheadAttention(embed_dim = emb_dim,
                                               num_heads = n_heads,
                                               dropout = dropout)
        # Fully connected linear layers
        self.q = torch.nn.Linear(emb_dim, emb_dim)
        self.k = torch.nn.Linear(emb_dim, emb_dim)
        self.v = torch.nn.Linear(emb_dim, emb_dim)

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        attn_output, attn_output_weights = self.att(q, k, v)
        return attn_output

In [9]:
class PreNorm(nn.Module):
    """ 
    Performs layer normalization
    """
    def __init__(self, emb_dim, fn):
        super().__init__()

        # Norm function
        self.norm = nn.LayerNorm(emb_dim)

        # Outer function ~ here Attention
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [10]:
norm = PreNorm(emb_dim = 128, fn = Attention(emb_dim=128, n_heads=4, dropout=0.))
norm(patches).shape

torch.Size([1, 3136, 128])

In [11]:
class FeedForward(nn.Sequential):
    def __init__(self, emb_dim, hidden_dim, dropout = 0.):
        super().__init__(
            nn.Linear(emb_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, emb_dim),
            nn.Dropout(dropout)
        )
ff = FeedForward(emb_dim = 128, hidden_dim = 256)
ff(patches).shape

torch.Size([1, 3136, 128])

In [12]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

In [13]:
class ViT(nn.Module):
    def __init__(self, 
                 channels_in = 3, 
                 img_size = 224, 
                 patch_size = 16, 
                 emb_dim = 768,
                 n_layers = 2, 
                 dropout = 0.1, 
                 heads = 12):
        super(ViT, self).__init__()

        # Attributes
        self.channels = channels_in
        self.height = img_size
        self.width = img_size
        self.patch_size = patch_size
        self.n_layers = n_layers # The number transformer layers
        self.dim_out = (img_size ** 2) * self.channels

        # Get the patch embeddings - This uses a CNN
        self.patch_embedding = MyPatchEmbedding(in_channels = channels_in,
                                                patch_size = patch_size, 
                                                stride = 16,
                                                emb_size = emb_dim)
        # Get the number of patches
        num_patches = (img_size // patch_size) ** 2

        # Get the postional encoding
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))

        # Transformer Encoder
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            transformer_block = nn.Sequential(

                # MHSA
                ResidualAdd(PreNorm(emb_dim, Attention(in_dim = emb_dim, num_heads = heads, kernel_size = 3))),

                # Feed forward linear layer
                ResidualAdd(PreNorm(emb_dim, FeedForward(emb_dim, emb_dim, dropout = dropout))))
            self.layers.append(transformer_block)


    def forward(self, img):

        # Get patch embedding vectors
        x = self.patch_embedding(img)

        # The number of patches
        batch_size, num_patches, emb_dim = x.shape

        # Add the positional encoding
        x += self.pos_embedding[:, :(num_patches)]

        # Transformer layers
        for i in range(self.n_layers):
            print(f"Layer {i} -- Shape: {x.shape}")
            x = self.layers[i](x)
        
        # Rearrange back to image dims
        x = rearrange(x, 'b (h w) (patch_c ph pw) -> b patch_c (h ph) (w pw)', 
                         h=14, w=14, patch_c=3, ph=16, pw=16)


        # Final fully connected layer
        print(f"Shape end: {x.shape}")
        
        return x



model = ViT()
#print(model)
print(single_image_batch.shape)
x = model(single_image_batch)
print(x.shape)

TypeError: Attention.__init__() got an unexpected keyword argument 'in_dim'