In [1]:
import torch
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
model.eval()

Using cache found in /Users/francescobassignana/.cache/torch/hub/facebookresearch_dinov2_main


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-23): 24 x NestedTensorBlock(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (norm): LayerNorm((1024,), eps=1e-06, element

In [2]:
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize(518, interpolation=3),
    transforms.CenterCrop(518),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [13]:
# Create dataloader AFTER loading the model so we can infer image size
from utils.dataset import BatchedImageIterable
from torch.utils.data import DataLoader
images_dir: str = "./data/COCO_inpainted"
masks_dir: str = "./data/masks"
dataset = BatchedImageIterable(
    images_dir,
    batch_size=1,
    transform_img=transform,
)
dataloader = DataLoader(dataset, batch_size=4, num_workers=1)  # dataset yields pre-batched items

In [14]:
noise_level = 0.07

In [15]:
import torch.nn.functional as F
import torch
from utils.utils import Metric
avg_estimate = Metric()
embeddings = []

for i, (image, orig) in enumerate(dataloader):
    with torch.no_grad():
        image = image.squeeze(0)  # remove first dim
        print(image.shape)
        noise = torch.randn_like(image) * noise_level
        noisy = (image + noise).clamp(0, 1)

        batch_pair = torch.stack([image, noisy], dim=1)
        B = batch_pair.shape[0]
        flat = batch_pair.reshape(B * 2, 3, 518, 518)

        print(f"batch_pair shape {batch_pair.shape}, flat shape {flat.shape}")

        emb = model(flat)
        emb = emb / emb.norm()

        emb_pair = emb.reshape(B, 2, -1)

        print(f"emb flat shape {emb.shape} -> emb_pair shape {emb_pair.shape}")
        embeddings.append(emb_pair)

        # to continue

embeddings = torch.cat(embeddings, dim=0)
embeddings.shape

torch.Size([4, 1, 3, 518, 518])
batch_pair shape torch.Size([4, 2, 1, 3, 518, 518]), flat shape torch.Size([8, 3, 518, 518])
emb flat shape torch.Size([8, 1024]) -> emb_pair shape torch.Size([4, 2, 1024])
torch.Size([4, 1, 3, 518, 518])
batch_pair shape torch.Size([4, 2, 1, 3, 518, 518]), flat shape torch.Size([8, 3, 518, 518])
emb flat shape torch.Size([8, 1024]) -> emb_pair shape torch.Size([4, 2, 1024])
torch.Size([4, 1, 3, 518, 518])
batch_pair shape torch.Size([4, 2, 1, 3, 518, 518]), flat shape torch.Size([8, 3, 518, 518])
emb flat shape torch.Size([8, 1024]) -> emb_pair shape torch.Size([4, 2, 1024])
torch.Size([4, 1, 3, 518, 518])
batch_pair shape torch.Size([4, 2, 1, 3, 518, 518]), flat shape torch.Size([8, 3, 518, 518])
emb flat shape torch.Size([8, 1024]) -> emb_pair shape torch.Size([4, 2, 1024])
torch.Size([4, 1, 3, 518, 518])
batch_pair shape torch.Size([4, 2, 1, 3, 518, 518]), flat shape torch.Size([8, 3, 518, 518])
emb flat shape torch.Size([8, 1024]) -> emb_pair shape 

torch.Size([20, 2, 1024])

In [17]:
avg_similarity = Metric()
for pair in embeddings:
    # pair shape (2, D)
    sim = F.cosine_similarity(pair[0], pair[1], dim=0)
    avg_similarity.update(sim.item())
print("Average cosine similarity between original and noisy images:", avg_similarity.avg())
    

Average cosine similarity between original and noisy images: 0.8924035340547561
