## Visual Processing

### 1. Import Libraries

In [1]:
import torch.nn as nn
from torchvision import models
from torchvision import transforms
import torch
import os
from utils.image_preprocessing_utils import *

In [2]:
IMAGE_PATH = "../raw_data/image/"
SAVE_PATH = "../preprocessed_data/image_representations_vit/"

## Create save path
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps")
DEVICE

device(type='mps')

In [3]:
class CustomViT(nn.Module):
    def __init__(self, pretrained=True):
        super(CustomViT, self).__init__()
        
        # Load the pretrained ViT model
        self.vit = models.vit_b_16(weights="DEFAULT")
        
    def forward(self, x):
        x = self.vit._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.vit.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.vit.encoder(x)
        
        return x

# Create an instance of the custom model
visual = CustomViT()

In [4]:
transformations = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)), # Resize to the expected input size for ViT
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalization values for pretrained models on ImageNet
])

In [5]:
saveRepresentations(read_folder_path=IMAGE_PATH, save_folder_path=SAVE_PATH, model=visual, transforms=transformations, device=DEVICE)

100%|██████████| 10659/10659 [05:17<00:00, 33.62it/s]
