# **Visual Transformer 2D**


In [1]:
import os
import sys
import re
from glob import glob
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import MulticlassMatthewsCorrCoef
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn.functional import kl_div
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor

import open3d as o3
import math
import yaml
import argparse


np.random.seed(0)
torch.manual_seed(0)


sys.path.append('../')
from src.models.VisualTransformerEncoder import *
from src.models.VisualTransformerDecoder import *
from src.models.MultiHeadAttentionBlock import *
from src.models.VisualTransformerGenerator import *
from src.utils import features, utils
from src.data.dataset import DataMNIST, DataMNISTGen

from tqdm import tqdm, trange

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch
torch.manual_seed(0)


torch.mps.empty_cache()

In [3]:
# TEMP for supressing pytorch user warnings
import warnings
warnings.filterwarnings("ignore")

In [4]:
print('MPS is build: {}'.format(torch.backends.mps.is_built()))
print('MPS Availability: {}'.format(torch.backends.mps.is_available()))
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps'
print('Device is set to :{}'.format(DEVICE))
torch.seed = 42

MPS is build: True
MPS Availability: True
Device is set to :mps


# DATA LOADING

In [5]:
with open('../config/config.yaml', 'r') as file:
    config = yaml.safe_load(file)

In [6]:
data = DataMNISTGen(**config["data_parameters"])
data.setup()

60000it [00:13, 4378.66it/s]
10000it [00:03, 2880.95it/s]


In [7]:
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()

## PRE & POST PROCESSING SCRIPTS

In [8]:
def TensorToImageGrid(images_batch, rows, cols):
    grid = torchvision.utils.make_grid(images_batch, nrow=cols)
    grid = grid.permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(cols, rows))
    plt.imshow(grid, cmap='gray')
    plt.axis('off')
    return plt.show()

def TensorToImage(image):
    plt.figure(figsize = (2, 2))
    plt.imshow(image.numpy(), cmap='gray')
    return plt.show()

![Alt text](image.png)

![Alt text](image-1.png)

In [9]:
n_patches = 7
num_layers = 4
hidden_d = 128
n_heads = 8
d_ff = 512
dropout = 0.1
learning_rate = 0.0001

transformer = Transformer(hidden_d, n_heads, num_layers, d_ff, dropout, n_patches).to(DEVICE)
criterion = nn.MSELoss().to(DEVICE)
optimizer = optim.Adam(transformer.parameters(), lr=learning_rate, eps=1e-9)
transformer.train()


for epoch in range(1):
    loss_epoch = []
    n=0
    for images, y in tqdm(train_dataloader):
        #images, y = batch
        #print(batch)
        images = images.to(DEVICE)
        y = y.to(DEVICE)
        optimizer.zero_grad()
        output = transformer(images)
        loss = criterion(y, output)
        loss.backward()
        optimizer.step()
        loss_epoch.append(loss.detach().cpu().item()) 
        #torch.mps.empty_cache()
    
    print("Epoch: {}; Loss: {}".format(epoch+1, np.mean(loss_epoch)))

  4%|▍         | 148/3750 [00:49<20:29,  2.93it/s]

In [None]:
images_ref, images_gen = next(iter(val_dataloader))

In [None]:
transformer.eval()
N=0

In [None]:
TensorToImage(transformer(images_ref)[N].squeeze().detach().cpu()), TensorToImage(images_gen[N].squeeze().detach().cpu()), TensorToImage(images_ref[N].squeeze().detach().cpu())