In [1]:
import torch
from torch.utils.data import Dataset
import torchvision
from PIL import Image
import requests
from io import BytesIO
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import numpy as np
import os


class LazyRotationImageDataset(Dataset):
	def __init__(self, root_dir, transform=None):
		super(LazyRotationImageDataset, self).__init__()
		self.data = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(".jpg")]
		self.rotations = [0, 90, 180, 270]
		self.transform = transform

	def __len__(self):
		return len(self.data)

	def __getitem__(self, idx):
		img_path = self.data[idx]
		image = Image.open(img_path)

		if self.transform:
			image = self.transform(image)

		rotation_idx = torch.randint(0, 4, (1,)).item()  # Random index for rotation
		rotation_angle = self.rotations[rotation_idx]  # Corresponding rotation angle

		rotation_transform = transforms.Compose([
			transforms.RandomRotation([rotation_angle, rotation_angle], expand=True),
			transforms.ToTensor()
		])

		rotated_image = rotation_transform(image)  # Applies the selected rotation
		return rotated_image, rotation_idx

# Example usage
dataset = LazyRotationImageDataset("../data/images")
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Now you can use this dataloader in your training loop

In [2]:
len(dataset)

3000

In [3]:
from tqdm import tqdm

# Model modification to predict rotation
class RotationPredictor(nn.Module):
	def __init__(self):
		super().__init__()
		self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
		self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 4)  # Predicting 4 rotation classes

	def forward(self, x):
		return self.resnet(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RotationPredictor().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
# model.train()
# epochs = 10
# for epoch in range(epochs):
#   for images, labels in tqdm(dataloader):
#     images = images.to(device)
#     labels = labels.to(device)
#     outputs = model(images)
#     loss = criterion(outputs, labels)
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
#   print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

# Now the model is fine-tuned to predict rotations, which also improves its feature extraction capability

In [None]:
torch.save(model.state_dict(), "model_finetuned.pt")

In [4]:
def count_parameters(model):
	return sum(p.numel() for p in model.parameters())


In [None]:
count_parameters(model)

11178564

In [None]:
model.eval()
import copy

resnet_pretrained = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
resnet_pretrained = nn.Sequential(*list(resnet_pretrained.children())[:-4], nn.AdaptiveAvgPool2d((1, 1)))
resnet_pretrained = torch.quantization.quantize_dynamic(
	resnet_pretrained, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
)


model_finetunned = copy.deepcopy(model.resnet)
model_finetunned = nn.Sequential(*list(model_finetunned.children())[:-4], nn.AdaptiveAvgPool2d((1, 1)))
model_finetunned = torch.quantization.quantize_dynamic(
	model_finetunned, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
)

count_parameters(resnet_pretrained), count_parameters(model_finetunned)

(683072, 683072)

In [None]:
torch.save(resnet_pretrained.state_dict(), "model_pretrained.pt")

In [None]:
features_pretrained = torch.empty((len(dataset), 128))
batch_size = 32

In [None]:
for i, (images, labels) in tqdm(enumerate(dataloader)):
	outputs = resnet_pretrained(images).squeeze().detach()
	features_pretrained[i * batch_size: i * batch_size + len(images)] = outputs

94it [01:04,  1.45it/s]


In [None]:
torch.save(features_pretrained, "features_pretrained.pt")

In [None]:
features_finetuned = torch.empty((len(dataset), 128)).to(device)
batch_size = 32

for i, (images, labels) in tqdm(enumerate(dataloader)):
	images = images.to(device)
	labels = labels.to(device)
	outputs = model_finetunned(images).squeeze().detach()
	features_finetuned[i * batch_size: i * batch_size + len(images)] = outputs

0it [00:00, ?it/s]

94it [00:05, 17.52it/s]


In [None]:
torch.save(features_finetuned, "features_finetuned.pt")

In [4]:
w_pretrained = torch.load("../models/model_pretrained.pt")
features_pretrained = torch.load("../data/features_pretrained.pt")
w_finetuned = torch.load("../models/model_finetuned.pt")
features_finetuned = torch.load("../data/features_finetuned.pt")

In [5]:
base_resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model_pretrained = nn.Sequential(
    *(list(base_resnet.children())[:-4]),
    nn.AdaptiveAvgPool2d((1, 1))
)
model_pretrained.load_state_dict(w_pretrained)

<All keys matched successfully>

In [7]:
# base_resnet = RotationPredictor().resnet
# model_finetuned = nn.Sequential(
#     *(list(base_resnet.children())[:-4]),
#     nn.AdaptiveAvgPool2d((1, 1))
# )
# model_finetuned.load_state_dict(w_finetuned)

In [6]:
from sklearn.cluster import KMeans
import numpy as np

kmeans_finetuned = KMeans(n_clusters=5, random_state=0)
clusters_finetuned = kmeans_finetuned.fit_predict(features_finetuned.cpu().numpy())
kmeans_pretrained = KMeans(n_clusters=5, random_state=0)
clusters_pretrained = kmeans_pretrained.fit_predict(features_pretrained.cpu().numpy())



In [7]:
import copy

resnet_original = models.resnet18(pretrained=False)
model_test_pretrained = nn.Sequential(*list(resnet_original.children())[:-4], nn.AdaptiveAvgPool2d((1, 1)))



In [8]:
# Load the saved model weights
model_weights_path = '../models/model_pretrained.pt'
model_test_pretrained.load_state_dict(torch.load(model_weights_path))

<All keys matched successfully>

In [9]:
model_test_pretrained.eval()

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

In [10]:
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

class LazyImageDataset(Dataset):
	def __init__(self, root_dir, transform=None):
		super(LazyImageDataset, self).__init__()
		self.data = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(".jpg")]
		self.transform = transform

	def __len__(self):
		return len(self.data)

	def __getitem__(self, idx):
		img_path = self.data[idx]
		image = Image.open(img_path)

		if self.transform:
			image = self.transform(image)

		return image

test_dataset = LazyImageDataset('../data/images_test', transform=transforms.ToTensor())
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) 

In [11]:
model_test_pretrained.eval()  # Ensure the model is in evaluation mode

batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_test_pretrained = model_test_pretrained.to(device)
test_embs = torch.empty((len(test_dataset), 128)).to(device)

for i, images in tqdm(enumerate(test_data_loader)):
	images = images.to(device)
	outputs = model_test_pretrained(images).squeeze().detach()
	test_embs[i * batch_size: i * batch_size + len(images)] = outputs

0it [00:00, ?it/s]

1it [00:00,  1.71it/s]


In [12]:
test_embs.shape

torch.Size([10, 128])

In [13]:
clusters = kmeans_pretrained.predict(test_embs.cpu().numpy())
clusters

array([0, 0, 3, 1, 1, 0, 1, 0, 0, 0], dtype=int32)

In [14]:
idx = 4
same_cluster = np.where(clusters == clusters[idx])[0]
cluster_embs = test_embs[same_cluster]
cluster_embs.shape

torch.Size([3, 128])

In [15]:
# compute similarity_matrix
vectors = cluster_embs / cluster_embs.norm(dim=1, keepdim=True)  # normalize the vectors
similarity_matrix = vectors @ vectors.t() 
similarity_matrix.shape

torch.Size([3, 3])

In [20]:
similarities = similarity_matrix.mean(dim=1)
similarities.shape

torch.Size([3])

In [21]:
def get_clusters(embs, kmeans):
	return kmeans.predict(embs.cpu().numpy())

def get_cluster_embs(embs, clusters, idx):
	same_cluster = np.where(clusters == clusters[idx])[0]
	return embs[same_cluster]

def get_similarity_embs(embs):
	vectors = embs / embs.norm(dim=1, keepdim=True)
	similarity_matrix = vectors @ vectors.t()
	return similarity_matrix.mean(dim=1)

In [None]:
import gradio as gr

images_dir = "data/images"

images = [os.path.join(images_dir, file) for file in os.listdir(images_dir) if file.endswith(".jpg")]

def images_to_show(img):
	
	return img, img, img

def set_as_input(img):
    blank = np.ones_like(img)*255
    return img, blank, blank, blank

with gr.Blocks() as gui:
  with gr.Row():
    with gr.Column(scale=2):
      img_in = gr.Image(type="numpy")
      btn = gr.Button("Search")
    with gr.Column(scale=1):
      img_out1 = gr.Image(show_download_button=False, interactive=False, type="numpy")
      btn1 = gr.Button("Set as input")
      img_out2 = gr.Image(show_download_button=False, interactive=False, type="numpy")
      btn2 = gr.Button("Set as input")
      img_out3 = gr.Image(show_download_button=False, interactive=False, type="numpy")
      btn3 = gr.Button("Set as input")
  
  btn.click(images_to_show, inputs=img_in, outputs=[img_out1, img_out2, img_out3])
  btn1.click(set_as_input, inputs=img_out1, outputs=[img_in, img_out1, img_out2, img_out3])
  btn2.click(set_as_input, inputs=img_out2, outputs=[img_in, img_out1, img_out2, img_out3])
  btn3.click(set_as_input, inputs=img_out3, outputs=[img_in, img_out1, img_out2, img_out3])

gui.launch()

# Fine tunning model

In [18]:
import copy

# Load the original pre-trained ResNet model
resnet_original = models.resnet18(weights=None)  # Make sure to use the correct ResNet model (resnet18, resnet50, etc.)


# Modify the model by removing the last four layers and adding an AdaptiveAvgPool2d
model_test_finetunning = nn.Sequential(*list(resnet_original.children())[:-4], nn.AdaptiveAvgPool2d((1, 1)))



In [19]:
# Load the saved model weights
model_weights_path = '../models/model_finetuned.pt'
model_test_finetunning.load_state_dict(torch.load(model_weights_path), strict=False)

_IncompatibleKeys(missing_keys=['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '4.0.conv1.weight', '4.0.bn1.weight', '4.0.bn1.bias', '4.0.bn1.running_mean', '4.0.bn1.running_var', '4.0.conv2.weight', '4.0.bn2.weight', '4.0.bn2.bias', '4.0.bn2.running_mean', '4.0.bn2.running_var', '4.1.conv1.weight', '4.1.bn1.weight', '4.1.bn1.bias', '4.1.bn1.running_mean', '4.1.bn1.running_var', '4.1.conv2.weight', '4.1.bn2.weight', '4.1.bn2.bias', '4.1.bn2.running_mean', '4.1.bn2.running_var', '5.0.conv1.weight', '5.0.bn1.weight', '5.0.bn1.bias', '5.0.bn1.running_mean', '5.0.bn1.running_var', '5.0.conv2.weight', '5.0.bn2.weight', '5.0.bn2.bias', '5.0.bn2.running_mean', '5.0.bn2.running_var', '5.0.downsample.0.weight', '5.0.downsample.1.weight', '5.0.downsample.1.bias', '5.0.downsample.1.running_mean', '5.0.downsample.1.running_var', '5.1.conv1.weight', '5.1.bn1.weight', '5.1.bn1.bias', '5.1.bn1.running_mean', '5.1.bn1.running_var', '5.1.conv2.weight', '5.1.bn2.weight', '5.1.bn2.b

In [21]:
model_test_finetunning.eval()

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con