In [1]:
"""
AIL861: Assignment 3

"""

import os
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageFilter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([transforms.Resize(224), transforms.ToTensor()]) 

current_dir = os.getcwd()
images = []

# Iterate over the files in the folder
for filename in os.listdir(current_dir):
    # Check if the file is a PNG image
    if filename.endswith(".png"):
        # Load the image and apply the transformation
        img = Image.open(os.path.join(current_dir, filename)).convert('RGB')
        img = img.filter(ImageFilter.MedianFilter(size=3))
        img_tensor = transform(img)
        # Append the tensor to the list of images
        images.append(img_tensor)
images_tensor = torch.stack(images)

In [3]:
import torchvision.models as models

vgg16 = models.vgg16(pretrained=True)



In [4]:
features = list(vgg16.features)
model = torch.nn.Sequential(*features)

In [5]:
## Define your function to find the similarity here
def findSimilarity(img1, img2):  ##Pass your arguments
    fm1 = torch.flatten(model(img1)).unsqueeze(0)
    cos_sim = []
    fm2 = model(img2)
    # right shift
    cols = fm2.size()[-1]
    rows = fm2.size()[-2]
    for i in range(cols):
        shifted_tensor = torch.roll(fm2, shifts=i, dims=-1)
        for j in range(rows):
            shifted_tensor = torch.roll(shifted_tensor, shifts=j, dims=-2)
            cos_sim.append(torch.nn.functional.cosine_similarity(fm1, torch.flatten(shifted_tensor).unsqueeze(0)).item())

    return max(cos_sim)

In [6]:
sim_scores = []
for im1 in images_tensor:
    temp = []
    for im2 in images_tensor:
        score = findSimilarity(im1, im2)
        temp.append(score*100)
    sim_scores.append(temp)

In [7]:
from tabulate import tabulate

headers = ['1','2','3','4','5','6']
table1 = tabulate(sim_scores, headers=headers, showindex=headers, tablefmt="fancy_grid")
print(table1)

╒════╤══════════╤══════════╤══════════╤══════════╤══════════╤══════════╕
│    │        1 │        2 │        3 │        4 │        5 │        6 │
╞════╪══════════╪══════════╪══════════╪══════════╪══════════╪══════════╡
│  1 │ 100      │  24.4571 │  24.5867 │  95.8261 │  24.2701 │  25.0057 │
├────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┤
│  2 │  24.4571 │ 100      │  44.6523 │  23.664  │  96.4061 │  42.8153 │
├────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┤
│  3 │  24.5867 │  44.6523 │ 100      │  24.3646 │  42.9165 │  95.6366 │
├────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┤
│  4 │  95.8261 │  23.664  │  24.3646 │ 100      │  24.1915 │  25.5874 │
├────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┤
│  5 │  24.2701 │  96.4061 │  42.9165 │  24.1915 │ 100      │  43.1631 │
├────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┤
│  6 │  25.0057 │  42.8153 │  95.6366 │  25.5874 │ 