In [1]:
import os
import einops
from typing import Any, List, Optional

import torch
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer
from torchmetrics import MaxMetric
from torchmetrics.classification.accuracy import Accuracy

from transformers import ViTModel

from src.models.components.layers import CopyDetectEmbedding, NormalizedFeatures, SimImagePred, ContrastiveProj
from src.datamodules.copydetect_datamodule import CopyDetectDataModule
from src.datamodules.components.augmentation import Augment
from src.utils.nt_xent_loss import NTXentLoss


In [2]:

get_path = lambda x: os.path.join(os.getcwd(),'data', x)

augment = Augment(overlay_image_dir = get_path('train/'),
                  n_upper = 2,
                  n_lower = 1)
ntxentloss = NTXentLoss(temperature = 0.9, eps = 1e-5)


In [None]:

class CopyDetectModule(LightningModule):
    def __init__(self,
                 pretrained_arch: str,          # Pretrained ViT architecture
                 ntxentloss: object,            # Contrastive loss
                 hidden_dim: int = 2048,        # Contrastive projection size of hidden layer
                 projected_dim: int = 512,      # Contrastive projection size of projection head 
                 beta1: int = 1,                # Similar image BCE loss multiplier
                 beta2: int = 1,                # Contrastive loss multiplier
                 lr: float = 0.001,
                 weight_decay: float = 0.0005):               
        super().__init__()
        self.save_hyperparameters(logger = False)
        #! Change all this to hparams
        self.beta1 = beta1
        self.beta2 = beta2
        self.lr = lr
        self.weight_decay = weight_decay
         
        # Instantiate ViT encoder from pretrained model
        pretrained_model = ViTModel.from_pretrained(pretrained_arch)
        encoder = pretrained_model.encoder
                
        # Instantiate embedding, we use the pretrained ViT cls and position embedding
        embedding = CopyDetectEmbedding(config = pretrained_model.config,
                                        vit_cls = pretrained_model.embeddings.cls_token,
                                        pos_emb = pretrained_model.embeddings.position_embeddings)
        
        # Normalized features
        normfeats = NormalizedFeatures(hidden_dim = pretrained_model.config.hidden_size,
                                       layer_norm_eps = pretrained_model.config.layer_norm_eps)
        # Feature Vector Extractor
        self.feature_extractor = nn.Sequential(embedding, encoder, normfeats)
        
        # Instantiate SimImagePredictor
        simimagepred = SimImagePred(embedding_dim = pretrained_model.config.hidden_size)
        self.embedding = embedding
        self.simimagepred = nn.Sequential(encoder, normfeats, simimagepred)

        
        # Instantiate ContrastiveProjection
        contrastiveproj = ContrastiveProj(embedding_dim = pretrained_model.config.hidden_size,
                                          hidden_dim = hidden_dim,
                                          projected_dim = projected_dim)
        self.contrastiveproj = nn.Sequential(embedding, encoder, normfeats, contrastiveproj)
        
        # Contrastive loss 
        self.contrastive_loss = ntxentloss
        
        # Binary cross entropy loss for similar image pair
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        # Model accuracy in detecting modified copy
        self.train_acc, self.val_acc = Accuracy(), Accuracy()     
        # For logging best validation accuracy
        self.val_acc_best = MaxMetric()

    def forward(self,
                img_r: torch.Tensor,
                img_q: Optional[torch.Tensor] = None,
                ) -> torch.Tensor:
        if img_q is not None:
            return self.feature_extractor(img_r)
        else:
            embedding_rq = self.embedding(img_r, img_q)
            logits = self.simimagepred(embedding_rq)
            preds = torch.argmax(logits, dim = 1)
            return preds

    def step(self, img_r: List[torch.Tensor], img_q: List[torch.Tensor], label: List[torch.Tensor], val: Optional[bool] = False):
        # img_r, img_q to SimImagePredictor
        embedding_rq = self.embedding(img_r, img_q) ## nn sequential don't take multiple input
        logits = self.simimagepred(embedding_rq)
        # Calculate binary cross entropy loss of similar image pair
        simimage_loss = self.bce_loss(logits, label.unsqueeze(dim = 1))
        # Predictions
        preds = torch.argmax(logits, dim = 1)
        
        # Get positive indices
        pos_indices = label.bool()
        # Forward positive indices of img_r and img_q to ContrastiveProjection
        proj_r = self.contrastiveproj(img_r[pos_indices])
        proj_q = self.contrastiveproj(img_q[pos_indices])

        # Calculate contrastive loss between un-augmented img_r and augmented positive pair of img_q
        contrastive_loss = self.contrastive_loss(proj_r, proj_q)
        
        # Weighted sum of bce and contrastive loss
        total_loss = self.beta1 * simimage_loss + self.beta2 * contrastive_loss
        
        return {'simimage': simimage_loss, 'contrastive': contrastive_loss, 'total': total_loss}, preds

    def training_step(self, batch: Any, batch_idx: int):
        img_r, img_q, label = batch
        img_r, img_q, label = torch.vstack(img_r), torch.vstack(img_q), torch.hstack(label)

        losses, preds = self.step(img_r, img_q, label)
        
        # Log train metrics
        acc = self.train_acc(preds, label.int())
        self.log("train/total_loss", losses['total'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/simimage_loss", losses['simimage'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/contrastive_loss", losses['contrastive'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/acc", acc, on_step = True, on_epoch = True, prog_bar = True)

        return losses['total']

    def validation_step(self, batch: Any, batch_idx: int):
        img_r, img_q, label = batch
        losses, preds = self.step(img_r, img_q, label, val = True)

        # Log val metrics
        acc = self.val_acc(preds, label.int())
        self.log("val/total_loss", losses['total'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/simimage_loss", losses['simimage'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/contrastive_loss", losses['contrastive'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/acc", acc, on_step = True, on_epoch = True, prog_bar = True)

        return losses['total']

    def validation_epoch_end(self, outputs: Any):
        acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best.update(acc)
        self.log("val/acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True)

    def on_epoch_end(self):
        # Reset metrics at the end of every epoch
        self.train_acc.reset()
        self.val_acc.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(
            params = self.parameters(), lr = self.lr, weight_decay = self.weight_decay
        )

pretrained_arch = 'google/vit-base-patch16-224'

model = CopyDetectModule(pretrained_arch, ntxentloss)

datamodule = CopyDetectDataModule(train_dir = get_path('train/'),
                           references_dir = get_path('references/'),
                           dev_queries_dir = get_path('dev_queries/'),
                           final_queries_dir = get_path('final_queries/'),
                           augment = augment,
                           dev_validation_set = get_path('dev_validation_set.csv'),
                           batch_size = 16,
                           pin_memory = True,
                           num_workers = 10,
                           n_crops = 2
                           )

datamodule.setup()

trainer = Trainer(fast_dev_run =  True)
trainer.fit(model = model, datamodule = datamodule)

In [3]:
datamodule = CopyDetectDataModule(train_dir = get_path('train/'),
                           references_dir = get_path('references/'),
                           dev_queries_dir = get_path('dev_queries/'),
                           final_queries_dir = get_path('final_queries/'),
                           augment = augment,
                           dev_validation_set = get_path('dev_validation_set.csv'),
                           batch_size = 16,
                           pin_memory = True,
                           num_workers = 10,
                           n_crops = 2
                           )

datamodule.setup()

In [8]:
next(iter(datamodule.test_dataloader())).size()



torch.Size([16, 3, 224, 224])