In [11]:
from __future__ import print_function
from __future__ import division

import sys
import torch
import torch.utils.data as data
import os
import time
import pickle
import numpy as np
from PIL import Image
import re
import io

import json
import matplotlib.pyplot as plt
from torchvision import transforms, datasets, models
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F



from IPython.display import clear_output
import torch.nn as nn
import torch.optim as optim

import itertools
import collections
import pdb
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
cuda

True

In [12]:
# Path for file dset_dataloader.json
def open_json(path):
    f = open(path) 
    data = json.load(f) 
    f.close()
    return data 

def flatten(S):
    if S == []:
        return S
    if isinstance(S[0], list):
        return flatten(S[0]) + flatten(S[1:])
    return S[:1] + flatten(S[1:])

### Bar to visualize progress

In [13]:
def update_progress(progress):
    bar_length = 20
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
    if progress < 0:
        progress = 0
    if progress >= 1:
        progress = 1

    block = int(round(bar_length * progress))
    clear_output(wait = True)
    text = "Progress: [{0}] {1:.1f}%".format( "#" * block + "-" * (bar_length - block), progress * 100)
    print(text)

In [14]:
# IMAGES
im_path_fur = '/home/ubuntu/ADARI/images/v2/full'
# im_path_fur = "../images/furniture/v2/full" # small are 64x64, medium 256x256 and large 512x512

# JSON_FILES
# data_path_fur = "../ADARI/json_files/cleaned/ADARI_v2/furniture_v2_c.json"

# WORD EMBEDDINGS
word_embeddings_path = "../json_files/fur_5c_50d_sk_glove_ft.json"

# IMAGE EMBEDDINGS
img_embds_id_p = "../json_files/afur_resnet_emb_id.json"
img_embds_name_p = "../json_files/afur_resnet_emb_names.json"

# FILES FOR DATALOADER
dset_words_p = "../json_files/ADARI_v2_furniture_images_words.json"
# dset_sentences_p = "../ADARI/json_files/ADARI_images_sentences_words/furniture/ADARI_v2_furniture_images_sentences.json"
# dset_sentences_POS_p = "../ADARI/json_files/ADARI_images_sentences_words/furniture/ADARI_v2_furniture_images_sentences_tokenized.json"

In [15]:
# Open json files with embeddings 
image_embeddings = open_json(img_embds_name_p)
dataset_labels = open_json(dset_words_p)
labels_embeddings = open_json(word_embeddings_path)

### Transform dictionary of ordered labels to list of labels 

In [16]:
glove_path = '../json_files/glove.6B.50d.txt'
with io.open(glove_path, 'r', encoding='utf8') as f:    
    glove_file = f.read()
    
glove_sentences = glove_file.splitlines()
glove_vocab = {}
for sentence in glove_sentences:
    word = sentence.split()[0]
    embedding = np.array(sentence.split()[1:], dtype = float)
    glove_vocab[word] = embedding

### nasty temporal vector for unknowns

In [17]:
with open(glove_path, 'r') as f:
    for i, line in enumerate(f):
        pass
n_vec = i + 1
hidden_dim = len(line.split(' ')) - 1

vecs = np.zeros((n_vec, hidden_dim), dtype=np.float32)

with open(glove_path, 'r') as f:
    for i, line in enumerate(f):
        vecs[i] = np.array([float(n) for n in line.split(' ')[1:]], dtype=np.float32)

AVG_VECTOR = np.mean(vecs, axis=0)

In [18]:
def is_date(word):
    rx = r"[0-9]+(?:st|[nr]d|th)"
    if re.findall(rx, word, flags=re.I) != []:
        return True
    return False

def labels_dict2list(dset_words):
    # The 2 dictionaries below for dataset dataloader
    im2idx = dict()
    im_words = dict()

    # Temp lists 
    image_names = list(dset_words.keys())
    words = list(dset_words.values())

    # Iterate over length of dictionary and get im2idx and im_words 
    for i in range(len(image_names)):
        im = image_names[i]
        words_list = flatten(list(words[i].values()))
        cleaned_w = []
        for w in words_list:
            if w != '"the' and w != '"The' and len(w) > 1 and is_date(w) != True:
                cleaned_w.append(w)

#         im_words[im] = cleaned_w
        im_words[im] = list(set(cleaned_w))
        im2idx[im] = i
    return im_words, im2idx

def create_vocab(dataset_labels):
    """
    We have 17532 images and a total of 707852 adjectives, so average of 40 words per image
    We have a total of 4786 unique words. This is our vocabulary size
    """
    # 1) Convert raw dataset (dictionary of ordered labels per image) to list of labels
    dset_im_words, _ = labels_dict2list(dataset_labels)
    
    # 2) Flatten the list 
    all_words = list(dset_im_words.values())
    flat_list = []
    for sublist in all_words:
        for item in sublist:
            flat_list.append(item)
    
    # 3) Get set of unique words = vocabulary
    unique_words = set(flat_list)
    
    # 4) Get dicitonary to map idx to words and viceversa
    words2idx = dict()
    idx2words = dict()
    
    set2list = list(unique_words)
    for i in range(len(set2list)):
        w = set2list[i]
        words2idx[w] = i
        idx2words[i] = w
        
    return dset_im_words, words2idx, idx2words

In [19]:
dset_im_words, vocab2idx, idx2vocab = create_vocab(dataset_labels)

### Split dataset into train, validation and test

In [20]:
def splitDict(d_img_words, d_img_embs, percent, val_number):

    val_n = val_number
    train_test_size = len(d_img_words) - val_n
    train_n = int(train_test_size*percent)
    test_n = train_test_size - train_n
    
    im_words = iter(d_img_words.items())      
    im_embs = iter(d_img_embs.items())
    
    # Image - words
    dtrain_imw = dict(itertools.islice(im_words, train_n))  
    dtest_imw = dict(itertools.islice(im_words, test_n))   
    dval_imw = dict(itertools.islice(im_words, val_n))
    
    
    print('trainset size: ', len(dtrain_imw), 'dataset size: ',len(dtest_imw), 'val set size: ', len(dval_imw))
    return dtrain_imw, dtest_imw, dval_imw 

In [21]:
dtrain_w, dtest_w, dval_w = splitDict(dset_im_words,image_embeddings, .95, 500)

trainset size:  16180 dataset size:  852 val set size:  500


In [22]:
img_size = 64
class ADARIdataset(Dataset):
    """
    Receives images and labels.
    Returns tensor image and tensor labels
    """
    def __init__(self, data_labels, word_embeddings, image_embeddings, vocab2idx, idx2vocab, img_path):

        self.labels_data = data_labels # dictionary of images -> labels
        self.word_embeds = word_embeddings
        
        self.images_names = list(self.labels_data.keys())    # names
        self.images_embeds = list(image_embeddings.values()) # values
        
        self.vocab2idx = vocab2idx
        self.idx2vocab = idx2vocab
        
        self.image_path = img_path
        
    def __len__(self):
        return len(self.images_names)
    
    def name2idx(self):
        self.name2idx = dict()
        self.idx2name = dict()
        for i, key in enumerate(self.images_names.keys()):
            self.name2idx[key] = i
            self.idx2name[i] = key
        
    def get_image_tensor(self, image_name):
        """
        Gets image name and returns a tensor
        """
        name = self.image_path + "/" + image_name
        img = Image.open(name)
        img = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor()])(img)
        
        return img
        
    def get_labels_embeddings_from_idx(self, idx):

        name_image = self.images_names[idx]
        labels = self.labels_data[name_image]
        labels = list(set(labels))
        
        # Set random distribution for setting a max number of labels
        if len(labels) > 20:
            labels = np.random.choice(labels, 20)
        
        # Get positive and negative labels
        all_idx = list(self.vocab2idx.values())
        pos_idxs = []
        # Remove indexes that correspond to the positive labels
        for l in labels:
            v2i = self.vocab2idx[l]
            pos_idxs.append(v2i)
            if v2i in all_idx:
                all_idx.remove(v2i)
        
        # Choose random labels as negative samples -> this can be improved with info about distance of labels
        neg_idxs = np.random.choice(all_idx, len(labels))
        
        neg_samples = []
        for n in neg_idxs:
            neg_samples.append(self.idx2vocab[n])
        
        assert(len(labels) == len(neg_samples))
        pos_w_embs = []
        neg_w_embs = []
        
        # positive
        for l in labels:
            try:
                pos_w_embs.append(self.word_embeds[l.lower()]) # appending 50 vector embedding
            except:
                try:
                    pos_w_embs.append(glove_vocab[l.lower()])
                except:
                    pos_w_embs.append(AVG_VECTOR)
        # negative
        for nl in neg_samples:
            try:
                neg_w_embs.append(self.word_embeds[nl.lower()]) # appending 50 vector embedding
            except:
                try:
                    neg_w_embs.append(glove_vocab[nl.lower()])
                except:
                    neg_w_embs.append(AVG_VECTOR)
                    
                    
        return pos_w_embs, neg_w_embs
    
    def __getitem__(self, index):
        """
        Return tensor image and label embedding
        """
        name_image = self.images_names[index]
        img = self.get_image_tensor(name_image)
        
        #image_emb = self.images_embeds[index] # list size 2048 
        pos_label_embs, neg_label_embs = self.get_labels_embeddings_from_idx(index) # list size variable 
        
        return img, pos_label_embs, neg_label_embs
 

def collate(sequence):
    """
    "the input of this function is the output of function __getitem__"
    "this gets BATCH_SIZE times GETITEM! "
    if batch_Size == 2 --> sequence is a list with length 2. 
    Each list is a tuple (image_embedding, labels_embedding) = (2048 vector, list of vectors size 50)
    Pad labels with maximum from batch
    """
    
    # Concatenate all images in the batch
    # For images (not embeddigns)
    images = torch.cat(([torch.FloatTensor(batch[0]).view(-1, 3, 64, 64) for batch in sequence]), dim=0)
    # For images embeddings
    #images = torch.cat(([torch.FloatTensor(batch[0]).view(-1, 2048) for batch in sequence]), dim=0)
    
    # Pad labels with max_sequence_label
    # batch 1 is batch * word embedding
    pos_labels = pad_sequence([torch.FloatTensor(batch[1]) for batch in sequence], batch_first=True)
    labels_length = torch.LongTensor([len(batch[1]) for batch in sequence])     

    neg_labels = pad_sequence([torch.FloatTensor(batch[2]) for batch in sequence], batch_first=True)
    
    return images, pos_labels, neg_labels, labels_length

In [23]:
dataset_train = ADARIdataset(dtrain_w, labels_embeddings, image_embeddings, vocab2idx, idx2vocab, im_path_fur)
dataset_test = ADARIdataset(dtest_w, labels_embeddings, image_embeddings, vocab2idx, idx2vocab, im_path_fur)
dataset_val = ADARIdataset(dval_w, labels_embeddings, image_embeddings, vocab2idx, idx2vocab, im_path_fur)

In [24]:
batch_size = 64
num_workers = 8 if cuda else 0

In [25]:
train_dataloader = DataLoader(dataset_train, batch_size=batch_size, collate_fn = collate, shuffle=True, num_workers=num_workers, drop_last=False)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size, collate_fn = collate,shuffle=False, num_workers=num_workers, drop_last=False)
val_dataloader = DataLoader(dataset_val, batch_size=batch_size, collate_fn = collate,shuffle=False, num_workers=num_workers, drop_last=False)

In [42]:
# to test dataloader
it = iter(test_dataloader)

min_ = 10000
for i in range(len(test_dataloader)):
    first = next(it)
    if first[1].shape[1] < min_:
        min_ = first[1].shape[1]
print(min_)
   # print(second[0].shape, second[1].shape, second[2])

20


### Lstm

In [146]:
class LSTM(nn.Module):
    """
    The only combination I get to work is using the output of the LSTM, not the hiddens. 
    Can be either 1 direction or 2.
    
    """
    def __init__(self, embedding_dim, hidden_dim, n_layers, bidirectional, dropout):
        super(LSTM, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.bi = bidirectional
        # lstm layer
        self.lstm = nn.LSTM(embedding_dim,
                           hidden_dim,
                           num_layers=n_layers,
                           bidirectional=bidirectional,
                           dropout=dropout,
                           batch_first=True)
        
    def forward(self, labels, labels_lengths):

        batch_size = labels.shape[0]
        # --------------------------------   SHAPES -----------------------------------------------------
        # labels ----------------------> [batch, max_len, 50] they come padded to go through the data loader

        # embedded_label --------------> [batch size, max_length, emb dim] (64 dimensions for each of the 5 labels)
        # packed_embdded --------------> [XXXX, emb dimension]
        # out lstm shape --------------> [seq_len, batch_size, hidden_dim * directions (2 if bidirectional else 1)])
        # Hidden lstm shape -----------> [batch_size, hidden_size*num_layers])
        # Cell lstm shape -------------> [directions * layers, batch_size, hidden_size])
        # REturn ----------------------> [seq_len * batch, vocab size]
        # Using pad packed ------------> [seq_len, batch, vocab size]

        # packed sequence
        packed_embedded = nn.utils.rnn.pack_padded_sequence(labels, labels_lengths, 
                                                            enforce_sorted=False, batch_first=True)

        # Lstm returns packed output 
        packed_output, (hidden, cell) = self.lstm(packed_embedded) # hidden shape[num_layers*num_directions, batch, 50]

        last_hidden = hidden[-1, :, :] # [batch, emb_dim=50]

        
        output, lens = nn.utils.rnn.pad_packed_sequence(packed_output)
        # output shape: [20, 64, 50]
        
        return last_hidden # last layer of hidden 
    

### Image embedding

In [68]:
feature_extract = False # so it fine tunes
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [72]:
def initialize_model(num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    model_ft = models.resnet152(pretrained=use_pretrained)
    model_ft = torch.nn.Sequential(*(list(model_ft.children())[:-1]))
    set_parameter_requires_grad(model_ft, feature_extract)
    #num_ftrs = model_ft.fc.in_features
#     model_ft.fc = nn.Linear(num_ftrs, num_classes)
#     input_size = 64
    
    return model_ft

model_ft = initialize_model(50, feature_extract)
model_ft

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [147]:
class Wim(nn.Module):
    def __init__(self, embed_size):
        super(Wim, self).__init__()   
        
        self.fc = nn.Linear(2048, embed_size)
    
    def forward(self, im):
        return self.fc(im)

In [205]:
def loss_devise(image_vector, Wim, label_true_vectors, label_random_vectors):
    
    image_v = normalize_vec(image_vector)         # [batch, 2048]
    t_label = normalize_vec(label_true_vectors)   # [batch, 50]
    t_j     = normalize_vec(label_random_vectors) # [batch, 50]
    
    W = Wim.fc.weight                              # [50, 2048]

    #x = torch.mm(image_v, W.T) # [batch, 2048] x [2048, 50] -> [batch, 50] 
    x = Wim(image_v) # [batch, 50]
    s_x_v = torch.mm(x, t_label.T) # [batch, 50] x [50, batch] -> [batch, batch]
    s_x_vj = torch.mm(x, t_j.T) # [batch, 50] x [50, batch] -> [batch, batch]
    
    batch = image_v.shape[0]
    # test with loop
    loss = 0.0
    for b in range(batch):
        sxv_sum = s_x_v[b, :].sum()
        sxvj_sum = s_x_vj[b, :].sum()
        c = 0.2 - sxv_sum + sxvj_sum
        loss_ind = torch.max(torch.zeros_like(c), c)
        loss += loss_ind
        
    # end test with loop
#     c = 0.2 - s_x_v + s_x_vj  
#     print(c)
#     loss = torch.max(torch.zeros_like(c), c)
#     loss = loss.sum()

    return loss 

In [206]:
def loss_cosine_similarity(image_vector, label_true_vectors, label_random_vectors):

    v_image = image_vector         # [batch, 50]
    t_label = label_true_vectors   # [batch, 50]
    t_j     = label_random_vectors # [batch, 50]
    
    # Cosine similarity
    d_true = F.cosine_similarity(v_image, t_label) #[batch]
    d_contrast = F.cosine_similarity(v_image, t_j)
    
    c = 0.2 - d_true + d_contrast
    loss = torch.max(torch.zeros_like(c), c)
    loss = loss.sum()

    return loss 

In [207]:
def normalize_vec(vec):
    norm = vec.norm(p=2, dim=1, keepdim=True)
    vec_norm = vec.div(norm)
    return vec_norm

In [208]:
def train_epoch(img_cnn, wim, lstm, train_loader, optimizer):
    img_cnn.train()
    lstm.train()
    wim.train()

    running_loss = 0.0
    total_predictions = 0.0
    correct_predictions = 0.0
    
    predictions = []
    ground_truth = []
    
    start_time = time.time()
    for batch_idx, (data, pos_labels, neg_labels, length) in enumerate(train_loader):   
        optimizer.zero_grad()   # .backward() accumulates gradients
        data = data.to(device)
        length = length.to(device)
        
        # Image representation 
        im_repres = img_cnn(data).squeeze(3).squeeze(2) # [batch, 2048]
        
        pos_labels = pos_labels.to(device) # [batch, max_len, 50]
        neg_labels = neg_labels.to(device) # [batch, max_len, 50]
        
        # Lstm
        pos_labels_vec = lstm(pos_labels, length) # [batch, 50]
        neg_labels_vec = lstm(neg_labels, length) # [batch, 50]

        
        # Loss 
        loss = loss_devise(im_repres, wim, pos_labels_vec, neg_labels_vec)

        if batch_idx % 20 == 0 and batch_idx != 0:
            
            print('loss: ', loss.item())
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    end_time = time.time()
    
    loss_epoch = running_loss / len(train_loader)
    print('------ Training -----')
    return loss_epoch

In [209]:
def test_epoch(img_cnn, wim, lstm, test_loader):
    img_cnn.eval()
    lstm.eval()
    wim.eval()
    
    running_loss = 0.0
    total_predictions = 0.0
    correct_predictions = 0.0
    
    predictions = []
    ground_truth = []

    with torch.no_grad():
        for batch_idx, (data, pos_labels, neg_labels, length) in enumerate(test_loader):   
            data = data.to(device)
            length = length.to(device)
            
            # Image representation 
            im_repres = img_cnn(data).squeeze(3).squeeze(2) # [batch, 2048]

            pos_labels = pos_labels.to(device) # [batch, max_len, 50]
            neg_labels = neg_labels.to(device) # [batch, max_len, 50]

            # Lstm
            pos_labels_vec = lstm(pos_labels, length) # [batch, 50]
            neg_labels_vec = lstm(neg_labels, length) # [batch, 50]

            # Loss 
            loss = loss_devise(im_repres, wim, pos_labels_vec, neg_labels_vec)

            if batch_idx % 20 == 0 and batch_idx != 0:
                print('loss: ', loss.item())

            running_loss += loss.item()
    
    loss_epoch = running_loss / len(test_loader)
    print('------ Testing -----')
    return loss_epoch

In [210]:
# MODELS 
device = torch.device("cuda" if cuda else "cpu")
image_CNN = model_ft.to(device)
W_image = Wim(50)
W_image = W_image.to(device)

words_lstm = LSTM(embedding_dim = 50, hidden_dim = 50, n_layers = 2,
                 bidirectional = True, dropout=0.1).to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.

params_to_update = image_CNN.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name, param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)


params = list(params_to_update) + list(W_image.parameters()) + list(words_lstm.parameters())
# optimizer = optim.Adam(model.parameters(), lr = 0.001)
#optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Observe that all parameters are being optimized
# optimizer = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
optimizer = optim.Adam(params, lr = 0.001)

Params to learn:
	 0.weight
	 1.weight
	 1.bias
	 4.0.conv1.weight
	 4.0.bn1.weight
	 4.0.bn1.bias
	 4.0.conv2.weight
	 4.0.bn2.weight
	 4.0.bn2.bias
	 4.0.conv3.weight
	 4.0.bn3.weight
	 4.0.bn3.bias
	 4.0.downsample.0.weight
	 4.0.downsample.1.weight
	 4.0.downsample.1.bias
	 4.1.conv1.weight
	 4.1.bn1.weight
	 4.1.bn1.bias
	 4.1.conv2.weight
	 4.1.bn2.weight
	 4.1.bn2.bias
	 4.1.conv3.weight
	 4.1.bn3.weight
	 4.1.bn3.bias
	 4.2.conv1.weight
	 4.2.bn1.weight
	 4.2.bn1.bias
	 4.2.conv2.weight
	 4.2.bn2.weight
	 4.2.bn2.bias
	 4.2.conv3.weight
	 4.2.bn3.weight
	 4.2.bn3.bias
	 5.0.conv1.weight
	 5.0.bn1.weight
	 5.0.bn1.bias
	 5.0.conv2.weight
	 5.0.bn2.weight
	 5.0.bn2.bias
	 5.0.conv3.weight
	 5.0.bn3.weight
	 5.0.bn3.bias
	 5.0.downsample.0.weight
	 5.0.downsample.1.weight
	 5.0.downsample.1.bias
	 5.1.conv1.weight
	 5.1.bn1.weight
	 5.1.bn1.bias
	 5.1.conv2.weight
	 5.1.bn2.weight
	 5.1.bn2.bias
	 5.1.conv3.weight
	 5.1.bn3.weight
	 5.1.bn3.bias
	 5.2.conv1.weight
	 5.2.bn1.weight

In [211]:
import gc
del dtrain_w, dtest_w, dval_w, dset_im_words, vocab2idx, idx2vocab
del image_embeddings
del dataset_labels 
del labels_embeddings 
del glove_vocab
gc.collect()

NameError: name 'dtrain_w' is not defined

In [212]:
train_losses = []
test_losses = []
save_model_path = './saved_models/unifying/'
for i in range(20):
    print('-----Training epoch {}/{} --------'.format(i,9))
    tr_loss = train_epoch(image_CNN, W_image, words_lstm, train_dataloader, optimizer)
    print('train epoch: {}, loss: {}'.format(i, tr_loss))
    print()
    print('------Testing epoch {}/{} --------'.format(i,9))
    tst_loss = test_epoch(image_CNN, test_dataloader)
    print('test epoch: {}, loss: {}'.format(i, tst_loss))
    
    train_losses.append(tr_loss)
    test_losses.append(tst_loss)
    
    save_path_im = save_model_path + 'cnnim_{}.pt'.format(i)
    torch.save(image_CNN, save_path_im)


-----Training epoch 0/9 --------
loss:  0.0
loss:  0.0
loss:  0.0
loss:  0.0
loss:  0.0
loss:  0.0
loss:  0.0


KeyboardInterrupt: 