In [1]:
import torch
import torch.nn as nn
import torch.nn.init
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from torchtext.models import ROBERTA_BASE_ENCODER
from torchtext.functional import to_tensor

import numpy as np
import pandas as pd
import glob
import os
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def normalization(X):
    """L2-normalization of features columns"""
    norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt()
    X = torch.div(X, norm)
    
    return X

In [3]:
class ImageEncoder(nn.Module):

  def __init__(self, embedding_size, cnn_type):
    """Initializing parameters"""
    super(ImageEncoder).__init__()
    self.embedding_size = embedding_size # Size of projected image
    self.cnn = self.load_cnn(cnn_type)

    # No need to finetune parameters = frozen layers
    for param in self.cnn.parameters():
        param.requires_grad = False

    # Replacing last fully connected layer with new one
    self.fc = nn.Linear(self.cnn.classifier._modules['6'].in_features, embedding_size)
    self.cnn.classifier = nn.Sequential(*list(self.cnn.classifier.children())[:-1])

    # Initializing the weights of fully-connected layer, which makes projection to new space
    self.initialization_weights()
  
  def load_cnn(self, cnn_type):
    """Loading pretrained model"""
    model = models.__dict__[cnn_type](pretrained=True)

    return model

  def initialization_weights(self):
    """Xavier initialization"""
    r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
    self.fc.weight.data.uniform_(-r, r)
    self.fc.bias.data.fill_(0)

  def forward(self, X):
    """Creation of features"""
    # Creation of embeddings
    features = self.cnn(X)

    # Normalization of embeddings
    features = normalization(features)

    # Projection to new space
    features = self.fc(features)

    return features

In [4]:
ROBERTA_OUT_DIM = 768

class TextEncoder(nn.Module):

    def __init__(self, embedding_size):
        """Initializing parameters"""
        super(TextEncoder).__init__()
        self.embedding_size = embedding_size # Size of projected text
        self.roberta = ROBERTA_BASE_ENCODER.get_model()
        self.transform = ROBERTA_BASE_ENCODER.transform()

        # Linear layer
        self.fc = nn.Linear(ROBERTA_OUT_DIM, embedding_size)

        # Initializing the weights of fully-connected layer, which makes projection to new space
        self.initialization_weights()
        
    def _roberta_encode(self, batch):
        transformed = self.transform(batch)
        model_input = to_tensor(transformed, padding_value=1)
        return self.roberta(model_input)

    def initialization_weights(self):
        """Xavier initialization"""
        r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
        self.fc.weight.data.uniform_(-r, r)
        self.fc.bias.data.fill_(0)

    def forward(self, X, lengths=None):
        """Creation of features"""
        # Creation of embeddings
        features = self._roberta_encode(X)

        # Normalization of embeddings
        features = normalization(features)

        # Projection to new space
        features = self.fc(features)

        return features

In [5]:
class FullEncoder(nn.Module):

    def __init__(self, embedding_size, cnn_type):
        """Initializing parameters"""
        super(FullEncoder).__init__()
        self.embedding_size = embedding_size # Size of projected text
        self.image_encoder = ImageEncoder(embedding_size)
        self.text_encoder = TextEncoder(embedding_size)

    def forward(self, X):
        """Creation of features"""
        img_feas = self.image_encoder(X)
        txt_feas = self.text_encoder(X)
        return img_feas, txt_feas

In [12]:
ex = '18 Kt Rose Gold Supreme Swan Charm Bracelet The Supreme Swan Charm Bracelet is the symbol of love, peace and grace. Gift this to a loved one that adds these elements to your life. Made in Sterling Silver with a rose gold polish.'
ex = "this is a test with 7 words"
roberta = ROBERTA_BASE_ENCODER.get_model()
transform = ROBERTA_BASE_ENCODER.transform()
transformed = transform([ex])
model_input = to_tensor(transformed, padding_value=1)
out = roberta(model_input)
print(out.shape)
out

torch.Size([1, 9, 768])


tensor([[[-0.1720,  0.3779, -0.1899,  ..., -0.6369, -0.4730, -0.1509],
         [-0.1646, -0.2277, -0.1225,  ..., -0.3788, -0.0618,  0.0266],
         [ 0.3021,  0.0470,  0.0477,  ..., -0.5150, -0.1109,  0.1990],
         ...,
         [-0.0529, -0.0704,  0.0889,  ...,  0.0667, -0.1826, -0.2713],
         [ 0.0733, -0.1830,  0.0620,  ..., -0.1513, -0.0279, -0.2905],
         [-0.0012,  0.0454, -0.0357,  ..., -0.1739, -0.1121, -0.0572]]],
       grad_fn=<TransposeBackward0>)

In [7]:
ex2 = ['18 Kt Rose Gold Supreme Swan Charm Bracelet The Supreme Swan Charm Bracelet is the symbol of love, peace and grace. Gift this to a loved one that adds these elements to your life. Made in Sterling Silver with a rose gold polish.',
 '18 Kt Yellow Gold Stardust Earrings Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',
 '18 Kt Yellow Gold Bliss Square Earrings Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',
 '18 Kt Yellow Gold Fiery Flame Earrings Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',
 '18 Kt Yellow Gold Sparkling Star Earrings Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',
 '18 Kt Rose Gold Queen of Heart Earrings Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',
 '18 Kt Rose Gold First Steps Earrings Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',
 '18 Kt Yellow Gold Wide-Eye Owl Earring Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',
 '18 Kt Yellow Gold XOXO Milennial Earrings Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',
 '18 Kt Rose Gold Cute Kitty Earrings Crafted with perfection, Talisman’s Sterling Silver Stud Earrings are perfect accessory to double up your stellar look. Whether it is an important work meeting or a night out with your besties, we have got you covered! Elevate your outfits with our fresh, everyday styles.',]

In [8]:
type(ex2[0])

str

In [9]:
roberta = ROBERTA_BASE_ENCODER.get_model()
transform = ROBERTA_BASE_ENCODER.transform()
transformed = transform(ex2)
model_input = to_tensor(transformed, padding_value=1)
out = roberta(model_input)
print(out.shape)

torch.Size([10, 72, 768])
