In [4]:
%load_ext autoreload
%autoreload 2

In [8]:
import pandas as pd
import numpy as np
import json
import os
from collections import Counter
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
import re
import seaborn as sns
import matplotlib.pyplot as plt
import swifter
import multiprocessing
import time
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import statsmodels.formula.api as smf
from collections import defaultdict
from dataclasses import dataclass
from torch import nn
import torch.nn.functional as F
import torch
import random
import copy
from tqdm import tqdm

sns.set_style("darkgrid")

# Loading Processed Reviews

In [10]:
# Paths
PROCESSED_FOLDER = './data/processed/'
PROCESSED_REVIEWS_FILE = 'processed_reviews.csv'
PROCESSED_REVIEWS_FILE = 'processed_reviews_with_sentiment.csv'

In [11]:
reviews = pd.read_csv(os.path.join(PROCESSED_FOLDER, PROCESSED_REVIEWS_FILE))

In [12]:
reviews.head()

Unnamed: 0,review_id,user_id,item_id,text,rating,tokenized_text
0,255938,0,1,"First things first. My ""reviews"" system is exp...",8,"[['First', 'things', 'first', '.'], ['My', '``..."
1,259117,1,2,Let me start off by saying that Made in Abyss ...,10,"[['Let', 'me', 'start', 'off', 'by', 'saying',..."
2,253664,2,3,"Art 9/10: It is great, especially the actions ...",7,"[['Art', '9/10', ':', 'It', 'is', 'great', ','..."
3,247454,3,4,As someone who loves Studio Ghibli and its mov...,6,"[['As', 'someone', 'who', 'loves', 'Studio', '..."
4,23791,4,5,code geass is one of those series that everybo...,10,"[['code', 'geass', 'is', 'one', 'of', 'those',..."


# Converting Data for Modeling

In [13]:
# Convert item_id to 0 indexed
if min(reviews['item_id']) != 0:
    reviews['item_id'] = reviews['item_id'] - 1
    print("Done")

Done


In [14]:
@dataclass
class Review:
    user_id: int
    item_id: int
    rating: int
    text: str
    pos_sent: float
    neg_sent: float
    compound_sent: float

In [15]:
USER_KEY = 'user_id'
ITEM_KEY = 'item_id'
RATING_KEY = 'rating'

In [16]:
reviews.head()

Unnamed: 0,review_id,user_id,item_id,text,rating,tokenized_text
0,255938,0,0,"First things first. My ""reviews"" system is exp...",8,"[['First', 'things', 'first', '.'], ['My', '``..."
1,259117,1,1,Let me start off by saying that Made in Abyss ...,10,"[['Let', 'me', 'start', 'off', 'by', 'saying',..."
2,253664,2,2,"Art 9/10: It is great, especially the actions ...",7,"[['Art', '9/10', ':', 'It', 'is', 'great', ','..."
3,247454,3,3,As someone who loves Studio Ghibli and its mov...,6,"[['As', 'someone', 'who', 'loves', 'Studio', '..."
4,23791,4,4,code geass is one of those series that everybo...,10,"[['code', 'geass', 'is', 'one', 'of', 'those',..."


In [17]:
user_to_reviews = defaultdict(list)
for _, row in reviews.iterrows():
    user_id, item_id, rating, text = row[USER_KEY], row[ITEM_KEY], row[RATING_KEY], row['text']
    pos_sent, neg_sent, compound_sent = row['pos_sent_avg'], row['neg_sent_avg'], row['compound_sent_avg']
    user_to_reviews[user_id].append(Review(user_id, item_id, rating, text, pos_sent, neg_sent, compound_sent))

KeyError: 'pos_sent_avg'

## Creating the score matrix

In [18]:
# users by items
X = np.zeros(shape=(reviews['user_id'].nunique(), reviews['item_id'].nunique()))

In [19]:
for _, row in reviews.iterrows():
    user_id, item_id, rating = row[USER_KEY], row[ITEM_KEY], row[RATING_KEY]
    X[user_id][item_id] = rating

## Train/Test Split

In [13]:
train_X = copy.deepcopy(X)
valid_X = np.zeros(shape=X.shape)
test_X = np.zeros(shape=X.shape)

for user_id, reviews in user_to_reviews.items():
    # can confirm this actually shuffles properly (this code block works)
    random.shuffle(reviews)

    # Leave one out for valid
    valid_review = reviews[0]
    train_X[valid_review.user_id][valid_review.item_id] = 0
    valid_X[valid_review.user_id][valid_review.item_id] = valid_review.rating
    
    # Leave one out for test
    test_review = reviews[1]
    train_X[test_review.user_id][test_review.item_id] = 0
    test_X[test_review.user_id][test_review.item_id] = test_review.rating
    
    # Rest for train

## Creating bias terms for users / items from the training data

In [21]:
# users
user_to_pos_sent = defaultdict(list)
user_to_neg_sent = defaultdict(list)
user_to_compound_sent = defaultdict(list)

# items
item_to_pos_sent = defaultdict(list)
item_to_neg_sent = defaultdict(list)
item_to_compound_sent = defaultdict(list)

# loadding
for user_id, reviews in user_to_reviews.items():
    for r in reviews:
        # skip if not in train
        if train_X[user_id, r.item_id] == 0:
            continue
        user_to_pos_sent[user_id].append(r.pos_sent)
        user_to_neg_sent[user_id].append(r.neg_sent)
        user_to_compound_sent[user_id].append(r.compound_sent)
        item_to_pos_sent[r.item_id].append(r.pos_sent)
        item_to_neg_sent[r.item_id].append(r.neg_sent)
        item_to_compound_sent[r.item_id].append(r.compound_sent)

In [75]:
# Averaging values to get bias term
def list_mapping_to_float_mapping(hm: dict):
    id_to_sent_term = defaultdict(float)
    for k, v in hm.items():
        id_to_sent_term[k] = np.mean(v)
    return id_to_sent_term

user_to_pos_sent_term = list_mapping_to_float_mapping(user_to_pos_sent)
user_to_neg_sent_term = list_mapping_to_float_mapping(user_to_neg_sent)
user_to_compound_sent_term = list_mapping_to_float_mapping(user_to_compound_sent)
item_to_pos_sent_term = list_mapping_to_float_mapping(item_to_pos_sent)
item_to_neg_sent_term = list_mapping_to_float_mapping(item_to_neg_sent)
item_to_compound_sent_term = list_mapping_to_float_mapping(item_to_compound_sent)

## Vanilla NCF (w/ MLP)

In [150]:
def l2_regularization(values):
    return torch.sum(torch.square(values))

class VanillaNCF(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim=20, regularization_constant=1e-6, eps=1e-8):
        super().__init__()
        self.user_factors = nn.Embedding(num_users, embedding_dim)
        self.item_factors = nn.Embedding(num_items, embedding_dim)
        self.regularization_constant = regularization_constant
        self.eps = eps
        
        # NCF layers
        self.fc1 = nn.Linear(2*embedding_dim, 128)
        self.fc2 = nn.Linear(128, 10)
        self.fc3 = nn.Linear(10, 1)
        
        
    def forward(self, user: torch.LongTensor, item: torch.LongTensor):
        # user is shape (users, 1)
        # item is shape (items, 1)
        # embedding output shape is (*, emb_dim) = (users/items, emb_dim)
        user_latent_factors = self.user_factors(user)
        item_latent_factors = self.item_factors(item)
        
        # FC takes (*, in_dim) and outputs (*, out_dim)
        output = self.fc1(torch.cat((user_latent_factors, item_latent_factors), dim=1))
        output = self.fc2(output)
        output = self.fc(output)
        
        # Clip in the desired range
        pred_rating = 1 + 9 * torch.sigmoid(output)        
        return pred_rating.diagonal()
    
    def loss(self, pred_rating: torch.LongTensor, rating: torch.LongTensor, rmse=False):
        if rmse:
            loss = torch.sqrt(F.mse_loss(pred_rating, rating) + self.eps)
        else:
            loss = F.mse_loss(pred_rating, rating) + self.eps
        
        
        # L2 Regularization
        sum_of_squared_values = l2_regularization(self.user_factors.weight) + l2_regularization(self.item_factors.weight)
        l2_penalty = (1/len(rating)) * self.regularization_constant * sum_of_squared_values
        
        # Total Loss
        total_loss = loss + l2_penalty
        return total_loss
    
    def predict_single_interaction(self, user_id: int, item_id: int):
        user = torch.LongTensor([user_id])
        item = torch.LongTensor([item_id])
        return self.forward(user, item)

In [122]:
def eval_MSE_loss(eval_X, model, round_digits=3):
    """Uses reduction mean"""
    user_ids_list, item_ids_list = eval_X.nonzero()
    gt_ratings = torch.FloatTensor([eval_X[user_id, item_id] for user_id, item_id in zip(user_ids_list, item_ids_list)])
    curr_users_tensor = torch.LongTensor(user_ids_list)
    curr_items_tensor = torch.LongTensor(item_ids_list)
    pred_ratings = model.forward(curr_users_tensor, curr_items_tensor)
    
    return round(F.mse_loss(pred_ratings, gt_ratings).item(), 3)

def eval_RMSE_loss(eval_X, model):
    """Uses reduction mean"""
    user_ids_list, item_ids_list = eval_X.nonzero()
    gt_ratings = torch.FloatTensor([eval_X[user_id, item_id] for user_id, item_id in zip(user_ids_list, item_ids_list)])
    curr_users_tensor = torch.LongTensor(user_ids_list)
    curr_items_tensor = torch.LongTensor(item_ids_list)
    pred_ratings = model.forward(curr_users_tensor, curr_items_tensor)
    
    return round(torch.sqrt(F.mse_loss(pred_ratings, gt_ratings)).item(), 3)

def eval_MAE_loss(eval_X, model):
    """Uses reduction mean"""
    user_ids_list, item_ids_list = eval_X.nonzero()
    gt_ratings = torch.FloatTensor([eval_X[user_id, item_id] for user_id, item_id in zip(user_ids_list, item_ids_list)])
    curr_users_tensor = torch.LongTensor(user_ids_list)
    curr_items_tensor = torch.LongTensor(item_ids_list)
    pred_ratings = model.forward(curr_users_tensor, curr_items_tensor)
    
    return round(F.l1_loss(pred_ratings, gt_ratings).item(), 3)


def train_v2(train_X, valid_X, model, optimizer, n_epochs=10, batch_size=5, rmse=False):
    """Training Function, calculates training and validation loss"""
    
    for epoch in (range(1, n_epochs+1)):
        users, items = train_X.nonzero()
        num_examples = len(users)
        permuted_indices = np.random.permutation(num_examples)
        users, items = users[permuted_indices], items[permuted_indices]
        

        total_train_loss = 0
        
        
        for i in tqdm(range(num_examples // batch_size)):
            user_ids_list = users[i*batch_size:i*batch_size+batch_size]
            item_ids_list = items[i*batch_size:i*batch_size+batch_size]

            # Set gradients to zero
            optimizer.zero_grad()

            # Turn data into tensors
            rating = torch.FloatTensor([train_X[user_id, item_id] for user_id, item_id in zip(user_ids_list, item_ids_list)])
            curr_users_tensor = torch.LongTensor(user_ids_list)
            curr_items_tensor = torch.LongTensor(item_ids_list)

            # Predict and calculate loss
            pred_rating = model.forward(curr_users_tensor, curr_items_tensor)
            assert pred_rating.shape == rating.shape
            
            ## SELECTING LOSS HERE
            # loss = model.loss(pred_rating, rating)
            loss = model.loss(pred_rating, rating, rmse=rmse)

            # Backpropagate
            loss.backward()

            # Update the parameters
            optimizer.step()

            # MSE Loss w/o regularization (just for status updates)
            total_train_loss += F.mse_loss(pred_rating, rating, reduction='sum')

        # Computing validation loss for display
        total_valid_loss = eval_MSE_loss(valid_X, model)
        total_valid_RMSE_loss = eval_RMSE_loss(valid_X, model)
        total_valid_MAE_loss = eval_MAE_loss(valid_X, model)
        
        print(f"Epoch {epoch} MSE Loss: {round(total_train_loss.item() / (batch_size * (num_examples//batch_size)), 3)}, valid MSE Loss: {total_valid_loss}, valid RMSE Loss: {total_valid_RMSE_loss}, valid MAE Loss: {total_valid_MAE_loss}")

## Config Cuda

In [None]:
!nvidia-smi

In [9]:
## Forcing GPU
assert torch.cuda.is_available()
torch.cuda.set_device("cuda:5")
device = torch.device("cuda")
a = torch.tensor([[1., 2.], [3., 4.]]).to(device)
a

tensor([[1., 2.],
        [3., 4.]], device='cuda:5')

## Training

In [157]:
# these parameter settings are pretty good, just adjust LR after you get low enough => MSE of 3.6)
# batch_size=64
# Adam
# weight decay in regularization constant
embedding_dim=200
lr=1e-2
regularization_constant=1e-2
sentiment_regularization_constant=0

model = SentimentMF(num_users=X.shape[0], num_items=X.shape[1], 
                    embedding_dim=embedding_dim, 
                    regularization_constant=regularization_constant,
                    sentiment_regularization_constant=sentiment_regularization_constant
                   )
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [1]:
# embedding_dim=400
# lr=1e-3
# regularization_constant=1e-2
# sentiment_regularization_constant=0

# model = SentimentMF(num_users=X.shape[0], num_items=X.shape[1], 
#                     embedding_dim=embedding_dim, 
#                     regularization_constant=regularization_constant,
#                     sentiment_regularization_constant=sentiment_regularization_constant
#                    )
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)

## Testing Above

In [211]:
# On-the-fly modifications
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model.regularization_constant = .011

In [214]:
train_v2(train_X, valid_X, model, optimizer, n_epochs=100, batch_size=32, rmse=True)

100%|███████████████████████████████████| 1017/1017 [00:12<00:00, 83.70it/s]


Epoch 1 MSE Loss: 1.187, valid MSE Loss: 2.99, valid RMSE Loss: 1.729, valid MAE Loss: 1.296


100%|███████████████████████████████████| 1017/1017 [00:12<00:00, 83.17it/s]


Epoch 2 MSE Loss: 1.496, valid MSE Loss: 2.978, valid RMSE Loss: 1.726, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:12<00:00, 84.20it/s]


Epoch 3 MSE Loss: 1.703, valid MSE Loss: 2.973, valid RMSE Loss: 1.724, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 84.92it/s]


Epoch 4 MSE Loss: 1.845, valid MSE Loss: 2.969, valid RMSE Loss: 1.723, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:12<00:00, 84.67it/s]


Epoch 5 MSE Loss: 1.929, valid MSE Loss: 2.968, valid RMSE Loss: 1.723, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.14it/s]


Epoch 6 MSE Loss: 1.983, valid MSE Loss: 2.975, valid RMSE Loss: 1.725, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.30it/s]


Epoch 7 MSE Loss: 2.02, valid MSE Loss: 2.978, valid RMSE Loss: 1.726, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.13it/s]


Epoch 8 MSE Loss: 2.042, valid MSE Loss: 2.977, valid RMSE Loss: 1.725, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.08it/s]


Epoch 9 MSE Loss: 2.059, valid MSE Loss: 2.978, valid RMSE Loss: 1.726, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.78it/s]


Epoch 10 MSE Loss: 2.069, valid MSE Loss: 2.977, valid RMSE Loss: 1.725, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.44it/s]


Epoch 11 MSE Loss: 2.073, valid MSE Loss: 2.977, valid RMSE Loss: 1.725, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.80it/s]


Epoch 12 MSE Loss: 2.075, valid MSE Loss: 2.979, valid RMSE Loss: 1.726, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.31it/s]


Epoch 13 MSE Loss: 2.077, valid MSE Loss: 2.981, valid RMSE Loss: 1.727, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.33it/s]


Epoch 14 MSE Loss: 2.075, valid MSE Loss: 2.982, valid RMSE Loss: 1.727, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.31it/s]


Epoch 15 MSE Loss: 2.074, valid MSE Loss: 2.983, valid RMSE Loss: 1.727, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.57it/s]


Epoch 16 MSE Loss: 2.073, valid MSE Loss: 2.983, valid RMSE Loss: 1.727, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.13it/s]


Epoch 17 MSE Loss: 2.07, valid MSE Loss: 2.978, valid RMSE Loss: 1.726, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.04it/s]


Epoch 18 MSE Loss: 2.068, valid MSE Loss: 2.981, valid RMSE Loss: 1.727, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.90it/s]


Epoch 19 MSE Loss: 2.065, valid MSE Loss: 2.982, valid RMSE Loss: 1.727, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.58it/s]


Epoch 20 MSE Loss: 2.063, valid MSE Loss: 2.985, valid RMSE Loss: 1.728, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.51it/s]


Epoch 21 MSE Loss: 2.06, valid MSE Loss: 2.984, valid RMSE Loss: 1.728, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 84.84it/s]


Epoch 22 MSE Loss: 2.059, valid MSE Loss: 2.985, valid RMSE Loss: 1.728, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.62it/s]


Epoch 23 MSE Loss: 2.056, valid MSE Loss: 2.987, valid RMSE Loss: 1.728, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 87.97it/s]


Epoch 24 MSE Loss: 2.053, valid MSE Loss: 2.982, valid RMSE Loss: 1.727, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.20it/s]


Epoch 25 MSE Loss: 2.049, valid MSE Loss: 2.986, valid RMSE Loss: 1.728, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.29it/s]


Epoch 26 MSE Loss: 2.048, valid MSE Loss: 2.987, valid RMSE Loss: 1.728, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.58it/s]


Epoch 27 MSE Loss: 2.045, valid MSE Loss: 2.985, valid RMSE Loss: 1.728, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.28it/s]


Epoch 28 MSE Loss: 2.045, valid MSE Loss: 2.987, valid RMSE Loss: 1.728, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.27it/s]


Epoch 29 MSE Loss: 2.042, valid MSE Loss: 2.988, valid RMSE Loss: 1.728, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.63it/s]


Epoch 30 MSE Loss: 2.039, valid MSE Loss: 2.987, valid RMSE Loss: 1.728, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 87.23it/s]


Epoch 31 MSE Loss: 2.033, valid MSE Loss: 2.987, valid RMSE Loss: 1.728, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.24it/s]


Epoch 32 MSE Loss: 2.035, valid MSE Loss: 2.99, valid RMSE Loss: 1.729, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.61it/s]


Epoch 33 MSE Loss: 2.031, valid MSE Loss: 2.987, valid RMSE Loss: 1.728, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.72it/s]


Epoch 34 MSE Loss: 2.03, valid MSE Loss: 2.988, valid RMSE Loss: 1.729, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 87.06it/s]


Epoch 35 MSE Loss: 2.029, valid MSE Loss: 2.993, valid RMSE Loss: 1.73, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.54it/s]


Epoch 36 MSE Loss: 2.025, valid MSE Loss: 2.988, valid RMSE Loss: 1.729, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.24it/s]


Epoch 37 MSE Loss: 2.022, valid MSE Loss: 2.99, valid RMSE Loss: 1.729, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.30it/s]


Epoch 38 MSE Loss: 2.021, valid MSE Loss: 2.99, valid RMSE Loss: 1.729, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.10it/s]


Epoch 39 MSE Loss: 2.019, valid MSE Loss: 2.991, valid RMSE Loss: 1.729, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 87.09it/s]


Epoch 40 MSE Loss: 2.018, valid MSE Loss: 2.991, valid RMSE Loss: 1.729, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.97it/s]


Epoch 41 MSE Loss: 2.017, valid MSE Loss: 2.991, valid RMSE Loss: 1.729, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.44it/s]


Epoch 42 MSE Loss: 2.011, valid MSE Loss: 2.993, valid RMSE Loss: 1.73, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 86.78it/s]


Epoch 43 MSE Loss: 2.011, valid MSE Loss: 2.988, valid RMSE Loss: 1.729, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 87.21it/s]


Epoch 44 MSE Loss: 2.01, valid MSE Loss: 2.993, valid RMSE Loss: 1.73, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 85.95it/s]


Epoch 45 MSE Loss: 2.01, valid MSE Loss: 2.992, valid RMSE Loss: 1.73, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.65it/s]


Epoch 46 MSE Loss: 2.007, valid MSE Loss: 2.992, valid RMSE Loss: 1.73, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.25it/s]


Epoch 47 MSE Loss: 2.007, valid MSE Loss: 2.994, valid RMSE Loss: 1.73, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.94it/s]


Epoch 48 MSE Loss: 2.003, valid MSE Loss: 2.992, valid RMSE Loss: 1.73, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.60it/s]


Epoch 49 MSE Loss: 2.0, valid MSE Loss: 2.995, valid RMSE Loss: 1.731, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.85it/s]


Epoch 50 MSE Loss: 2.002, valid MSE Loss: 2.994, valid RMSE Loss: 1.73, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.41it/s]


Epoch 51 MSE Loss: 1.999, valid MSE Loss: 2.993, valid RMSE Loss: 1.73, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 87.52it/s]


Epoch 52 MSE Loss: 1.997, valid MSE Loss: 2.995, valid RMSE Loss: 1.731, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.09it/s]


Epoch 53 MSE Loss: 1.997, valid MSE Loss: 2.996, valid RMSE Loss: 1.731, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.64it/s]


Epoch 54 MSE Loss: 1.995, valid MSE Loss: 2.995, valid RMSE Loss: 1.73, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.30it/s]


Epoch 55 MSE Loss: 1.993, valid MSE Loss: 2.994, valid RMSE Loss: 1.73, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.21it/s]


Epoch 56 MSE Loss: 1.994, valid MSE Loss: 2.995, valid RMSE Loss: 1.731, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.20it/s]


Epoch 57 MSE Loss: 1.992, valid MSE Loss: 2.995, valid RMSE Loss: 1.731, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 90.06it/s]


Epoch 58 MSE Loss: 1.988, valid MSE Loss: 2.997, valid RMSE Loss: 1.731, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.08it/s]


Epoch 59 MSE Loss: 1.988, valid MSE Loss: 2.996, valid RMSE Loss: 1.731, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.22it/s]


Epoch 60 MSE Loss: 1.986, valid MSE Loss: 2.999, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.52it/s]


Epoch 61 MSE Loss: 1.985, valid MSE Loss: 2.997, valid RMSE Loss: 1.731, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.37it/s]


Epoch 62 MSE Loss: 1.983, valid MSE Loss: 2.998, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.56it/s]


Epoch 63 MSE Loss: 1.982, valid MSE Loss: 3.001, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 90.35it/s]


Epoch 64 MSE Loss: 1.983, valid MSE Loss: 2.998, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.97it/s]


Epoch 65 MSE Loss: 1.981, valid MSE Loss: 2.999, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.02it/s]


Epoch 66 MSE Loss: 1.981, valid MSE Loss: 3.001, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.24it/s]


Epoch 67 MSE Loss: 1.977, valid MSE Loss: 3.0, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.25it/s]


Epoch 68 MSE Loss: 1.978, valid MSE Loss: 2.999, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.20it/s]


Epoch 69 MSE Loss: 1.976, valid MSE Loss: 3.0, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.63it/s]


Epoch 70 MSE Loss: 1.977, valid MSE Loss: 3.001, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.90it/s]


Epoch 71 MSE Loss: 1.974, valid MSE Loss: 3.002, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.67it/s]


Epoch 72 MSE Loss: 1.973, valid MSE Loss: 3.002, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.30it/s]


Epoch 73 MSE Loss: 1.972, valid MSE Loss: 3.0, valid RMSE Loss: 1.732, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.38it/s]


Epoch 74 MSE Loss: 1.971, valid MSE Loss: 3.001, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.50it/s]


Epoch 75 MSE Loss: 1.97, valid MSE Loss: 3.002, valid RMSE Loss: 1.733, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.47it/s]


Epoch 76 MSE Loss: 1.97, valid MSE Loss: 3.003, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.99it/s]


Epoch 77 MSE Loss: 1.969, valid MSE Loss: 3.003, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 87.98it/s]


Epoch 78 MSE Loss: 1.969, valid MSE Loss: 3.001, valid RMSE Loss: 1.732, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.06it/s]


Epoch 79 MSE Loss: 1.967, valid MSE Loss: 3.002, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.90it/s]


Epoch 80 MSE Loss: 1.965, valid MSE Loss: 3.001, valid RMSE Loss: 1.732, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.12it/s]


Epoch 81 MSE Loss: 1.965, valid MSE Loss: 3.003, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.29it/s]


Epoch 82 MSE Loss: 1.962, valid MSE Loss: 3.003, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.43it/s]


Epoch 83 MSE Loss: 1.961, valid MSE Loss: 3.004, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.76it/s]


Epoch 84 MSE Loss: 1.961, valid MSE Loss: 3.004, valid RMSE Loss: 1.733, valid MAE Loss: 1.293


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.06it/s]


Epoch 85 MSE Loss: 1.962, valid MSE Loss: 3.005, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.06it/s]


Epoch 86 MSE Loss: 1.961, valid MSE Loss: 3.005, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.67it/s]


Epoch 87 MSE Loss: 1.959, valid MSE Loss: 3.007, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.50it/s]


Epoch 88 MSE Loss: 1.959, valid MSE Loss: 3.005, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.87it/s]


Epoch 89 MSE Loss: 1.959, valid MSE Loss: 3.002, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.17it/s]


Epoch 90 MSE Loss: 1.958, valid MSE Loss: 3.003, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.90it/s]


Epoch 91 MSE Loss: 1.957, valid MSE Loss: 3.004, valid RMSE Loss: 1.733, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.75it/s]


Epoch 92 MSE Loss: 1.955, valid MSE Loss: 3.005, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.00it/s]


Epoch 93 MSE Loss: 1.956, valid MSE Loss: 3.006, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.97it/s]


Epoch 94 MSE Loss: 1.955, valid MSE Loss: 3.008, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 89.77it/s]


Epoch 95 MSE Loss: 1.955, valid MSE Loss: 3.006, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.40it/s]


Epoch 96 MSE Loss: 1.954, valid MSE Loss: 3.007, valid RMSE Loss: 1.734, valid MAE Loss: 1.295


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.75it/s]


Epoch 97 MSE Loss: 1.954, valid MSE Loss: 3.006, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.98it/s]


Epoch 98 MSE Loss: 1.952, valid MSE Loss: 3.006, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 87.83it/s]


Epoch 99 MSE Loss: 1.951, valid MSE Loss: 3.007, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


100%|███████████████████████████████████| 1017/1017 [00:11<00:00, 88.49it/s]


Epoch 100 MSE Loss: 1.951, valid MSE Loss: 3.007, valid RMSE Loss: 1.734, valid MAE Loss: 1.294


## Evaluation

In [217]:
total_test_loss = eval_MSE_loss(test_X, model)
total_test_RMSE_loss = eval_RMSE_loss(test_X, model)
total_test_MAE_loss = eval_MAE_loss(test_X, model)
print(f"test MSE Loss: {total_test_loss}, test RMSE Loss: {total_test_RMSE_loss}, test MAE Loss: {total_test_MAE_loss}")

test MSE Loss: 2.912, test RMSE Loss: 1.707, test MAE Loss: 1.273
