<a href="https://colab.research.google.com/github/earendil94/SMLExam/blob/master/SML_Project_Claudia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**STATISTICAL MACHINE LEARNING**

ARRIGHI Leonardo, BRAND Francesco, DORIGO Claudia


Dataset folder is saved in "/content/drive/My Drive/SML/SML_Project".

# Introduction

In [1]:
# link colab and drive
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
# then follow passages

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import pandas as pd
import os
import torchvision
from torchvision import transforms, models
from IPython import display
import shelve
from PIL import Image
import glob

torch.manual_seed(160898)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device: {}'.format(device))

Device: cpu


# Images Preprocessing

In [3]:
# APPLY TRANSFORMATIONS TO PIL IMAGE 
def transform_image(image):
  transform = transforms.Compose([transforms.Resize(256),
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                  ])
  return transform(image)

img_prova = Image.open('/content/drive/My Drive/SML/SML_Project/Chunk1/img1_200/1624481.jpg')
i = transform_image(img_prova)

# PRINT IMAGE HAVING AN IMAGE KEY
def print_image_from_key(image_key):
  path = '/content/drive/My Drive/SML/SML_Project/All_images/'
  im = Image.open(os.path.join(path,str(image_key)+'.jpg'))
  %matplotlib inline
  imshow(np.asarray(im))

#print_image_from_key(1624481)

In [4]:
# FUNCTION TO BUILD THE DICTIONARY (JUST FIRST 15K IMAGES)
# don't run, it takes 2 hours

def build_img_shelve(img_folder,shelve_path):
  path = os.path.join(img_folder,'*.jpg')
  with shelve.open(shelve_path) as d:
    for i in glob.glob(path)[:15000]: #just first 15000 otherwise too big
      im=Image.open(i)
      d[i[len(img_folder):].split('.')[0]]=transform_image(im)

    
imgs = '/content/drive/My Drive/SML/SML_Project/All_images/'
#build_img_shelve(imgs,'/content/drive/My Drive/SML/SML_Project/img_shelve_15k')

In [5]:
# DICTIONARY FROM SHELVE
def dict_from_shelve(shelve_file):
  dictionary = {}
  d=shelve.open(shelve_file)
  for k in d.keys():
    dictionary[int(k)]=d[k]
  return dictionary

chunk1_dict = dict_from_shelve('/content/drive/My Drive/SML/SML_Project/Chunk1/img_shelve')

# Encoder

In [6]:
#https://towardsdatascience.com/automatic-image-captioning-with-cnn-rnn-aae3cd442d83
# ENCODER CNN
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
        
        modules = list(resnet.children())[:-1] # remove last layer
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.embed(features)
        return features

'''
image = a[10002456]
image = image.unsqueeze(0)
cnn = EncoderCNN()

if torch.cuda.is_available():
  image = image.to('cuda')
  cnn.to('cuda')

output = cnn(image)
print(output,output.shape)
'''

"\nimage = a[10002456]\nimage = image.unsqueeze(0)\ncnn = EncoderCNN()\n\nif torch.cuda.is_available():\n  image = image.to('cuda')\n  cnn.to('cuda')\n\noutput = cnn(image)\nprint(output,output.shape)\n"

In [7]:
# FUNCTION TO CHECK THE RESULT OF CNN 
#(without removing last layer print 10 more probable classes) 
import json
from matplotlib.pyplot import imshow

def CNN_classification(image,model):
  path = '/content/drive/My Drive/SML/SML_Project/All_images/'
  im = Image.open(os.path.join(path,str(image)+'.jpg'))
  model.eval()

  out = model(transform_image(im).unsqueeze(0))
  #class_probs = torch.nn.functional.softmax(out, dim=0) #not needed

  class_idx = json.load(open("/content/drive/My Drive/SML/SML_Project/classes.json"))
  idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]

  # sort wrt probability and take 10 more probable indexes
  for idx in out[0].sort()[1][-10:]:
    print(idx2label[idx])

  #print the image
  %matplotlib inline
  imshow(np.asarray(im))

#pretrained_ResNet = torch.hub.load('pytorch/vision:v0.6.0', 'resnet152', pretrained=True)
#CNN_classification('1624481',pretrained_ResNet)


# Dataset, sampler, split

In [8]:
# DATASET,SAMPLER CLASSES (FOR DATALOADER) AND SPLIT METHODS
import torch
from torch.utils.data import Dataset, Sampler, DataLoader, SubsetRandomSampler
import pandas as pd

class imageCaptionDataset(Dataset):
  '''
    What we want to achieve is a mapping of type:
    { img_name: [img_tensor, caption_1, caption_2, caption_3, caption_4, caption_5]}
  '''

  #TODO: should have a function that puts the two items in the same dictionary
  def __init__(self, preProcessedImages, preProcessedCaptions):
    '''
      This function takes in input two dictionaries and merge them in a big_dictionary
      according to common key values
    '''
    self.big_Dict = {}
    for k in preProcessedCaptions.keys():
      a = preProcessedImages.get(k) # return none if it doesn't exist
      if a is not None:
        self.big_Dict[k] = [a,preProcessedCaptions[k]]

  def __getitem__(self, key):
    '''
      This function returns the preprocessed image and the first preprocessed comment 
      associated with the given key
      TO DO:
      - introduce a function to select the best comment
      - define the behaviour when a key is not present
    '''
    #which_comment = 0 # get first comment - we can add here a function to select the best
    comments = self.big_Dict[key][1]
    returned_comment = comments.get('Caption_1')
    return self.big_Dict[key][0], returned_comment, returned_comment.size()

  def __len__(self):
    return len(self.big_Dict)

  def get_keys(self):
    '''
      I need this function in the sampler. It returns a list with all the 
      keys present in the big dictionary
    '''
    return self.big_Dict.keys()


class imageCaptionSampler(Sampler):
  '''
    TO DO: 
    - _iter_ which returns an iterable over the dataset
    - _len_ which returns the length of the dataset (needed to compute number of batches in dataloader)
  '''
  def __init__(self, data_source):
    self.data_source = data_source

  def __iter__(self):
    '''
      we don't care the order in which iterate the dataset so this function defines
      an iterator over the key list
    '''
    return iter(self.data_source.get_keys())

  def __len__(self):
    return len(self.data_source.get_keys())  



def split(imgDataset, val_size):
  '''
    @imgDataset: an image caption dataset object
    @val_size: percentage of the dataset that should compose the validation set

    This function allows us to split our dataset into 
    a validation set and a training set. This is used internally in 
    Loader, should check that one.
  '''

  # We want to split our dataset given itself and the % of sample for validation
  num = len(imgDataset)
  index = list(imgDataset.get_keys())
  np.random.shuffle(index) # pick at random
  flag_split = int(val_size * num)

  train_index = index[flag_split:]
  validation_index = index[:flag_split]

  # https://pytorch.org/docs/stable/data.html -> Samples elements randomly from a given list of indices, without replacement
  train_sampler = SubsetRandomSampler(train_index)
  validation_sampler = SubsetRandomSampler(validation_index)

  return train_sampler, validation_sampler

def loaders(dataset, val_size, batch_size, num_workers):
  ''' 
    @dataset: an image caption dataset object
    @val_size: the percentage (must be [0,1]) of the validation set data
    @batch_size: the number of data in each batch
    @num_workers: number of subprocesses to use in the data loader
  '''

  train_sampler, validation_sampler = split(dataset, val_size)
  train_loader = DataLoader(dataset,
                            batch_size = batch_size,
                            sampler = train_sampler,
                            num_workers = num_workers)
  val_loader = DataLoader(dataset,
                          batch_size = batch_size,
                          sampler = validation_sampler,
                          num_workers = num_workers)
  return train_loader, val_loader

# Text preProcessing

In [9]:
# CAPTION PREPROCESSING FUNCTIONS
from torchtext.data import Field
from torchtext.data import TabularDataset
from torchtext.data import Iterator

def prepare_data(path, input_file, output_file):
  input_path = os.path.join(path, input_file)
  output_path = os.path.join(path, output_file)
  df = pd.read_csv(input_path, sep = "|")
  captions_array = df[' comment']
  captions_array.fillna("", inplace=True)
  
  image_names = df["image_name"].values
  image_number = []
  for i in range(0, len(image_names)):
    image_number.append(image_names[i].split('.')[0])

  df.drop(labels=['image_name', ' comment_number'], axis=1, inplace=True)
  df.index = image_number
  df.to_csv(output_path, index_label="image_number")
  return output_path

def build_vocab(path_to_caption_file, caption_file):
  output_path = prepare_data(path_to_caption_file, caption_file, "clean.csv")
  tokenize = lambda x : x.split()
  TEXT = Field(sequential = True, tokenize = tokenize, lower=True, init_token='<start>', eos_token='<end>')
  LABEL = Field(sequential=False, use_vocab=False)
  td_datafields = [("image_number", LABEL ),
                  ("comment", TEXT)]

  trn = TabularDataset(
              path=output_path, # the root directory where the data lies
              format='csv',
              skip_header=True, # if your csv header has a header, make sure to pass this to ensure it doesn't get proceesed as data!
              fields=td_datafields,
              )
  
  TEXT.build_vocab(trn)
  return TEXT.vocab

def word_caption_to_index(path_to_caption_file, caption_file):
  '''
    This function takes as input the file containing the captions
    and returns a matrix of the captions indexed with respect to the inner
    vocabulary, as well as an array that can be used to map the indexed caption
    to the image it belongs to.
    @path_to_caption_file: the path to file containing the captions
    @caption_file: the name of the file containing the captions
  '''

  output_path = prepare_data(path_to_caption_file, caption_file, "clean.csv")
  tokenize = lambda x : x.split()
  TEXT = Field(sequential = True, tokenize = tokenize, lower=True, init_token='<start>', eos_token='<end>')
  LABEL = Field(sequential=False, use_vocab=False)

  td_datafields = [("image_number", LABEL ),
                  ("comment", TEXT)]

  trn = TabularDataset(
                path=output_path, # the root directory where the data lies
                format='csv',
                skip_header=True, # if your csv header has a header, make sure to pass this to ensure it doesn't get proceesed as data!
                fields=td_datafields
                )
  
  TEXT.build_vocab(trn)
  train_iter = Iterator(trn, batch_size=len(trn), device = -1)

  for i in train_iter: #TODO:I really havent got it how this still works honestly
    cmt = i.comment
    img = i.image_number

  return cmt.T, img

def vocab_as_dict(path_to_caption_file, caption_file):
  vocab = build_vocab(path_to_caption_file, caption_file)
  return vocab.stoi

def tensor_to_word(indexed_word, vocab):
  for i in indexed_word:
    k = i.item()
    if k == 1:
      break
    else:
      print(vocab.itos[k], end = " ")
    
  

def get_caption_from_image(caption_indexes, caption_refs, image_number):
  #If image name actually has the .jpg tail
  caption_refs = (caption_refs == int(image_number))
  caption_refs = caption_refs.nonzero().T.numpy()[0].tolist()
  return caption_indexes[caption_refs]

def buildCaptionDict(path_to_caption_file, caption_file):

  caption_index, refs = word_caption_to_index(path_to_caption_file, caption_file)
  refs_list = refs.numpy().tolist()
  refs_set = set(refs_list)
  unique_refs_list = list(refs_set)

  df = []
  df = pd.DataFrame(columns=["Image_number", "Caption_1", "Caption_2", "Caption_3", "Caption_4", "Caption_5"])

  for i in range(0, len(unique_refs_list)):
    capt_1 = get_caption_from_image(caption_index, refs, unique_refs_list[i])[0]
    capt_2 = get_caption_from_image(caption_index, refs, unique_refs_list[i])[1]
    capt_3 = get_caption_from_image(caption_index, refs, unique_refs_list[i])[2]
    capt_4 = get_caption_from_image(caption_index, refs, unique_refs_list[i])[3]
    capt_5 = get_caption_from_image(caption_index, refs, unique_refs_list[i])[4]
    df.loc[i] = [unique_refs_list[i], capt_1, capt_2, capt_3, capt_4, capt_5]

  df.set_index('Image_number', inplace=True)
  capt_dict = df.to_dict('index')

  return capt_dict

# Building dictionary and dataset

In [10]:
# TEST ON OUR DATA
# image dictionary
chunk1_dict = dict_from_shelve('/content/drive/My Drive/SML/SML_Project/Chunk1/img_shelve')

# caption dictionary
path_to_caption_file = "/content/drive/My Drive/SML/SML_Project/Chunk1"
caption_file = "results.csv"
caption_dict = buildCaptionDict(path_to_caption_file, caption_file)



The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.


In [11]:
dataset = imageCaptionDataset(chunk1_dict,caption_dict)
#chunk1_dict['1624481']
img_tensor, caption_tensor, size = dataset[1624481]
print(img_tensor)
print(caption_tensor)
print(size)
dataset.get_keys()

tensor([[[-0.7822, -0.6794, -0.6109,  ...,  0.2453,  0.4508,  0.8276],
         [-0.8678, -0.8507, -0.8678,  ..., -0.6794, -0.5596, -0.4226],
         [-0.8849, -0.8507, -0.8507,  ..., -0.7479, -0.7308, -0.7650],
         ...,
         [ 0.8447,  0.8447,  0.8276,  ...,  0.3309,  0.3481,  0.3481],
         [ 0.9303,  0.9474,  0.9303,  ...,  0.3652,  0.3652,  0.3652],
         [ 0.9132,  0.9303,  0.9303,  ...,  0.4166,  0.3823,  0.3652]],

        [[-0.6352, -0.4951, -0.3901,  ...,  0.4678,  0.6779,  1.0455],
         [-0.7227, -0.6702, -0.6527,  ..., -0.5476, -0.4251, -0.2675],
         [-0.7402, -0.6702, -0.6527,  ..., -0.7402, -0.7227, -0.7227],
         ...,
         [ 1.0980,  1.0980,  1.0805,  ...,  0.5728,  0.5903,  0.5903],
         [ 1.1856,  1.2031,  1.1856,  ...,  0.6078,  0.6078,  0.6078],
         [ 1.1681,  1.1856,  1.1856,  ...,  0.6604,  0.6254,  0.6078]],

        [[-0.6367, -0.5321, -0.4798,  ...,  0.2522,  0.4614,  0.8622],
         [-0.7064, -0.7064, -0.6890,  ..., -0

dict_keys([7340189, 134206, 5377361, 5771732, 4985704, 4199555, 1317156, 793558, 2760167, 10101477, 3680138, 3025093, 6827028, 6696219, 5648321, 667626, 8664920, 8664922, 3160699, 5914327, 8404753, 5521996, 148284, 9324151, 675153, 3035057, 5918675, 5918840, 807129, 5919020, 5526034, 1989609, 5791070, 5791244, 5791568, 6054169, 5400154, 9726060, 5402085, 3043766, 8680922, 4749855, 2784746, 9600569, 4489731, 8684718, 8029536, 2656351, 36979, 301246, 7510394, 6331511, 10002456, 960092, 6335241, 5287405, 438106, 178045, 6338704, 6338733, 6339096, 10010052, 10404007, 7520721, 7520731, 6734417, 574181, 4376178, 8832804, 1624481, 2148982, 4378823, 5558592, 2806447, 6214447, 7656601, 8443156, 7527111, 4906946, 2285664, 1369162, 4515460, 5958182, 4386588, 5566972, 10287332, 1243756, 3996401, 3734864, 5570219, 5570254, 65567, 8454235, 7013217, 984950, 8063007, 8849890, 2689611, 7015055, 854749, 10160966, 9637989, 5444724, 6100315, 726414, 4135695, 3219606, 205842, 7808046, 1254659, 4926723, 885

# Decoder

In [12]:
#TODO: Right now this is basically the same exact class described in the towards data science article
#Should/Do we have to make any changes to this?
class DecoderRNN(nn.Module): 
  def __init__(self, embed_size, hidden_size, vocab_size, num_layers = 1): #Here we define the layers
    super().__init__()
    self.embedding_captions_layer = nn.Embedding(vocab_size, embed_size)
    self.LSTM = nn.LSTM(input_size = embed_size, hidden_size = hidden_size, 
                        num_layers = num_layers, batch_first = True)
    self.linear = nn.Linear(hidden_size, vocab_size)

  def forward(self, features, captions): #Notice that features will only be used when we will have the encoder images in input
    captions = captions[:, :-1]
    # The following line retrieves the embedded representation of the indexes that we pass to it
    embed = self.embedding_captions_layer(captions)
    #A couple of comments here: torch.cat concatenates just like in unix
    #torch.unsqueeze instead transforms the tensor into a column vector (column since we specify 1 here)
    embed = torch.cat((features.unsqueeze(1), embed), dim = 1) 
    lstm_outputs, _ = self.LSTM(embed)
    out = self.linear(lstm_outputs)
    return out

  def sample(self, inputs, states=None, max_len=20):
    output_sentence = []

    for i in range(max_len):
      lstm_outputs, states = self.LSTM(inputs, states)
      lstm_outputs = lstm_outputs.squeeze(1)
      out = self.linear(lstm_outputs)
      last_pick = out.max(1)[1]
      output_sentence.append(last_pick.item())
      inputs = self.embedding_captions_layer(last_pick).unsqueeze(1)
      
    return output_sentence

# Parameters definition

In [13]:
#PARAMETERS DEFINITION


#from torch.nn.utils.rnn import pack_padded_sequence

# RNN - Model parameters
hidden_size = 1000
embed_size = 400
vocab_size = len(vocab_as_dict('/content/drive/My Drive/SML/SML_Project/Chunk1/', 'results.csv'))
dropout = 0.5 # suggested

# Parameters
epoch = 20
start_epoch = 0
epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 64
decoder_lr = 5e-4  # learning rate
checkpoint = None
bleu4 = 0. # bleu score -> https://www.aclweb.org/anthology/P02-1040.pdf
workers = 1 # for dataloaders: how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process
grad_clip = 5. # must be changed after all 

In [14]:
# maybe this should be saved in a utils.py file

def adjust_learning_rate(optimizer, shrink_factor):
  for param_group in optimizer.param_groups:
    param_group['lr'] = param_group['lr'] * shrink_factor

def clip_gradient(optimizer, grad_clip):
  for group in optimizer.param_groups:
    for param in group['params']:
      if param.grad is not None:
        param.grad.data.clamp_(-grad_clip, grad_clip)

# Train

In [15]:
def train(train_loader, encoder, decoder, criterion, decoder_optimizer, epoch, device, grad_clip):

  decoder.train() # train mode

  # Epochs
  for i in range(epoch):
    
    # Batch
    for i_step, (image, caption, captionlen) in enumerate(train_loader):
      
      # Move to GPU, if available
      image = image.to(device)
      caption = caption.to(device)
      #captionlen = captionlen.to(device)

      # Output encoder
      image = encoder(image)

      # Forward 
      decoder_output = decoder(image, caption)
      
      
      # Calculate loss
      loss = criterion(decoder_output.view(-1, vocab_size), caption.view(-1))
      #loss = criterion(decoder_output, caption)
      #print('Loss = ', loss)
      print_every = 1
      stats = 'Epoch [%d/%d], Step: %d, Loss: %.4f, Perplexity: %5.4f' % (i, epoch, i_step, loss.item(), np.exp(loss.item()))
      if i_step % print_every == 0:
          print('\r' + stats)
    
      # Back propagation
      decoder_optimizer.zero_grad()
      loss.backward()
      

      # Clip gradients
      # https://machinelearningmastery.com/how-to-avoid-exploding-gradients-in-neural-networks-with-gradient-clipping/
      if grad_clip is not None:
        clip_gradient(decoder_optimizer, grad_clip)

      # Update weights
      decoder_optimizer.step()


# Variable declaration and training

In [16]:
# Encoder
encoder = EncoderCNN(embed_size)

# Decoder
decoder = DecoderRNN(embed_size, hidden_size, vocab_size) #dropout = dropout)

decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, 
                                                   decoder.parameters()),
                                     lr=decoder_lr)

encoder = encoder.to(device)
decoder = decoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

# DataLoader
train_loader, val_loader = loaders(dataset, 0.2, batch_size, 1)

'''
# Epochs
for epoch in range(start_epoch, epochs):
  # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
  if epochs_since_improvement == 20:
    break
  if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
    adjust_learning_rate(decoder_optimizer, 0.8)
'''

# Training
train(train_loader=train_loader,
      encoder=encoder,
      decoder=decoder,
      criterion=criterion,
      decoder_optimizer=decoder_optimizer,
      epoch=epoch,
      device=device,
      grad_clip=grad_clip)

  # Validation
  #bleu4_new

'''
  # BLEU4
  flag1 = bleu4_new > bleu4
  bleu4 = max(bleu4_new, bleu4)
  if not flag1:
    epochs_since_improvement += 1
  else:
    epochs_since_improvement = 0
'''

Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /root/.cache/torch/checkpoints/resnet152-b121ed2d.pth


HBox(children=(FloatProgress(value=0.0, max=241530880.0), HTML(value='')))


Epoch [0/20], Step: 0, Loss: 9.9471, Perplexity: 20891.0642
Epoch [0/20], Step: 1, Loss: 8.1101, Perplexity: 3327.8911
Epoch [0/20], Step: 2, Loss: 6.1173, Perplexity: 453.6281
Epoch [1/20], Step: 0, Loss: 4.3611, Perplexity: 78.3423


KeyboardInterrupt: ignored

In [None]:
# FUNCTION THAT GIVEN AN IMAGE (KEY) PRINTS IT AND THE PRODUCED CAPTION
def generate_caption(image_key,encoder,decoder,vocab):
  # print the image
  print_image_from_key(image_key)
  # image preprocessing
  input_image = transform_image(image_key)
  # CNN
  features = encoder(input_image)
  # RNN
  decoder_output = decoder.sample(features) 
  # get the caption
  caption = tensor_to_word(decoder_output, vocab)
  print(caption)

In [None]:
#We need to balance the lenght of the output of the NN for test purpouses
def paddingFill(idx_list, max_size):

  i = len(idx_list)
  for k in range(i,max_size):
    idx_list.append(1)

  return idx_list

def test(test_loader, encoder, decoder, criterion, device, vocab_size):

  decoder.eval()
  with torch.no_grad():

    total_loss = 0
    for i, (image, caption, captionlen) in enumerate(test_loader):

      image = image.to(device)
      caption = caption.to(device)

      for k_img, k_capt in zip(image,caption):

        img = k_img.unsqueeze(0)
        img = encoder(img)
        features = img.unsqueeze(1)

        decoder_out = decoder.sample(features)
        decoder_out = paddingFill(decoder_out, len(k_capt))

        decoder_out = torch.tensor(decoder_out, dtype=torch.int64)
        #decoder_out = decoder_out.to(device)
        embed_out = decoder.embedding_captions_layer(decoder_out)

        print("Embed_out shape: ", embed_out.shape)
        print("Caption shape: ", k_capt.shape)

        loss = criterion(embed_out, k_capt) #TODO: this might not work and we might need to check the representation of decoder_out and caption
        total_loss += loss
    
    print("Total loss: {} %".format(total_loss))

In [None]:
inpt = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)

In [None]:
inpt.shape

In [None]:
target.shape

In [None]:
test(val_loader, 
     encoder=encoder, 
     decoder=decoder,
     criterion=criterion,
     device=device,
     vocab_size=vocab_size)