## Imports

In [1]:
import torch
torch.cuda.empty_cache() 
from torch import nn
import torch.nn.functional as F
from torchvision import transforms

import livia.embedding as embedding
from _7_image_embedding.dataset import TripletDataset
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm
import numpy as np

## Create dataset and load data

In [2]:
# specify root directory that contains the images
root_dir = 'data/test_images/wm_cropped_train'

# load sentece embedding that should be used for computing triplets
embedding_loaded = embedding.load_csv("data/wm/wm_sbert_title_districts_subjects_256d.csv")

# specify transforms
size = 224
transform = transforms.Compose([transforms.ToTensor()])#, transforms.CenterCrop(size)])

In [3]:
# generate train and test dataset
train_dataset = TripletDataset(sentence_embedding=embedding_loaded, 
                         root_dir=root_dir,
                         n = 2000,
                         transform=transform)

test_dataset = TripletDataset(sentence_embedding=embedding_loaded, 
                         root_dir=root_dir,
                         n = 10,
                         transform=transform)

2000 triplets are generated. This may take a while ... 
3 triplets have been removed to preserve uniqueness
1997 triplets are returned
10 triplets are generated. This may take a while ... 
0 triplets have been removed to preserve uniqueness
10 triplets are returned


In [4]:
# create dataloaders
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Create Model

In [20]:
# Creating a class
class EmbeddingNet(torch.nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__() 

        self.pool = nn.MaxPool2d(2,2)
        self.relu = nn.ReLU(True)
        self.sigmoid = nn.Sigmoid()
        self.flatten = nn.Flatten()

        # encoder layers
        self.conv1 = nn.Conv2d(3, 16, 11)
        self.conv2 = nn.Conv2d(16, 16, 1)
        
        
        self.conv3 = nn.Conv2d(16, 8, 9)
        self.conv4 = nn.Conv2d(8, 8, 9)
        
        self.conv5 = nn.Conv2d(8, 8, 7)
        self.conv6 = nn.Conv2d(8, 4, 7)
        
        
        self.lin1 = nn.Linear(1024, 512)
        self.lin2 = nn.Linear(512, 32)


    def forward(self, x):

        # encoder
        h = self.relu(self.conv1(x))
        h = self.relu(self.conv2(h))
        h = self.pool(h)
        
        h = self.relu(self.conv3(h))
        h = self.relu(self.conv4(h))
        h = self.pool(h)        
        
        h = self.relu(self.conv5(h))
        h = self.relu(self.conv6(h))
        h = self.pool(h)
            
        h = self.relu(self.lin1(self.flatten(h)))
        h = self.lin2(h)
        
        #print(h.shape)
            
        return h
    
    
class TripletNet(nn.Module):
    def __init__(self, embedding_model):
        super(TripletNet, self).__init__()
        self.embedding_model = embedding_model

    def forward(self, anchor, pos, neg):
        embedded_anchor = self.embedding_model(anchor)
        embedded_pos = self.embedding_model(pos)
        embedded_neg = self.embedding_model(neg)
        
        return embedded_anchor, embedded_pos, embedded_neg
    
emb_net = EmbeddingNet()
model = TripletNet(emb_net).to(device="cuda")

In [21]:
for anchor, pos, neg in train_dataloader:

    print("in:", anchor.shape)

    anchor = anchor.to(device="cuda")
    pos = pos.to(device="cuda")
    neg = neg.to(device="cuda")

    anch_hidden, pos_hidden, neg_hidden = model(anchor, pos, neg)

    print("h:", anch_hidden.shape)
    break


in: torch.Size([32, 3, 224, 224])
h: torch.Size([32, 32])


## Train Model

In [22]:
# triplet loss
margin=1
triplet_loss = torch.nn.TripletMarginLoss(margin=margin)
#triplet_loss = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1-F.cosine_similarity(x, y), 
#                                                margin=margin)

emb_net = EmbeddingNet()
model = TripletNet(emb_net).to(device="cuda")
lr=1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

train_loss = list()

writer = SummaryWriter(f'experiments/runs/triplet_network_margin={margin}_lr={lr}')#_centercropped={size}

num_epochs = 100
progress_bar = tqdm(range(num_epochs))
j = 0
for i in progress_bar:

    #if i%print_at == 0:
    #    print(f"Epoch:{i}")

    epoch_loss = list()
    for anchor, positive, negative in train_dataloader:
        
        anchor = anchor.to(device="cuda")
        pos = positive.to(device="cuda")
        neg = negative.to(device="cuda")

        anch_hidden, pos_hidden, neg_hidden = model(anchor, pos, neg)
        
        
        loss = triplet_loss(anch_hidden, pos_hidden, neg_hidden)
        
        
        loss.backward()
        optimizer.step()
        
        j += 1
        
        optimizer.zero_grad()
        
        loss_info = loss.cpu().detach().numpy()
        
        epoch_loss.append(loss_info)
        
        if j%10==0:
            writer.add_scalar('training/loss',
                    loss_info,
                    j)

    # compute epoch loss and write into progress bar
    loss_str = str(np.around(np.mean(epoch_loss),5))
    progress_bar.set_postfix_str(loss_str)
    
    # tensorboard
    writer.add_scalar('training/mean epoch loss',
                    np.around(np.mean(epoch_loss),5),
                    i)
    
    # save epoch loss
    train_loss.append(np.mean(epoch_loss))


100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [34:00<00:00, 20.40s/it, 0.01119]
