In [1]:
from __future__ import print_function
from __future__ import division

import sys
import torch
import torch.utils.data as data
import os
import time
import pickle
import numpy as np
from PIL import Image
import re
import io

import json
import matplotlib.pyplot as plt
from torchvision import transforms, datasets, models
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F



from IPython.display import clear_output
import torch.nn as nn
import torch.optim as optim

import itertools
import collections
import pdb
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
cuda

True

In [3]:
# Path for file dset_dataloader.json
def open_json(path):
    f = open(path) 
    data = json.load(f) 
    f.close()
    return data 

def flatten(S):
    if S == []:
        return S
    if isinstance(S[0], list):
        return flatten(S[0]) + flatten(S[1:])
    return S[:1] + flatten(S[1:])

### Bar to visualize progress

In [4]:
def update_progress(progress):
    bar_length = 20
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
    if progress < 0:
        progress = 0
    if progress >= 1:
        progress = 1

    block = int(round(bar_length * progress))
    clear_output(wait = True)
    text = "Progress: [{0}] {1:.1f}%".format( "#" * block + "-" * (bar_length - block), progress * 100)
    print(text)

In [7]:
# IMAGES
im_path_fur = '/home/ubuntu/ADARI/images/v2/full'
# im_path_fur = "../images/furniture/v2/full" # small are 64x64, medium 256x256 and large 512x512

# JSON_FILES
# data_path_fur = "../ADARI/json_files/cleaned/ADARI_v2/furniture_v2_c.json"

# WORD EMBEDDINGS
word_embeddings_path = "../json_files/fur_5c_50d_sk_glove_ft.json"

# IMAGE EMBEDDINGS
img_embds_id_p = "../json_files/afur_resnet_emb_id.json"
img_embds_name_p = "../json_files/afur_resnet_emb_names.json"

# FILES FOR DATALOADER
dset_words_p = "../json_files/ADARI_v2_furniture_images_words.json"
# dset_sentences_p = "../ADARI/json_files/ADARI_images_sentences_words/furniture/ADARI_v2_furniture_images_sentences.json"
# dset_sentences_POS_p = "../ADARI/json_files/ADARI_images_sentences_words/furniture/ADARI_v2_furniture_images_sentences_tokenized.json"

In [8]:
# Open json files with embeddings 
image_embeddings = open_json(img_embds_name_p)
dataset_labels = open_json(dset_words_p)
labels_embeddings = open_json(word_embeddings_path)

### Transform dictionary of ordered labels to list of labels 

In [9]:
glove_path = '../json_files/glove.6B.50d.txt'
with io.open(glove_path, 'r', encoding='utf8') as f:    
    glove_file = f.read()
    
glove_sentences = glove_file.splitlines()
glove_vocab = {}
for sentence in glove_sentences:
    word = sentence.split()[0]
    embedding = np.array(sentence.split()[1:], dtype = float)
    glove_vocab[word] = embedding

### nasty temporal vector for unknowns

In [10]:
with open(glove_path, 'r') as f:
    for i, line in enumerate(f):
        pass
n_vec = i + 1
hidden_dim = len(line.split(' ')) - 1

vecs = np.zeros((n_vec, hidden_dim), dtype=np.float32)

with open(glove_path, 'r') as f:
    for i, line in enumerate(f):
        vecs[i] = np.array([float(n) for n in line.split(' ')[1:]], dtype=np.float32)

AVG_VECTOR = np.mean(vecs, axis=0)

In [11]:
def is_date(word):
    rx = r"[0-9]+(?:st|[nr]d|th)"
    if re.findall(rx, word, flags=re.I) != []:
        return True
    return False

def labels_dict2list(dset_words):
    # The 2 dictionaries below for dataset dataloader
    im2idx = dict()
    im_words = dict()

    # Temp lists 
    image_names = list(dset_words.keys())
    words = list(dset_words.values())

    # Iterate over length of dictionary and get im2idx and im_words 
    for i in range(len(image_names)):
        im = image_names[i]
        words_list = flatten(list(words[i].values()))
        cleaned_w = []
        for w in words_list:
            if w != '"the' and w != '"The' and len(w) > 1 and is_date(w) != True:
                cleaned_w.append(w)

#         im_words[im] = cleaned_w
        im_words[im] = list(set(cleaned_w))
        im2idx[im] = i
    return im_words, im2idx

def create_vocab(dataset_labels):
    """
    We have 17532 images and a total of 707852 adjectives, so average of 40 words per image
    We have a total of 4786 unique words. This is our vocabulary size
    """
    # 1) Convert raw dataset (dictionary of ordered labels per image) to list of labels
    dset_im_words, _ = labels_dict2list(dataset_labels)
    
    # 2) Flatten the list 
    all_words = list(dset_im_words.values())
    flat_list = []
    for sublist in all_words:
        for item in sublist:
            flat_list.append(item)
    
    # 3) Get set of unique words = vocabulary
    unique_words = set(flat_list)
    
    # 4) Get dicitonary to map idx to words and viceversa
    words2idx = dict()
    idx2words = dict()
    
    set2list = list(unique_words)
    for i in range(len(set2list)):
        w = set2list[i]
        words2idx[w] = i
        idx2words[i] = w
        
    return dset_im_words, words2idx, idx2words

In [12]:
dset_im_words, vocab2idx, idx2vocab = create_vocab(dataset_labels)

### Split dataset into train, validation and test

In [13]:
def splitDict(d_img_words, d_img_embs, percent, val_number):

    val_n = val_number
    train_test_size = len(d_img_words) - val_n
    train_n = int(train_test_size*percent)
    test_n = train_test_size - train_n
    
    im_words = iter(d_img_words.items())      
    im_embs = iter(d_img_embs.items())
    
    # Image - words
    dtrain_imw = dict(itertools.islice(im_words, train_n))  
    dtest_imw = dict(itertools.islice(im_words, test_n))   
    dval_imw = dict(itertools.islice(im_words, val_n))
    
    
    print('trainset size: ', len(dtrain_imw), 'dataset size: ',len(dtest_imw), 'val set size: ', len(dval_imw))
    return dtrain_imw, dtest_imw, dval_imw 

In [55]:
dtrain_w, dtest_w, dval_w = splitDict(dset_im_words,image_embeddings, .95, 500)

trainset size:  16180 dataset size:  852 val set size:  500


In [56]:
img_size = 64
class ADARIdataset(Dataset):
    """
    Receives images and labels.
    Returns tensor image and tensor labels
    """
    def __init__(self, data_labels, word_embeddings, image_embeddings, vocab2idx, idx2vocab, img_path):

        self.labels_data = data_labels # dictionary of images -> labels
        self.word_embeds = word_embeddings
        
        self.images_names = list(self.labels_data.keys())    # names
        self.images_embeds = list(image_embeddings.values()) # values
        
        self.vocab2idx = vocab2idx
        self.idx2vocab = idx2vocab
        
        self.image_path = img_path
        
    def __len__(self):
        return len(self.images_names)
    
    def name2idx(self):
        self.name2idx = dict()
        self.idx2name = dict()
        for i, key in enumerate(self.images_names.keys()):
            self.name2idx[key] = i
            self.idx2name[i] = key
        
    def get_image_tensor(self, image_name):
        """
        Gets image name and returns a tensor
        """
        name = self.image_path + "/" + image_name
        img = Image.open(name)
        img = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor()])(img)
        
        return img
        
    def get_labels_embeddings_from_idx(self, idx):

        name_image = self.images_names[idx]
        labels = self.labels_data[name_image]
        labels = list(set(labels))
        
        # Set random distribution for setting a max number of labels
        if len(labels) > 20:
            labels = np.random.choice(labels, 20)
        
        # Get positive and negative labels
        all_idx = list(self.vocab2idx.values())
        pos_idxs = []
        # Remove indexes that correspond to the positive labels
        for l in labels:
            v2i = self.vocab2idx[l]
            pos_idxs.append(v2i)
            if v2i in all_idx:
                all_idx.remove(v2i)
        
        # Choose random labels as negative samples -> this can be improved with info about distance of labels
        neg_idxs = np.random.choice(all_idx, len(labels))
        
        neg_samples = []
        for n in neg_idxs:
            neg_samples.append(self.idx2vocab[n])
        
        assert(len(labels) == len(neg_samples))
        pos_w_embs = []
        neg_w_embs = []
        
        # positive
        for l in labels:
            try:
                pos_w_embs.append(self.word_embeds[l.lower()]) # appending 50 vector embedding
            except:
                try:
                    pos_w_embs.append(glove_vocab[l.lower()])
                except:
                    pos_w_embs.append(AVG_VECTOR)
        # negative
        for nl in neg_samples:
            try:
                neg_w_embs.append(self.word_embeds[nl.lower()]) # appending 50 vector embedding
            except:
                try:
                    neg_w_embs.append(glove_vocab[nl.lower()])
                except:
                    neg_w_embs.append(AVG_VECTOR)
                    
                    
        return pos_w_embs, neg_w_embs
    
    def __getitem__(self, index):
        """
        Return tensor image and label embedding
        """
        name_image = self.images_names[index]
        img = self.get_image_tensor(name_image)
        
        #image_emb = self.images_embeds[index] # list size 2048 
        pos_label_embs, neg_label_embs = self.get_labels_embeddings_from_idx(index) # list size variable 
        
        return img, pos_label_embs, neg_label_embs
 

def collate(sequence):
    """
    "the input of this function is the output of function __getitem__"
    "this gets BATCH_SIZE times GETITEM! "
    if batch_Size == 2 --> sequence is a list with length 2. 
    Each list is a tuple (image_embedding, labels_embedding) = (2048 vector, list of vectors size 50)
    Pad labels with maximum from batch
    """
    
    # Concatenate all images in the batch
    # For images (not embeddigns)
    images = torch.cat(([torch.FloatTensor(batch[0]).view(-1, 3, 64, 64) for batch in sequence]), dim=0)
    # For images embeddings
    #images = torch.cat(([torch.FloatTensor(batch[0]).view(-1, 2048) for batch in sequence]), dim=0)
    
    # Pad labels with max_sequence_label
    # batch 1 is batch * word embedding
    pos_labels = pad_sequence([torch.FloatTensor(batch[1]) for batch in sequence], batch_first=True)
    labels_length = torch.LongTensor([len(batch[1]) for batch in sequence])     

    neg_labels = pad_sequence([torch.FloatTensor(batch[2]) for batch in sequence], batch_first=True)
    
    return images, pos_labels, neg_labels, labels_length

In [57]:
dataset_train = ADARIdataset(dtrain_w, labels_embeddings, image_embeddings, vocab2idx, idx2vocab, im_path_fur)
dataset_test = ADARIdataset(dtest_w, labels_embeddings, image_embeddings, vocab2idx, idx2vocab, im_path_fur)
dataset_val = ADARIdataset(dval_w, labels_embeddings, image_embeddings, vocab2idx, idx2vocab, im_path_fur)

In [58]:
batch_size = 64
num_workers = 8 if cuda else 0

In [59]:
train_dataloader = DataLoader(dataset_train, batch_size=batch_size, collate_fn = collate, shuffle=True, num_workers=num_workers, drop_last=False)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size, collate_fn = collate,shuffle=False, num_workers=num_workers, drop_last=False)
val_dataloader = DataLoader(dataset_val, batch_size=batch_size, collate_fn = collate,shuffle=False, num_workers=num_workers, drop_last=False)

In [42]:
# to test dataloader
it = iter(test_dataloader)

min_ = 10000
for i in range(len(test_dataloader)):
    first = next(it)
    if first[1].shape[1] < min_:
        min_ = first[1].shape[1]
print(min_)
   # print(second[0].shape, second[1].shape, second[2])

20


### Image embedding

In [69]:
feature_extract = False
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [70]:
def initialize_model(num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    model_ft = models.resnet152(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, feature_extract)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
    input_size = 64
    
    return model_ft, input_size

model_ft, input_size = initialize_model(50, feature_extract)
model_ft

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [71]:
model_ft.fc.weight.shape

torch.Size([50, 2048])

### Matching CNN 

In [72]:
# MULTICONVOLUTION CONCATENATING
class matchingCNN(nn.Module):
    """
   
    """
    def __init__(self):
        super(matchingCNN, self).__init__()
        
        # Multimodal Convolution 
        self.cnn1 = nn.Conv1d(in_channels = 1, out_channels = 200,  kernel_size = 12, padding=0, stride=2)
        self.max1 = nn.MaxPool1d(2, stride=2)
        
        # 2nd Convolution
        self.cnn2 = nn.Conv1d(in_channels = 200, out_channels = 300, kernel_size = 12, padding=0, stride=2)
        self.max2 = nn.MaxPool1d(2, stride=2)
        
        # 3rd Convolution
        self.cnn3 = nn.Conv1d(in_channels = 300, out_channels = 300, kernel_size = 12, padding=0, stride=2)
        self.max3 = nn.MaxPool1d(2, stride=2)
                    
    def forward(self, image_feature, labels, labels_lengths):
        batch_size = labels.shape[0]
        max_length_per_batch = labels.shape[1]
        dim_labels = labels.shape[2]
        dim_image = image_feature.shape[1]
        
        # Mixing image and labels 
        new_vectors = [] # [batch, max_length, dim*3+256]
        receptive_field_words = 3
        new_vector_length = receptive_field_words * (dim_labels + dim_image)
        
        for i in range(labels.shape[1]): # iterating over words 
            next3words = labels[:, i:i+3, :] 
            next3words = next3words.view(batch_size, -1)
            words_img = torch.cat((next3words, image_feature), 1).unsqueeze(1) # [batch, 1, 200]
            len_last_dim = words_img.shape[2]
            
            if len_last_dim != new_vector_length:
                pad_times = new_vector_length - len_last_dim
                words_img = F.pad(words_img, (0, pad_times), "constant", 0)
            new_vectors.append(words_img)
        
        joint = torch.cat((new_vectors),1) # [batch, 20, 406]

        # Reshaping matrix
        joint = joint.view(-1, 1, new_vector_length) # [batch*20, 1, 406]
#         print('joint matrix shape: ', joint.shape)
        joint = F.relu(F.dropout(self.cnn1(joint), .1))
#         print('after first convolution: ', joint.shape)
        joint = self.max1(joint)
#         print('after maxpool1: ', joint.shape)

        joint = F.relu(F.dropout(self.cnn2(joint), .1))
#         print('after conv2: ', joint.shape)
        joint = self.max2(joint)
#         print('after 2nd maxpool: ', joint.shape)
        joint = F.relu(F.dropout(self.cnn3(joint), .1))
#         print('after conv3: ', joint.shape)
        joint = self.max3(joint)
#         print('after 3rd maxpool: ', joint.shape)
        
        joint = joint.view(batch_size, -1)
#         print(joint.shape)
        return joint
        
    
    

In [73]:
class Score_deviseAndKiros(nn.Module):
    def __init__(self, image_vector, label_true_vectors, label_random_vector, lengths, cnn_img_weights):
        super(Score_deviseAndKiros, self).__init__()
        M = cnn_img_weights      # [50, 2048]
        v_image = image_vector.T # [batch, 50]
        t_label = label_true_vectors   # [batch, 50, max_length]
        t_j     = label_random_vectors # [batch, 50, max_length]
    

In [74]:
class Scoring_function(nn.Module):
    def __init__(self):
        super(Scoring_function, self).__init__()
        self.fc1 = nn.Linear(6000, 400)
        self.fc2 = nn.Linear(400, 100)
        
    def forward(self, x): 
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [75]:
def margin_ranking_loss(margin, score_good, score_bad):
    cost = (margin - score_good + score_bad)
    z = torch.zeros_like(cost)
    loss = torch.max(z, cost)
    return loss

In [76]:
def normalize_vec(vec):
    norm = vec.norm(p=2, dim=1, keepdim=True)
    vec_norm = vec.div(norm)
    return vec_norm

In [89]:
def train_epoch(img_cnn, matching_cnn, mlp, train_loader, optimizer):
    img_cnn.train()
    matching_cnn.train()
    mlp.train()
    
    running_loss = 0.0
    total_predictions = 0.0
    correct_predictions = 0.0
    
    predictions = []
    ground_truth = []
    
    start_time = time.time()
    for batch_idx, (data, pos_labels, neg_labels, length) in enumerate(train_loader):   
        optimizer.zero_grad()   # .backward() accumulates gradients
        data = data.to(device)
        
        pos_labels = pos_labels.to(device) # [batch, 50, max_length]
        neg_labels = neg_labels.to(device) # [batch, 50, max_length]
        
        # image_CNN
        im_repres = img_cnn(data) # [batch, 50]
        im_repres = F.relu(im_repres) # [batch, 50]
        
        # Matching CNN
        joint_repre_related = matching_cnn(im_repres, pos_labels, length)   # [batch, 6000]
        joint_repre_unrelated = matching_cnn(im_repres, neg_labels, length) # [batch, 6000]
        
        # Scoring mLP
        score_good = mlp(joint_repre_related).mean((1,0))   # [batch, 100] -> [batch]
        score_bad = mlp(joint_repre_unrelated).mean((1,0))  # [batch, 100] -> [batch]
        
        loss = margin_ranking_loss(0.5, score_good, score_bad)

        if batch_idx % 10 == 0 and batch_idx != 0:
            print('sg: ', score_good.item())
            print('sb: ', score_bad.item())
            print('loss: ', loss.item())
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    end_time = time.time()
    
    loss_epoch = running_loss / len(train_loader)
    print('------ Training -----')
    return loss_epoch

In [90]:
def test_epoch(img_cnn, matching_cnn, mlp, test_loader):
    img_cnn.eval()
    matching_cnn.eval()
    mlp.eval()
    
    running_loss = 0.0
    total_predictions = 0.0
    correct_predictions = 0.0
    
    predictions = []
    ground_truth = []

    with torch.no_grad():
        for batch_idx, (data, pos_labels, neg_labels, length) in enumerate(test_loader):   
            data = data.to(device)
            pos_labels = pos_labels.to(device) # all data & model on same device
            neg_labels = neg_labels.to(device) 

            # image_CNN
            im_repres = img_cnn(data) # [batch, 50]
            im_repres = F.relu(im_repres)

            # Matching CNN
            joint_repre_related = matching_cnn(im_repres, pos_labels, length)
            joint_repre_unrelated = matching_cnn(im_repres, neg_labels, length)

            # Scoring mLP
            score_good = mlp(joint_repre_related).mean((1,0))
            score_bad = mlp(joint_repre_unrelated).mean((1,0))
            
            loss = margin_ranking_loss(0.5, score_good, score_bad)
            if batch_idx % 10 == 0 and batch_idx != 0:
                print('sg: ', score_good.item())
                print('sb: ', score_bad.item())

                print('loss: ', loss.item())

            running_loss += loss.item()
    
    
    loss_epoch = running_loss / len(test_loader)
    print('------ Testing -----')
    return loss_epoch

In [84]:
# MODELS 
device = torch.device("cuda" if cuda else "cpu")
image_CNN = model_ft.to(device)
matching_CNN = matchingCNN().to(device)
scoring_MLP = Scoring_function().to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.

params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name, param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

params = list(params_to_update) + list(matching_CNN.parameters()) + list(scoring_MLP.parameters())

# optimizer = optim.Adam(model.parameters(), lr = 0.001)
#optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Observe that all parameters are being optimized
optimizer = optim.SGD(params, lr=0.001, momentum=0.9)

Params to learn:
	 conv1.weight
	 bn1.weight
	 bn1.bias
	 layer1.0.conv1.weight
	 layer1.0.bn1.weight
	 layer1.0.bn1.bias
	 layer1.0.conv2.weight
	 layer1.0.bn2.weight
	 layer1.0.bn2.bias
	 layer1.0.conv3.weight
	 layer1.0.bn3.weight
	 layer1.0.bn3.bias
	 layer1.0.downsample.0.weight
	 layer1.0.downsample.1.weight
	 layer1.0.downsample.1.bias
	 layer1.1.conv1.weight
	 layer1.1.bn1.weight
	 layer1.1.bn1.bias
	 layer1.1.conv2.weight
	 layer1.1.bn2.weight
	 layer1.1.bn2.bias
	 layer1.1.conv3.weight
	 layer1.1.bn3.weight
	 layer1.1.bn3.bias
	 layer1.2.conv1.weight
	 layer1.2.bn1.weight
	 layer1.2.bn1.bias
	 layer1.2.conv2.weight
	 layer1.2.bn2.weight
	 layer1.2.bn2.bias
	 layer1.2.conv3.weight
	 layer1.2.bn3.weight
	 layer1.2.bn3.bias
	 layer2.0.conv1.weight
	 layer2.0.bn1.weight
	 layer2.0.bn1.bias
	 layer2.0.conv2.weight
	 layer2.0.bn2.weight
	 layer2.0.bn2.bias
	 layer2.0.conv3.weight
	 layer2.0.bn3.weight
	 layer2.0.bn3.bias
	 layer2.0.downsample.0.weight
	 layer2.0.downsample.1.weight

In [91]:
train_losses = []
test_losses = []
save_model_path = './saved_models/cnn_im_50_finetune/'
for i in range(10):
    print('-----Training epoch {}/{} --------'.format(i,9))
    tr_loss = train_epoch(image_CNN, matching_CNN, scoring_MLP, train_dataloader, optimizer)
    print('train epoch: {}, loss: {}'.format(i, tr_loss))
    print()
    print('------Testing epoch {}/{} --------'.format(i,9))
    tst_loss = test_epoch(image_CNN, matching_CNN, scoring_MLP, test_dataloader)
    print('test epoch: {}, loss: {}'.format(i, tst_loss))
    
    train_losses.append(tr_loss)
    test_losses.append(tst_loss)
    
    save_path_im = save_model_path + 'cnnim_{}.pt'.format(i)
    save_path_match= save_model_path + 'cnnmatch_{}.pt'.format(i)
    save_path_mlp = save_model_path + 'mlp_{}.pt'.format(i)
    
    torch.save(image_CNN, save_path_im)
    torch.save(matching_CNN, save_path_match)
    torch.save(scoring_MLP, save_path_mlp)
    

-----Training epoch 0/9 --------
sg:  0.004097073804587126
sb:  0.0010304147144779563
loss:  0.49693334102630615
sg:  0.00986428651958704
sb:  0.00474146194756031
loss:  0.4948771595954895
sg:  0.015590873546898365
sb:  0.00855383649468422
loss:  0.49296295642852783
sg:  0.02124929428100586
sb:  0.011777550913393497
loss:  0.49052825570106506
sg:  0.02864675037562847
sb:  0.017119891941547394
loss:  0.48847314715385437
sg:  0.03575052320957184
sb:  0.0215508583933115
loss:  0.4858003556728363
sg:  0.04412742331624031
sb:  0.027623068541288376
loss:  0.4834956228733063
sg:  0.05635696277022362
sb:  0.03603341802954674
loss:  0.4796764552593231
sg:  0.07268647849559784
sb:  0.047514453530311584
loss:  0.47482794523239136
sg:  0.09205743670463562
sb:  0.05860103294253349
loss:  0.466543585062027
sg:  0.12075681239366531
sb:  0.0774388462305069
loss:  0.456682026386261
sg:  0.15404389798641205
sb:  0.10454294830560684
loss:  0.450499027967453
sg:  0.1936601996421814
sb:  0.1302753537893295

sg:  2.3773210048675537
sb:  1.6056060791015625
loss:  0.0
sg:  2.403116464614868
sb:  1.618457555770874
loss:  0.0
sg:  2.4896867275238037
sb:  1.7250471115112305
loss:  0.0
sg:  2.4375646114349365
sb:  1.6536462306976318
loss:  0.0
sg:  2.333554267883301
sb:  1.622067928314209
loss:  0.0
------ Training -----
train epoch: 4, loss: 0.0

------Testing epoch 4/9 --------
sg:  2.594128370285034
sb:  1.7460918426513672
loss:  0.0
------ Testing -----
test epoch: 4, loss: 0.0
-----Training epoch 5/9 --------
sg:  2.5142130851745605
sb:  1.6515380144119263
loss:  0.0
sg:  2.4245498180389404
sb:  1.673907995223999
loss:  0.0
sg:  2.4741058349609375
sb:  1.675045132637024
loss:  0.0
sg:  2.4104628562927246
sb:  1.6379343271255493
loss:  0.0
sg:  2.4063973426818848
sb:  1.610346794128418
loss:  0.0
sg:  2.470592737197876
sb:  1.6634306907653809
loss:  0.0
sg:  2.4511282444000244
sb:  1.6806341409683228
loss:  0.0
sg:  2.4704549312591553
sb:  1.6823315620422363
loss:  0.0
sg:  2.442395925521850

sg:  2.4486687183380127
sb:  1.674466848373413
loss:  0.0
sg:  2.4145143032073975
sb:  1.624483346939087
loss:  0.0
sg:  2.466353416442871
sb:  1.6414159536361694
loss:  0.0
sg:  2.336916208267212
sb:  1.6099854707717896
loss:  0.0
sg:  2.365604877471924
sb:  1.6548970937728882
loss:  0.0
sg:  2.4263298511505127
sb:  1.6877131462097168
loss:  0.0
sg:  2.440871000289917
sb:  1.6585395336151123
loss:  0.0
sg:  2.4778621196746826
sb:  1.6961959600448608
loss:  0.0
------ Training -----
train epoch: 9, loss: 0.0

------Testing epoch 9/9 --------
sg:  2.601405620574951
sb:  1.7554138898849487
loss:  0.0
------ Testing -----
test epoch: 9, loss: 0.0
