This script was run in google colab on an A100 GPU

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install pytorch_lightning
!pip install xformers
!pip install triton

!pip freeze > requirements.txt
!python --version

# Importing Models

### DINOv2

In [None]:
import torch

# import dinov2_vitl14_reg model
dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')

# have model go into evaluation mode
dino_model.eval()

# Importing Libraries and Setup




In [None]:
import os
import random
import numpy as np
from numpy import dot
from numpy.linalg import norm
from numpy import random as rand
from pathlib import Path

from scipy.spatial import distance

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import cv2
import skimage
from skimage import exposure
import PIL
from PIL import Image

import torch
import torch.nn as nn

from torch.amp.autocast_mode import autocast
from torchvision import transforms
from torchvision.transforms import v2
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms.v2 import functional as F
from torchvision.io import read_image

from tqdm import tqdm
from datetime import datetime

# Get current date
date = datetime.now().strftime('%d%m%Y')

### EDIT PATHS ###
PROJECT_DIR = Path("/content/drive/MyDrive/YOUR_PROJECT_FOLDER")
DATASET_PATH = PROJECT_DIR / "preprocessed_images" # Set dreamachine drawing path

# Create folder for saving feature vectors
OUTPUT_DIR = PROJECT_DIR / "embeddings"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Set number of workers for dataloader
NUM_WORKERS = os.cpu_count()

# Set random seed for reproducibility
SEED = 42

# Set device to GPU if available (We used A100 GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def randseed(seed):
    ''' Set random seed for reproducibility '''

    # Set random seed for Python
    random.seed(seed)
    np.random.seed(seed)

    # Set random seed for PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # Set random seed for cuDNN
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

randseed(SEED)

# Set pytorch to be deterministic/reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Print device and number of workers
print("Number of workers:", NUM_WORKERS)
print('Images in dataset: ', len(os.listdir(DATASET_PATH)))

Move the dreamachine images to temporary memory for faster upload to the GPU. (Runtime ~10 min)

In [None]:
import shutil

shutil.copytree(DATASET_PATH, Path("/dev/shm/dataset"))
DATASET_PATH = Path("/dev/shm/dataset")

print('Images in dataset: ', len(os.listdir(DATASET_PATH)))

# Preprocessing & Augmentation Transforms

In [None]:
# Set Parameters
dino_params = {'norm_mean': (0.485, 0.456, 0.406), # Standard ImageNet Params (https://github.com/facebookresearch/dinov2/issues/181)
              'norm_std_dev': (0.229, 0.224, 0.225),
              'image_size': 448,
              'feature_length':1024}

def equalize(tensor):
    ''' Equalize the histogram of an image to make exposure uniform'''

    # Get dimensions of PyTorch tensor in the format C x H x W
    height, width = tensor.shape[1], tensor.shape[2]

    np_image = tensor.permute(1, 2, 0).numpy() # Convert tensor to numpy array
    np_image = exposure.equalize_adapthist(np_image, clip_limit=0.03) # Clip limit determined via testing...
    np_image = torch.tensor(np_image).permute(2, 0, 1) # Convert back to tensor
    return np_image.float()  # Ensure tensor is float for subsequent processing

def preprocess_transforms(image, model):
    ''' Preprocess image for DINOv2 model '''
    transformation = v2.Compose([v2.ToImage(), # Convert to PIL Image
                                 v2.ToDtype(torch.uint8, scale=True), # Convert to uint8
                                 v2.Lambda(lambda x: equalize(x)), # Histogram Equalization
                                 v2.GaussianBlur(kernel_size=(7, 7), sigma=(1)), # Gaussian Blur
                                 v2.ToDtype(torch.float32, scale=True) # Convert to float32
    ])
    return transformation(image)

class TestTimeAugmentations:
    ''' test time augmentations '''

    def __init__(self, num_transforms=5, model=None, norm=True):
        ''' Initialize test time augmentations

            num_transforms: number of augmentations to apply
            model: model parameters
            norm:normalize image
        '''

        self.num_transforms = num_transforms
        self.model = model
        self.norm = norm

    def transform(self, image):
        ''' Apply set of transforms to image '''

        # Start with a list containing just the original image (1 img)
        images = [image]

        # Horizontally flip the image and add to list (2 imgs)
        images.extend([v2.functional.hflip(img) for img in images])

        # Apply inversion to images in list (4 imgs)
        images.extend([v2.functional.invert(img) for img in images])

        # Rotate at all 90 degree positions to images in list (16 imgs)
        images = [v2.functional.rotate(img, angle) for img in images for angle in [0, 90, 180, 270]]

        # Set seed for color jitter such that all images get same same color jitter
        torch.manual_seed(SEED)

        # Define random color jitter augmentation
        jitter = v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)

        # Apply jitter to images in list (16 * num_transforms imgs)
        images = [jitter(img) for i in range(self.num_transforms) for img in images]

        # Normalize images
        if self.norm:
            normalize = v2.Normalize(mean=self.model['norm_mean'], std=self.model['norm_std_dev'])
            images = [normalize(img) for img in images]

        return images

### Visualization Of Transforms

In [None]:
def show(image):
    ''' Show image '''

    # Create a ToPILImage transform
    pil_image = v2.ToPILImage()(image)

    plt.imshow(pil_image)
    plt.show

# directory
dir = os.listdir(DATASET_PATH)

# set random seed
randseed(SEED)

# pick random image from directory and show
image = read_image(os.path.join(DATASET_PATH, random.choice(dir)))
show(image)

# set parameters
params = dino_params

# set number of transforms and augmentations
num_transforms = 1
num_augmenations = 16*num_transforms

# preprocess and show
preprocess = preprocess_transforms(image, params)
show(preprocess)

# apply test time augmentations
augmentation = TestTimeAugmentations(num_transforms=num_transforms, model=params, norm=False)

images = augmentation.transform(preprocess)

# show images
for j in range(int(len(images)/16)):

    images_idx = images[j*16:(j+1)*16]

    plt.figure(figsize=(12,12))

    for idx, image in enumerate(images_idx):

        plt.subplot(4,4,idx+1)

        #Create a ToPILImage transform
        pil_image = v2.ToPILImage()(image)

        plt.axis('off')
        plt.imshow(pil_image)

    plt.show()

# Inference

### Define Dataset Class

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

class MyImageFolder(torch.utils.data.Dataset):
    def __init__(self, root_dir, model, preprocess_transform=None):
        self.root_dir = root_dir
        self.preprocess_transform = preprocess_transform
        self.img_files = os.listdir(root_dir)
        self.model = model

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.root_dir, self.img_files[idx])
        image = read_image(image_path)

        if self.preprocess_transform:
            image = self.preprocess_transform(image, model=self.model)

        return image, self.img_files[idx]

In [None]:
from time import time
from collections import defaultdict
import gc

def ExtractFeatureVectors(model,
                          dataloader,
                          model_params,
                          model_name,
                          num_augmentations=1,
                          feature_vector_name=None):
        "Run dataset through model and save feature vectors to disk"

        labels = []

        randseed(SEED)

        # Empty Tensor for feature vectors
        feature_vectors = torch.empty((0, model_params['feature_length']), device=device)

        # Define test time augmentations
        augmentation_transforms = TestTimeAugmentations(num_transforms=num_augmentations,
                                                        model=model_params)

        # Process the images in batches
        for idx, (inputs, names) in enumerate(tqdm(dataloader)):

                augmented_images = []

                # For image in each batch
                for i in range(inputs.shape[0]):

                    # Apply test time augmentations
                    augmented_images.extend(augmentation_transforms.transform(inputs[i]))

                    # Add label for each image
                    labels.extend([names[i]] * num_augmentations * 16)

                # Stack all augmented_images
                batch = torch.stack(augmented_images)

                with torch.no_grad():
                    with autocast("cuda"):

                        # Inference
                        output = model(batch.to(device="cuda"))

                        # Store the embeddings in the feature_vectors tensor
                        feature_vectors = torch.cat((feature_vectors, output), dim=0)

        # Move feature_vectors tensor to cpu
        feature_vectors = feature_vectors.cpu()

        # Save all feature vectors
        torch.save(
            feature_vectors,
            str(OUTPUT_DIR / f"feature_vectors_all_augments_{feature_vector_name}_{model_name}_{num_augmentations*16}_augmentations_{date}.pt"),
        )

        # Save all feature vectors dictionary
        feature_vectors_dict = defaultdict(list)

        for idx, i in enumerate(labels):
            feature_vectors_dict[i].append(feature_vectors[idx])

        torch.save(
            feature_vectors_dict,
            str(OUTPUT_DIR / f"feature_vectors_dict_all_augments_{feature_vector_name}_{model_name}_{num_augmentations*16}_augmentations_{date}.pt"),
        )

        # Save average feature vectors dictionary
        feature_vector_dict_mean = {}
        for key, value in feature_vectors_dict.items():

            fv = torch.stack(value)

            feature_vector_dict_mean[key] = torch.mean(fv, dim=0)

        torch.save(
            feature_vector_dict_mean,
            str(OUTPUT_DIR / f"feature_vectors_dict_{feature_vector_name}_{model_name}_{num_augmentations*16}_augmentations_{date}.pt"),
        )

runtime (58:20)

In [None]:
# Define GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Empty GPU Cache
torch.cuda.empty_cache()

# Set model and parameters
model = dino_model
model_params = dino_params
model_name =  'dino'

# Set feature vector filename
feature_vector_name = 'DM_full'

# Set model float type
model.to(device=device, dtype=torch.float16)

# Number of augmentations (*16) per image and batch size (*num_augmentations*16 images total)
num_augmentations = 3
batch_size = 3

# Define dataset
dataset = MyImageFolder(root_dir=DATASET_PATH,
                        model = model_params,
                        preprocess_transform=preprocess_transforms)

# Define dataloader
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=NUM_WORKERS,
                        pin_memory=True)

# Run Inference on Images
ExtractFeatureVectors(model,
                      dataloader,
                      model_params,
                      model_name,
                      num_augmentations=num_augmentations,
                      feature_vector_name=feature_vector_name)