## Visual Processing

### 1. Import Libraries

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [1]:
## Torch libraries
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

## Import 
from skimage import io
import os

from tqdm import tqdm

In [2]:
IMAGE_PATH = "data/images/"
SAVE_PATH = "data/image_representations_vit/"

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

In [3]:
img_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)), # Resize to the expected input size for ViT
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalization values for pretrained models on ImageNet
])

In [4]:
visual = models.vit_l_16(pretrained=True)
extracted_layers = list(visual.children())
extracted_layers = extracted_layers[0:2] #Remove the last fc and avg pool
visual = torch.nn.Sequential(*(list(extracted_layers)))



In [4]:
import torch.nn as nn
from torchvision import models

class CustomViT(nn.Module):
    def __init__(self, pretrained=True):
        super(CustomViT, self).__init__()
        
        # Load the pretrained ViT model
        self.visual = models.vit_b_16(weights="DEFAULT")
        
        # Remove the classifier head
        self.visual.heads.head = nn.Identity()
        
        # Freeze the model
        for param in self.visual.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        # Pass the input through the ViT model
        x = self.visual(x)
        
        # Retrieve the CLS token (first token in sequence)
        #cls_token = x[:, 0]
        
        return x

# Create an instance of the custom model
visual = CustomViT()

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to C:\Users\damia/.cache\torch\hub\checkpoints\vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:02<00:00, 116MB/s]  


In [5]:
def readImage(image_path):
    img = io.imread(image_path)
    return img

def countImages(folder_path):
    return len(os.listdir(folder_path))

def transformImage(img, transforms):
    img_transformed = transforms(img.copy())
    return img_transformed

def getRepresentation(image, model):
    image = image.unsqueeze(0)
    with torch.no_grad():  # No need to track gradients when making predictions
        representation = model(image)
    return representation

def saveRepresentations(read_folder_path, save_folder_path, transforms, model):
    
    # Total number of images
    total_images = int(countImages(read_folder_path)/2)

    # set model to device
    model.to(DEVICE)
    model.eval()

    # Looping through all images 

    for i in tqdm(range(total_images)):

        ## Read and transform image
        read_path = os.path.join(read_folder_path, str(i) + '.tif')
        save_path = os.path.join(save_folder_path, str(i) + '.pt')

        ## Read image
        img = readImage(read_path)
        
        ## Transform image & put to DEVICE
        img_transformed = transformImage(img, transforms)
        img_transformed = img_transformed.to(DEVICE)

        ## Get representation
        representation = getRepresentation(img_transformed, model)

        ## Move to CPU
        #representation = representation.cpu()

        ## Squeeze first dimension
        representation = representation.squeeze(0).cpu().detach()

        ## Save representation
        saveArray(representation, save_path)

def saveArray(array, path):
    torch.save(array, path)

def loadArray(path):
    return torch.load(path)

In [6]:
# DEBUG OUTPUT shape
image = io.imread("data/images/0.tif")
image_transformed = transformImage(image, img_transforms)
#image_transformed = image_transformed.unsqueeze(0)
print(image_transformed.shape)
representation = getRepresentation(image_transformed.to(DEVICE), visual.to(DEVICE))
representation.shape

torch.Size([3, 224, 224])


torch.Size([1, 768])

In [7]:
saveRepresentations(IMAGE_PATH, SAVE_PATH, img_transforms, visual) ## Takes like 12/13 minutes to run

100%|██████████| 10659/10659 [01:49<00:00, 97.28it/s]
