In [1]:
from __future__ import division
from __future__ import print_function

import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

import numpy as np
import torch
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import sys
sys.path.append("/home/marta/jku/SBNet/ssnet_fop")

import pandas as pd
from scipy import random
from sklearn import preprocessing
# import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn

from tqdm import tqdm
from retrieval_model import FOP


In [2]:
torch.cuda.is_available()

True

In [3]:
data_folder = '/home/marta/jku/LLaVA/mmimdb'

In [4]:
texts_folder = os.path.join(data_folder, 'llava_encoded_texts')

train_text_df = os.path.join(texts_folder, 'llava_plot_first_latent_train.csv')
test_text_df = os.path.join(texts_folder, 'llava_plot_first_latent_test.csv')

images_folder = os.path.join(data_folder, 'llava_encoded_images')

train_image_df = os.path.join(images_folder, 'llava_images_latent_train.csv')
test_image_df = os.path.join(images_folder, 'llava_images_latent_test.csv')

labels = ['action', 'adult', 'adventure', 'animation', 'biography', 'comedy',
       'crime', 'documentary', 'drama', 'family', 'fantasy', 'film-noir',
       'history', 'horror', 'music', 'musical', 'mystery', 'news',
       'reality-tv', 'romance', 'sci-fi', 'short', 'sport', 'talk-show',
       'thriller', 'war', 'western']

In [10]:
images_folder

'/home/marta/jku/LLaVA/mmimdb/llava_encoded_images'

In [5]:
def sigmoid(x):
   return 1. / (1. + np.exp(-x))

In [6]:
def read_data(FLAGS):

    print('Split Type: %s'%(FLAGS.split_type))

    if FLAGS.split_type == 'text_only':
        print('Reading Text Train')
        train_file_text = train_text_df
        train_data = pd.read_csv(train_file_text, index_col='item_id')
        train_label = train_data[labels]
        train_data = train_data.drop(columns=labels)
        train_data = np.asarray(train_data)
        # Shuffle the data also if only one modality is used
        combined = list(zip(train_data, train_label))
        random.shuffle(combined)
        train_data, train_label = zip(*combined)

        return train_data, train_label

    elif FLAGS.split_type == 'image_only':
        print('Reading Image Train')
        train_file_image = train_image_df
        train_data = pd.read_csv(train_file_image, index_col='item_id')
        train_label = train_data[labels]
        train_data = train_data.drop(columns=labels)
        train_data = np.asarray(train_data)
        # Shuffle the data also if only one modality is used
        combined = list(zip(train_data, train_label))
        random.shuffle(combined)
        train_data, train_label = zip(*combined)

        return train_data, train_label

    train_data = []
    train_label = []

    train_file_face = '/share/hel/datasets/voxceleb/sbnet_feats/data/face/facenetfaceTrain.csv'
    train_file_voice = '/share/hel/datasets/voxceleb/sbnet_feats/data/voice/voiceTrain.csv'

    print('Reading Train Faces')
    img_train = pd.read_csv(train_file_face, header=None)
    train_tmp = img_train[512]
    img_train = np.asarray(img_train)
    img_train = img_train[:, :-1]

    train_tmp = np.asarray(train_tmp)
    train_tmp = train_tmp.reshape((train_tmp.shape[0], 1))
    print('Reading Train Voices')
    voice_train = pd.read_csv(train_file_voice, header=None)
    voice_train = np.asarray(voice_train)
    voice_train = voice_train[:, :-1]

    combined = list(zip(img_train, voice_train, train_tmp))
    # todo marta: why do we need to shuffle here?
    random.shuffle(combined)
    img_train, voice_train, train_tmp = zip(*combined)

    if FLAGS.split_type == 'random':
        # todo marta: aren't we doubling the dataset, like this?
        train_data = np.vstack((img_train, voice_train))
        train_label = np.vstack((train_tmp, train_tmp))
        combined = list(zip(train_data, train_label))
        random.shuffle(combined)
        train_data, train_label = zip(*combined)
        train_data = np.asarray(train_data).astype(np.float)
        train_label = np.asarray(train_label)

    elif FLAGS.split_type == 'vfvf':
        for i in range(len(voice_train)):
            train_data.append(voice_train[i])
            train_data.append(img_train[i])
            train_label.append(train_tmp[i])
            train_label.append(train_tmp[i])

    elif FLAGS.split_type == 'fvfv':
        for i in range(len(voice_train)):
            train_data.append(img_train[i])
            train_data.append(voice_train[i])
            train_label.append(train_tmp[i])
            train_label.append(train_tmp[i])

    elif FLAGS.split_type == 'hefhev':
        train_data = np.vstack((img_train, voice_train))
        train_label = np.vstack((train_tmp, train_tmp))

    elif FLAGS.split_type == 'hevhef':
        train_data = np.vstack((voice_train, img_train))
        train_label = np.vstack((train_tmp, train_tmp))

    else:
        print('Invalid Split Type')

    le = preprocessing.LabelEncoder()
    le.fit(train_label)
    train_label = le.transform(train_label)

    # print("Train file length", len(img_train))
    # print('Shuffling\n')

    train_data = np.asarray(train_data).astype(np.float)
    train_label = np.asarray(train_label)

    return train_data, train_label

def get_batch(batch_index, batch_size, labels, f_lst):
    start_ind = batch_index * batch_size
    end_ind = (batch_index + 1) * batch_size
    return np.asarray(f_lst[start_ind:end_ind]), np.asarray(labels[start_ind:end_ind])

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

def main(train_data, train_label):
    n_class = train_label.shape[1]
    model = FOP(FLAGS, train_data.shape[1], n_class)
    model.apply(init_weights)
    
    # ce_loss = nn.CrossEntropyLoss().cuda()
    bce_logits_loss = nn.BCEWithLogitsLoss()
    # We do not necessarily want orthogonal projection loss imo
    # opl_loss = OrthogonalProjectionLoss().cuda()
    opl_loss = None
    
    if FLAGS.cuda:
        model.cuda()
        # ce_loss.cuda()    
        bce_logits_loss.cuda()
        if opl_loss:
            opl_loss.cuda()
        cudnn.benchmark = True
    
    optimizer = optim.Adam(model.parameters(), lr=FLAGS.lr, weight_decay=0.01)

    n_parameters = sum([p.data.nelement() for p in model.parameters()])
    print('  + Number of params: {}'.format(n_parameters))
    
    
    for alpha in FLAGS.alpha_list:
        epoch = 1
        num_of_batches = (len(train_label) // FLAGS.batch_size)
        loss_plot = []
        precision_list = []
        recall_list = []
        loss_per_epoch = 0
        s_fac_per_epoch = 0
        d_fac_per_epoch = 0
        txt_dir = 'output'
        save_dir = 'fc2_%s_%s_alpha_%0.2f'%(FLAGS.split_type, FLAGS.save_dir, alpha)
        txt = '%s/ce_opl_%03d_%0.2f.txt'%(txt_dir, FLAGS.max_num_epoch, alpha)
        
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        if not os.path.exists(txt_dir):
            os.makedirs(txt_dir)
        
        with open(txt,'w+') as f:
            f.write('EPOCH\tLOSS\tprecision\trecall\tS_FAC\tD_FAC\n')
        
        save_best = 'best_%s'%(save_dir)
        
        if not os.path.exists(save_best):
            os.mkdir(save_best)
        with open(txt,'a+') as f:
            while (epoch < FLAGS.max_num_epoch):
                print('%s\tEpoch %03d'%(FLAGS.split_type, epoch))
                for idx in tqdm(range(num_of_batches)):
                    train_batch, batch_labels = get_batch(idx, FLAGS.batch_size, train_label, train_data)
                    # voice_feats, _ = get_batch(idx, FLAGS.batch_size, train_label, voice_train)
                    loss_tmp, loss_opl, loss_soft, s_fac, d_fac = train(train_batch, 
                                                                 batch_labels, 
                                                                 model, optimizer, bce_logits_loss, opl_loss, alpha)
                    loss_per_epoch+=loss_tmp
                    s_fac_per_epoch+=s_fac
                    d_fac_per_epoch+=d_fac
                
                loss_per_epoch/=num_of_batches
                s_fac_per_epoch/=num_of_batches
                d_fac_per_epoch/=num_of_batches
                
                loss_plot.append(loss_per_epoch)
                # ToDo
                
                # if FLAGS.split_type == 'voice_only' or FLAGS.split_type == 'face_only':
                #     eer, auc = onlineTestSingleModality.test(FLAGS, model, test_feat)
                # else:
                #     eer, auc = online_evaluation.test(FLAGS, model, test_feat)
                # eer_list.append(eer)
                # auc_list.append(auc)
                # save_checkpoint({
                #    'epoch': epoch,
                #    'state_dict': model.state_dict()}, save_dir, 'checkpoint_%04d_%0.3f.pth.tar'%(epoch, eer*100))

#                 print('==> Epoch: %d/%d Loss: %0.2f Alpha:%0.2f, Min_EER: %0.2f'%(epoch, FLAGS.max_num_epoch, loss_per_epoch, alpha, min(eer_list)))
                
#                 if eer <= min(eer_list):
#                     min_eer = eer
#                     max_auc = auc
#                     save_checkpoint({
#                     'epoch': epoch,
#                     'state_dict': model.state_dict()}, save_best, 'checkpoint.pth.tar')
                # ToDo 
                eer, auc = 0., 0.
                f.write('%04d\t%0.4f\t%0.2f\t%0.2f\t%0.2f\t%0.2f\n'%(epoch, loss_per_epoch, eer, auc, s_fac_per_epoch, d_fac_per_epoch))
                loss_per_epoch = 0
                s_fac_per_epoch = 0
                d_fac_per_epoch = 0
                epoch += 1
        
        return loss_plot# , min_eer, max_auc                
#         return loss_plot, min_eer, max_auc
    
    
class OrthogonalProjectionLoss(nn.Module):
    def __init__(self):
        super(OrthogonalProjectionLoss, self).__init__()
        self.device = (torch.device('cuda') if FLAGS.cuda else torch.device('cpu'))

    def forward(self, features, labels=None):
        
        features = F.normalize(features, p=2, dim=1)

        labels = labels[:, None]

        mask = torch.eq(labels, labels.t()).bool().to(self.device)
        eye = torch.eye(mask.shape[0], mask.shape[1]).bool().to(self.device)

        mask_pos = mask.masked_fill(eye, 0).float()
        mask_neg = (~mask).float()
        dot_prod = torch.matmul(features, features.t())

        pos_pairs_mean = (mask_pos * dot_prod).sum() / (mask_pos.sum() + 1e-6)
        neg_pairs_mean = torch.abs(mask_neg * dot_prod).sum() / (mask_neg.sum() + 1e-6)

        loss = (1.0 - pos_pairs_mean) + (0.7 * neg_pairs_mean)

        return loss, pos_pairs_mean, neg_pairs_mean


def train(train_batch, labels, model, optimizer, bce_logits_loss, opl_loss, alpha):
    
    average_loss = RunningAverage()
    soft_losses = RunningAverage()
    if opl_loss:
        opl_losses = RunningAverage()

    model.train()
    # face_feats = torch.from_numpy(face_feats).float()
    train_batch = torch.from_numpy(train_batch).float()
    labels = torch.from_numpy(labels).float()
    
    if FLAGS.cuda:
        train_batch, labels = train_batch.cuda(), labels.cuda()

    train_batch, labels = Variable(train_batch), Variable(labels)
    comb = model.train_forward(train_batch)
    
    # loss_soft = ce_loss(comb[1], labels)
    loss_soft = bce_logits_loss(comb[1], labels)
    predictions = sigmoid(loss_soft)
    predictions = (predictions > 0.5).astype(int).reshape(-1)
    
    
    if opl_loss:
        loss_opl, s_fac, d_fac = opl_loss(comb[0], labels)
        loss = loss_soft + alpha * loss_opl
    else: 
        loss = loss_soft
        s_fac, d_fac = 0., 0.
        opl_losses = 0.

    optimizer.zero_grad()
    
    loss.backward()
    average_loss.update(loss.item())
    if opl_loss:
        opl_losses.update(loss_opl.item())
    soft_losses.update(loss_soft.item())
    
    optimizer.step()
    if opl_loss:
        return average_loss.avg(), opl_losses.avg(), soft_losses.avg(), s_fac, d_fac
    else:
        return average_loss.avg(), opl_losses, soft_losses.avg(), s_fac, d_fac

class RunningAverage(object):
    def __init__(self):
        self.value_sum = 0.
        self.num_items = 0. 

    def update(self, val):
        self.value_sum += val 
        self.num_items += 1

    def avg(self):
        average = 0.
        if self.num_items > 0:
            average = self.value_sum / self.num_items

        return average
 
def save_checkpoint(state, directory, filename):
    filename = os.path.join(directory, filename)
    torch.save(state, filename)
    

In [1]:
from sklearn.metrics import precision_score

In [2]:
# multilabel classification
y_true = [[0, 0, 0], [1, 1, 1], [0, 1, 1]]
y_pred = [[0, 0, 0], [1, 1, 1], [1, 1, 0]]

In [4]:
precision_score(y_true, y_pred, average=None)

array([0.5, 1. , 1. ])

In [11]:
precision_score(y_true, y_pred, average='da'
                                    )

ValueError: average has to be one of (None, 'micro', 'macro', 'weighted', 'samples')

In [7]:
global FLAGS

In [8]:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random Seed')
parser.add_argument('--cuda', action='store_true', default=True, help='CUDA Training')
parser.add_argument('--save_dir', type=str, default='model', help='Directory for saving checkpoints.')
parser.add_argument('--lr', type=float, default=1e-2, metavar='LR',
                    help='learning rate (default: 1e-4)') 
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
parser.add_argument('--max_num_epoch', type=int, default=100, help='Max number of epochs to train, number')
parser.add_argument('--alpha_list', type=list, default=[1], help='Alpha Values List')
parser.add_argument('--dim_embed', type=int, default=256,
                    help='Embedding Size')
parser.add_argument('--split_type', type=str, default='image_only', help='split_type')

_StoreAction(option_strings=['--split_type'], dest='split_type', nargs=None, const=None, default='image_only', type=<class 'str'>, choices=None, help='split_type', metavar=None)

In [9]:
FLAGS, unparsed = parser.parse_known_args()

In [10]:
FLAGS

Namespace(alpha_list=[1], batch_size=128, cuda=True, dim_embed=256, lr=0.01, max_num_epoch=100, save_dir='model', seed=1, split_type='image_only')

In [11]:
train_data, train_label = read_data(FLAGS)

Split Type: image_only
Reading Image Train


In [12]:
print('Split Type: %s'%(FLAGS.split_type))

Split Type: image_only


In [13]:
train_data.shape, train_label.shape

((15552, 7168), (15552, 27))

In [14]:
losses = main(train_data, train_label)



  + Number of params: 1909019
image_only	Epoch 001


  1%|          | 1/121 [00:00<00:12,  9.91it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


  4%|▍         | 5/121 [00:00<00:04, 26.63it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


  8%|▊         | 10/121 [00:00<00:03, 34.69it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 12%|█▏        | 14/121 [00:00<00:02, 36.58it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 16%|█▌        | 19/121 [00:00<00:02, 39.71it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 20%|█▉        | 24/121 [00:00<00:02, 39.98it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 24%|██▍       | 29/121 [00:00<00:02, 40.49it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 28%|██▊       | 34/121 [00:00<00:02, 40.72it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 32%|███▏      | 39/121 [00:01<00:01, 42.19it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 40%|████      | 49/121 [00:01<00:01, 42.52it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 45%|████▍     | 54/121 [00:01<00:01, 42.85it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 49%|████▉     | 59/121 [00:01<00:01, 41.86it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 53%|█████▎    | 64/121 [00:01<00:01, 43.12it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])

 57%|█████▋    | 69/121 [00:01<00:01, 43.10it/s]


torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 61%|██████    | 74/121 [00:01<00:01, 43.15it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 65%|██████▌   | 79/121 [00:01<00:00, 43.04it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 69%|██████▉   | 84/121 [00:02<00:00, 42.72it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 74%|███████▎  | 89/121 [00:02<00:00, 42.82it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 78%|███████▊  | 94/121 [00:02<00:00, 43.06it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 86%|████████▌ | 104/121 [00:02<00:00, 45.44it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 90%|█████████ | 109/121 [00:02<00:00, 45.30it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 94%|█████████▍| 114/121 [00:02<00:00, 45.77it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


100%|██████████| 121/121 [00:02<00:00, 42.10it/s]


torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
image_only	Epoch 002


  0%|          | 0/121 [00:00<?, ?it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


  5%|▍         | 6/121 [00:00<00:02, 50.40it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 10%|▉         | 12/121 [00:00<00:02, 52.03it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 15%|█▍        | 18/121 [00:00<00:02, 51.30it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 20%|█▉        | 24/121 [00:00<00:01, 49.48it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 24%|██▍       | 29/121 [00:00<00:01, 47.11it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 28%|██▊       | 34/121 [00:00<00:01, 44.03it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27])

 32%|███▏      | 39/121 [00:00<00:01, 41.82it/s]

 torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 36%|███▋      | 44/121 [00:00<00:01, 42.16it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 40%|████      | 49/121 [00:01<00:01, 42.95it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 45%|████▍     | 54/121 [00:01<00:01, 44.37it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 49%|████▉     | 59/121 [00:01<00:01, 45.62it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 53%|█████▎    | 64/121 [00:01<00:01, 46.80it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 57%|█████▋    | 69/121 [00:01<00:01, 46.66it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 61%|██████    | 74/121 [00:01<00:01, 44.90it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 65%|██████▌   | 79/121 [00:01<00:00, 42.82it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 69%|██████▉   | 84/121 [00:01<00:00, 43.38it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 74%|███████▎  | 89/121 [00:01<00:00, 45.17it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 78%|███████▊  | 94/121 [00:02<00:00, 45.95it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 82%|████████▏ | 99/121 [00:02<00:00, 45.16it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 86%|████████▌ | 104/121 [00:02<00:00, 44.35it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 90%|█████████ | 109/121 [00:02<00:00, 43.95it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 94%|█████████▍| 114/121 [00:02<00:00, 44.24it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 98%|█████████▊| 119/121 [00:02<00:00, 45.54it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


100%|██████████| 121/121 [00:02<00:00, 45.21it/s]


image_only	Epoch 003


  0%|          | 0/121 [00:00<?, ?it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


  5%|▍         | 6/121 [00:00<00:02, 52.35it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 10%|▉         | 12/121 [00:00<00:02, 45.01it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 14%|█▍        | 17/121 [00:00<00:02, 42.86it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 18%|█▊        | 22/121 [00:00<00:02, 43.02it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 22%|██▏       | 27/121 [00:00<00:02, 41.82it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 26%|██▋       | 32/121 [00:00<00:02, 43.20it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27])

 31%|███       | 37/121 [00:00<00:01, 42.35it/s]

 torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 35%|███▍      | 42/121 [00:00<00:01, 43.36it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 39%|███▉      | 47/121 [00:01<00:01, 43.22it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 43%|████▎     | 52/121 [00:01<00:01, 43.35it/s]

torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])
torch.Size([128, 27]) torch.Size([128, 27])


 45%|████▌     | 55/121 [00:01<00:01, 42.71it/s]


KeyboardInterrupt: 