In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import json
import pickle

In [3]:
df = pd.read_csv('/content/drive/MyDrive/집교 2_Team P/user-track-listen_count_filtered5.csv')

In [4]:
# Pickle 파일 읽기
with open('/content/drive/MyDrive/집교 2_Team P/lyrics_Embedding/all_embeddings_roberta-emotion_full.pkl', 'rb') as file:
    data = pickle.load(file)

# DataFrame으로 변환
df_embedding = pd.DataFrame(data, columns=['embedding', 'track_id'])

# track_id를 정수로 변환 (필요하다면)
df_embedding['track_id'] = df_embedding['track_id'].astype(int)

# 'embedding' 열을 768차원의 각 차원으로 나누기
# df_embedding[['embedding_{}'.format(i) for i in range(768)]] = pd.DataFrame(df_embedding['embedding'].tolist(), index=df_embedding.index)

# 'embedding' 열 삭제
# df_embedding = df_embedding.drop(['embedding'], axis=1)

# DataFrame 확인
print(df_embedding.head())


                                           embedding  track_id
0  [-0.43970332, 0.17358397, 0.52987355, 0.215903...         2
1  [-0.91214895, -0.50413597, -0.6426596, 0.88957...         8
2  [-0.35779303, -0.21245702, -0.4350589, 1.27470...      1524
3  [-0.5116009, 0.5157131, 0.46193805, -0.3704641...      1785
4  [-1.1975511, 0.5550555, -0.049068704, 0.417929...      1787


In [5]:
print(df.shape)
df = pd.merge(df, df_embedding, on='track_id', how='inner')
df.shape

(4645010, 4)


(4644051, 5)

In [6]:
from sklearn.preprocessing import LabelEncoder
user_encoder = LabelEncoder()
track_encoder = LabelEncoder()
lyrics_encoder = LabelEncoder()
df['user_id'] = user_encoder.fit_transform(df['user_id'])
df['track_id'] = track_encoder.fit_transform(df['track_id'])
df_embedding['track_id'] = lyrics_encoder.fit_transform(df_embedding['track_id'])

In [7]:
lyrics_dict = dict(zip(df_embedding['track_id'], df_embedding['embedding']))

In [8]:
# !pip install torch torchvision -U

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import pandas as pd
import numpy as np
from tqdm import tqdm

# 데이터 불러오기 (예시: CSV 파일)
# Label Encoding
# PyTorch DataLoader에 맞게 데이터 변환
def df_to_tensor(dataset):
    users = torch.tensor(dataset['user_id'].values, dtype=torch.int)
    items = torch.tensor(dataset['track_id'].values, dtype=torch.int)
    ratings = torch.tensor(dataset['listen_count_bin'].values, dtype=torch.float)
    lyrics_embeddings = torch.tensor(np.vstack(dataset['embedding'].values), dtype=torch.float)
    return users, items, ratings, lyrics_embeddings

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

train_users, train_items, train_ratings, train_lyrics_embeddings = df_to_tensor(train_df)
test_users, test_items, test_ratings,test_lyrics_embeddings = df_to_tensor(test_df)

train_data = TensorDataset(train_users, train_items, train_ratings,train_lyrics_embeddings)
test_data = TensorDataset(test_users, test_items, test_ratings,test_lyrics_embeddings)

train_loader = DataLoader(train_data, batch_size=256, shuffle=True)
test_loader = DataLoader(test_data, batch_size=256, shuffle=False)
num_users = (df['user_id'].nunique())
num_items = (df['track_id'].nunique())
print(num_users)
print(num_items)

23761
28309


In [11]:
# NCF 모델 정의
class NCF_embedding(nn.Module):
    def __init__(self, num_users, num_items, embedding_size):
        super(NCF_embedding, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_size)
        self.item_embedding = nn.Embedding(num_items, embedding_size)
        self.lyrics_embedding = nn.Linear(768, embedding_size)
        self.fc_layers = nn.Sequential(
            nn.Linear(embedding_size * 3, embedding_size),
            nn.ReLU(),
            nn.Linear(embedding_size, int(embedding_size/2)),
            nn.ReLU(),
            nn.Linear(int(embedding_size/2), int(embedding_size/4)),
            nn.ReLU(),
            nn.Linear(int(embedding_size/4), int(embedding_size/8)),
            nn.ReLU(),
            nn.Linear(int(embedding_size/8), 1)
        )

    def forward(self, user, item, lyrics_embedding):
        user_embedding = self.user_embedding(user)
        item_embedding = self.item_embedding(item)
        lyrics_embedding = self.lyrics_embedding(lyrics_embedding.reshape(lyrics_embedding.shape[0],768))
        x = torch.cat((user_embedding, item_embedding,lyrics_embedding), dim=1)
        x = self.fc_layers(x)
        return x

# CUDA 디바이스 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
def train_embedding(embedding,n_epoch):
  # NCF 모델 정의 및 GPU로 이동
  model = NCF_embedding(num_users=num_users, num_items=num_items, embedding_size=embedding)
  model.to(device)
  criterion = nn.MSELoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  # tqdm을 사용하여 학습 및 테스트 진행 상황 확인
  num_epochs = n_epoch
  for epoch in range(num_epochs):
      model.train()
      total_loss = 0
      for user, item, rating,lyrics_embedding in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
          optimizer.zero_grad()
          user, item, rating,lyrics_embedding = user.to(device), item.to(device), rating.to(device),lyrics_embedding.to(device)  # GPU로 이동
          output = model(user, item,lyrics_embedding.float())
          loss = criterion(output, rating.unsqueeze(1))
          loss.backward()
          optimizer.step()
          total_loss += loss.item()

      avg_loss = total_loss / len(train_loader)
      print(f'Epoch {epoch+1}/{num_epochs}, Avg. Loss: {avg_loss:.4f}')

      # 각 에폭이 끝날 때마다 테스트 데이터에 대한 예측 수행
      model.eval()
      all_predictions = []
      with torch.no_grad():
          for user, item, _,lyrics_embedding in tqdm(test_loader, desc=f'Testing Epoch {epoch+1}'):
              user, item,lyrics_embedding = user.to(device), item.to(device),lyrics_embedding.to(device)  # GPU로 이동
              output = model(user, item,lyrics_embedding)
              all_predictions.append(output)

      # RMSE 계산
      predictions = torch.cat(all_predictions).squeeze().cpu().numpy()  # CPU로 이동 후 numpy로 변환
      rmse = np.sqrt(mean_squared_error(test_df['listen_count_bin'].values, predictions))
      print(f'Epoch {epoch+1}/{num_epochs}, RMSE on test set: {rmse}')


Using device: cuda


In [12]:
train_embedding(64,24)

Epoch 1/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.74it/s]


Epoch 1/24, Avg. Loss: 1.9840


Testing Epoch 1: 100%|██████████| 3629/3629 [00:16<00:00, 225.14it/s]


Epoch 1/24, RMSE on test set: 1.280326800840603


Epoch 2/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.53it/s]


Epoch 2/24, Avg. Loss: 1.6426


Testing Epoch 2: 100%|██████████| 3629/3629 [00:16<00:00, 225.59it/s]


Epoch 2/24, RMSE on test set: 1.2803703243199893


Epoch 3/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.12it/s]


Epoch 3/24, Avg. Loss: 1.6426


Testing Epoch 3: 100%|██████████| 3629/3629 [00:15<00:00, 231.41it/s]


Epoch 3/24, RMSE on test set: 1.2803750972477477


Epoch 4/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.75it/s]


Epoch 4/24, Avg. Loss: 1.6426


Testing Epoch 4: 100%|██████████| 3629/3629 [00:15<00:00, 230.08it/s]


Epoch 4/24, RMSE on test set: 1.2803226469744386


Epoch 5/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.71it/s]


Epoch 5/24, Avg. Loss: 1.6426


Testing Epoch 5: 100%|██████████| 3629/3629 [00:15<00:00, 228.48it/s]


Epoch 5/24, RMSE on test set: 1.280320563776937


Epoch 6/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.03it/s]


Epoch 6/24, Avg. Loss: 1.6426


Testing Epoch 6: 100%|██████████| 3629/3629 [00:15<00:00, 230.10it/s]


Epoch 6/24, RMSE on test set: 1.2803145037856176


Epoch 7/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.56it/s]


Epoch 7/24, Avg. Loss: 1.6426


Testing Epoch 7: 100%|██████████| 3629/3629 [00:15<00:00, 230.54it/s]


Epoch 7/24, RMSE on test set: 1.2803240474059676


Epoch 8/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.79it/s]


Epoch 8/24, Avg. Loss: 1.6426


Testing Epoch 8: 100%|██████████| 3629/3629 [00:15<00:00, 230.16it/s]


Epoch 8/24, RMSE on test set: 1.2803188959040028


Epoch 9/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.32it/s]


Epoch 9/24, Avg. Loss: 1.6426


Testing Epoch 9: 100%|██████████| 3629/3629 [00:15<00:00, 230.75it/s]


Epoch 9/24, RMSE on test set: 1.2803213847448607


Epoch 10/24: 100%|██████████| 14513/14513 [01:42<00:00, 140.95it/s]


Epoch 10/24, Avg. Loss: 1.6426


Testing Epoch 10: 100%|██████████| 3629/3629 [00:15<00:00, 231.01it/s]


Epoch 10/24, RMSE on test set: 1.2803284783345075


Epoch 11/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.03it/s]


Epoch 11/24, Avg. Loss: 1.6426


Testing Epoch 11: 100%|██████████| 3629/3629 [00:15<00:00, 230.14it/s]


Epoch 11/24, RMSE on test set: 1.2803263792158612


Epoch 12/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.89it/s]


Epoch 12/24, Avg. Loss: 1.6426


Testing Epoch 12: 100%|██████████| 3629/3629 [00:15<00:00, 227.68it/s]


Epoch 12/24, RMSE on test set: 1.2803214730755774


Epoch 13/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.57it/s]


Epoch 13/24, Avg. Loss: 1.6426


Testing Epoch 13: 100%|██████████| 3629/3629 [00:15<00:00, 232.76it/s]


Epoch 13/24, RMSE on test set: 1.2803695992869133


Epoch 14/24: 100%|██████████| 14513/14513 [01:42<00:00, 140.98it/s]


Epoch 14/24, Avg. Loss: 1.6426


Testing Epoch 14: 100%|██████████| 3629/3629 [00:15<00:00, 228.36it/s]


Epoch 14/24, RMSE on test set: 1.2803182666064268


Epoch 15/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.65it/s]


Epoch 15/24, Avg. Loss: 1.6426


Testing Epoch 15: 100%|██████████| 3629/3629 [00:15<00:00, 229.10it/s]


Epoch 15/24, RMSE on test set: 1.280409498921844


Epoch 16/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.71it/s]


Epoch 16/24, Avg. Loss: 1.6426


Testing Epoch 16: 100%|██████████| 3629/3629 [00:15<00:00, 231.92it/s]


Epoch 16/24, RMSE on test set: 1.2803204514522328


Epoch 17/24: 100%|██████████| 14513/14513 [01:42<00:00, 140.93it/s]


Epoch 17/24, Avg. Loss: 1.6426


Testing Epoch 17: 100%|██████████| 3629/3629 [00:15<00:00, 232.47it/s]


Epoch 17/24, RMSE on test set: 1.2803161403045729


Epoch 18/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.10it/s]


Epoch 18/24, Avg. Loss: 1.6426


Testing Epoch 18: 100%|██████████| 3629/3629 [00:15<00:00, 230.92it/s]


Epoch 18/24, RMSE on test set: 1.2803186267066449


Epoch 19/24: 100%|██████████| 14513/14513 [01:42<00:00, 140.99it/s]


Epoch 19/24, Avg. Loss: 1.6426


Testing Epoch 19: 100%|██████████| 3629/3629 [00:15<00:00, 231.65it/s]


Epoch 19/24, RMSE on test set: 1.2803175717513893


Epoch 20/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.82it/s]


Epoch 20/24, Avg. Loss: 1.6426


Testing Epoch 20: 100%|██████████| 3629/3629 [00:15<00:00, 233.30it/s]


Epoch 20/24, RMSE on test set: 1.2803206400640315


Epoch 21/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.54it/s]


Epoch 21/24, Avg. Loss: 1.6426


Testing Epoch 21: 100%|██████████| 3629/3629 [00:15<00:00, 230.93it/s]


Epoch 21/24, RMSE on test set: 1.2803297633844934


Epoch 22/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.25it/s]


Epoch 22/24, Avg. Loss: 1.6426


Testing Epoch 22: 100%|██████████| 3629/3629 [00:15<00:00, 230.52it/s]


Epoch 22/24, RMSE on test set: 1.2803531337466083


Epoch 23/24: 100%|██████████| 14513/14513 [01:42<00:00, 141.37it/s]


Epoch 23/24, Avg. Loss: 1.6426


Testing Epoch 23: 100%|██████████| 3629/3629 [00:15<00:00, 230.27it/s]


Epoch 23/24, RMSE on test set: 1.2803371252925089


Epoch 24/24: 100%|██████████| 14513/14513 [01:43<00:00, 140.75it/s]


Epoch 24/24, Avg. Loss: 1.6425


Testing Epoch 24: 100%|██████████| 3629/3629 [00:15<00:00, 233.82it/s]

Epoch 24/24, RMSE on test set: 1.280338415472186





In [13]:
train_embedding(256,15)

Epoch 1/15: 100%|██████████| 14513/14513 [02:26<00:00, 98.82it/s] 


Epoch 1/15, Avg. Loss: 1.3935


Testing Epoch 1: 100%|██████████| 3629/3629 [00:15<00:00, 228.92it/s]


Epoch 1/15, RMSE on test set: 1.1567973015180522


Epoch 2/15: 100%|██████████| 14513/14513 [02:30<00:00, 96.50it/s]


Epoch 2/15, Avg. Loss: 1.3244


Testing Epoch 2: 100%|██████████| 3629/3629 [00:15<00:00, 228.09it/s]


Epoch 2/15, RMSE on test set: 1.1523342730283166


Epoch 3/15: 100%|██████████| 14513/14513 [02:30<00:00, 96.60it/s]


Epoch 3/15, Avg. Loss: 1.3115


Testing Epoch 3: 100%|██████████| 3629/3629 [00:15<00:00, 227.91it/s]


Epoch 3/15, RMSE on test set: 1.14773692535678


Epoch 4/15: 100%|██████████| 14513/14513 [02:29<00:00, 96.83it/s]


Epoch 4/15, Avg. Loss: 1.2928


Testing Epoch 4: 100%|██████████| 3629/3629 [00:15<00:00, 228.14it/s]


Epoch 4/15, RMSE on test set: 1.1416318633487716


Epoch 5/15: 100%|██████████| 14513/14513 [02:30<00:00, 96.60it/s]


Epoch 5/15, Avg. Loss: 1.2709


Testing Epoch 5: 100%|██████████| 3629/3629 [00:15<00:00, 230.34it/s]


Epoch 5/15, RMSE on test set: 1.1387877875798105


Epoch 6/15: 100%|██████████| 14513/14513 [02:30<00:00, 96.57it/s]


Epoch 6/15, Avg. Loss: 1.2522


Testing Epoch 6: 100%|██████████| 3629/3629 [00:15<00:00, 226.84it/s]


Epoch 6/15, RMSE on test set: 1.1343596584981313


Epoch 7/15: 100%|██████████| 14513/14513 [02:30<00:00, 96.59it/s]


Epoch 7/15, Avg. Loss: 1.2337


Testing Epoch 7: 100%|██████████| 3629/3629 [00:15<00:00, 230.74it/s]


Epoch 7/15, RMSE on test set: 1.132007221843524


Epoch 8/15: 100%|██████████| 14513/14513 [02:30<00:00, 96.65it/s] 


Epoch 8/15, Avg. Loss: 1.2139


Testing Epoch 8: 100%|██████████| 3629/3629 [00:15<00:00, 229.10it/s]


Epoch 8/15, RMSE on test set: 1.1276046874639563


Epoch 9/15: 100%|██████████| 14513/14513 [02:29<00:00, 96.94it/s]


Epoch 9/15, Avg. Loss: 1.1936


Testing Epoch 9: 100%|██████████| 3629/3629 [00:15<00:00, 229.06it/s]


Epoch 9/15, RMSE on test set: 1.1239433643186096


Epoch 10/15: 100%|██████████| 14513/14513 [02:29<00:00, 96.76it/s]


Epoch 10/15, Avg. Loss: 1.1730


Testing Epoch 10: 100%|██████████| 3629/3629 [00:15<00:00, 230.30it/s]


Epoch 10/15, RMSE on test set: 1.1255443342335298


Epoch 11/15: 100%|██████████| 14513/14513 [02:29<00:00, 96.90it/s]


Epoch 11/15, Avg. Loss: 1.1520


Testing Epoch 11: 100%|██████████| 3629/3629 [00:15<00:00, 231.77it/s]


Epoch 11/15, RMSE on test set: 1.1205889624097933


Epoch 12/15: 100%|██████████| 14513/14513 [02:30<00:00, 96.61it/s]


Epoch 12/15, Avg. Loss: 1.1305


Testing Epoch 12: 100%|██████████| 3629/3629 [00:15<00:00, 230.68it/s]


Epoch 12/15, RMSE on test set: 1.1187881397081527


Epoch 13/15: 100%|██████████| 14513/14513 [02:29<00:00, 96.89it/s] 


Epoch 13/15, Avg. Loss: 1.1092


Testing Epoch 13: 100%|██████████| 3629/3629 [00:15<00:00, 229.72it/s]


Epoch 13/15, RMSE on test set: 1.1196608687180727


Epoch 14/15: 100%|██████████| 14513/14513 [02:30<00:00, 96.57it/s]


Epoch 14/15, Avg. Loss: 1.0877


Testing Epoch 14: 100%|██████████| 3629/3629 [00:15<00:00, 228.10it/s]


Epoch 14/15, RMSE on test set: 1.1196297129425523


Epoch 15/15: 100%|██████████| 14513/14513 [02:29<00:00, 96.90it/s] 


Epoch 15/15, Avg. Loss: 1.0670


Testing Epoch 15: 100%|██████████| 3629/3629 [00:16<00:00, 226.34it/s]


Epoch 15/15, RMSE on test set: 1.120662276228274


In [14]:
train_embedding(512,15)

Epoch 1/15: 100%|██████████| 14513/14513 [03:33<00:00, 67.86it/s]


Epoch 1/15, Avg. Loss: 1.3908


Testing Epoch 1: 100%|██████████| 3629/3629 [00:15<00:00, 228.83it/s]


Epoch 1/15, RMSE on test set: 1.1569401327845676


Epoch 2/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.61it/s]


Epoch 2/15, Avg. Loss: 1.3246


Testing Epoch 2: 100%|██████████| 3629/3629 [00:15<00:00, 228.88it/s]


Epoch 2/15, RMSE on test set: 1.1517807784855933


Epoch 3/15: 100%|██████████| 14513/14513 [03:39<00:00, 66.17it/s]


Epoch 3/15, Avg. Loss: 1.3029


Testing Epoch 3: 100%|██████████| 3629/3629 [00:15<00:00, 234.74it/s]


Epoch 3/15, RMSE on test set: 1.1431134920157646


Epoch 4/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.77it/s]


Epoch 4/15, Avg. Loss: 1.2793


Testing Epoch 4: 100%|██████████| 3629/3629 [00:15<00:00, 230.69it/s]


Epoch 4/15, RMSE on test set: 1.1390371211963237


Epoch 5/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.67it/s]


Epoch 5/15, Avg. Loss: 1.2587


Testing Epoch 5: 100%|██████████| 3629/3629 [00:15<00:00, 231.32it/s]


Epoch 5/15, RMSE on test set: 1.1342706166950958


Epoch 6/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.66it/s]


Epoch 6/15, Avg. Loss: 1.2380


Testing Epoch 6: 100%|██████████| 3629/3629 [00:15<00:00, 227.56it/s]


Epoch 6/15, RMSE on test set: 1.1338857243225617


Epoch 7/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.70it/s]


Epoch 7/15, Avg. Loss: 1.2154


Testing Epoch 7: 100%|██████████| 3629/3629 [00:15<00:00, 227.91it/s]


Epoch 7/15, RMSE on test set: 1.1288026744818036


Epoch 8/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.65it/s]


Epoch 8/15, Avg. Loss: 1.1913


Testing Epoch 8: 100%|██████████| 3629/3629 [00:15<00:00, 226.96it/s]


Epoch 8/15, RMSE on test set: 1.124738966379471


Epoch 9/15: 100%|██████████| 14513/14513 [03:38<00:00, 66.56it/s]


Epoch 9/15, Avg. Loss: 1.1666


Testing Epoch 9: 100%|██████████| 3629/3629 [00:16<00:00, 224.90it/s]


Epoch 9/15, RMSE on test set: 1.122235376827447


Epoch 10/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.74it/s]


Epoch 10/15, Avg. Loss: 1.1415


Testing Epoch 10: 100%|██████████| 3629/3629 [00:15<00:00, 228.85it/s]


Epoch 10/15, RMSE on test set: 1.1195787250389535


Epoch 11/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.61it/s]


Epoch 11/15, Avg. Loss: 1.1164


Testing Epoch 11: 100%|██████████| 3629/3629 [00:15<00:00, 228.80it/s]


Epoch 11/15, RMSE on test set: 1.1186011809100993


Epoch 12/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.69it/s]


Epoch 12/15, Avg. Loss: 1.0915


Testing Epoch 12: 100%|██████████| 3629/3629 [00:15<00:00, 229.38it/s]


Epoch 12/15, RMSE on test set: 1.119667766352682


Epoch 13/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.59it/s]


Epoch 13/15, Avg. Loss: 1.0678


Testing Epoch 13: 100%|██████████| 3629/3629 [00:15<00:00, 228.23it/s]


Epoch 13/15, RMSE on test set: 1.12150937266194


Epoch 14/15: 100%|██████████| 14513/14513 [03:37<00:00, 66.58it/s]


Epoch 14/15, Avg. Loss: 1.0453


Testing Epoch 14: 100%|██████████| 3629/3629 [00:15<00:00, 230.04it/s]


Epoch 14/15, RMSE on test set: 1.119627782748126


Epoch 15/15: 100%|██████████| 14513/14513 [03:38<00:00, 66.53it/s]


Epoch 15/15, Avg. Loss: 1.0278


Testing Epoch 15: 100%|██████████| 3629/3629 [00:15<00:00, 229.55it/s]

Epoch 15/15, RMSE on test set: 1.127771929172428





In [15]:
train_embedding(768,15)

Epoch 1/15: 100%|██████████| 14513/14513 [04:47<00:00, 50.56it/s]


Epoch 1/15, Avg. Loss: 1.3895


Testing Epoch 1: 100%|██████████| 3629/3629 [00:15<00:00, 230.77it/s]


Epoch 1/15, RMSE on test set: 1.1560923084702708


Epoch 2/15: 100%|██████████| 14513/14513 [04:50<00:00, 50.04it/s]


Epoch 2/15, Avg. Loss: 1.3250


Testing Epoch 2: 100%|██████████| 3629/3629 [00:15<00:00, 229.57it/s]


Epoch 2/15, RMSE on test set: 1.150080516471623


Epoch 3/15: 100%|██████████| 14513/14513 [04:50<00:00, 49.95it/s]


Epoch 3/15, Avg. Loss: 1.3010


Testing Epoch 3: 100%|██████████| 3629/3629 [00:15<00:00, 232.59it/s]


Epoch 3/15, RMSE on test set: 1.1435084590552862


Epoch 4/15: 100%|██████████| 14513/14513 [04:50<00:00, 50.01it/s]


Epoch 4/15, Avg. Loss: 1.2755


Testing Epoch 4: 100%|██████████| 3629/3629 [00:15<00:00, 230.93it/s]


Epoch 4/15, RMSE on test set: 1.1394366839344108


Epoch 5/15: 100%|██████████| 14513/14513 [04:50<00:00, 49.94it/s]


Epoch 5/15, Avg. Loss: 1.2532


Testing Epoch 5: 100%|██████████| 3629/3629 [00:15<00:00, 231.63it/s]


Epoch 5/15, RMSE on test set: 1.1340000169430002


Epoch 6/15: 100%|██████████| 14513/14513 [04:50<00:00, 49.99it/s]


Epoch 6/15, Avg. Loss: 1.2318


Testing Epoch 6: 100%|██████████| 3629/3629 [00:15<00:00, 230.56it/s]


Epoch 6/15, RMSE on test set: 1.1289036892323845


Epoch 7/15: 100%|██████████| 14513/14513 [04:50<00:00, 50.01it/s]


Epoch 7/15, Avg. Loss: 1.2106


Testing Epoch 7: 100%|██████████| 3629/3629 [00:15<00:00, 231.06it/s]


Epoch 7/15, RMSE on test set: 1.1291618943981852


Epoch 8/15: 100%|██████████| 14513/14513 [04:50<00:00, 49.96it/s]


Epoch 8/15, Avg. Loss: 1.1897


Testing Epoch 8: 100%|██████████| 3629/3629 [00:15<00:00, 230.51it/s]


Epoch 8/15, RMSE on test set: 1.1238287719776034


Epoch 9/15: 100%|██████████| 14513/14513 [04:50<00:00, 50.01it/s]


Epoch 9/15, Avg. Loss: 1.1690


Testing Epoch 9: 100%|██████████| 3629/3629 [00:15<00:00, 228.23it/s]


Epoch 9/15, RMSE on test set: 1.1207072090052692


Epoch 10/15: 100%|██████████| 14513/14513 [04:50<00:00, 50.00it/s]


Epoch 10/15, Avg. Loss: 1.1482


Testing Epoch 10: 100%|██████████| 3629/3629 [00:15<00:00, 227.49it/s]


Epoch 10/15, RMSE on test set: 1.119987068171083


Epoch 11/15: 100%|██████████| 14513/14513 [04:50<00:00, 49.96it/s]


Epoch 11/15, Avg. Loss: 1.1276


Testing Epoch 11: 100%|██████████| 3629/3629 [00:15<00:00, 227.51it/s]


Epoch 11/15, RMSE on test set: 1.1172914651053485


Epoch 12/15: 100%|██████████| 14513/14513 [04:50<00:00, 50.01it/s]


Epoch 12/15, Avg. Loss: 1.1077


Testing Epoch 12: 100%|██████████| 3629/3629 [00:15<00:00, 230.40it/s]


Epoch 12/15, RMSE on test set: 1.1193985496026186


Epoch 13/15: 100%|██████████| 14513/14513 [04:50<00:00, 49.92it/s]


Epoch 13/15, Avg. Loss: 1.0879


Testing Epoch 13: 100%|██████████| 3629/3629 [00:15<00:00, 230.65it/s]


Epoch 13/15, RMSE on test set: 1.116467051483191


Epoch 14/15: 100%|██████████| 14513/14513 [04:51<00:00, 49.86it/s]


Epoch 14/15, Avg. Loss: 1.0682


Testing Epoch 14: 100%|██████████| 3629/3629 [00:15<00:00, 230.33it/s]


Epoch 14/15, RMSE on test set: 1.1201817049666423


Epoch 15/15: 100%|██████████| 14513/14513 [04:50<00:00, 49.97it/s]


Epoch 15/15, Avg. Loss: 1.0496


Testing Epoch 15: 100%|██████████| 3629/3629 [00:15<00:00, 232.53it/s]

Epoch 15/15, RMSE on test set: 1.1186582399967575



