In [1]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [14]:
import pandas as pd
import os
import numpy as np
import random
import requests
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
from io import BytesIO
from time import time, sleep
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image, ImageReadMode

In [3]:
device = torch.device('cuda')

In [4]:
def Conv_Block(channel_in, channel_out):
    return nn.Sequential(
        nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1),
        nn.Conv2d(channel_out, channel_out, kernel_size=3, padding=1),
        nn.MaxPool2d(2),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(channel_out),
    )

In [5]:
###### Define the Siamese network architecture
def SiameseNetwork():
    return nn.Sequential(   #  3x768x768
        nn.BatchNorm2d(3), # no tinc ganes de calcular mean i var
        Conv_Block(3, 4),   #  4x384x384
        Conv_Block(4, 8),   #  8x192x192
        Conv_Block(8, 16),  # 16x 96x 96
        Conv_Block(16, 32), # 32x 48x 48
        Conv_Block(32, 32), # 32x 24x 24
        Conv_Block(32, 64), # 64x 12x 12
        Conv_Block(64, 64), # 64x  6x  6
        nn.Flatten(),
        nn.Linear(2304, 1000),
        nn.Dropout(0.2),
        nn.Linear(1000,256)
    )

In [6]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin
        self.dist = nn.PairwiseDistance()
    
    def forward(self, output1, output2, target):
        distance = self.dist(output1, output2)
        loss = torch.mean((1 - target) * torch.pow(distance, 2) +
                              (target) * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))
        return loss

In [7]:
# img_dir = '/kaggle/input/hackupc/Imatges'
img_dir = '/kaggle/input/masked10k/masked_images'
img_paths = os.listdir(img_dir)

In [10]:
class CustomImageDataset(Dataset):
    def __init__(self, img_paths, img_dir):
        self.img_paths = img_paths
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        while True:
            try:
                return [transforms.ToTensor()(Image.open(img_dir + '/' + img_path).convert('RGB')) for img_path in self.img_paths[idx]]
            except:
                idx = random.randint(0, len(self.img_paths)-1)

In [11]:
grouped_paths = {}
for path in img_paths:
    group = int(path.split('_')[1])
    if group not in grouped_paths:
        grouped_paths[group] = []
    grouped_paths[group].append(path)

all_groups = [val for val in grouped_paths.values()]
big_groups = [val for val in grouped_paths.values() if len(val) > 1]

len(all_groups), len(big_groups)

(7739, 7050)

In [12]:
siamese_net = SiameseNetwork().to(device)

print(siamese_net)
summary(siamese_net, (3, 768, 768))

Sequential(
  (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ReLU(inplace=True)
    (4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (2): Sequential(
    (0): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ReLU(inplace=True)
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (3): Sequential(
    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 

In [None]:
criterion = ContrastiveLoss()
optimizer = optim.Adam(siamese_net.parameters(), lr=0.0004)
batch_size = 16

siamese_net.train()
num_epochs = 2
# total_losses = []
for epoch in range(num_epochs):
    
    random.shuffle(all_groups)
    curr_groups = [random.sample(big_groups[i], k=2) + random.sample(all_groups[i], k=1) for i in range(len(big_groups))]
    
    dataset = CustomImageDataset(curr_groups, img_dir)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    total_loss, last_loss = 0, 0
    idx = 0
    for img1, img2, img3 in tqdm(dataloader):
        idx += 1

        optimizer.zero_grad()
        img1, img2, img3 = img1.to(device), img2.to(device), img3.to(device)
        emb1, emb2, emb3 = siamese_net(img1), siamese_net(img2), siamese_net(img3)
        loss = criterion(emb1, emb2, 0) + criterion(emb1, emb3, 1)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        if idx%100 == 0:
            print("loss since last:", (total_loss-last_loss)/100)
            last_loss = total_loss
            
            dist = nn.PairwiseDistance()
            print("1-2 dist:", dist(emb1[0], emb2[0]).item())
            print("1-3 dist:", dist(emb1[0], emb3[0]).item())
            display(transforms.ToPILImage()(img1[0]))
            display(transforms.ToPILImage()(img2[0]))
            display(transforms.ToPILImage()(img3[0]))
        
    total_loss /= len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss}")
    total_losses.append(total_loss)
    plt.plot(total_losses)
    plt.show()

  0%|          | 0/441 [00:00<?, ?it/s]

In [16]:
link = '/kaggle/input/links-inditex/inditextech_hackupc_challenge_images.csv'
all_links = pd.read_csv(link).values

# siamese_net = SiameseNetwork().to(device)

siamese_net.eval()
embeds = {}
for idx, vec in enumerate(all_groups):
#     if idx == 200:
#         break
    name = vec[0]
    batch = torch.concat([read_image(img_dir + '/' + img_path).unsqueeze(0)/255 for img_path in vec], dim=0).to(device)
    embedding = siamese_net(batch).mean(dim=0).cpu().detach().to_sparse().values()
    embeds[name] = embedding

df = pd.DataFrame(embeds).transpose().rename(columns = {i : f'embedding_{i}' for i in range(256)}).reset_index()
df['link_1'] = df['index'].map(lambda x: all_links[int(x.split('_')[1])][0])
df['link_2'] = df['index'].map(lambda x: all_links[int(x.split('_')[1])][1])
df['link_3'] = df['index'].map(lambda x: all_links[int(x.split('_')[1])][2])
df = df.drop(columns='index').rename(columns={0:'embedding'})
df.to_csv('temp_embed.csv')