In [1]:
import os
import einops
from PIL import Image
from typing import Any, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from torchmetrics import MaxMetric
from torchmetrics.classification.accuracy import Accuracy
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


from transformers import ViTModel

from src.models.components.layers import CopyDetectEmbedding, NormalizedFeatures, SimImagePred, ContrastiveProj
from src.datamodules.copydetect_datamodule import CopyDetectDataModule
from src.datamodules.components.augmentation import Augment
from src.utils.nt_xent_loss import NTXentLoss


In [2]:

get_path = lambda x: os.path.join(os.getcwd(),'data', x)

augment = Augment(overlay_image_dir = get_path('train/'),
                  n_upper = 2,
                  n_lower = 1)
ntxentloss = NTXentLoss(temperature = 0.9, eps = 1e-5)


In [3]:
class CopyDetectModule(LightningModule):
    def __init__(self,
                 pretrained_arch: str,          # Pretrained ViT architecture
                 ntxentloss: object,            # Contrastive loss
                 hidden_dim: int = 2048,        # Contrastive projection size of hidden layer
                 projected_dim: int = 512,      # Contrastive projection size of projection head 
                 beta1: int = 1,                # Similar image BCE loss multiplier
                 beta2: int = 1,                # Contrastive loss multiplier
                 lr: float = 0.001,
                 weight_decay: float = 0.0005):               
        super().__init__()
        self.save_hyperparameters(logger = False)
         
        # Instantiate ViT encoder from pretrained model
        pretrained_model = ViTModel.from_pretrained(pretrained_arch)
        encoder = pretrained_model.encoder
        self.patch_size = pretrained_model.config.patch_size
                
        # Instantiate embedding, we use the pretrained ViT cls and position embedding
        embedding = CopyDetectEmbedding(config = pretrained_model.config,
                                        vit_cls = pretrained_model.embeddings.cls_token,
                                        pos_emb = pretrained_model.embeddings.position_embeddings)
        
        # Normalized features
        normfeats = NormalizedFeatures(hidden_dim = pretrained_model.config.hidden_size,
                                       layer_norm_eps = pretrained_model.config.layer_norm_eps)
        # Feature Vector Extractor
        self.feature_extractor = nn.Sequential(embedding, encoder, normfeats)
        
        # Instantiate SimImagePredictor
        simimagepred = SimImagePred(embedding_dim = pretrained_model.config.hidden_size)
        self.embedding = embedding
        self.simimagepred = nn.Sequential(encoder, normfeats, simimagepred)

        # Instantiate ContrastiveProjection
        contrastiveproj = ContrastiveProj(embedding_dim = pretrained_model.config.hidden_size,
                                          hidden_dim = hidden_dim,
                                          projected_dim = projected_dim)
        self.contrastiveproj = nn.Sequential(embedding, encoder, normfeats, contrastiveproj)
        
        # Contrastive loss 
        self.contrastive_loss = ntxentloss
        
        # Binary cross entropy loss for similar image pair
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        
        # Model accuracy in detecting modified copy
        self.train_acc, self.val_acc = Accuracy(), Accuracy()   
          
        # For logging best validation accuracy
        self.val_acc_best = MaxMetric()

    def forward(self,
                img_r: torch.Tensor,
                img_q: Optional[torch.Tensor] = None,
                ) -> torch.Tensor:
        
        if img_q is None:
            encoding = self.feature_extractor(img_r)
            batch_size, num_ch, H, W, = img_r.size()
            #dim = encoding.size(2) # batch_size, seq_len, dim 
            h, w = int(H/self.patch_size), int(W/self.patch_size)
            cls, feats = encoding[:,0,:], encoding[:,1:,:] # Get the cls token and all the images features
            
            #feats = feats.reshape(batch_size, h, w, dim).clamp(min = 1e-6).permute(0,3,1,2)
            feats = einops.rearrange(feats, 'b (h w) d -> b d h w', h = h, w = w).clamp(min = 1e-6)
            # GeM Pooling
            feats = F.avg_pool2d(feats.pow(4), (h,w)).pow(1./4)
            feats = einops.rearrange(feats, 'b d () () -> b d')
            # Concatenate cls tokens with image patches to give local and global views of image
            feature_vector = torch.cat((cls, feats), dim = 1)
            return feature_vector
        else:
            embedding_rq = self.embedding(img_r, img_q)
            logits = self.simimagepred(embedding_rq)
            preds = torch.argmax(logits, dim = 1)
            return preds

    def step(self,
             img_r: List[torch.Tensor],
             img_q: List[torch.Tensor],
             label: List[torch.Tensor]):
        
        # img_r, img_q to SimImagePredictor
        embedding_rq = self.embedding(img_r, img_q) ## nn sequential don't take multiple input
        logits = self.simimagepred(embedding_rq)
        # Calculate binary cross entropy loss of similar image pair
        simimage_loss = self.bce_loss(logits, label.unsqueeze(dim = 1))
        # Predictions
        preds = torch.argmax(logits, dim = 1)
        
        # Get positive indices
        pos_indices = label.bool()
        # Forward positive indices of img_r and img_q to ContrastiveProjection
        proj_r = self.contrastiveproj(img_r[pos_indices])
        proj_q = self.contrastiveproj(img_q[pos_indices])

        # Calculate contrastive loss between un-augmented img_r and augmented positive pair of img_q
        contrastive_loss = self.contrastive_loss(proj_r, proj_q)
        
        # Weighted sum of bce and contrastive loss
        total_loss = self.hparams.beta1 * simimage_loss + self.hparams.beta2 * contrastive_loss
        
        return {'simimage': simimage_loss, 'contrastive': contrastive_loss, 'total': total_loss}, preds

    def training_step(self, batch: Any, batch_idx: int):
        img_r, img_q, label = batch
        img_r, img_q, label = torch.vstack(img_r), torch.vstack(img_q), torch.hstack(label)

        losses, preds = self.step(img_r, img_q, label)
        
        # Log train metrics
        acc = self.train_acc(preds, label.int())
        self.log("train/total_loss", losses['total'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/simimage_loss", losses['simimage'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/contrastive_loss", losses['contrastive'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("train/acc", acc, on_step = True, on_epoch = True, prog_bar = True)

        return losses['total']

    def validation_step(self, batch: Any, batch_idx: int):
        img_r, img_q, label = batch
        losses, preds = self.step(img_r, img_q, label)

        # Log val metrics
        acc = self.val_acc(preds, label.int())
        self.log("val/total_loss", losses['total'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/simimage_loss", losses['simimage'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/contrastive_loss", losses['contrastive'], on_step = True, on_epoch = True, prog_bar = False)
        self.log("val/acc", acc, on_step = True, on_epoch = True, prog_bar = True)

        return losses['total']

    def validation_epoch_end(self, outputs: Any):
        acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best.update(acc)
        self.log("val/acc_best", self.val_acc_best.compute(), on_epoch = True, prog_bar = True)
        
    def test_step(self, batch: Any, batch_idx: int):
        feats = self.forward(batch) # Get feat
        return feats

    def on_epoch_end(self):
        # Reset metrics at the end of every epoch
        self.train_acc.reset()
        self.val_acc.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(params = self.parameters(),
                                lr = self.hparams.lr,
                                weight_decay = self.hparams.weight_decay)

In [4]:

pretrained_arch = 'google/vit-base-patch16-224'

model = CopyDetectModule(pretrained_arch, ntxentloss)

datamodule = CopyDetectDataModule(train_dir = get_path('train/'),
                           references_dir = get_path('references/'),
                           dev_queries_dir = get_path('dev_queries/'),
                           final_queries_dir = get_path('final_queries/'),
                           augment = augment,
                           dev_validation_set = get_path('dev_validation_set.csv'),
                           batch_size = 16,
                           pin_memory = True,
                           num_workers = 10,
                           n_crops = 2)

datamodule.setup()


Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_SDEV = [0.229, 0.224, 0.225]

transform =  transforms.Compose([transforms.ToTensor(),
                                 transforms.Resize((224, 224)),
                                 transforms.Normalize(IMAGENET_MEAN, IMAGENET_SDEV)])

In [None]:
next(iter(datamodule.references_dataloader())).size()

In [None]:

#trainer = Trainer(fast_dev_run =  True)
#trainer.fit(model = model, datamodule = datamodule)

In [6]:
get_image = lambda folder, img: Image.open(folder[index])
get_image_file = lambda image_dir: [os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_SDEV = [0.229, 0.224, 0.225]

class CopyDetectPredDataset(Dataset):
    def __init__(self,
                 predictions: list,
                 references_dir: str,
                 final_queries_dir: str):
        self.predictions = predictions
        self.references_dir = references_dir
        self.final_queries_dir = final_queries_dir
        self.transform =  transforms.Compose([transforms.ToTensor(),
                                              transforms.Resize((224, 224)),
                                              transforms.Normalize(IMAGENET_MEAN, IMAGENET_SDEV)])
        
    def __len__(self) -> int:
        return len(self.predictions)
    
    def __getitem__(self, index: int) -> torch.Tensor:
        final_queries_id, references_id = self.predictions[index]
        reference_image = get_image(self.references_images, index)
        final_queries_image = get_image(self.final_queries_images, index)
        
        return self.transform(reference_image). self.transform(final_queries_image) 
    

In [9]:
ref_feats = []

device = torch.device('cuda:4')
model.to(device)

for img in datamodule.references_dataloader():
    img = img.to(device)
    feats = model(img)
    ref_feats.append(feats)
ref_feats = torch.vstack(ref_feats).detach().cpu().numpy()

In [10]:
query_feats = []

for img in datamodule.final_queries_dataloader():
    feats = model(img.to(device))
    query_feats.append(feats)
    
query_feats = torch.vstack(query_feats).detach().cpu().numpy()



In [11]:
get_file_ids = lambda x: [f for f in os.listdir(x)]
ref_ids = get_file_ids(get_path('references/'))[:100]
query_ids = get_file_ids(get_path('final_queries'))[:100]

In [12]:
from src.utils import search_with_capped_res

lims, dis, ids = search_with_capped_res(query_feats, ref_feats, 1000)

In [13]:
lims

array([   0,    0,    2,   10,   37,   53,   53,   57,   59,   75,   75,
         75,  103,  136,  137,  150,  170,  172,  194,  195,  207,  207,
        207,  207,  231,  236,  248,  265,  280,  281,  292,  297,  314,
        335,  341,  361,  362,  367,  402,  419,  421,  421,  433,  433,
        456,  481,  481,  495,  495,  495,  497,  521,  541,  541,  541,
        567,  572,  590,  590,  604,  604,  604,  634,  653,  653,  655,
        655,  656,  662,  662,  689,  701,  702,  702,  717,  723,  723,
        734,  739,  744,  754,  771,  805,  839,  861,  864,  870,  873,
        873,  905,  921,  923,  953,  970,  976,  978,  985,  995,  996,
        996, 1000])

In [15]:
ids

array([66, 87,  0, 23, 26, 28, 29, 57, 62, 91,  0,  3,  4,  6, 14, 19, 20,
       26, 28, 32, 33, 35, 36, 43, 45, 48, 49, 57, 58, 76, 79, 81, 84, 90,
       91, 94, 98,  3,  6, 16, 20, 26, 33, 36, 43, 55, 57, 58, 75, 81, 84,
       90, 91, 29, 57, 74, 91, 30, 82,  0,  4, 12, 13, 19, 24, 29, 36, 49,
       57, 69, 89, 90, 91, 94, 96,  0,  3,  4,  6, 13, 14, 19, 20, 28, 29,
       32, 36, 43, 45, 49, 55, 56, 57, 58, 60, 62, 79, 81, 84, 90, 91, 94,
       96,  0,  3,  4,  6,  8, 13, 14, 19, 20, 22, 29, 32, 33, 35, 36, 40,
       43, 49, 56, 57, 58, 61, 64, 67, 76, 79, 81, 84, 89, 90, 91, 94, 96,
       17, 10, 13, 19, 32, 36, 38, 41, 49, 59, 69, 76, 77, 85,  0,  2,  6,
        7,  8, 14, 22, 33, 35, 36, 43, 48, 55, 58, 65, 79, 81, 84, 90, 98,
       12, 91,  0,  1,  3,  6,  8, 26, 28, 29, 32, 35, 36, 42, 43, 45, 48,
       57, 58, 62, 81, 84, 90, 91, 27, 13, 19, 30, 32, 41, 49, 52, 76, 77,
       78, 82, 85,  0,  2,  3,  4,  6,  8, 14, 16, 18, 20, 33, 35, 36, 43,
       45, 46, 48, 54, 55

In [18]:
predictions_list = []
for i in range(100):
    for j in range(lims[i], lims[i+1]):
        #predictions_list.append((lims[i], ids[j]))
        predictions_list.append([query_ids[i], ref_ids[ids[j]]])

In [19]:
copydetectpred = CopyDetectPredDataset(predictions = predictions_list,
                                       references_dir = get_path('references'),
                                       final_queries_dir = get_path('final_queries/'))

In [20]:
from torch.utils.data import DataLoader

copydetect_dataloader = DataLoader(dataset = copydetectpred,
                                   batch_size = 16,
                                   num_workers = 8,
                                   pin_memory = True)

In [21]:
next(iter(copydetect_dataloader))



AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/leejiahe/anaconda3/envs/cama/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/leejiahe/anaconda3/envs/cama/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/leejiahe/anaconda3/envs/cama/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_29059/4066742790.py", line 26, in __getitem__
    return self.transform(reference_image). self.transform(final_queries_image)
AttributeError: 'Tensor' object has no attribute 'self'


In [None]:
torch.__version__