In [1]:
import os
import einops
from PIL import Image
from typing import Any, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from torchmetrics import MaxMetric
from torchmetrics.classification.accuracy import Accuracy
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from transformers import ViTModel

from src.models.components.layers import CopyDetectEmbedding, NormalizedFeatures, SimImagePred, ContrastiveProj
from src.datamodules.copydetect_datamodule import CopyDetectDataModule
from src.datamodules.components.augmentation import Augment
from src.utils.nt_xent_loss import NTXentLoss


get_path = lambda x: os.path.join(os.getcwd(),'data', x)

augment = Augment(overlay_image_dir = get_path('train/'),
                  n_upper = 2,
                  n_lower = 1)

ntxentloss = NTXentLoss(temperature = 0.9, eps = 1e-5)

device = torch.device('cuda:4')


In [2]:
class CopyDetectModule(LightningModule):
    def __init__(self,
                 pretrained_arch: str,          # Pretrained ViT architecture
                 ntxentloss: object,            # Contrastive loss
                 hidden_dim: int = 2048,        # Contrastive projection size of hidden layer
                 projected_dim: int = 512,      # Contrastive projection size of projection head 
                 beta1: int = 1,                # Similar image BCE loss multiplier
                 beta2: int = 1,                # Contrastive loss multiplier
                 lr: float = 0.001,
                 weight_decay: float = 0.0005):               
        super().__init__()
        self.save_hyperparameters(logger = False)
         
        # Instantiate ViT encoder from pretrained model
        pretrained_model = ViTModel.from_pretrained(pretrained_arch)
        encoder = pretrained_model.encoder
        self.patch_size = pretrained_model.config.patch_size
                
        # Instantiate embedding, we use the pretrained ViT cls and position embedding
        embedding = CopyDetectEmbedding(config = pretrained_model.config,
                                        vit_cls = pretrained_model.embeddings.cls_token,
                                        pos_emb = pretrained_model.embeddings.position_embeddings)
        
        # Normalized features
        normfeats = NormalizedFeatures(hidden_dim = pretrained_model.config.hidden_size,
                                       layer_norm_eps = pretrained_model.config.layer_norm_eps)
        # Feature Vector Extractor
        self.feature_extractor = nn.Sequential(embedding, encoder, normfeats)
        
        # Instantiate SimImagePredictor
        simimagepred = SimImagePred(embedding_dim = pretrained_model.config.hidden_size)
        self.embedding = embedding
        self.simimagepred = nn.Sequential(encoder, normfeats, simimagepred)

        # Instantiate ContrastiveProjection
        contrastiveproj = ContrastiveProj(embedding_dim = pretrained_model.config.hidden_size,
                                          hidden_dim = hidden_dim,
                                          projected_dim = projected_dim)
        self.contrastiveproj = nn.Sequential(embedding, encoder, normfeats, contrastiveproj)
        
        # Contrastive loss 
        self.contrastive_loss = ntxentloss
        
        # Binary cross entropy loss for similar image pair
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        
        # Model accuracy in detecting modified copy
        self.train_acc, self.val_acc = Accuracy(), Accuracy()   
          
        # For logging best validation accuracy
        self.val_acc_best = MaxMetric()

    def feature_extract(self, batch: Any) -> torch.Tensor:
        # To extract feature vector
        img_r, img_id = batch
        encoding = self.feature_extractor(img_r)
        batch_size, num_ch, H, W, = img_r.size()
        #dim = encoding.size(2) # batch_size, seq_len, dim 
        h, w = int(H/self.patch_size), int(W/self.patch_size)
        cls, feats = encoding[:,0,:], encoding[:,1:,:] # Get the cls token and all the images features
            
        #feats = feats.reshape(batch_size, h, w, dim).clamp(min = 1e-6).permute(0,3,1,2)
        feats = einops.rearrange(feats, 'b (h w) d -> b d h w', h = h, w = w).clamp(min = 1e-6)
        # GeM Pooling
        feats = F.avg_pool2d(feats.pow(4), (h,w)).pow(1./4)
        feats = einops.rearrange(feats, 'b d () () -> b d')
        # Concatenate cls tokens with image patches to give local and global views of image
        feature_vector = torch.cat((cls, feats), dim = 1)

        return feature_vector, img_id
        
    def predict_copy(self, batch):
        # For copy detection 
        img_r, img_q = batch
        embedding_rq = self.embedding(img_r, img_q)
        logits = self.simimagepred(embedding_rq)
        preds = torch.argmax(logits, dim = 1)
        return preds

    def step(self,
             img_r: List[torch.Tensor],
             img_q: List[torch.Tensor],
             label: List[torch.Tensor]):
        
        # img_r, img_q to SimImagePredictor
        embedding_rq = self.embedding(img_r, img_q) ## nn sequential don't take multiple input
        logits = self.simimagepred(embedding_rq)
        # Calculate binary cross entropy loss of similar image pair
        simimage_loss = self.bce_loss(logits, label.unsqueeze(dim = 1))
        # Predictions
        preds = torch.argmax(logits, dim = 1)
        
        # Get positive indices
        pos_indices = label.bool()
        # Forward positive indices of img_r and img_q to ContrastiveProjection
        proj_r = self.contrastiveproj(img_r[pos_indices])
        proj_q = self.contrastiveproj(img_q[pos_indices])

        # Calculate contrastive loss between un-augmented img_r and augmented positive pair of img_q
        contrastive_loss = self.contrastive_loss(proj_r, proj_q)
        
        # Weighted sum of bce and contrastive loss
        total_loss = self.hparams.beta1 * simimage_loss + self.hparams.beta2 * contrastive_loss
        
        return {'simimage': simimage_loss, 'contrastive': contrastive_loss, 'total': total_loss}, preds

    def training_step(self, batch: Any, batch_idx: int):
        img_r, img_q, label = batch
        img_r, img_q, label = torch.vstack(img_r), torch.vstack(img_q), torch.hstack(label)

        losses, preds = self.step(img_r, img_q, label)
        
        # Log train metrics
        acc = self.train_acc(preds, label.int())
        self.log("train/total_loss", losses['total'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/simimage_loss", losses['simimage'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/contrastive_loss", losses['contrastive'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/acc", acc, on_step = True, on_epoch = True, prog_bar = True)

        return losses['total']

    def validation_step(self, batch: Any, batch_idx: int):
        img_r, img_q, label = batch
        losses, preds = self.step(img_r, img_q, label)

        # Log val metrics
        acc = self.val_acc(preds, label.int())
        self.log("val/total_loss", losses['total'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/simimage_loss", losses['simimage'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/contrastive_loss", losses['contrastive'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/acc", acc, on_step = True, on_epoch = True, prog_bar = True)

        return losses['total']

    def validation_epoch_end(self, outputs: Any):
        acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best.update(acc)
        self.log("val/acc_best", self.val_acc_best.compute(), on_epoch = True, prog_bar = True)
        
    def test_step(self, batch: Any, batch_idx: int):
        feats = self.feature_extract(batch) # Get feat
        
        return feats
    
    def test_epoch_end(self, test_step_outputs: Any):
        all_feats, all_ids = [], []
        for step_output in test_step_outputs:
            all_feats.append(step_output[0])
            all_ids.extend(step_output[1])
            
        all_feats = torch.vstack(all_feats)
        self.test_results = (all_feats, all_ids)
        
        return all_feats
    
    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        self.test_results = None
        score = self.predict_copy(batch)
        
        return score
    
    def on_epoch_end(self):
        # Reset metrics at the end of every epoch
        self.train_acc.reset()
        self.val_acc.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(params = self.parameters(),
                                lr = self.hparams.lr,
                                weight_decay = self.hparams.weight_decay)
        
pretrained_arch = 'google/vit-base-patch16-224'

model = CopyDetectModule(pretrained_arch, ntxentloss)

Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
datamodule = CopyDetectDataModule(train_dir = get_path('train/'),
                                  references_dir = get_path('references/'),
                                  dev_queries_dir = get_path('dev_queries/'),
                                  final_queries_dir = get_path('final_queries/'),
                                  augment = augment,
                                  dev_validation_set = get_path('dev_validation_set.csv'),
                                  batch_size = 16,
                                  pin_memory = True,
                                  num_workers = 10,
                                  n_crops = 2)

datamodule.setup()

In [4]:
trainer = Trainer(accelerator = 'gpu', devices = [4])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [5]:
trainer.test(model = model, dataloaders = datamodule.references_dataloader())
r, r_id = model.test_results
r = r.detach().cpu().numpy()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


In [6]:
trainer.test(model = model, dataloaders = datamodule.final_queries_dataloader())
q, q_id = model.test_results
q = q.detach().cpu().numpy()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing: 0it [00:00, ?it/s]



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


In [7]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_SDEV = [0.229, 0.224, 0.225]

transform =  transforms.Compose([transforms.ToTensor(),
                                 transforms.Resize((224, 224)),
                                 transforms.Normalize(IMAGENET_MEAN, IMAGENET_SDEV)])

get_image = lambda img_dir, img: Image.open(os.path.join(img_dir, img))
get_image_file = lambda image_dir: [os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]

class CopyDetectPredDataset(Dataset):
    def __init__(self,
                 predictions: list,
                 references_dir: str,
                 final_queries_dir: str):
        self.predictions = predictions
        self.references_dir = references_dir
        self.final_queries_dir = final_queries_dir
        self.transform =  transforms.Compose([transforms.ToTensor(),
                                              transforms.Resize((224, 224)),
                                              transforms.Normalize(IMAGENET_MEAN, IMAGENET_SDEV)])
        
    def __len__(self) -> int:
        return len(self.predictions)
    
    def __getitem__(self, index: int):
        final_queries_id, references_id = self.predictions[index][0], self.predictions[index][1]
        
        reference_image = get_image(self.references_dir, references_id)
        final_queries_image = get_image(self.final_queries_dir, final_queries_id)
        
        return self.transform(reference_image), self.transform(final_queries_image)

In [8]:
from src.utils import search_with_capped_res

lims, dis, ids = search_with_capped_res(q, r, 1000)

predictions_list = []
for i in range(100):
    for j in range(lims[i], lims[i+1]):
        predictions_list.append([q_id[i], r_id[ids[j]]])

In [9]:
copydetectpred = CopyDetectPredDataset(predictions = predictions_list,
                                       references_dir = get_path('references'),
                                       final_queries_dir = get_path('final_queries/'))


copydetect_dataloader = DataLoader(dataset = copydetectpred,
                                   batch_size = 16,
                                   num_workers = 8,
                                   pin_memory = True)

In [10]:
p = trainer.predict(model = model, dataloaders = copydetect_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Predicting: 0it [00:00, ?it/s]



In [None]:
model.test_results

In [14]:
torch.hstack(p).cpu().numpy().shape

(1000,)

In [22]:
ref_feats, ref_ids = [], []

for batch in datamodule.references_dataloader():
    feats, img_id = model.feature_extract(batch)
    ref_feats.append(feats)
    ref_ids.extend(img_id)
ref_feats = torch.vstack(ref_feats).detach().cpu().numpy()



KeyboardInterrupt: 

In [20]:
model(img)

NotImplementedError: 

In [None]:
query_feats, query_ids = [], []

for batch in datamodule.final_queries_dataloader():
    img, img_id = batch
    feats = model(img.to(device))
    query_feats.append(feats)
    query_ids.extend(img_id)
    
query_feats = torch.vstack(query_feats).detach().cpu().numpy()

In [None]:
from src.utils import search_with_capped_res

lims, dis, ids = search_with_capped_res(query_feats, ref_feats, 1000)

predictions_list = []
for i in range(100):
    for j in range(lims[i], lims[i+1]):
        predictions_list.append([query_ids[i], ref_ids[ids[j]]])

In [None]:
model.test_results[1]

In [None]:
scores = []
for searches in copydetect_dataloader:
    query_id, ref_id = searches
    query_id, ref_id = query_id.to(device), ref_id.to(device)
    score = model(ref_id, query_id)
    scores.append(score)

In [None]:
scores = torch.hstack(scores).detach().cpu().numpy()
scores.sum()