# Import Library

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet50
import numpy as np
import pandas as pd
import cv2
import os
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from torchsummary import summary
from tensorboard.plugins import projector
from scipy.optimize import linear_sum_assignment
from utils.utils import *
from utils.assignment import *
from utils.latent_loss import *

# Baseline

In [2]:
class Baseline(nn.Module):
    def __init__(self, hidden_dim=384, nheads=4, ## According to feature vectors that we get from SBERT, hidden_dim = 384
                 num_encoder_layers=3, num_decoder_layers=3):
        super().__init__()

        # create ResNet-50 backbone
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create encoder and decoder layers
        self.encoder = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nheads)
        self.decoder = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=nheads)
        
        # create a default PyTorch transformer: nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder, num_encoder_layers)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder, num_decoder_layers)

        # output positional encodings (sentence)
        self.sentence = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings (may be changed to sin positional encodings)
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        
    def forward(self, X):
        X = self.backbone(X)
        feat = self.conv(X)
        H, W = feat.shape[-2:]
        
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        
        feat = self.transformer_encoder(pos + 0.1 * feat.flatten(2).permute(2, 0, 1))
        R = self.transformer_decoder(self.sentence.unsqueeze(1), feat).transpose(0, 1)
        return R, feat

In [3]:
img_path = 'data/images'
filenames = os.listdir(img_path)
f = os.path.join(img_path, filenames[0])

img = Image.open(f)

transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

t_img = transform_img(img, transform)
print(t_img.shape)

torch.Size([1, 3, 975, 800])


In [5]:
model = Baseline()
R, feat = model(t_img)

print(R.shape, feat.shape)

torch.Size([1, 100, 384]) torch.Size([775, 1, 384])


# LSP Decoder

In [6]:
class LSP_Decoder(nn.Module):
    def __init__(self, hidden_dim=384, nhead=4, num_layers=3):
        super().__init__()
        
        self.decoder = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=nhead)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder, num_layers=num_layers)
        
    def forward(self, tgt, memory):
        decode_sen = self.transformer_decoder(tgt, memory)
        return decode_sen

In [7]:
decoder = LSP_Decoder()

N, c = 10, 384
emb = nn.Embedding(N, c)
x = torch.arange(N)
x = emb(x).unsqueeze(1)

y = decoder(x, feat)
y = y.transpose(0, 1)
print(y.shape)

torch.Size([1, 10, 384])


# Tensorboard

In [12]:
writer = SummaryWriter()

for n_iter in range(100):
    writer.add_scalar('Loss/train', np.random.random(), n_iter)
    writer.add_scalar('Loss/test', np.random.random(), n_iter)
    writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
    writer.add_scalar('Accuracy/test', np.random.random(), n_iter)


writer.add_embedding(y.reshape((10,384)))
writer.close()

In [9]:
%load_ext tensorboard
%tensorboard --logdir=runs

Reusing TensorBoard on port 6006 (pid 6208), started 12:12:28 ago. (Use '!kill 6208' to kill it.)

# Loss

In [12]:
loss_fn = MSEGCRLatentLoss()

In [16]:
B = torch.tensor([[2, 0], [1, 2], [5, 3]], dtype=float, requires_grad=True)
len_B = torch.tensor([1, 2])

R = torch.tensor([[2, 0], [2, 0], [5, 2]], dtype=float, requires_grad=True)
len_R = torch.tensor([1, 2])

loss = loss_fn.forward(B, len_B, R, len_R)
print(loss)
loss.backward()

tensor(0.8250, dtype=torch.float64, grad_fn=<AddBackward0>)
