# **Visual Transformer 2D**


In [1]:
import os
import sys
import re
from glob import glob
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import MulticlassMatthewsCorrCoef
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn.functional import kl_div
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor

import open3d as o3
import math
import yaml
import argparse


np.random.seed(0)
torch.manual_seed(0)


sys.path.append('../')
from src.models.VisualTransformerEncoder import *
from src.models.VisualTransformerDecoder import *
from src.models.MultiHeadAttentionBlock import *
from src.models.VisualTransformerGenerator import *
from src.utils import features, utils
from src.data.dataset import DataMNIST

from tqdm import tqdm, trange

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch
torch.manual_seed(0)


torch.mps.empty_cache()

In [3]:
# TEMP for supressing pytorch user warnings
import warnings
warnings.filterwarnings("ignore")

In [4]:
print('MPS is build: {}'.format(torch.backends.mps.is_built()))
print('MPS Availability: {}'.format(torch.backends.mps.is_available()))
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps'
print('Device is set to :{}'.format(DEVICE))
torch.seed = 42

MPS is build: True
MPS Availability: True
Device is set to :mps


# DATA LOADING

In [5]:
with open('../config/config.yaml', 'r') as file:
    config = yaml.safe_load(file)

In [6]:
data = DataMNIST(**config["data_parameters"])
data.setup()

In [7]:
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()

In [10]:
### GENERATE DATA ###
training_data = []
for index_ref, (image_ref, number_ref) in tqdm(enumerate(data.train_dataset)):
    for index_gen, (image_gen, number_gen) in enumerate(data.train_dataset):
        if  index_ref != index_gen and number_ref == number_gen:
            training_data.append((image_ref, image_gen))
            break


60000it [00:17, 3361.70it/s]


In [39]:
### GENERATE VALIDATION DATA ###
validating_data = []
for index_ref, (image_ref, number_ref) in tqdm(enumerate(data.val_dataset)):
    for index_gen, (image_gen, number_gen) in enumerate(data.val_dataset):
        if  index_ref != index_gen and number_ref == number_gen:
            validating_data.append((image_ref, image_gen))
            break

10000it [00:03, 2826.10it/s]


In [40]:
val_dataloader = DataLoader(
            validating_data,
            batch_size=16,
            shuffle=True,
            pin_memory=False,
        )

## PRE & POST PROCESSING SCRIPTS

In [71]:
def TensorToImageGrid(images_batch, rows, cols):
    grid = torchvision.utils.make_grid(images_batch, nrow=cols)
    grid = grid.permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(cols, rows))
    plt.imshow(grid, cmap='gray')
    plt.axis('off')
    return plt.show()

def TensorToImage(image):
    plt.figure(figsize = (2, 2))
    plt.imshow(image.numpy(), cmap='gray')
    return plt.show()

![Alt text](image.png)

![Alt text](image-1.png)

In [13]:
n_patches = 7
num_layers = 4
hidden_d = 128
n_heads = 8
d_ff = 512
dropout = 0.1
learning_rate = 0.0001

transformer = Transformer(hidden_d, n_heads, num_layers, d_ff, dropout, n_patches).to(DEVICE)
criterion = nn.MSELoss().to(DEVICE)
optimizer = optim.Adam(transformer.parameters(), lr=learning_rate, eps=1e-9)
transformer.train()


for epoch in range(2):
    loss_epoch = []
    n=0
    for images, y in tqdm(train_dataloader):
        #images, y = batch
        #print(batch)
        images = images.to(DEVICE)
        y = y.to(DEVICE)
        optimizer.zero_grad()
        output = transformer(images)
        loss = criterion(y, output)
        loss.backward()
        optimizer.step()
        loss_epoch.append(loss.detach().cpu().item()) 
        #torch.mps.empty_cache()
    
    print("Epoch: {}; Loss: {}".format(epoch+1, np.mean(loss_epoch)))

100%|██████████| 3750/3750 [20:36<00:00,  3.03it/s]


Epoch: 1; Loss: 0.032163921030486625


100%|██████████| 3750/3750 [20:30<00:00,  3.05it/s]

Epoch: 2; Loss: 0.006669528393726796





In [None]:
loss.detach().cpu().item()

In [41]:
images_ref, images_gen = next(iter(val_dataloader))

In [114]:
transformer.eval()
N=0

In [115]:
TensorToImage(transformer(images_ref)[N].squeeze().detach().cpu()), TensorToImage(images_gen[N].squeeze().detach().cpu()), TensorToImage(images_ref[N].squeeze().detach().cpu())

IndexError: index 16 is out of bounds for dimension 0 with size 16