In [None]:
# -*- coding: utf-8 -*-
"""
create on Sep 24, 2019

@author: wangshuo
"""

import random
import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.io import loadmat

random.seed(1234)

workdir = '/content/drive/MyDrive/datasets/'

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='FilmTrust_SimRank', help='dataset name: FilmTrust_SimRank')
parser.add_argument('--test_prop', default=0.4, help='the proportion of data used for test')
args = parser.parse_args(args=[])

# load data
if args.dataset == 'Ciao':
	click_f = loadmat(workdir + 'Ciao/rating.mat')['rating']
	trust_f = loadmat(workdir + 'Ciao/trustnetwork.mat')['trustnetwork']
elif args.dataset == 'FilmTrust_SimRank':
	click_f = np.loadtxt(workdir+'FilmTrust_SimRank/ft_ratings.txt', dtype = np.int32)
	trust_f = np.loadtxt(workdir+'FilmTrust_SimRank/ft_trust_SimRank.txt', dtype = np.int32)
else:
	pass 

click_list = []
trust_list = []

u_items_list = []
u_users_list = []
u_users_items_list = []
i_users_list = []

user_count = 0
item_count = 0
rate_count = 0

for s in click_f:
	uid = s[0]
	iid = s[1]
	if args.dataset == 'Ciao':
		label = s[3]
	elif args.dataset == 'FilmTrust_SimRank':
		label = s[2]

	if uid > user_count:
		user_count = uid
	if iid > item_count:
		item_count = iid
	if label > rate_count:
		rate_count = label
	click_list.append([uid, iid, label])

pos_list = []
for i in range(len(click_list)):
	pos_list.append((click_list[i][0], click_list[i][1], click_list[i][2]))

# remove duplicate items in pos_list because there are some cases where a user may have different rate scores on the same item.
pos_list = list(set(pos_list))

# train, valid and test data split
random.shuffle(pos_list)
num_test = int(len(pos_list) * args.test_prop)
test_set = pos_list[:num_test]
valid_set = pos_list[num_test:2 * num_test]
train_set = pos_list[2 * num_test:]
print('Train samples: {}, Valid samples: {}, Test samples: {}'.format(len(train_set), len(valid_set), len(test_set)))

with open(workdir + args.dataset + '/dataset.pkl', 'wb') as f:
	pickle.dump(train_set, f, pickle.HIGHEST_PROTOCOL)
	pickle.dump(valid_set, f, pickle.HIGHEST_PROTOCOL)
	pickle.dump(test_set, f, pickle.HIGHEST_PROTOCOL)


train_df = pd.DataFrame(train_set, columns = ['uid', 'iid', 'label'])
valid_df = pd.DataFrame(valid_set, columns = ['uid', 'iid', 'label'])
test_df = pd.DataFrame(test_set, columns = ['uid', 'iid', 'label'])

click_df = pd.DataFrame(click_list, columns = ['uid', 'iid', 'label'])
train_df = train_df.sort_values(axis = 0, ascending = True, by = 'uid')

"""
u_items_list: 存储每个用户交互过的物品iid和对应的评分，没有则为[(0, 0)]
"""
for u in tqdm(range(user_count + 1)):
	hist = train_df[train_df['uid'] == u]
	u_items = hist['iid'].tolist()
	u_ratings = hist['label'].tolist()
	if u_items == []:
		u_items_list.append([(0, 0)])
	else:
		u_items_list.append([(iid, rating) for iid, rating in zip(u_items, u_ratings)])

train_df = train_df.sort_values(axis = 0, ascending = True, by = 'iid')

"""
i_users_list: 存储与每个物品相关联的用户及其评分，没有则为[(0, 0)]
"""
for i in tqdm(range(item_count + 1)):
	hist = train_df[train_df['iid'] == i]
	i_users = hist['uid'].tolist()
	i_ratings = hist['label'].tolist()
	if i_users == []:
		i_users_list.append([(0, 0)])
	else:
		i_users_list.append([(uid, rating) for uid, rating in zip(i_users, i_ratings)])

for s in trust_f:
	uid = s[0]
	fid = s[1]
	if uid > user_count or fid > user_count:
		continue
	trust_list.append([uid, fid])

trust_df = pd.DataFrame(trust_list, columns = ['uid', 'fid'])
trust_df = trust_df.sort_values(axis = 0, ascending = True, by = 'uid')


"""
u_users_list: 存储每个用户互动过的用户uid；
u_users_items_list: 存储用户每个朋友的物品iid列表
"""
for u in tqdm(range(user_count + 1)):
	hist = trust_df[trust_df['uid'] == u]
	u_users = hist['fid'].unique().tolist()
	if u_users == []:
		u_users_list.append([0])
		u_users_items_list.append([[(0,0)]])
	else:
		u_users_list.append(u_users)
		uu_items = []
		for uid in u_users:
			uu_items.append(u_items_list[uid])
		u_users_items_list.append(uu_items)
	
with open(workdir + args.dataset + '/list.pkl', 'wb') as f:
	pickle.dump(u_items_list, f, pickle.HIGHEST_PROTOCOL)
	pickle.dump(u_users_list, f, pickle.HIGHEST_PROTOCOL)
	pickle.dump(u_users_items_list, f, pickle.HIGHEST_PROTOCOL)
	pickle.dump(i_users_list, f, pickle.HIGHEST_PROTOCOL)
	pickle.dump((user_count, item_count, rate_count), f, pickle.HIGHEST_PROTOCOL)




Train samples: 14484, Valid samples: 28967, Test samples: 28967


100%|██████████| 17616/17616 [00:13<00:00, 1315.00it/s]
100%|██████████| 16122/16122 [00:11<00:00, 1369.62it/s]
100%|██████████| 17616/17616 [00:12<00:00, 1405.03it/s]


In [None]:
import numpy as np
import random
import torch
from torch.utils.data import Dataset

class GRDataset(Dataset):
	def __init__(self, data, u_items_list, u_users_list, u_users_items_list, i_users_list):
		self.data = data
		self.u_items_list = u_items_list
		self.u_users_list = u_users_list
		self.u_users_items_list = u_users_items_list
		self.i_users_list = i_users_list

	def __getitem__(self, index):
		uid = self.data[index][0]
		iid = self.data[index][1]
		label = self.data[index][2]
		u_items = self.u_items_list[uid]
		u_users = self.u_users_list[uid]
		u_users_items = self.u_users_items_list[uid]
		i_users = self.i_users_list[iid]

		return (uid, iid, label), u_items, u_users, u_users_items, i_users

	def __len__(self):
		return len(self.data)

In [None]:
import torch
import random

truncate_len = 30

"""
Ciao dataset info:
Avg number of items rated per user: 38.3
Avg number of users interacted per user: 2.7
Avg number of users connected per item: 16.4
"""

def collate_fn(batch_data):
    """This function will be used to pad the graph to max length in the batch
       It will be used in the Dataloader
    """
    uids, iids, labels = [], [], []
    u_items, u_users, u_users_items, i_users = [], [], [], []
    u_items_len, u_users_len, i_users_len = [], [], []

    for data, u_items_u, u_users_u, u_users_items_u, i_users_i in batch_data:

        (uid, iid, label) = data
        uids.append(uid)
        iids.append(iid)
        labels.append(label)

        # user-items    
        if len(u_items_u) <= truncate_len:
            u_items.append(u_items_u)
        else:
            u_items.append(random.sample(u_items_u, truncate_len))
        u_items_len.append(min(len(u_items_u), truncate_len))
        
        # user-users and user-users-items
        if len(u_users_u) <= truncate_len:
            u_users.append(u_users_u)
            u_u_items = [] 
            for uui in u_users_items_u:
                if len(uui) < truncate_len:
                    u_u_items.append(uui)
                else:
                    u_u_items.append(random.sample(uui, truncate_len))
            u_users_items.append(u_u_items)
        else:
            sample_index = random.sample(list(range(len(u_users_u))), truncate_len)
            u_users.append([u_users_u[si] for si in sample_index])

            u_users_items_u_tr = [u_users_items_u[si] for si in sample_index]
            u_u_items = [] 
            for uui in u_users_items_u_tr:
                if len(uui) < truncate_len:
                    u_u_items.append(uui)
                else:
                    u_u_items.append(random.sample(uui, truncate_len))
            u_users_items.append(u_u_items)

        u_users_len.append(min(len(u_users_u), truncate_len))	

        # item-users
        if len(i_users_i) <= truncate_len:
            i_users.append(i_users_i)
        else:
            i_users.append(random.sample(i_users_i, truncate_len))
        i_users_len.append(min(len(i_users_i), truncate_len))

    batch_size = len(batch_data)

    # padding
    u_items_maxlen = max(u_items_len)
    u_users_maxlen = max(u_users_len)
    i_users_maxlen = max(i_users_len)
    
    u_item_pad = torch.zeros([batch_size, u_items_maxlen, 2], dtype=torch.long)
    for i, ui in enumerate(u_items):
        u_item_pad[i, :len(ui), :] = torch.LongTensor(ui)
    
    u_user_pad = torch.zeros([batch_size, u_users_maxlen], dtype=torch.long)
    for i, uu in enumerate(u_users):
        u_user_pad[i, :len(uu)] = torch.LongTensor(uu)
    
    u_user_item_pad = torch.zeros([batch_size, u_users_maxlen, u_items_maxlen, 2], dtype=torch.long)
    for i, uu_items in enumerate(u_users_items):
        for j, ui in enumerate(uu_items):
            u_user_item_pad[i, j, :len(ui), :] = torch.LongTensor(ui)

    i_user_pad = torch.zeros([batch_size, i_users_maxlen, 2], dtype=torch.long)
    for i, iu in enumerate(i_users):
        i_user_pad[i, :len(iu), :] = torch.LongTensor(iu)

    return torch.LongTensor(uids), torch.LongTensor(iids), torch.FloatTensor(labels), \
            u_item_pad, u_user_pad, u_user_item_pad, i_user_pad

In [None]:
from torch import nn
import torch

class _MultiLayerPercep(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(_MultiLayerPercep, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2, bias=True),
            nn.ReLU(),            
            nn.Linear(input_dim // 2, output_dim, bias=True),
        )

    def forward(self, x):
        return self.mlp(x)


class _Aggregation(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(_Aggregation, self).__init__()
        self.aggre = nn.Sequential(
            nn.Linear(input_dim, output_dim, bias=True),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.aggre(x)


class _UserModel(nn.Module):
    ''' User modeling to learn user latent factors.
    User modeling leverages two types aggregation: item aggregation and social aggregation
    '''
    def __init__(self, emb_dim, user_emb, item_emb, rate_emb):
        super(_UserModel, self).__init__()
        self.user_emb = user_emb
        self.item_emb = item_emb
        self.rate_emb = rate_emb
        self.emb_dim = emb_dim

        self.g_v = _MultiLayerPercep(2 * self.emb_dim, self.emb_dim)

        self.user_items_att = _MultiLayerPercep(2 * self.emb_dim, 1)
        self.aggre_items = _Aggregation(self.emb_dim, self.emb_dim)

        self.user_users_att = _MultiLayerPercep(2 * self.emb_dim, 1)
        self.aggre_neigbors = _Aggregation(self.emb_dim, self.emb_dim)
        
        self.combine_mlp = nn.Sequential(
            nn.Linear(2 * self.emb_dim, self.emb_dim, bias = True),
            nn.ReLU(),
            nn.Linear(self.emb_dim, self.emb_dim, bias = True),
            nn.ReLU(),
            nn.Linear(self.emb_dim, self.emb_dim, bias = True),
            nn.ReLU(),
        )

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # used for preventing zero div error when calculating softmax score
        self.eps = 1e-10

    def forward(self, uids, u_item_pad, u_user_pad, u_user_item_pad):
        # item aggregation
        q_a = self.item_emb(u_item_pad[:,:,0])   # B x maxi_len x emb_dim
        mask_u = torch.where(u_item_pad[:,:,0] > 0, torch.tensor([1.], device=self.device), torch.tensor([0.], device=self.device))   # B x maxi_len
        u_item_er = self.rate_emb(u_item_pad[:,:,1])  # B x maxi_len x emb_dim
        
        x_ia = self.g_v(torch.cat([q_a, u_item_er], dim = 2).view(-1, 2 * self.emb_dim)).view(q_a.size())  # B x maxi_len x emb_dim

        ## calculate attention scores in item aggregation 
        p_i = mask_u.unsqueeze(2).expand_as(x_ia) * self.user_emb(uids).unsqueeze(1).expand_as(x_ia)  # B x maxi_len x emb_dim
        
        alpha = self.user_items_att(torch.cat([x_ia, p_i], dim = 2).view(-1, 2 * self.emb_dim)).view(mask_u.size()) # B x maxi_len
        alpha = torch.exp(alpha) * mask_u
        alpha = alpha / (torch.sum(alpha, 1).unsqueeze(1).expand_as(alpha) + self.eps)

        h_iI = self.aggre_items(torch.sum(alpha.unsqueeze(2).expand_as(x_ia) * x_ia, 1))     # B x emb_dim

        # social aggregation
        q_a_s = self.item_emb(u_user_item_pad[:,:,:,0])   # B x maxu_len x maxi_len x emb_dim
        mask_s = torch.where(u_user_item_pad[:,:,:,0] > 0, torch.tensor([1.], device=self.device), torch.tensor([0.], device=self.device))  # B x maxu_len x maxi_len
        u_user_item_er = self.rate_emb(u_user_item_pad[:,:,:,1]) # B x maxu_len x maxi_len x emb_dim
        
        x_ia_s = self.g_v(torch.cat([q_a_s, u_user_item_er], dim = 3).view(-1, 2 * self.emb_dim)).view(q_a_s.size())  # B x maxu_len x maxi_len x emb_dim   

        p_i_s = mask_s.unsqueeze(3).expand_as(x_ia_s) * self.user_emb(u_user_pad).unsqueeze(2).expand_as(x_ia_s)  # B x maxu_len x maxi_len x emb_dim

        alpha_s = self.user_items_att(torch.cat([x_ia_s, p_i_s], dim = 3).view(-1, 2 * self.emb_dim)).view(mask_s.size())    # B x maxu_len x maxi_len
        alpha_s = torch.exp(alpha_s) * mask_s
        alpha_s = alpha_s / (torch.sum(alpha_s, 2).unsqueeze(2).expand_as(alpha_s) + self.eps)

        h_oI_temp = torch.sum(alpha_s.unsqueeze(3).expand_as(x_ia_s) * x_ia_s, 2)    # B x maxu_len x emb_dim
        h_oI = self.aggre_items(h_oI_temp.view(-1, self.emb_dim)).view(h_oI_temp.size())  # B x maxu_len x emb_dim

        ## calculate attention scores in social aggregation
        beta = self.user_users_att(torch.cat([h_oI, self.user_emb(u_user_pad)], dim = 2).view(-1, 2 * self.emb_dim)).view(u_user_pad.size())
        mask_su = torch.where(u_user_pad > 0, torch.tensor([1.], device=self.device), torch.tensor([0.], device=self.device))
        beta = torch.exp(beta) * mask_su
        beta = beta / (torch.sum(beta, 1).unsqueeze(1).expand_as(beta) + self.eps)
        h_iS = self.aggre_neigbors(torch.sum(beta.unsqueeze(2).expand_as(h_oI) * h_oI, 1))     # B x emb_dim

        ## learning user latent factor
        h_i = self.combine_mlp(torch.cat([h_iI, h_iS], dim = 1))

        return h_i


class _ItemModel(nn.Module):
    '''Item modeling to learn item latent factors.
    '''
    def __init__(self, emb_dim, user_emb, item_emb, rate_emb):
        super(_ItemModel, self).__init__()
        self.emb_dim = emb_dim
        self.user_emb = user_emb
        self.item_emb = item_emb
        self.rate_emb = rate_emb

        self.g_u = _MultiLayerPercep(2 * self.emb_dim, self.emb_dim)
        
        self.item_users_att = _MultiLayerPercep(2 * self.emb_dim, 1)
        self.aggre_users = _Aggregation(self.emb_dim, self.emb_dim)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # used for preventing zero div error when calculating softmax score
        self.eps = 1e-10

    def forward(self, iids, i_user_pad):
        # user aggregation
        p_t = self.user_emb(i_user_pad[:,:,0])
        mask_i = torch.where(i_user_pad[:,:,0] > 0, torch.tensor([1.], device=self.device), torch.tensor([0.], device=self.device))
        i_user_er = self.rate_emb(i_user_pad[:,:,1])
        
        f_jt = self.g_u(torch.cat([p_t, i_user_er], dim = 2).view(-1, 2 * self.emb_dim)).view(p_t.size())
        
        # calculate attention scores in user aggregation
        q_j = mask_i.unsqueeze(2).expand_as(f_jt) * self.item_emb(iids).unsqueeze(1).expand_as(f_jt)
        
        miu = self.item_users_att(torch.cat([f_jt, q_j], dim = 2).view(-1, 2 * self.emb_dim)).view(mask_i.size())
        miu = torch.exp(miu) * mask_i
        miu = miu / (torch.sum(miu, 1).unsqueeze(1).expand_as(miu) + self.eps)
        
        z_j = self.aggre_users(torch.sum(miu.unsqueeze(2).expand_as(f_jt) * f_jt, 1))

        return z_j


class GraphRec(nn.Module):
    '''GraphRec model proposed in the paper Graph neural network for social recommendation 

    Args:
        number_users: the number of users in the dataset.
        number_items: the number of items in the dataset.
        num_rate_levels: the number of rate levels in the dataset.
        emb_dim: the dimension of user and item embedding (default = 64).

    '''
    def __init__(self, num_users, num_items, num_rate_levels, emb_dim = 64):
        super(GraphRec, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.num_rate_levels = num_rate_levels
        self.emb_dim = emb_dim
        self.user_emb = nn.Embedding(self.num_users, self.emb_dim, padding_idx = 0)
        self.item_emb = nn.Embedding(self.num_items, self.emb_dim, padding_idx = 0)
        self.rate_emb = nn.Embedding(self.num_rate_levels, self.emb_dim, padding_idx = 0)

        self.user_model = _UserModel(self.emb_dim, self.user_emb, self.item_emb, self.rate_emb)

        self.item_model = _ItemModel(self.emb_dim, self.user_emb, self.item_emb, self.rate_emb)
        
        self.rate_pred = nn.Sequential(
            nn.Linear(2 * self.emb_dim, self.emb_dim, bias = True),
            nn.ReLU(),
            nn.Linear(self.emb_dim, self.emb_dim, bias = True),
            nn.ReLU(),
            nn.Linear(self.emb_dim, 1),
        )


    def forward(self, uids, iids, u_item_pad, u_user_pad, u_user_item_pad, i_user_pad):
        '''
        Args:
            uids: the user id sequences.
            iids: the item id sequences.
            u_item_pad: the padded user-item graph.
            u_user_pad: the padded user-user graph.
            u_user_item_pad: the padded user-user-item graph.
            i_user_pad: the padded item-user graph.

        Shapes:
            uids: (B).
            iids: (B).
            u_item_pad: (B, ItemSeqMaxLen, 2).
            u_user_pad: (B, UserSeqMaxLen).
            u_user_item_pad: (B, UserSeqMaxLen, ItemSeqMaxLen, 2).
            i_user_pad: (B, UserSeqMaxLen, 2).

        Returns:
            the predicted rate scores of the user to the item.
        '''

        h_i = self.user_model(uids, u_item_pad, u_user_pad, u_user_item_pad)
        z_j = self.item_model(iids, i_user_pad)

        # make prediction
        r_ij = self.rate_pred(torch.cat([h_i, z_j], dim = 1))
        return r_ij
       

In [None]:
#!/usr/bin/env python37
# -*- coding: utf-8 -*-
"""
Created on 30 Sep, 2019

@author: wangshuo
"""

import os
import time
import argparse
import pickle
import numpy as np
import random
from tqdm import tqdm
from os.path import join

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.autograd import Variable
from torch.backends import cudnn



parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', default='/content/drive/MyDrive/datasets/FilmTrust_SimRank/', help='dataset directory path: datasets/FilmTrust_SimRank')
parser.add_argument('--batch_size', type=int, default=256, help='input batch size')
parser.add_argument('--embed_dim', type=int, default=64, help='the dimension of embedding')
parser.add_argument('--epoch', type=int, default=30, help='the number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')  # [0.001, 0.0005, 0.0001]
parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay rate')
parser.add_argument('--lr_dc_step', type=int, default=30, help='the number of steps after which the learning rate decay')
parser.add_argument('--test', action='store_true', help='test')
args = parser.parse_args(args=[])
print(args)

#here = os.path.dirname(os.path.abspath(__file__))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main():
    print('Loading data...')
    with open(args.dataset_path + 'dataset.pkl', 'rb') as f:
        train_set = pickle.load(f)
        valid_set = pickle.load(f)
        test_set = pickle.load(f)

    with open(args.dataset_path + 'list.pkl', 'rb') as f:
        u_items_list = pickle.load(f)
        u_users_list = pickle.load(f)
        u_users_items_list = pickle.load(f)
        i_users_list = pickle.load(f)
        (user_count, item_count, rate_count) = pickle.load(f)
    
    train_data = GRDataset(train_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
    valid_data = GRDataset(valid_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
    test_data = GRDataset(test_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
    train_loader = DataLoader(train_data, batch_size = args.batch_size, shuffle = True, collate_fn = collate_fn)
    valid_loader = DataLoader(valid_data, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn)
    test_loader = DataLoader(test_data, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn)
   
    model = GraphRec(user_count+1, item_count+1, rate_count+1, args.embed_dim).to(device)
    
    if args.test:
        print('Load checkpoint and testing...')
        ckpt = torch.load('best_checkpoint.pth.tar')
        model.load_state_dict(ckpt['state_dict'])
        mae, rmse = validate(test_loader, model)
        print("Test: MAE: {:.4f}, RMSE: {:.4f}".format(mae, rmse))
        return

    optimizer = optim.RMSprop(model.parameters(), args.lr)
    criterion = nn.MSELoss()
    scheduler = StepLR(optimizer, step_size = args.lr_dc_step, gamma = args.lr_dc)

    for epoch in tqdm(range(args.epoch)):
        # train for one epoch
        scheduler.step(epoch = epoch)
        trainForEpoch(train_loader, model, optimizer, epoch, args.epoch, criterion, log_aggr = 100)

        mae, rmse = validate(valid_loader, model)

        # store best loss and save a model checkpoint
        ckpt_dict = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }

        torch.save(ckpt_dict, 'latest_checkpoint.pth.tar')

        if epoch == 0:
            best_mae = mae
        elif mae < best_mae:
            best_mae = mae
            torch.save(ckpt_dict, 'best_checkpoint.pth.tar')

        print('Epoch {} validation: MAE: {:.4f}, RMSE: {:.4f}, Best MAE: {:.4f}'.format(epoch, mae, rmse, best_mae))

 

def trainForEpoch(train_loader, model, optimizer, epoch, num_epochs, criterion, log_aggr=1):
    model.train()

    sum_epoch_loss = 0

    start = time.time()
    for i, (uids, iids, labels, u_items, u_users, u_users_items, i_users) in tqdm(enumerate(train_loader), total=len(train_loader)):
        uids = uids.to(device)
        iids = iids.to(device)
        labels = labels.to(device)
        u_items = u_items.to(device)
        u_users = u_users.to(device)
        u_users_items = u_users_items.to(device)
        i_users = i_users.to(device)
        
            
    
        optimizer.zero_grad()
        outputs = model(uids, iids, u_items, u_users, u_users_items, i_users)

        loss = criterion(outputs, labels.unsqueeze(1))
        loss.backward()
        optimizer.step() 

        loss_val = loss.item()
        sum_epoch_loss += loss_val

        iter_num = epoch * len(train_loader) + i + 1

        if i % log_aggr == 0:
            print('[TRAIN] epoch %d/%d batch loss: %.4f (avg %.4f) (%.2f im/s)'
                % (epoch + 1, num_epochs, loss_val, sum_epoch_loss / (i + 1),
                  len(uids) / (time.time() - start)))

        start = time.time()


           


def validate(valid_loader, model):
    model.eval()
    errors = []
    with torch.no_grad():
        for uids, iids, labels, u_items, u_users, u_users_items, i_users in tqdm(valid_loader):
            uids = uids.to(device)
            iids = iids.to(device)
            labels = labels.to(device)
            u_items = u_items.to(device)
            u_users = u_users.to(device)
            u_users_items = u_users_items.to(device)
            i_users = i_users.to(device)
            preds = model(uids, iids, u_items, u_users, u_users_items, i_users)
            error = torch.abs(preds.squeeze(1) - labels)
            errors.extend(error.data.cpu().numpy().tolist())
           
    
    mae = np.mean(errors)
    rmse = np.sqrt(np.mean(np.power(errors, 2)))
    return mae, rmse


if __name__ == '__main__':
    main()


Namespace(batch_size=256, dataset_path='/content/drive/MyDrive/datasets/FilmTrust_SimRank/', embed_dim=64, epoch=30, lr=0.001, lr_dc=0.1, lr_dc_step=30, test=False)
Loading data...



  0%|          | 0/57 [00:00<?, ?it/s][A
  2%|▏         | 1/57 [00:01<01:32,  1.66s/it][A

[TRAIN] epoch 1/30 batch loss: 18.4081 (avg 18.4081) (150.31 im/s)



  4%|▎         | 2/57 [00:02<01:23,  1.51s/it][A
  5%|▌         | 3/57 [00:03<01:12,  1.34s/it][A
  7%|▋         | 4/57 [00:04<01:07,  1.27s/it][A
  9%|▉         | 5/57 [00:05<01:03,  1.22s/it][A
 11%|█         | 6/57 [00:06<00:56,  1.11s/it][A
 12%|█▏        | 7/57 [00:07<00:51,  1.03s/it][A
 14%|█▍        | 8/57 [00:08<00:49,  1.00s/it][A
 16%|█▌        | 9/57 [00:09<00:42,  1.12it/s][A
 18%|█▊        | 10/57 [00:10<00:40,  1.15it/s][A
 19%|█▉        | 11/57 [00:10<00:35,  1.29it/s][A
 21%|██        | 12/57 [00:11<00:32,  1.37it/s][A
 23%|██▎       | 13/57 [00:12<00:36,  1.19it/s][A
 25%|██▍       | 14/57 [00:13<00:39,  1.09it/s][A
 26%|██▋       | 15/57 [00:13<00:33,  1.24it/s][A
 28%|██▊       | 16/57 [00:14<00:34,  1.20it/s][A
 30%|██▉       | 17/57 [00:15<00:30,  1.30it/s][A
 32%|███▏      | 18/57 [00:16<00:33,  1.15it/s][A
 33%|███▎      | 19/57 [00:17<00:30,  1.25it/s][A
 35%|███▌      | 20/57 [00:18<00:32,  1.13it/s][A
 37%|███▋      | 21/57 [00:19<00:31,  

Epoch 0 validation: MAE: 1.5259, RMSE: 2.0648, Best MAE: 1.5259



  2%|▏         | 1/57 [00:00<00:31,  1.79it/s][A

[TRAIN] epoch 2/30 batch loss: 0.3305 (avg 0.3305) (453.32 im/s)



  4%|▎         | 2/57 [00:01<00:40,  1.37it/s][A
  5%|▌         | 3/57 [00:02<00:40,  1.32it/s][A
  7%|▋         | 4/57 [00:03<00:38,  1.38it/s][A
  9%|▉         | 5/57 [00:04<00:43,  1.19it/s][A
 11%|█         | 6/57 [00:05<00:46,  1.09it/s][A
 12%|█▏        | 7/57 [00:06<00:43,  1.14it/s][A
 14%|█▍        | 8/57 [00:07<00:46,  1.06it/s][A
 16%|█▌        | 9/57 [00:07<00:41,  1.17it/s][A
 18%|█▊        | 10/57 [00:09<00:44,  1.06it/s][A
 19%|█▉        | 11/57 [00:10<00:45,  1.01it/s][A
 21%|██        | 12/57 [00:11<00:43,  1.03it/s][A
 23%|██▎       | 13/57 [00:11<00:37,  1.18it/s][A
 25%|██▍       | 14/57 [00:12<00:39,  1.08it/s][A
 26%|██▋       | 15/57 [00:13<00:40,  1.03it/s][A
 28%|██▊       | 16/57 [00:14<00:41,  1.01s/it][A
 30%|██▉       | 17/57 [00:16<00:41,  1.03s/it][A
 32%|███▏      | 18/57 [00:17<00:41,  1.05s/it][A
 33%|███▎      | 19/57 [00:17<00:35,  1.08it/s][A
 35%|███▌      | 20/57 [00:18<00:36,  1.02it/s][A
 37%|███▋      | 21/57 [00:19<00:31,  

Epoch 1 validation: MAE: 1.5057, RMSE: 2.0461, Best MAE: 1.5057



  2%|▏         | 1/57 [00:01<01:02,  1.12s/it][A

[TRAIN] epoch 3/30 batch loss: 0.4053 (avg 0.4053) (227.26 im/s)



  4%|▎         | 2/57 [00:02<01:01,  1.12s/it][A
  5%|▌         | 3/57 [00:02<00:52,  1.03it/s][A
  7%|▋         | 4/57 [00:03<00:50,  1.06it/s][A
  9%|▉         | 5/57 [00:04<00:47,  1.10it/s][A
 11%|█         | 6/57 [00:05<00:49,  1.04it/s][A
 12%|█▏        | 7/57 [00:06<00:44,  1.11it/s][A
 14%|█▍        | 8/57 [00:07<00:46,  1.04it/s][A
 16%|█▌        | 9/57 [00:08<00:41,  1.17it/s][A
 18%|█▊        | 10/57 [00:09<00:43,  1.08it/s][A
 19%|█▉        | 11/57 [00:10<00:45,  1.02it/s][A
 21%|██        | 12/57 [00:10<00:38,  1.16it/s][A
 23%|██▎       | 13/57 [00:11<00:34,  1.27it/s][A
 25%|██▍       | 14/57 [00:12<00:31,  1.35it/s][A
 26%|██▋       | 15/57 [00:12<00:29,  1.45it/s][A
 28%|██▊       | 16/57 [00:13<00:29,  1.37it/s][A
 30%|██▉       | 17/57 [00:14<00:26,  1.50it/s][A
 32%|███▏      | 18/57 [00:14<00:25,  1.52it/s][A
 33%|███▎      | 19/57 [00:15<00:30,  1.26it/s][A
 35%|███▌      | 20/57 [00:16<00:27,  1.34it/s][A
 37%|███▋      | 21/57 [00:17<00:25,  

Epoch 2 validation: MAE: 1.5343, RMSE: 2.0752, Best MAE: 1.5057



  2%|▏         | 1/57 [00:01<01:01,  1.10s/it][A

[TRAIN] epoch 4/30 batch loss: 0.3216 (avg 0.3216) (231.36 im/s)



  4%|▎         | 2/57 [00:02<01:00,  1.10s/it][A
  5%|▌         | 3/57 [00:02<00:54,  1.00s/it][A
  7%|▋         | 4/57 [00:04<00:54,  1.03s/it][A
  9%|▉         | 5/57 [00:04<00:46,  1.13it/s][A
 11%|█         | 6/57 [00:05<00:48,  1.05it/s][A
 12%|█▏        | 7/57 [00:06<00:47,  1.05it/s][A
 14%|█▍        | 8/57 [00:07<00:40,  1.20it/s][A
 16%|█▌        | 9/57 [00:07<00:37,  1.30it/s][A
 18%|█▊        | 10/57 [00:08<00:40,  1.15it/s][A
 19%|█▉        | 11/57 [00:09<00:36,  1.26it/s][A
 21%|██        | 12/57 [00:10<00:33,  1.35it/s][A
 23%|██▎       | 13/57 [00:11<00:33,  1.30it/s][A
 25%|██▍       | 14/57 [00:11<00:33,  1.27it/s][A
 26%|██▋       | 15/57 [00:12<00:31,  1.35it/s][A
 28%|██▊       | 16/57 [00:13<00:34,  1.19it/s][A
 30%|██▉       | 17/57 [00:14<00:36,  1.09it/s][A
 32%|███▏      | 18/57 [00:15<00:37,  1.03it/s][A
 33%|███▎      | 19/57 [00:16<00:38,  1.01s/it][A
 35%|███▌      | 20/57 [00:17<00:38,  1.04s/it][A
 37%|███▋      | 21/57 [00:19<00:37,  

Epoch 3 validation: MAE: 1.5486, RMSE: 2.0791, Best MAE: 1.5057



  2%|▏         | 1/57 [00:00<00:52,  1.06it/s][A

[TRAIN] epoch 5/30 batch loss: 0.2589 (avg 0.2589) (269.91 im/s)



  4%|▎         | 2/57 [00:02<00:54,  1.01it/s][A
  5%|▌         | 3/57 [00:03<00:55,  1.02s/it][A
  7%|▋         | 4/57 [00:03<00:50,  1.04it/s][A
  9%|▉         | 5/57 [00:04<00:44,  1.16it/s][A
 11%|█         | 6/57 [00:05<00:43,  1.17it/s][A
 12%|█▏        | 7/57 [00:06<00:46,  1.07it/s][A
 14%|█▍        | 8/57 [00:07<00:42,  1.14it/s][A
 16%|█▌        | 9/57 [00:08<00:40,  1.20it/s][A
 18%|█▊        | 10/57 [00:08<00:37,  1.25it/s][A
 19%|█▉        | 11/57 [00:09<00:34,  1.34it/s][A
 21%|██        | 12/57 [00:10<00:38,  1.17it/s][A
 23%|██▎       | 13/57 [00:11<00:41,  1.07it/s][A
 25%|██▍       | 14/57 [00:12<00:42,  1.01it/s][A
 26%|██▋       | 15/57 [00:13<00:42,  1.02s/it][A
 28%|██▊       | 16/57 [00:14<00:37,  1.10it/s][A
 30%|██▉       | 17/57 [00:15<00:38,  1.03it/s][A
 32%|███▏      | 18/57 [00:16<00:39,  1.01s/it][A
 33%|███▎      | 19/57 [00:17<00:36,  1.04it/s][A
 35%|███▌      | 20/57 [00:18<00:37,  1.00s/it][A
 37%|███▋      | 21/57 [00:19<00:32,  

Epoch 4 validation: MAE: 1.5625, RMSE: 2.0922, Best MAE: 1.5057



  2%|▏         | 1/57 [00:01<01:01,  1.10s/it][A

[TRAIN] epoch 6/30 batch loss: 0.2634 (avg 0.2634) (231.64 im/s)



  4%|▎         | 2/57 [00:01<00:51,  1.06it/s][A
  5%|▌         | 3/57 [00:02<00:45,  1.18it/s][A
  7%|▋         | 4/57 [00:03<00:48,  1.08it/s][A
  9%|▉         | 5/57 [00:04<00:45,  1.13it/s][A
 11%|█         | 6/57 [00:04<00:39,  1.28it/s][A
 12%|█▏        | 7/57 [00:05<00:43,  1.14it/s][A
 14%|█▍        | 8/57 [00:06<00:46,  1.06it/s][A
 16%|█▌        | 9/57 [00:07<00:39,  1.21it/s][A
 18%|█▊        | 10/57 [00:08<00:36,  1.31it/s][A
 19%|█▉        | 11/57 [00:08<00:35,  1.31it/s][A
 21%|██        | 12/57 [00:09<00:38,  1.16it/s][A
 23%|██▎       | 13/57 [00:10<00:34,  1.26it/s][A
 25%|██▍       | 14/57 [00:11<00:38,  1.13it/s][A
 26%|██▋       | 15/57 [00:12<00:39,  1.06it/s][A
 28%|██▊       | 16/57 [00:13<00:36,  1.12it/s][A
 30%|██▉       | 17/57 [00:14<00:37,  1.05it/s][A
 32%|███▏      | 18/57 [00:15<00:33,  1.17it/s][A
 33%|███▎      | 19/57 [00:16<00:35,  1.08it/s][A
 35%|███▌      | 20/57 [00:16<00:31,  1.19it/s][A
 37%|███▋      | 21/57 [00:17<00:27,  

Epoch 5 validation: MAE: 1.5199, RMSE: 2.0592, Best MAE: 1.5057



  2%|▏         | 1/57 [00:00<00:35,  1.57it/s][A

[TRAIN] epoch 7/30 batch loss: 0.3228 (avg 0.3228) (399.35 im/s)



  4%|▎         | 2/57 [00:01<00:37,  1.49it/s][A
  5%|▌         | 3/57 [00:02<00:38,  1.39it/s][A
  7%|▋         | 4/57 [00:02<00:36,  1.45it/s][A
  9%|▉         | 5/57 [00:03<00:34,  1.49it/s][A
 11%|█         | 6/57 [00:04<00:40,  1.26it/s][A
 12%|█▏        | 7/57 [00:05<00:40,  1.24it/s][A
 14%|█▍        | 8/57 [00:06<00:43,  1.12it/s][A
 16%|█▌        | 9/57 [00:07<00:38,  1.23it/s][A
 18%|█▊        | 10/57 [00:07<00:35,  1.32it/s][A
 19%|█▉        | 11/57 [00:08<00:35,  1.28it/s][A
 21%|██        | 12/57 [00:09<00:39,  1.15it/s][A
 23%|██▎       | 13/57 [00:10<00:41,  1.06it/s][A
 25%|██▍       | 14/57 [00:11<00:42,  1.01it/s][A
 26%|██▋       | 15/57 [00:12<00:37,  1.12it/s][A
 28%|██▊       | 16/57 [00:13<00:32,  1.26it/s][A
 30%|██▉       | 17/57 [00:14<00:35,  1.13it/s][A
 32%|███▏      | 18/57 [00:14<00:33,  1.17it/s][A
 33%|███▎      | 19/57 [00:15<00:29,  1.27it/s][A
 35%|███▌      | 20/57 [00:16<00:32,  1.14it/s][A
 37%|███▋      | 21/57 [00:17<00:33,  

Epoch 6 validation: MAE: 1.6250, RMSE: 2.1116, Best MAE: 1.5057



  2%|▏         | 1/57 [00:00<00:35,  1.57it/s][A

[TRAIN] epoch 8/30 batch loss: 0.4747 (avg 0.4747) (399.83 im/s)



  4%|▎         | 2/57 [00:01<00:42,  1.29it/s][A
  5%|▌         | 3/57 [00:02<00:47,  1.14it/s][A
  7%|▋         | 4/57 [00:03<00:46,  1.14it/s][A
  9%|▉         | 5/57 [00:04<00:45,  1.16it/s][A
 11%|█         | 6/57 [00:05<00:47,  1.07it/s][A
 12%|█▏        | 7/57 [00:06<00:49,  1.02it/s][A
 14%|█▍        | 8/57 [00:07<00:45,  1.07it/s][A
 16%|█▌        | 9/57 [00:08<00:47,  1.02it/s][A
 18%|█▊        | 10/57 [00:09<00:43,  1.07it/s][A
 19%|█▉        | 11/57 [00:10<00:37,  1.22it/s][A
 21%|██        | 12/57 [00:10<00:36,  1.22it/s][A
 23%|██▎       | 13/57 [00:11<00:33,  1.31it/s][A
 25%|██▍       | 14/57 [00:12<00:37,  1.15it/s][A
 26%|██▋       | 15/57 [00:13<00:39,  1.06it/s][A
 28%|██▊       | 16/57 [00:14<00:34,  1.18it/s][A
 30%|██▉       | 17/57 [00:15<00:36,  1.08it/s][A
 32%|███▏      | 18/57 [00:16<00:38,  1.03it/s][A
 33%|███▎      | 19/57 [00:17<00:38,  1.02s/it][A
 35%|███▌      | 20/57 [00:18<00:32,  1.14it/s][A
 37%|███▋      | 21/57 [00:18<00:28,  

Epoch 7 validation: MAE: 1.4915, RMSE: 2.0282, Best MAE: 1.4915



  2%|▏         | 1/57 [00:01<01:02,  1.12s/it][A

[TRAIN] epoch 9/30 batch loss: 0.5179 (avg 0.5179) (225.38 im/s)



  4%|▎         | 2/57 [00:01<00:56,  1.02s/it][A
  5%|▌         | 3/57 [00:03<00:56,  1.04s/it][A
  7%|▋         | 4/57 [00:04<00:56,  1.06s/it][A
  9%|▉         | 5/57 [00:05<00:55,  1.07s/it][A
 11%|█         | 6/57 [00:06<00:55,  1.08s/it][A
 12%|█▏        | 7/57 [00:06<00:46,  1.08it/s][A
 14%|█▍        | 8/57 [00:07<00:48,  1.02it/s][A
 16%|█▌        | 9/57 [00:09<00:48,  1.02s/it][A
 18%|█▊        | 10/57 [00:09<00:42,  1.11it/s][A
 19%|█▉        | 11/57 [00:10<00:44,  1.04it/s][A
 21%|██        | 12/57 [00:11<00:37,  1.19it/s][A
 23%|██▎       | 13/57 [00:12<00:34,  1.28it/s][A
 25%|██▍       | 14/57 [00:13<00:37,  1.14it/s][A
 26%|██▋       | 15/57 [00:14<00:39,  1.07it/s][A
 28%|██▊       | 16/57 [00:14<00:33,  1.22it/s][A
 30%|██▉       | 17/57 [00:15<00:33,  1.20it/s][A
 32%|███▏      | 18/57 [00:16<00:35,  1.09it/s][A
 33%|███▎      | 19/57 [00:17<00:34,  1.11it/s][A
 35%|███▌      | 20/57 [00:18<00:32,  1.13it/s][A
 37%|███▋      | 21/57 [00:19<00:34,  

Epoch 8 validation: MAE: 1.5503, RMSE: 2.0516, Best MAE: 1.4915



  2%|▏         | 1/57 [00:01<01:02,  1.12s/it][A

[TRAIN] epoch 10/30 batch loss: 0.2331 (avg 0.2331) (227.63 im/s)



  4%|▎         | 2/57 [00:01<00:57,  1.04s/it][A
  5%|▌         | 3/57 [00:03<00:56,  1.05s/it][A
  7%|▋         | 4/57 [00:03<00:49,  1.08it/s][A
  9%|▉         | 5/57 [00:04<00:50,  1.02it/s][A
 11%|█         | 6/57 [00:05<00:44,  1.15it/s][A
 12%|█▏        | 7/57 [00:06<00:43,  1.15it/s][A
 14%|█▍        | 8/57 [00:07<00:45,  1.07it/s][A
 16%|█▌        | 9/57 [00:08<00:47,  1.01it/s][A
 18%|█▊        | 10/57 [00:09<00:44,  1.06it/s][A
 19%|█▉        | 11/57 [00:10<00:45,  1.01it/s][A
 21%|██        | 12/57 [00:11<00:39,  1.13it/s][A
 23%|██▎       | 13/57 [00:12<00:41,  1.05it/s][A
 25%|██▍       | 14/57 [00:13<00:42,  1.01it/s][A
 26%|██▋       | 15/57 [00:13<00:36,  1.15it/s][A
 28%|██▊       | 16/57 [00:14<00:38,  1.07it/s][A
 30%|██▉       | 17/57 [00:16<00:39,  1.02it/s][A
 32%|███▏      | 18/57 [00:16<00:33,  1.17it/s][A
 33%|███▎      | 19/57 [00:17<00:29,  1.27it/s][A
 35%|███▌      | 20/57 [00:18<00:29,  1.25it/s][A
 37%|███▋      | 21/57 [00:19<00:32,  

Epoch 9 validation: MAE: 1.5238, RMSE: 2.0302, Best MAE: 1.4915



  2%|▏         | 1/57 [00:00<00:35,  1.56it/s][A

[TRAIN] epoch 11/30 batch loss: 0.2168 (avg 0.2168) (396.42 im/s)



  4%|▎         | 2/57 [00:01<00:38,  1.44it/s][A
  5%|▌         | 3/57 [00:02<00:43,  1.23it/s][A
  7%|▋         | 4/57 [00:03<00:43,  1.22it/s][A
  9%|▉         | 5/57 [00:03<00:38,  1.35it/s][A
 11%|█         | 6/57 [00:05<00:43,  1.17it/s][A
 12%|█▏        | 7/57 [00:06<00:46,  1.07it/s][A
 14%|█▍        | 8/57 [00:07<00:44,  1.10it/s][A
 16%|█▌        | 9/57 [00:07<00:38,  1.23it/s][A
 18%|█▊        | 10/57 [00:08<00:35,  1.32it/s][A
 19%|█▉        | 11/57 [00:08<00:32,  1.40it/s][A
 21%|██        | 12/57 [00:09<00:37,  1.20it/s][A
 23%|██▎       | 13/57 [00:11<00:40,  1.10it/s][A
 25%|██▍       | 14/57 [00:11<00:37,  1.14it/s][A
 26%|██▋       | 15/57 [00:12<00:37,  1.12it/s][A
 28%|██▊       | 16/57 [00:13<00:37,  1.10it/s][A
 30%|██▉       | 17/57 [00:14<00:32,  1.22it/s][A
 32%|███▏      | 18/57 [00:14<00:29,  1.32it/s][A
 33%|███▎      | 19/57 [00:15<00:27,  1.39it/s][A
 35%|███▌      | 20/57 [00:16<00:30,  1.20it/s][A
 37%|███▋      | 21/57 [00:17<00:27,  

Epoch 10 validation: MAE: 1.7117, RMSE: 2.1333, Best MAE: 1.4915



  2%|▏         | 1/57 [00:00<00:40,  1.39it/s][A

[TRAIN] epoch 12/30 batch loss: 0.6653 (avg 0.6653) (352.48 im/s)



  4%|▎         | 2/57 [00:01<00:45,  1.20it/s][A
  5%|▌         | 3/57 [00:02<00:49,  1.10it/s][A
  7%|▋         | 4/57 [00:03<00:43,  1.22it/s][A
  9%|▉         | 5/57 [00:04<00:46,  1.11it/s][A
 11%|█         | 6/57 [00:05<00:49,  1.04it/s][A
 12%|█▏        | 7/57 [00:06<00:42,  1.16it/s][A
 14%|█▍        | 8/57 [00:06<00:38,  1.27it/s][A
 16%|█▌        | 9/57 [00:08<00:42,  1.14it/s][A
 18%|█▊        | 10/57 [00:09<00:44,  1.05it/s][A
 19%|█▉        | 11/57 [00:09<00:37,  1.23it/s][A
 21%|██        | 12/57 [00:10<00:34,  1.32it/s][A
 23%|██▎       | 13/57 [00:10<00:31,  1.39it/s][A
 25%|██▍       | 14/57 [00:12<00:35,  1.20it/s][A
 26%|██▋       | 15/57 [00:13<00:38,  1.10it/s][A
 28%|██▊       | 16/57 [00:14<00:39,  1.03it/s][A
 30%|██▉       | 17/57 [00:15<00:37,  1.07it/s][A
 32%|███▏      | 18/57 [00:16<00:38,  1.01it/s][A
 33%|███▎      | 19/57 [00:17<00:39,  1.03s/it][A
 35%|███▌      | 20/57 [00:17<00:33,  1.12it/s][A
 37%|███▋      | 21/57 [00:18<00:34,  

Epoch 11 validation: MAE: 1.5284, RMSE: 2.0245, Best MAE: 1.4915



  2%|▏         | 1/57 [00:00<00:53,  1.05it/s][A

[TRAIN] epoch 13/30 batch loss: 0.3359 (avg 0.3359) (266.34 im/s)



  4%|▎         | 2/57 [00:02<00:54,  1.00it/s][A
  5%|▌         | 3/57 [00:02<00:47,  1.14it/s][A
  7%|▋         | 4/57 [00:03<00:50,  1.06it/s][A
  9%|▉         | 5/57 [00:04<00:43,  1.21it/s][A
 11%|█         | 6/57 [00:05<00:46,  1.10it/s][A
 12%|█▏        | 7/57 [00:05<00:40,  1.25it/s][A
 14%|█▍        | 8/57 [00:07<00:43,  1.12it/s][A
 16%|█▌        | 9/57 [00:08<00:45,  1.05it/s][A
 18%|█▊        | 10/57 [00:08<00:41,  1.12it/s][A
 19%|█▉        | 11/57 [00:09<00:43,  1.05it/s][A
 21%|██        | 12/57 [00:11<00:44,  1.00it/s][A
 23%|██▎       | 13/57 [00:11<00:38,  1.16it/s][A
 25%|██▍       | 14/57 [00:12<00:39,  1.08it/s][A
 26%|██▋       | 15/57 [00:13<00:41,  1.02it/s][A
 28%|██▊       | 16/57 [00:14<00:41,  1.02s/it][A
 30%|██▉       | 17/57 [00:15<00:38,  1.04it/s][A
 32%|███▏      | 18/57 [00:16<00:33,  1.16it/s][A
 33%|███▎      | 19/57 [00:17<00:32,  1.18it/s][A
 35%|███▌      | 20/57 [00:17<00:28,  1.28it/s][A
 37%|███▋      | 21/57 [00:18<00:25,  

Epoch 12 validation: MAE: 1.6053, RMSE: 2.0676, Best MAE: 1.4915



  2%|▏         | 1/57 [00:01<01:01,  1.09s/it][A

[TRAIN] epoch 14/30 batch loss: 0.3435 (avg 0.3435) (229.97 im/s)



  4%|▎         | 2/57 [00:02<01:00,  1.09s/it][A
  5%|▌         | 3/57 [00:03<00:54,  1.01s/it][A
  7%|▋         | 4/57 [00:03<00:47,  1.12it/s][A
  9%|▉         | 5/57 [00:04<00:42,  1.23it/s][A
 11%|█         | 6/57 [00:04<00:38,  1.32it/s][A
 12%|█▏        | 7/57 [00:05<00:43,  1.16it/s][A
 14%|█▍        | 8/57 [00:06<00:38,  1.27it/s][A
 16%|█▌        | 9/57 [00:07<00:42,  1.14it/s][A
 18%|█▊        | 10/57 [00:08<00:40,  1.16it/s][A
 19%|█▉        | 11/57 [00:09<00:39,  1.15it/s][A
 21%|██        | 12/57 [00:10<00:42,  1.07it/s][A
 23%|██▎       | 13/57 [00:11<00:43,  1.02it/s][A
 25%|██▍       | 14/57 [00:12<00:43,  1.02s/it][A
 26%|██▋       | 15/57 [00:13<00:43,  1.04s/it][A
 28%|██▊       | 16/57 [00:14<00:43,  1.06s/it][A
 30%|██▉       | 17/57 [00:15<00:43,  1.08s/it][A
 32%|███▏      | 18/57 [00:16<00:39,  1.01s/it][A
 33%|███▎      | 19/57 [00:17<00:39,  1.04s/it][A
 35%|███▌      | 20/57 [00:19<00:39,  1.06s/it][A
 37%|███▋      | 21/57 [00:19<00:33,  

Epoch 13 validation: MAE: 1.4888, RMSE: 1.9840, Best MAE: 1.4888



  2%|▏         | 1/57 [00:01<01:02,  1.12s/it][A

[TRAIN] epoch 15/30 batch loss: 0.3027 (avg 0.3027) (227.88 im/s)



  4%|▎         | 2/57 [00:02<01:01,  1.12s/it][A
  5%|▌         | 3/57 [00:03<01:00,  1.12s/it][A
  7%|▋         | 4/57 [00:03<00:51,  1.03it/s][A
  9%|▉         | 5/57 [00:04<00:45,  1.15it/s][A
 11%|█         | 6/57 [00:05<00:47,  1.07it/s][A
 12%|█▏        | 7/57 [00:06<00:41,  1.21it/s][A
 14%|█▍        | 8/57 [00:07<00:39,  1.24it/s][A
 16%|█▌        | 9/57 [00:07<00:37,  1.28it/s][A
 18%|█▊        | 10/57 [00:08<00:37,  1.26it/s][A
 19%|█▉        | 11/57 [00:09<00:40,  1.13it/s][A
 21%|██        | 12/57 [00:10<00:36,  1.23it/s][A
 23%|██▎       | 13/57 [00:11<00:39,  1.12it/s][A
 25%|██▍       | 14/57 [00:12<00:41,  1.04it/s][A
 26%|██▋       | 15/57 [00:13<00:40,  1.05it/s][A
 28%|██▊       | 16/57 [00:14<00:40,  1.00it/s][A
 30%|██▉       | 17/57 [00:15<00:37,  1.06it/s][A
 32%|███▏      | 18/57 [00:16<00:38,  1.01it/s][A
 33%|███▎      | 19/57 [00:17<00:35,  1.07it/s][A
 35%|███▌      | 20/57 [00:17<00:31,  1.18it/s][A
 37%|███▋      | 21/57 [00:18<00:28,  

Epoch 14 validation: MAE: 1.4436, RMSE: 1.9523, Best MAE: 1.4436



  2%|▏         | 1/57 [00:01<01:02,  1.12s/it][A

[TRAIN] epoch 16/30 batch loss: 0.2311 (avg 0.2311) (226.27 im/s)



  4%|▎         | 2/57 [00:01<00:53,  1.03it/s][A
  5%|▌         | 3/57 [00:02<00:47,  1.15it/s][A
  7%|▋         | 4/57 [00:03<00:50,  1.06it/s][A
  9%|▉         | 5/57 [00:04<00:51,  1.01it/s][A
 11%|█         | 6/57 [00:05<00:52,  1.03s/it][A
 12%|█▏        | 7/57 [00:06<00:52,  1.05s/it][A
 14%|█▍        | 8/57 [00:07<00:47,  1.04it/s][A
 16%|█▌        | 9/57 [00:08<00:41,  1.16it/s][A
 18%|█▊        | 10/57 [00:09<00:44,  1.06it/s][A
 19%|█▉        | 11/57 [00:09<00:38,  1.19it/s][A
 21%|██        | 12/57 [00:11<00:41,  1.09it/s][A
 23%|██▎       | 13/57 [00:11<00:36,  1.20it/s][A
 25%|██▍       | 14/57 [00:12<00:39,  1.08it/s][A
 26%|██▋       | 15/57 [00:13<00:41,  1.02it/s][A
 28%|██▊       | 16/57 [00:14<00:38,  1.06it/s][A
 30%|██▉       | 17/57 [00:15<00:39,  1.01it/s][A
 32%|███▏      | 18/57 [00:16<00:37,  1.05it/s][A
 33%|███▎      | 19/57 [00:17<00:35,  1.08it/s][A
 35%|███▌      | 20/57 [00:18<00:30,  1.19it/s][A
 37%|███▋      | 21/57 [00:19<00:31,  

Epoch 15 validation: MAE: 1.5425, RMSE: 2.0096, Best MAE: 1.4436



  2%|▏         | 1/57 [00:01<01:03,  1.14s/it][A

[TRAIN] epoch 17/30 batch loss: 0.3485 (avg 0.3485) (220.11 im/s)



  4%|▎         | 2/57 [00:02<01:02,  1.14s/it][A
  5%|▌         | 3/57 [00:02<00:52,  1.03it/s][A
  7%|▋         | 4/57 [00:03<00:45,  1.17it/s][A
  9%|▉         | 5/57 [00:04<00:48,  1.07it/s][A
 11%|█         | 6/57 [00:05<00:50,  1.02it/s][A
 12%|█▏        | 7/57 [00:06<00:51,  1.03s/it][A
 14%|█▍        | 8/57 [00:07<00:51,  1.05s/it][A
 16%|█▌        | 9/57 [00:08<00:43,  1.10it/s][A
 18%|█▊        | 10/57 [00:09<00:45,  1.03it/s][A
 19%|█▉        | 11/57 [00:10<00:40,  1.15it/s][A
 21%|██        | 12/57 [00:10<00:36,  1.25it/s][A
 23%|██▎       | 13/57 [00:11<00:39,  1.11it/s][A
 25%|██▍       | 14/57 [00:13<00:41,  1.04it/s][A
 26%|██▋       | 15/57 [00:13<00:36,  1.16it/s][A
 28%|██▊       | 16/57 [00:14<00:38,  1.07it/s][A
 30%|██▉       | 17/57 [00:15<00:39,  1.01it/s][A
 32%|███▏      | 18/57 [00:17<00:40,  1.03s/it][A
 33%|███▎      | 19/57 [00:17<00:36,  1.03it/s][A
 35%|███▌      | 20/57 [00:18<00:34,  1.07it/s][A
 37%|███▋      | 21/57 [00:19<00:35,  

Epoch 16 validation: MAE: 1.4352, RMSE: 1.9401, Best MAE: 1.4352



  2%|▏         | 1/57 [00:01<01:03,  1.13s/it][A

[TRAIN] epoch 18/30 batch loss: 0.2635 (avg 0.2635) (226.51 im/s)



  4%|▎         | 2/57 [00:02<01:01,  1.12s/it][A
  5%|▌         | 3/57 [00:03<01:00,  1.12s/it][A
  7%|▋         | 4/57 [00:03<00:51,  1.03it/s][A
  9%|▉         | 5/57 [00:04<00:45,  1.15it/s][A
 11%|█         | 6/57 [00:05<00:48,  1.06it/s][A
 12%|█▏        | 7/57 [00:06<00:49,  1.01it/s][A
 14%|█▍        | 8/57 [00:07<00:42,  1.16it/s][A
 16%|█▌        | 9/57 [00:08<00:45,  1.07it/s][A
 18%|█▊        | 10/57 [00:09<00:44,  1.07it/s][A
 19%|█▉        | 11/57 [00:10<00:40,  1.13it/s][A
 21%|██        | 12/57 [00:11<00:42,  1.05it/s][A
 23%|██▎       | 13/57 [00:12<00:40,  1.08it/s][A
 25%|██▍       | 14/57 [00:12<00:35,  1.20it/s][A
 26%|██▋       | 15/57 [00:13<00:32,  1.29it/s][A
 28%|██▊       | 16/57 [00:14<00:29,  1.37it/s][A
 30%|██▉       | 17/57 [00:14<00:30,  1.31it/s][A
 32%|███▏      | 18/57 [00:16<00:33,  1.15it/s][A
 33%|███▎      | 19/57 [00:16<00:30,  1.25it/s][A
 35%|███▌      | 20/57 [00:17<00:32,  1.13it/s][A
 37%|███▋      | 21/57 [00:18<00:29,  

Epoch 17 validation: MAE: 1.4239, RMSE: 1.9242, Best MAE: 1.4239



  2%|▏         | 1/57 [00:00<00:31,  1.77it/s][A

[TRAIN] epoch 19/30 batch loss: 0.3752 (avg 0.3752) (448.22 im/s)



  4%|▎         | 2/57 [00:01<00:40,  1.37it/s][A
  5%|▌         | 3/57 [00:02<00:45,  1.18it/s][A
  7%|▋         | 4/57 [00:03<00:44,  1.18it/s][A
  9%|▉         | 5/57 [00:04<00:43,  1.18it/s][A
 11%|█         | 6/57 [00:05<00:40,  1.25it/s][A
 12%|█▏        | 7/57 [00:06<00:44,  1.12it/s][A
 14%|█▍        | 8/57 [00:07<00:46,  1.05it/s][A
 16%|█▌        | 9/57 [00:08<00:45,  1.05it/s][A
 18%|█▊        | 10/57 [00:09<00:46,  1.00it/s][A
 19%|█▉        | 11/57 [00:10<00:47,  1.03s/it][A
 21%|██        | 12/57 [00:11<00:40,  1.12it/s][A
 23%|██▎       | 13/57 [00:11<00:35,  1.25it/s][A
 25%|██▍       | 14/57 [00:12<00:38,  1.12it/s][A
 26%|██▋       | 15/57 [00:13<00:40,  1.04it/s][A
 28%|██▊       | 16/57 [00:15<00:41,  1.01s/it][A
 30%|██▉       | 17/57 [00:16<00:41,  1.04s/it][A
 32%|███▏      | 18/57 [00:17<00:41,  1.06s/it][A
 33%|███▎      | 19/57 [00:18<00:38,  1.01s/it][A
 35%|███▌      | 20/57 [00:19<00:38,  1.04s/it][A
 37%|███▋      | 21/57 [00:20<00:38,  

Epoch 18 validation: MAE: 1.4219, RMSE: 1.9166, Best MAE: 1.4219



  2%|▏         | 1/57 [00:00<00:32,  1.75it/s][A

[TRAIN] epoch 20/30 batch loss: 0.4949 (avg 0.4949) (443.42 im/s)



  4%|▎         | 2/57 [00:01<00:31,  1.74it/s][A
  5%|▌         | 3/57 [00:01<00:35,  1.53it/s][A
  7%|▋         | 4/57 [00:03<00:41,  1.27it/s][A
  9%|▉         | 5/57 [00:04<00:46,  1.12it/s][A
 11%|█         | 6/57 [00:04<00:41,  1.24it/s][A
 12%|█▏        | 7/57 [00:05<00:37,  1.32it/s][A
 14%|█▍        | 8/57 [00:06<00:42,  1.16it/s][A
 16%|█▌        | 9/57 [00:07<00:37,  1.28it/s][A
 18%|█▊        | 10/57 [00:08<00:41,  1.13it/s][A
 19%|█▉        | 11/57 [00:09<00:40,  1.14it/s][A
 21%|██        | 12/57 [00:10<00:42,  1.06it/s][A
 23%|██▎       | 13/57 [00:11<00:39,  1.11it/s][A
 25%|██▍       | 14/57 [00:12<00:41,  1.04it/s][A
 26%|██▋       | 15/57 [00:13<00:42,  1.01s/it][A
 28%|██▊       | 16/57 [00:14<00:42,  1.04s/it][A
 30%|██▉       | 17/57 [00:15<00:37,  1.06it/s][A
 32%|███▏      | 18/57 [00:15<00:32,  1.19it/s][A
 33%|███▎      | 19/57 [00:16<00:34,  1.09it/s][A
 35%|███▌      | 20/57 [00:17<00:30,  1.20it/s][A
 37%|███▋      | 21/57 [00:18<00:29,  

Epoch 19 validation: MAE: 1.4119, RMSE: 1.9016, Best MAE: 1.4119



  2%|▏         | 1/57 [00:00<00:33,  1.68it/s][A

[TRAIN] epoch 21/30 batch loss: 0.3787 (avg 0.3787) (417.63 im/s)



  4%|▎         | 2/57 [00:01<00:33,  1.63it/s][A
  5%|▌         | 3/57 [00:02<00:41,  1.31it/s][A
  7%|▋         | 4/57 [00:03<00:41,  1.28it/s][A
  9%|▉         | 5/57 [00:03<00:37,  1.40it/s][A
 11%|█         | 6/57 [00:04<00:35,  1.45it/s][A
 12%|█▏        | 7/57 [00:05<00:40,  1.23it/s][A
 14%|█▍        | 8/57 [00:06<00:44,  1.11it/s][A
 16%|█▌        | 9/57 [00:07<00:46,  1.04it/s][A
 18%|█▊        | 10/57 [00:08<00:47,  1.01s/it][A
 19%|█▉        | 11/57 [00:09<00:43,  1.05it/s][A
 21%|██        | 12/57 [00:10<00:39,  1.13it/s][A
 23%|██▎       | 13/57 [00:10<00:33,  1.31it/s][A
 25%|██▍       | 14/57 [00:11<00:33,  1.28it/s][A
 26%|██▋       | 15/57 [00:12<00:30,  1.36it/s][A
 28%|██▊       | 16/57 [00:13<00:34,  1.18it/s][A
 30%|██▉       | 17/57 [00:14<00:36,  1.09it/s][A
 32%|███▏      | 18/57 [00:15<00:38,  1.02it/s][A
 33%|███▎      | 19/57 [00:16<00:38,  1.02s/it][A
 35%|███▌      | 20/57 [00:17<00:38,  1.05s/it][A
 37%|███▋      | 21/57 [00:18<00:38,  

Epoch 20 validation: MAE: 1.4812, RMSE: 1.9382, Best MAE: 1.4119



  2%|▏         | 1/57 [00:01<01:02,  1.12s/it][A

[TRAIN] epoch 22/30 batch loss: 0.1892 (avg 0.1892) (227.96 im/s)



  4%|▎         | 2/57 [00:01<00:53,  1.03it/s][A
  5%|▌         | 3/57 [00:02<00:54,  1.01s/it][A
  7%|▋         | 4/57 [00:03<00:50,  1.05it/s][A
  9%|▉         | 5/57 [00:04<00:51,  1.01it/s][A
 11%|█         | 6/57 [00:05<00:52,  1.03s/it][A
 12%|█▏        | 7/57 [00:06<00:52,  1.05s/it][A
 14%|█▍        | 8/57 [00:08<00:51,  1.06s/it][A
 16%|█▌        | 9/57 [00:08<00:44,  1.08it/s][A
 18%|█▊        | 10/57 [00:09<00:43,  1.07it/s][A
 19%|█▉        | 11/57 [00:10<00:42,  1.08it/s][A
 21%|██        | 12/57 [00:11<00:39,  1.15it/s][A
 23%|██▎       | 13/57 [00:11<00:35,  1.25it/s][A
 25%|██▍       | 14/57 [00:12<00:32,  1.33it/s][A
 26%|██▋       | 15/57 [00:13<00:35,  1.17it/s][A
 28%|██▊       | 16/57 [00:14<00:36,  1.14it/s][A
 30%|██▉       | 17/57 [00:15<00:32,  1.24it/s][A
 32%|███▏      | 18/57 [00:16<00:32,  1.22it/s][A
 33%|███▎      | 19/57 [00:17<00:34,  1.10it/s][A
 35%|███▌      | 20/57 [00:17<00:29,  1.25it/s][A
 37%|███▋      | 21/57 [00:18<00:32,  

Epoch 21 validation: MAE: 1.4013, RMSE: 1.8794, Best MAE: 1.4013



  2%|▏         | 1/57 [00:00<00:36,  1.54it/s][A

[TRAIN] epoch 23/30 batch loss: 0.3436 (avg 0.3436) (392.46 im/s)



  4%|▎         | 2/57 [00:01<00:43,  1.27it/s][A
  5%|▌         | 3/57 [00:02<00:38,  1.39it/s][A
  7%|▋         | 4/57 [00:02<00:36,  1.44it/s][A
  9%|▉         | 5/57 [00:03<00:38,  1.36it/s][A
 11%|█         | 6/57 [00:04<00:43,  1.18it/s][A
 12%|█▏        | 7/57 [00:05<00:42,  1.18it/s][A
 14%|█▍        | 8/57 [00:06<00:40,  1.21it/s][A
 16%|█▌        | 9/57 [00:07<00:43,  1.10it/s][A
 18%|█▊        | 10/57 [00:08<00:38,  1.21it/s][A
 19%|█▉        | 11/57 [00:08<00:35,  1.29it/s][A
 21%|██        | 12/57 [00:09<00:39,  1.15it/s][A
 23%|██▎       | 13/57 [00:10<00:39,  1.12it/s][A
 25%|██▍       | 14/57 [00:11<00:36,  1.18it/s][A
 26%|██▋       | 15/57 [00:12<00:35,  1.20it/s][A
 28%|██▊       | 16/57 [00:13<00:37,  1.10it/s][A
 30%|██▉       | 17/57 [00:14<00:35,  1.13it/s][A
 32%|███▏      | 18/57 [00:14<00:30,  1.27it/s][A
 33%|███▎      | 19/57 [00:16<00:33,  1.13it/s][A
 35%|███▌      | 20/57 [00:17<00:35,  1.06it/s][A
 37%|███▋      | 21/57 [00:17<00:30,  

Epoch 22 validation: MAE: 1.4936, RMSE: 1.9367, Best MAE: 1.4013



  2%|▏         | 1/57 [00:00<00:49,  1.13it/s][A

[TRAIN] epoch 24/30 batch loss: 0.2530 (avg 0.2530) (284.20 im/s)



  4%|▎         | 2/57 [00:01<00:47,  1.16it/s][A
  5%|▌         | 3/57 [00:02<00:50,  1.07it/s][A
  7%|▋         | 4/57 [00:03<00:52,  1.02it/s][A
  9%|▉         | 5/57 [00:04<00:43,  1.19it/s][A
 11%|█         | 6/57 [00:05<00:46,  1.09it/s][A
 12%|█▏        | 7/57 [00:06<00:40,  1.23it/s][A
 14%|█▍        | 8/57 [00:06<00:37,  1.31it/s][A
 16%|█▌        | 9/57 [00:07<00:37,  1.27it/s][A
 18%|█▊        | 10/57 [00:08<00:41,  1.13it/s][A
 19%|█▉        | 11/57 [00:09<00:37,  1.24it/s][A
 21%|██        | 12/57 [00:10<00:40,  1.12it/s][A
 23%|██▎       | 13/57 [00:10<00:35,  1.25it/s][A
 25%|██▍       | 14/57 [00:12<00:38,  1.12it/s][A
 26%|██▋       | 15/57 [00:12<00:36,  1.15it/s][A
 28%|██▊       | 16/57 [00:13<00:35,  1.16it/s][A
 30%|██▉       | 17/57 [00:14<00:30,  1.29it/s][A
 32%|███▏      | 18/57 [00:15<00:34,  1.14it/s][A
 33%|███▎      | 19/57 [00:16<00:36,  1.05it/s][A
 35%|███▌      | 20/57 [00:17<00:36,  1.00it/s][A
 37%|███▋      | 21/57 [00:18<00:33,  

Epoch 23 validation: MAE: 1.4797, RMSE: 1.9204, Best MAE: 1.4013



  2%|▏         | 1/57 [00:00<00:31,  1.77it/s][A

[TRAIN] epoch 25/30 batch loss: 0.2784 (avg 0.2784) (448.20 im/s)



  4%|▎         | 2/57 [00:01<00:31,  1.77it/s][A
  5%|▌         | 3/57 [00:01<00:30,  1.76it/s][A
  7%|▋         | 4/57 [00:02<00:34,  1.55it/s][A
  9%|▉         | 5/57 [00:03<00:33,  1.56it/s][A
 11%|█         | 6/57 [00:04<00:39,  1.29it/s][A
 12%|█▏        | 7/57 [00:05<00:44,  1.14it/s][A
 14%|█▍        | 8/57 [00:06<00:40,  1.20it/s][A
 16%|█▌        | 9/57 [00:06<00:36,  1.30it/s][A
 18%|█▊        | 10/57 [00:07<00:37,  1.26it/s][A
 19%|█▉        | 11/57 [00:08<00:34,  1.35it/s][A
 21%|██        | 12/57 [00:09<00:34,  1.31it/s][A
 23%|██▎       | 13/57 [00:09<00:31,  1.38it/s][A
 25%|██▍       | 14/57 [00:10<00:35,  1.20it/s][A
 26%|██▋       | 15/57 [00:11<00:38,  1.09it/s][A
 28%|██▊       | 16/57 [00:12<00:39,  1.03it/s][A
 30%|██▉       | 17/57 [00:14<00:40,  1.01s/it][A
 32%|███▏      | 18/57 [00:15<00:40,  1.03s/it][A
 33%|███▎      | 19/57 [00:16<00:40,  1.05s/it][A
 35%|███▌      | 20/57 [00:17<00:39,  1.07s/it][A
 37%|███▋      | 21/57 [00:18<00:38,  

Epoch 24 validation: MAE: 1.3789, RMSE: 1.8478, Best MAE: 1.3789



  2%|▏         | 1/57 [00:00<00:54,  1.03it/s][A

[TRAIN] epoch 26/30 batch loss: 0.2365 (avg 0.2365) (263.10 im/s)



  4%|▎         | 2/57 [00:01<00:52,  1.04it/s][A
  5%|▌         | 3/57 [00:03<00:54,  1.00s/it][A
  7%|▋         | 4/57 [00:03<00:50,  1.06it/s][A
  9%|▉         | 5/57 [00:04<00:51,  1.01it/s][A
 11%|█         | 6/57 [00:05<00:44,  1.14it/s][A
 12%|█▏        | 7/57 [00:06<00:40,  1.24it/s][A
 14%|█▍        | 8/57 [00:06<00:36,  1.36it/s][A
 16%|█▌        | 9/57 [00:07<00:33,  1.42it/s][A
 18%|█▊        | 10/57 [00:08<00:38,  1.22it/s][A
 19%|█▉        | 11/57 [00:09<00:38,  1.20it/s][A
 21%|██        | 12/57 [00:10<00:40,  1.10it/s][A
 23%|██▎       | 13/57 [00:11<00:42,  1.03it/s][A
 25%|██▍       | 14/57 [00:12<00:36,  1.18it/s][A
 26%|██▋       | 15/57 [00:13<00:38,  1.08it/s][A
 28%|██▊       | 16/57 [00:14<00:40,  1.02it/s][A
 30%|██▉       | 17/57 [00:15<00:40,  1.02s/it][A
 32%|███▏      | 18/57 [00:16<00:37,  1.05it/s][A
 33%|███▎      | 19/57 [00:16<00:31,  1.20it/s][A
 35%|███▌      | 20/57 [00:17<00:33,  1.09it/s][A
 37%|███▋      | 21/57 [00:18<00:35,  

Epoch 25 validation: MAE: 1.4604, RMSE: 1.8928, Best MAE: 1.3789



  2%|▏         | 1/57 [00:00<00:32,  1.74it/s][A

[TRAIN] epoch 27/30 batch loss: 0.2716 (avg 0.2716) (435.82 im/s)



  4%|▎         | 2/57 [00:01<00:35,  1.54it/s][A
  5%|▌         | 3/57 [00:02<00:42,  1.28it/s][A
  7%|▋         | 4/57 [00:03<00:42,  1.26it/s][A
  9%|▉         | 5/57 [00:04<00:46,  1.13it/s][A
 11%|█         | 6/57 [00:05<00:44,  1.15it/s][A
 12%|█▏        | 7/57 [00:06<00:47,  1.06it/s][A
 14%|█▍        | 8/57 [00:06<00:40,  1.21it/s][A
 16%|█▌        | 9/57 [00:07<00:39,  1.21it/s][A
 18%|█▊        | 10/57 [00:08<00:42,  1.10it/s][A
 19%|█▉        | 11/57 [00:09<00:44,  1.03it/s][A
 21%|██        | 12/57 [00:11<00:45,  1.01s/it][A
 23%|██▎       | 13/57 [00:11<00:39,  1.12it/s][A
 25%|██▍       | 14/57 [00:12<00:35,  1.22it/s][A
 26%|██▋       | 15/57 [00:12<00:32,  1.31it/s][A
 28%|██▊       | 16/57 [00:14<00:35,  1.15it/s][A
 30%|██▉       | 17/57 [00:15<00:37,  1.06it/s][A
 32%|███▏      | 18/57 [00:16<00:38,  1.00it/s][A
 33%|███▎      | 19/57 [00:17<00:39,  1.03s/it][A
 35%|███▌      | 20/57 [00:18<00:39,  1.06s/it][A
 37%|███▋      | 21/57 [00:19<00:38,  

Epoch 26 validation: MAE: 1.3637, RMSE: 1.8204, Best MAE: 1.3637



  2%|▏         | 1/57 [00:00<00:36,  1.54it/s][A

[TRAIN] epoch 28/30 batch loss: 0.2898 (avg 0.2898) (392.12 im/s)



  4%|▎         | 2/57 [00:01<00:39,  1.40it/s][A
  5%|▌         | 3/57 [00:02<00:37,  1.46it/s][A
  7%|▋         | 4/57 [00:02<00:35,  1.48it/s][A
  9%|▉         | 5/57 [00:03<00:34,  1.50it/s][A
 11%|█         | 6/57 [00:03<00:31,  1.61it/s][A
 12%|█▏        | 7/57 [00:04<00:31,  1.60it/s][A
 14%|█▍        | 8/57 [00:05<00:33,  1.47it/s][A
 16%|█▌        | 9/57 [00:06<00:38,  1.24it/s][A
 18%|█▊        | 10/57 [00:07<00:42,  1.11it/s][A
 19%|█▉        | 11/57 [00:08<00:44,  1.04it/s][A
 21%|██        | 12/57 [00:09<00:45,  1.01s/it][A
 23%|██▎       | 13/57 [00:10<00:38,  1.15it/s][A
 25%|██▍       | 14/57 [00:11<00:40,  1.06it/s][A
 26%|██▋       | 15/57 [00:12<00:35,  1.18it/s][A
 28%|██▊       | 16/57 [00:12<00:34,  1.17it/s][A
 30%|██▉       | 17/57 [00:13<00:31,  1.28it/s][A
 32%|███▏      | 18/57 [00:14<00:34,  1.14it/s][A
 33%|███▎      | 19/57 [00:15<00:31,  1.20it/s][A
 35%|███▌      | 20/57 [00:16<00:33,  1.10it/s][A
 37%|███▋      | 21/57 [00:17<00:29,  

Epoch 27 validation: MAE: 1.5045, RMSE: 1.9104, Best MAE: 1.3637



  2%|▏         | 1/57 [00:00<00:44,  1.25it/s][A

[TRAIN] epoch 29/30 batch loss: 0.3844 (avg 0.3844) (318.52 im/s)



  4%|▎         | 2/57 [00:01<00:41,  1.33it/s][A
  5%|▌         | 3/57 [00:02<00:46,  1.16it/s][A
  7%|▋         | 4/57 [00:03<00:41,  1.26it/s][A
  9%|▉         | 5/57 [00:03<00:37,  1.37it/s][A
 11%|█         | 6/57 [00:04<00:34,  1.47it/s][A
 12%|█▏        | 7/57 [00:05<00:40,  1.23it/s][A
 14%|█▍        | 8/57 [00:06<00:43,  1.12it/s][A
 16%|█▌        | 9/57 [00:07<00:38,  1.25it/s][A
 18%|█▊        | 10/57 [00:08<00:41,  1.12it/s][A
 19%|█▉        | 11/57 [00:08<00:36,  1.26it/s][A
 21%|██        | 12/57 [00:09<00:36,  1.25it/s][A
 23%|██▎       | 13/57 [00:10<00:35,  1.22it/s][A
 25%|██▍       | 14/57 [00:11<00:34,  1.26it/s][A
 26%|██▋       | 15/57 [00:11<00:30,  1.38it/s][A
 28%|██▊       | 16/57 [00:12<00:31,  1.29it/s][A
 30%|██▉       | 17/57 [00:13<00:30,  1.31it/s][A
 32%|███▏      | 18/57 [00:14<00:33,  1.16it/s][A
 33%|███▎      | 19/57 [00:15<00:35,  1.07it/s][A
 35%|███▌      | 20/57 [00:16<00:36,  1.01it/s][A
 37%|███▋      | 21/57 [00:17<00:36,  

Epoch 28 validation: MAE: 1.4207, RMSE: 1.8495, Best MAE: 1.3637



  2%|▏         | 1/57 [00:00<00:47,  1.18it/s][A

[TRAIN] epoch 30/30 batch loss: 0.2139 (avg 0.2139) (299.23 im/s)



  4%|▎         | 2/57 [00:01<00:42,  1.30it/s][A
  5%|▌         | 3/57 [00:02<00:47,  1.14it/s][A
  7%|▋         | 4/57 [00:03<00:49,  1.06it/s][A
  9%|▉         | 5/57 [00:04<00:48,  1.08it/s][A
 11%|█         | 6/57 [00:05<00:45,  1.13it/s][A
 12%|█▏        | 7/57 [00:05<00:39,  1.27it/s][A
 14%|█▍        | 8/57 [00:06<00:38,  1.28it/s][A
 16%|█▌        | 9/57 [00:07<00:42,  1.14it/s][A
 18%|█▊        | 10/57 [00:08<00:44,  1.06it/s][A
 19%|█▉        | 11/57 [00:09<00:45,  1.00it/s][A
 21%|██        | 12/57 [00:10<00:42,  1.06it/s][A
 23%|██▎       | 13/57 [00:11<00:36,  1.20it/s][A
 25%|██▍       | 14/57 [00:12<00:39,  1.09it/s][A
 26%|██▋       | 15/57 [00:13<00:41,  1.02it/s][A
 28%|██▊       | 16/57 [00:14<00:41,  1.02s/it][A
 30%|██▉       | 17/57 [00:15<00:38,  1.04it/s][A
 32%|███▏      | 18/57 [00:16<00:39,  1.00s/it][A
 33%|███▎      | 19/57 [00:17<00:39,  1.04s/it][A
 35%|███▌      | 20/57 [00:18<00:39,  1.06s/it][A
 37%|███▋      | 21/57 [00:19<00:35,  

Epoch 29 validation: MAE: 1.4097, RMSE: 1.8352, Best MAE: 1.3637


In [1]:
model = GraphRec(user_count+1, item_count+1, rate_count+1, args.embed_dim).to(device)
test_data = GRDataset(test_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
test_loader = DataLoader(test_data, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn)
print('Load checkpoint and testing...')
ckpt = torch.load('latest_checkpoint.pth.tar')
model.load_state_dict(ckpt['state_dict'])
mae, rmse = validate(test_loader, model)
print("Test: MAE: {:.4f}, RMSE: {:.4f}".format(mae, rmse))


NameError: ignored