In [None]:
import os
import random
import csv
import numpy as np
from PIL import Image
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
from torchvision.datasets.utils import download_and_extract_archive

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_SDEV = [0.229, 0.224, 0.225]


# Return all the image paths in a folder
get_image_file = lambda image_dir: [os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
get_image = lambda folder, index: Image.open(folder[index])
get_path = lambda x: os.path.join(os.getcwd(),'data', x)

In [None]:
class CopyDetectPretrainDataset(Dataset):
    def __init__(self,
                 image_dir: str):
        self.image_dir = image_dir
        self.image_files = np.array([os.path.join(image_dir, f) 
                                     for f in os.listdir(image_dir) 
                                     if os.path.isfile(os.path.join(image_dir, f))])
        
    def __len__(self) -> int:
        return len(self.image_files)
    
    def __getitem__(self, index: int):
        image_id = self.image_files[index]
        image = Image.open(os.path.join(self.image_dir, image_id))
        return image, image_id

class CopyDetectCollateFn(nn.Module):
    def __init__(self,
                 transform,
                 augment: object,
                 n_crops: Optional[int] = 1):
        super().__init__()
        self.transform  = transform
        self.augment = augment
        self.n_crops = n_crops

    def forward(self, batch):
        batch_size = len(batch)
        indices = np.arange(batch_size)
        # Transform image in batch and give a dimension for batching
        imgs, ids = batch
        ref_imgs = list(map(lambda x: self.transform(x).unsqueeze_(dim = 0), batch))
        
        ref_imgs_list, aug_imgs_list, ref_ids_list, aug_ids_list = [], [], [], []
        
        for _ in range(self.n_crops):
            rand_bool = np.random.uniform(size = batch_size) < 0.5
            rand_indices = np.random.randint(0, batch_size, size = batch_size)
            aug_indices = np.where(rand_bool, indices, rand_indices)
            aug_imgs = list(map(lambda i: batch[i], aug_indices.tolist()))
            aug_imgs = list(map(lambda x: self.transform(self.augment(x)).unsqueeze_(dim = 0), aug_imgs))
            aug_ids = ids[aug_indices.tolist()]
            
            ref_imgs_list.extend(ref_imgs), aug_imgs_list.extend(aug_imgs), ref_ids_list.extend(ids), aug_ids_list.extend(aug_ids)
            
        return torch.vstack(ref_imgs_list), torch.vstack(aug_imgs_list), torch.hstack(ref_ids_list), torch.hstack(aug_ids_list)

In [None]:
class CopyDetectCollateFn(nn.Module):
    def __init__(self,
                 transform,
                 augment: object,
                 n_crops: Optional[int] = 1):
        super().__init__()
        self.transform  = transform
        self.augment = augment
        self.n_crops = n_crops

    def forward(self, batch):
        batch_size = len(batch)
        indices = np.arange(batch_size)
        # Transform image in batch and give a dimension for batching
        imgs, ids = batch
        ref_imgs = list(map(lambda x: self.transform(x).unsqueeze_(dim = 0), batch))
        
        ref_imgs_list, aug_imgs_list, ref_ids_list, aug_ids_list = [], [], [], []
        
        for _ in range(self.n_crops):
            rand_bool = np.random.uniform(size = batch_size) < 0.5
            rand_indices = np.random.randint(0, batch_size, size = batch_size)
            aug_indices = np.where(rand_bool, indices, rand_indices)
            aug_imgs = list(map(lambda i: batch[i], aug_indices.tolist()))
            aug_imgs = list(map(lambda x: self.transform(self.augment(x)).unsqueeze_(dim = 0), aug_imgs))
            aug_ids = ids[aug_indices.tolist()]
            
            ref_imgs_list.extend(ref_imgs), aug_imgs_list.extend(aug_imgs), ref_ids_list.extend(ids), aug_ids_list.extend(aug_ids)
            
        return torch.vstack(ref_imgs_list), torch.vstack(aug_imgs_list), torch.hstack(ref_ids_list), torch.hstack(aug_ids_list)

In [None]:
from src.datamodules.components.augmentation import Augment

transform =  transforms.Compose([transforms.ToTensor(),
                                 transforms.Resize((224, 224)),
                                 transforms.Normalize(IMAGENET_MEAN, IMAGENET_SDEV)])

augment = Augment(overlay_image_dir = get_path('train/'),
                  n_upper = 2,
                  n_lower = 1)

In [None]:
collate_fn = CopyDetectCollateFn(tranform = transform,
                                 augment = augment,
                                 n_crops = 2)

In [None]:
train_dataset = CopyDetectPretrainDataset(image_dir = get_path('dev_queries/'))

train_dataloader = DataLoader(dataset       = train_dataset,
                              batch_size    = 16,
                              num_workers   = 8,
                              pin_memory    = True,
                              collate_fn    = collate_fn,
                              shuffle       = True,
                              drop_last     = True)

In [None]:
from transformers import ViTModel
from src.models.components.layers import CopyDetectEmbedding, NormalizedFeatures, SimImagePred, ContrastiveProj

pretrained_arch = 'google/vit-base-patch16-224'
pretrained_model = ViTModel.from_pretrained(pretrained_arch)
encoder = pretrained_model.encoder

                    
patch_size = pretrained_model.config.patch_size
                
# Instantiate embedding, we use the pretrained ViT cls and position embedding
embedding = CopyDetectEmbedding(config      = pretrained_model.config,
                                vit_cls     = pretrained_model.embeddings.cls_token,
                                pos_emb     = pretrained_model.embeddings.position_embeddings)
        
# Normalized features
normfeats = NormalizedFeatures(hidden_dim   = pretrained_model.config.hidden_size,
                               layer_norm_eps = pretrained_model.config.layer_norm_eps)

In [None]:
contrastiveproj = ContrastiveProj(embedding_dim = 786,
                                  hidden_dim = 1024,
                                  projected_dim = 512)

contrastiveproj = nn.Sequential(embedding, encoder, normfeats, contrastiveproj)

In [None]:
for batch in train_dataloader:
    ref_img, query_img, ref_id, query_id = batch
    label = torch.tensor(ref_id == query_id, dtype = torch.bool)
    proj_r = contrastiveproj(ref_img[label])
    proj_q = contrastiveproj(query_img[label])
    
    loss = contrastiveproj(proj_r, proj_q)
    