## Visual Processing

### 1. Import Libraries

In [1]:
## Torch libraries
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

## Import 
from skimage import io
import os

from tqdm import tqdm

In [3]:
IMAGE_PATH = "data/images/"
SAVE_PATH = "data/image_representations/"

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

In [4]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

In [5]:
visual = models.resnet152(pretrained=True)
extracted_layers = list(visual.children())
extracted_layers = extracted_layers[0:8] #Remove the last fc and avg pool
visual = torch.nn.Sequential(*(list(extracted_layers)))

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to C:\Users\damia/.cache\torch\hub\checkpoints\resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:38<00:00, 6.34MB/s] 


In [8]:
def readImage(image_path):
    img = io.imread(image_path)
    return img

def countImages(folder_path):
    return len(os.listdir(folder_path))

def transformImage(img, transforms):
    img_transformed = transforms(img.copy())
    return img_transformed

def getRepresentation(image, model):
    image = image.unsqueeze(0)
    with torch.no_grad():  # No need to track gradients when making predictions
        representation = model(image)
    return representation

def saveRepresentations(read_folder_path, save_folder_path, transforms, model):
    
    # Total number of images
    total_images = int(countImages(read_folder_path)/2)

    # set model to device
    model.to(DEVICE)
    model.eval()

    # Looping through all images 

    for i in tqdm(range(total_images)):

        ## Read and transform image
        read_path = os.path.join(read_folder_path, str(i) + '.tif')
        save_path = os.path.join(save_folder_path, str(i) + '.pt')

        ## Read image
        img = readImage(read_path)
        
        ## Transform image & put to DEVICE
        img_transformed = transformImage(img, transforms)
        img_transformed = img_transformed.to(DEVICE)

        ## Get representation
        representation = getRepresentation(img_transformed, model)

        ## Move to CPU
        #representation = representation.cpu()

        ## Squeeze first dimension
        representation = representation.squeeze(0).cpu().detach()

        ## Save representation
        saveArray(representation, save_path)

def saveArray(array, path):
    torch.save(array, path)

def loadArray(path):
    return torch.load(path)

In [9]:
saveRepresentations(IMAGE_PATH, SAVE_PATH, transforms, visual) ## Takes like 12/13 minutes to run

  0%|          | 0/10659 [00:00<?, ?it/s]

100%|██████████| 10659/10659 [06:28<00:00, 27.42it/s]
