# A Transformer-based recommendation system


## Introduction


This example demonstrates the [Behavior Sequence Transformer (BST) model, by Qiwei Chen et al.](https://arxiv.org/abs/1905.06874), using the Movielens dataset. The BST model leverages the sequential behaviour of the users in watching and rating movies, as well as user profile and movie features, to predict the rating of the user to a target movie.

More precisely, the BST model aims to predict the rating of a target movie by accepting the following inputs:

1. A fixed-length sequence of movie_ids watched by a user.
2. A fixed-length sequence of the ratings for the movies watched by a user.
3. A set of user features, including user_id, sex, occupation, and age_group.
4. A set of genres for each movie in the input sequence and the target movie.
5. A target_movie_id for which to predict the rating.

This example modifies the original BST model in the following ways:

1. We incorporate the movie features (genres) into the processing of the embedding of each movie of the input sequence and the target movie, rather than treating them as "other features" outside the transformer layer.
2. We utilize the ratings of movies in the input sequence, along with the their positions in the sequence, to update them before feeding them into the self-attention layer.

<img src='./images/BST1.png' width='800'>

## Setup

In [1]:
import pandas as pd
import numpy as np
from zipfile import ZipFile
import random
import os
from urllib import request
import sys
import time
from math import sqrt
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
import torchmetrics

from pathlib import Path
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    np.random.seed(random_seed)
    random.seed(random_seed)

set_seed(72)

## The Datasets

We use the 1M version of the Movielens dataset. The dataset includes around 1 million ratings from 6000 users on 4000 movies, along with some user features, movie genres. In addition, the timestamp of each user-movie rating is provided, which allows creating sequences of movie ratings for each user, as expected by the BST model.

In [4]:
# Download the actual data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip"
# Use the ratings.csv file
annotation_folder = os.path.abspath(".") + "/dataset/"
file_name = "ml-1m.zip"
url = "http://files.grouplens.org/datasets/movielens/ml-1m.zip"
savefile = os.path.join(annotation_folder, file_name)

if not os.path.exists(annotation_folder):
    os.makedirs(annotation_folder)
    print('Folder creation complete!')
else:
    print('The folder already exists.')
    
if not os.path.isfile(savefile):
    request.urlretrieve(url, savefile)
    with ZipFile(savefile, 'r') as zip_ref:
        zip_ref.extractall(annotation_folder)
    print('File creation complete!')
else:
    print('The File already exists.')

The folder already exists.
The File already exists.


In [5]:
users = pd.read_csv(
    os.path.join(annotation_folder, "ml-1m/users.dat"),
    sep="::",
    names=["user_id", "sex", "age_group", "occupation", "zip_code"],
    engine='python'
)

ratings = pd.read_csv(
    os.path.join(annotation_folder, "ml-1m/ratings.dat"),
    sep="::",
    names=["user_id", "movie_id", "rating", "unix_timestamp"],
    engine='python'
)

movies = pd.read_csv(
    os.path.join(annotation_folder, "ml-1m/movies.dat"), 
    sep="::", 
    names=["movie_id", "title", "genres"],
    engine='python'
)

In [6]:
users, ratings, movies

(      user_id sex  age_group  occupation zip_code
 0           1   F          1          10    48067
 1           2   M         56          16    70072
 2           3   M         25          15    55117
 3           4   M         45           7    02460
 4           5   M         25          20    55455
 ...       ...  ..        ...         ...      ...
 6035     6036   F         25          15    32603
 6036     6037   F         45           1    76006
 6037     6038   F         56           1    14706
 6038     6039   F         45           0    01060
 6039     6040   M         25           6    11106
 
 [6040 rows x 5 columns],          user_id  movie_id  rating  unix_timestamp
 0              1      1193       5       978300760
 1              1       661       3       978302109
 2              1       914       3       978301968
 3              1      3408       4       978300275
 4              1      2355       5       978824291
 ...          ...       ...     ...             .

In [7]:
## Movies
movies["year"] = movies["title"].apply(lambda x: x[-5:-1])
movies.year = pd.Categorical(movies.year)
movies["year"] = movies.year.cat.codes
## Users
users.sex = pd.Categorical(users.sex)
users["sex"] = users.sex.cat.codes

users.age_group = pd.Categorical(users.age_group)
users["age_group"] = users.age_group.cat.codes

users.occupation = pd.Categorical(users.occupation)
users["occupation"] = users.occupation.cat.codes

users.zip_code = pd.Categorical(users.zip_code)
users["zip_code"] = users.zip_code.cat.codes

## Ratings
ratings['unix_timestamp'] = pd.to_datetime(ratings['unix_timestamp'],unit='s')


In [8]:
users, ratings, movies

(      user_id  sex  age_group  occupation  zip_code
 0           1    0          0          10      1588
 1           2    1          6          16      2248
 2           3    1          2          15      1863
 3           4    1          4           7       140
 4           5    1          2          20      1938
 ...       ...  ...        ...         ...       ...
 6035     6036    0          2          15      1152
 6036     6037    0          4           1      2367
 6037     6038    0          6           1       626
 6038     6039    0          4           0        13
 6039     6040    1          2           6       466
 
 [6040 rows x 5 columns],
          user_id  movie_id  rating      unix_timestamp
 0              1      1193       5 2000-12-31 22:12:40
 1              1       661       3 2000-12-31 22:35:09
 2              1       914       3 2000-12-31 22:32:48
 3              1      3408       4 2000-12-31 22:04:35
 4              1      2355       5 2001-01-06 23:38:11


In [9]:
# Save primary csv's
users.to_csv("dataset/ml-1m/users.csv",index=False)
movies.to_csv("dataset/ml-1m/movies.csv",index=False)
ratings.to_csv("dataset/ml-1m/ratings.csv",index=False)

In [10]:
## Movies
movies["movie_id"] = movies["movie_id"].astype(str)
## Users
users["user_id"] = users["user_id"].astype(str)

## Ratings 
ratings["movie_id"] = ratings["movie_id"].astype(str)
ratings["user_id"] = ratings["user_id"].astype(str)

In [11]:
genres = [
    "Action",
    "Adventure",
    "Animation",
    "Children's",
    "Comedy",
    "Crime",
    "Documentary",
    "Drama",
    "Fantasy",
    "Film-Noir",
    "Horror",
    "Musical",
    "Mystery",
    "Romance",
    "Sci-Fi",
    "Thriller",
    "War",
    "Western",
]

for genre in genres:
    movies[genre] = movies["genres"].apply(
        lambda values: int(genre in values.split("|"))
    )

In [12]:
movies

Unnamed: 0,movie_id,title,genres,year,Action,Adventure,Animation,Children's,Comedy,Crime,...,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western
0,1,Toy Story (1995),Animation|Children's|Comedy,75,0,0,1,1,1,0,...,0,0,0,0,0,0,0,0,0,0
1,2,Jumanji (1995),Adventure|Children's|Fantasy,75,0,1,0,1,0,0,...,1,0,0,0,0,0,0,0,0,0
2,3,Grumpier Old Men (1995),Comedy|Romance,75,0,0,0,0,1,0,...,0,0,0,0,0,1,0,0,0,0
3,4,Waiting to Exhale (1995),Comedy|Drama,75,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
4,5,Father of the Bride Part II (1995),Comedy,75,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3878,3948,Meet the Parents (2000),Comedy,80,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
3879,3949,Requiem for a Dream (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3880,3950,Tigerland (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3881,3951,Two Family House (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


## Transform the movie ratings data into sequences

First, let's sort the the ratings data using the unix_timestamp, and then group the movie_id values and the rating values by user_id.

The output DataFrame will have a record for each user_id, with two ordered lists (sorted by rating datetime): the movies they have rated, and their ratings of these movies.

In [13]:
ratings_group = ratings.sort_values(by=["unix_timestamp"]).groupby("user_id")

ratings_data = pd.DataFrame(
    data={
        "user_id": list(ratings_group.groups.keys()),
        "movie_ids": list(ratings_group.movie_id.apply(list)),
        "ratings": list(ratings_group.rating.apply(list)),
        "timestamps": list(ratings_group.unix_timestamp.apply(list)),
    }
)


In [14]:
ratings_data

Unnamed: 0,user_id,movie_ids,ratings,timestamps
0,1,"[3186, 1721, 1270, 1022, 2340, 1836, 3408, 120...","[4, 4, 5, 5, 3, 5, 4, 4, 5, 4, 3, 5, 4, 4, 4, ...","[2000-12-31 22:00:19, 2000-12-31 22:00:55, 200..."
1,10,"[597, 858, 743, 1210, 1948, 2312, 3751, 1282, ...","[4, 3, 3, 4, 4, 5, 5, 5, 3, 3, 3, 5, 4, 4, 4, ...","[2000-12-31 00:59:35, 2000-12-31 00:59:35, 200..."
2,100,"[260, 1676, 1198, 541, 1210, 3948, 3536, 2567,...","[4, 3, 4, 3, 4, 3, 1, 1, 5, 4, 4, 3, 2, 3, 4, ...","[2000-12-23 17:46:35, 2000-12-23 17:46:35, 200..."
3,1000,"[971, 260, 2990, 2973, 1210, 3068, 3153, 1198,...","[4, 5, 4, 3, 5, 5, 2, 5, 5, 4, 5, 4, 3, 5, 5, ...","[2000-11-24 04:36:06, 2000-11-24 04:36:06, 200..."
4,1001,"[1198, 1617, 2885, 3909, 3555, 1479, 3903, 394...","[4, 4, 4, 2, 2, 1, 4, 5, 5, 4, 4, 4, 4, 3, 4, ...","[2000-11-24 04:19:51, 2000-11-24 04:21:42, 200..."
...,...,...,...,...
6035,995,"[1894, 260, 247, 433, 170, 74, 912, 3097, 1265...","[2, 4, 5, 3, 3, 4, 4, 4, 3, 5, 5, 5, 5, 5, 5, ...","[2000-11-24 08:33:05, 2000-11-24 08:33:05, 200..."
6036,996,"[1347, 2146, 1961, 2741, 1210, 527, 1196, 1213...","[4, 3, 5, 3, 5, 5, 5, 5, 4, 2, 5, 5, 5, 4, 5, ...","[2000-11-24 07:48:52, 2000-11-24 07:48:52, 200..."
6037,997,"[1196, 2082, 3247, 2447, 2633, 2028, 593, 318,...","[4, 3, 3, 3, 2, 5, 5, 5, 4, 4, 5, 4, 4, 3, 4, ...","[2000-11-24 05:37:15, 2000-11-24 05:40:25, 200..."
6038,998,"[2266, 1264, 1097, 1641, 805, 1388, 1968, 3751...","[3, 4, 5, 5, 4, 3, 4, 3, 4, 4, 4, 4, 5, 5, 4, ...","[2000-11-24 05:24:59, 2000-11-24 05:26:33, 200..."


Now, let's split the movie_ids list into a set of sequences of a fixed length. We do the same for the ratings. Set the sequence_length variable to change the length of the input sequence to the model. You can also change the step_size to control the number of sequences to generate for each user.

In [15]:
sequence_length = 4
step_size = 2


def create_sequences(values, window_size, step_size):
    sequences = []
    start_index = 0
    while True:
        end_index = start_index + window_size
        seq = values[start_index:end_index]
        if len(seq) < window_size:
            seq = values[-window_size:]
            if len(seq) == window_size:
                sequences.append(seq)
            break
        sequences.append(seq)
        start_index += step_size
    return sequences


ratings_data.movie_ids = ratings_data.movie_ids.apply(
    lambda ids: create_sequences(ids, sequence_length, step_size)
)

ratings_data.ratings = ratings_data.ratings.apply(
    lambda ids: create_sequences(ids, sequence_length, step_size)
)

del ratings_data["timestamps"]

In [16]:
ratings_data

Unnamed: 0,user_id,movie_ids,ratings
0,1,"[[3186, 1721, 1270, 1022], [1270, 1022, 2340, ...","[[4, 4, 5, 5], [5, 5, 3, 5], [3, 5, 4, 4], [4,..."
1,10,"[[597, 858, 743, 1210], [743, 1210, 1948, 2312...","[[4, 3, 3, 4], [3, 4, 4, 5], [4, 5, 5, 5], [5,..."
2,100,"[[260, 1676, 1198, 541], [1198, 541, 1210, 394...","[[4, 3, 4, 3], [4, 3, 4, 3], [4, 3, 1, 1], [1,..."
3,1000,"[[971, 260, 2990, 2973], [2990, 2973, 1210, 30...","[[4, 5, 4, 3], [4, 3, 5, 5], [5, 5, 2, 5], [2,..."
4,1001,"[[1198, 1617, 2885, 3909], [2885, 3909, 3555, ...","[[4, 4, 4, 2], [4, 2, 2, 1], [2, 1, 4, 5], [4,..."
...,...,...,...
6035,995,"[[1894, 260, 247, 433], [247, 433, 170, 74], [...","[[2, 4, 5, 3], [5, 3, 3, 4], [3, 4, 4, 4], [4,..."
6036,996,"[[1347, 2146, 1961, 2741], [1961, 2741, 1210, ...","[[4, 3, 5, 3], [5, 3, 5, 5], [5, 5, 5, 5], [5,..."
6037,997,"[[1196, 2082, 3247, 2447], [3247, 2447, 2633, ...","[[4, 3, 3, 3], [3, 3, 2, 5], [2, 5, 5, 5], [5,..."
6038,998,"[[2266, 1264, 1097, 1641], [1097, 1641, 805, 1...","[[3, 4, 5, 5], [5, 5, 4, 3], [4, 3, 4, 3], [4,..."


After that, we process the output to have each sequence in a separate records in the DataFrame. In addition, we join the user features with the ratings data.

In [17]:
ratings_data_movies = ratings_data[["user_id", "movie_ids"]].explode(
    "movie_ids", ignore_index=True
)
ratings_data_rating = ratings_data[["ratings"]].explode("ratings", ignore_index=True)
ratings_data_transformed = pd.concat([ratings_data_movies, ratings_data_rating], axis=1)
ratings_data_transformed = ratings_data_transformed.join(
    users.set_index("user_id"), on="user_id"
)
ratings_data_transformed.movie_ids = ratings_data_transformed.movie_ids.apply(
    lambda x: ",".join(x)
)
ratings_data_transformed.ratings = ratings_data_transformed.ratings.apply(
    lambda x: ",".join([str(v) for v in x])
)

del ratings_data_transformed["zip_code"]

ratings_data_transformed.rename(
    columns={"movie_ids": "sequence_movie_ids", "ratings": "sequence_ratings"},
    inplace=True,
)

In [18]:
ratings_data_transformed

Unnamed: 0,user_id,sequence_movie_ids,sequence_ratings,sex,age_group,occupation
0,1,3186172112701022,4455,0,0,10
1,1,1270102223401836,5535,0,0,10
2,1,2340183634081207,3544,0,0,10
3,1,340812072804260,4454,0,0,10
4,1,28042607201193,5435,0,0,10
...,...,...,...,...,...,...
498618,999,267625401363765,3233,1,2,15
498619,999,136376535651410,3342,1,2,15
498620,999,3565141022692504,4233,1,2,15
498621,999,22692504455193,3322,1,2,15


With sequence_length of 4 and step_size of 2, we end up with 498,623 sequences.

Finally, we split the data into training and testing splits, with 85% and 15% of the instances, respectively, and store them to CSV files.

In [19]:
random_selection = np.random.rand(len(ratings_data_transformed.index)) <= 0.85
train_data = ratings_data_transformed[random_selection]
test_data = ratings_data_transformed[~random_selection]

train_data.to_csv("dataset/ml-1m/train_data.csv", index=False, sep=",")
test_data.to_csv("dataset/ml-1m/test_data.csv", index=False, sep=",")

In [20]:
train_data

Unnamed: 0,user_id,sequence_movie_ids,sequence_ratings,sex,age_group,occupation
0,1,3186172112701022,4455,0,0,10
1,1,1270102223401836,5535,0,0,10
2,1,2340183634081207,3544,0,0,10
3,1,340812072804260,4454,0,0,10
4,1,28042607201193,5435,0,0,10
...,...,...,...,...,...,...
498618,999,267625401363765,3233,1,2,15
498619,999,136376535651410,3342,1,2,15
498620,999,3565141022692504,4233,1,2,15
498621,999,22692504455193,3322,1,2,15


In [21]:
test_data

Unnamed: 0,user_id,sequence_movie_ids,sequence_ratings,sex,age_group,occupation
11,1,20181501097914,4543,0,0,10
33,10,1544213536752657,4454,0,3,1
55,10,2997102832963702,4543,0,3,1
62,10,367114112324150,4455,0,3,1
65,10,3194127133581580,4555,0,3,1
...,...,...,...,...,...,...
498557,999,2932333252313,3233,1,2,15
498580,999,242037011810431,4233,1,2,15
498593,999,1598317434092120,2443,1,2,15
498601,999,24471515724266,3124,1,2,15


## Create Dataloder

In [22]:
users = pd.read_csv(
    "dataset/ml-1m/users.csv",
    sep=",",
    engine='python'
)

ratings = pd.read_csv(
    "dataset/ml-1m/ratings.csv",
    sep=",",
    engine='python'
)

movies = pd.read_csv(
    "dataset/ml-1m/movies.csv", 
    sep=",", 
    engine='python'
)

In [23]:
genres = [
    "Action",
    "Adventure",
    "Animation",
    "Children's",
    "Comedy",
    "Crime",
    "Documentary",
    "Drama",
    "Fantasy",
    "Film-Noir",
    "Horror",
    "Musical",
    "Mystery",
    "Romance",
    "Sci-Fi",
    "Thriller",
    "War",
    "Western",
]

for genre in genres:
    movies[genre] = movies["genres"].apply(
        lambda values: int(genre in values.split("|"))
    )

In [24]:
class CustomDataset(Dataset): 
    def __init__(self, file):
        self.dataset = pd.read_csv(file, delimiter=',')

    def __len__(self): 
        return len(self.dataset)

    def __getitem__(self, index): 
        features = self.dataset.iloc[index]
        user_id = features.user_id
        
        movie_history = eval(features.sequence_movie_ids)
        movie_history_ratings = eval(features.sequence_ratings)
        target_movie_id = movie_history[-1:][0]
        target_movie_rating = movie_history_ratings[-1:][0]
        
        movie_history = torch.LongTensor(movie_history[:-1])
        movie_history_ratings = torch.LongTensor(movie_history_ratings[:-1])
        
        sex = features.sex
        age_group = features.age_group
        occupation = features.occupation
        
        return user_id, movie_history, target_movie_id,  movie_history_ratings, target_movie_rating, sex, age_group, occupation

In [25]:
params = {'batch_size': 128,
          'shuffle': True,
          'num_workers': os.cpu_count(),
          'pin_memory' : True}

train_dataset = CustomDataset("dataset/ml-1m/train_data.csv")
val_dataset = CustomDataset("dataset/ml-1m/test_data.csv")
test_dataset = CustomDataset("dataset/ml-1m/test_data.csv")

trainloader = DataLoader(dataset=train_dataset, **params)
validloader = DataLoader(dataset=val_dataset, **params)
params['shuffle'] = False
testloader = DataLoader(dataset=test_dataset, **params)

In [26]:
movies

Unnamed: 0,movie_id,title,genres,year,Action,Adventure,Animation,Children's,Comedy,Crime,...,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western
0,1,Toy Story (1995),Animation|Children's|Comedy,75,0,0,1,1,1,0,...,0,0,0,0,0,0,0,0,0,0
1,2,Jumanji (1995),Adventure|Children's|Fantasy,75,0,1,0,1,0,0,...,1,0,0,0,0,0,0,0,0,0
2,3,Grumpier Old Men (1995),Comedy|Romance,75,0,0,0,0,1,0,...,0,0,0,0,0,1,0,0,0,0
3,4,Waiting to Exhale (1995),Comedy|Drama,75,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
4,5,Father of the Bride Part II (1995),Comedy,75,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3878,3948,Meet the Parents (2000),Comedy,80,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
3879,3949,Requiem for a Dream (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3880,3950,Tigerland (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3881,3951,Two Family House (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


## BST Model

<img src='./images/BST1.png' width='800'>

In [27]:
class BST(nn.Module):
    def __init__(self, args=None,):
        super(BST, self).__init__()
        self.args = args
        # Embedding layers
        ##Users 
        self.embeddings_user_id = nn.Embedding(
            int(users.user_id.max())+1, int(math.sqrt(users.user_id.max()))+1
        )
        ###Users features embeddings
        self.embeddings_user_sex = nn.Embedding(
            len(users.sex.unique()), int(math.sqrt(len(users.sex.unique())))
        )
        self.embeddings_age_group = nn.Embedding(
            len(users.age_group.unique()), int(math.sqrt(len(users.age_group.unique())))
        )
        self.embeddings_user_occupation = nn.Embedding(
            len(users.occupation.unique()), int(math.sqrt(len(users.occupation.unique())))
        )
        self.embeddings_user_zip_code = nn.Embedding(
            len(users.zip_code.unique()), int(math.sqrt(len(users.sex.unique())))
        )
        
        ##Movies
        self.embeddings_movie_id = nn.Embedding(
            int(movies.movie_id.max())+1, int(math.sqrt(movies.movie_id.max()))+1
        )
        self.embeddings_position  = nn.Embedding(
           sequence_length, int(math.sqrt(len(movies.movie_id.unique())))+1
        )
        ###Movies features embeddings
        genre_vectors = movies[genres].to_numpy()
        self.embeddings_movie_genre = nn.Embedding(
            genre_vectors.shape[0], genre_vectors.shape[1]
        )
        
        self.embeddings_movie_genre.weight.requires_grad = False #Not training genres
        
        
        self.embeddings_movie_year = nn.Embedding(
            len(movies.year.unique()), int(math.sqrt(len(movies.year.unique())))
        )
        
        
        # Network
        self.transfomerlayer = nn.TransformerEncoderLayer(63, 3, dropout=0.2)
        self.linear = nn.Sequential(
            nn.Linear(337,1024,),
            nn.LeakyReLU(),
            nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
        )
        
    def encode_input(self,inputs):
        user_id, movie_history, target_movie_id,  movie_history_ratings, target_movie_rating, sex, age_group, occupation = inputs
        
        #MOVIES
        movie_history = self.embeddings_movie_id(movie_history)
        target_movie = self.embeddings_movie_id(target_movie_id)
        
        positions = torch.arange(0,sequence_length-1,1,dtype=int,device=self.args['device'])
        positions = self.embeddings_position(positions)
        
        encoded_sequence_movies_with_poistion_and_rating = (movie_history + positions) #Yet to multiply by rating
        
        target_movie = torch.unsqueeze(target_movie, 1)
        transfomer_features = torch.cat((encoded_sequence_movies_with_poistion_and_rating, target_movie),dim=1)
        
        #USERS
        user_id = self.embeddings_user_id(user_id)
        
        sex = self.embeddings_user_sex(sex)
        age_group = self.embeddings_age_group(age_group)
        occupation = self.embeddings_user_occupation(occupation)
        user_features = torch.cat((user_id, sex, age_group,occupation), 1)
        
        return transfomer_features, user_features, target_movie_rating.float()
    
    def forward(self, batch):
        transfomer_features, user_features, target_movie_rating = self.encode_input(batch)
        transformer_output = self.transfomerlayer(transfomer_features)
        transformer_output = torch.flatten(transformer_output,start_dim=1)
        
        #Concat with other features
        features = torch.cat((transformer_output,user_features),dim=1)
        
        output = self.linear(features)
        return output, target_movie_rating

In [28]:
model = BST({'device':device}).to(device)

In [29]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_fn = nn.MSELoss()
mae_fn = torchmetrics.MeanAbsoluteError()
mse_fn = torchmetrics.MeanSquaredError()

In [30]:
[i.shape for i in next(iter(trainloader))]

[torch.Size([128]),
 torch.Size([128, 3]),
 torch.Size([128]),
 torch.Size([128, 3]),
 torch.Size([128]),
 torch.Size([128]),
 torch.Size([128]),
 torch.Size([128])]

## Run training and evaluation experiment

In [31]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [32]:
def train(model, train_data, optimizer, loss_fn, use_fp16=True, max_norm=None, progress_display=False):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    mses = AverageMeter('MSE', ':.4e')
    maes = AverageMeter('MAE', ':.4e')
    rmses = AverageMeter('RMSE', ':.4e')
    progress = ProgressMeter(
        len(train_data),
        [batch_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    model.train()
    end = time.time()
    for idx, input in enumerate(train_data):
        optimizer.zero_grad(set_to_none=True)
        scaler = torch.cuda.amp.GradScaler()
        
#         input = input.to(device)
        input = [i.to(device) for i in input]
        
        with torch.cuda.amp.autocast(enabled=use_fp16):
            out, target_movie_rating = model(input)
            out = out.flatten()
            train_loss = loss_fn(out, target_movie_rating)
        if use_fp16:
            scaler.scale(train_loss).backward()
            if max_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            train_loss.backward()
            if max_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
            
        losses.update(train_loss.item(), input[0].size(0))
        maes.update(mae_fn(out.detach().cpu(), target_movie_rating.detach().cpu()), input[0].size(0))
        mses.update(mse_fn(out.detach().cpu(), target_movie_rating.detach().cpu()), input[0].size(0))
        rmses.update(torch.sqrt(mse_fn(out.detach().cpu(), target_movie_rating.detach().cpu())), input[0].size(0))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if progress_display == True and idx % 5 == 0:
            progress.display(idx)
        
    return losses.avg, maes.avg, mses.avg, rmses.avg

In [33]:
def validation(model, val_data, loss_fn):
    losses = AverageMeter('Loss', ':.4e')
    mses = AverageMeter('MSE', ':.4e')
    maes = AverageMeter('MAE', ':.4e')
    rmses = AverageMeter('RMSE', ':.4e')
    model.eval()
    val_loss = 0
    for idx, input in enumerate(val_data):
        input = [i.to(device) for i in input]
        with torch.no_grad():
            out, target_movie_rating = model(input)
            out = out.flatten()
            val_loss = loss_fn(out, target_movie_rating)
        losses.update(val_loss.item(), input[0].size(0))
        maes.update(mae_fn(out.detach().cpu(), target_movie_rating.detach().cpu()), input[0].size(0))
        mses.update(mse_fn(out.detach().cpu(), target_movie_rating.detach().cpu()), input[0].size(0))
        rmses.update(torch.sqrt(mse_fn(out.detach().cpu(), target_movie_rating.detach().cpu())), input[0].size(0))
        
    return losses.avg, maes.avg, mses.avg, rmses.avg

In [34]:
%%time
EPOCHS = 50
best = {"val_loss": sys.float_info.min}
history = dict()

for epoch in range(1, EPOCHS+1):
    epoch_loss, epoch_mae, epoch_mse, epoch_rmse = train(model, trainloader, optimizer, loss_fn, use_fp16=False)
    val_loss, val_mae, val_mse, val_rmse = validation(model, validloader, loss_fn)
    
    history.setdefault('loss', []).append(epoch_loss)
    history.setdefault('mae', []).append(epoch_mae) 
    history.setdefault('mse', []).append(epoch_mse)
    history.setdefault('rmse', []).append(epoch_rmse) 
    history.setdefault('val_loss', []).append(val_loss)
    history.setdefault('val_mae', []).append(val_mae) 
    history.setdefault('val_mse', []).append(val_mse) 
    history.setdefault('val_rmse', []).append(val_rmse) 
    
    print(f"[Train] Epoch : {epoch:^3}"\
        f"  Train Loss: {epoch_loss:.4}"\
        f"  Train MAE: {epoch_mae:.4}"\
        f"  Train MSE: {epoch_mse:.4}"\
        f"  Train RMSE: {epoch_rmse:.4}")
    print(f"[Validation] Valid Loss: {val_loss:.4}"\
        f"  Valid MAE: {val_mae:.4}"\
        f"  Valid MSE: {val_mse:.4}"\
        f"  Valid RMSE: {val_rmse:.4}"
         )


    if val_loss > best["val_loss"]:
            best["state"] = model.state_dict()
            best["epoch"] = epoch

[Train] Epoch :  1   Train Loss: 1.038  Train MAE: 0.8146  Train MSE: 1.038  Train RMSE: 1.012
[Validation] Valid Loss: 1.025  Valid MAE: 0.7812  Valid MSE: 1.025  Valid RMSE: 1.01
[Train] Epoch :  2   Train Loss: 0.8834  Train MAE: 0.7465  Train MSE: 0.8834  Train RMSE: 0.9379
[Validation] Valid Loss: 0.8829  Valid MAE: 0.7321  Valid MSE: 0.8829  Valid RMSE: 0.937
[Train] Epoch :  3   Train Loss: 0.8447  Train MAE: 0.7273  Train MSE: 0.8447  Train RMSE: 0.917
[Validation] Valid Loss: 0.8764  Valid MAE: 0.7267  Valid MSE: 0.8764  Valid RMSE: 0.9338
[Train] Epoch :  4   Train Loss: 0.8201  Train MAE: 0.7157  Train MSE: 0.8201  Train RMSE: 0.9035
[Validation] Valid Loss: 0.9006  Valid MAE: 0.7327  Valid MSE: 0.9006  Valid RMSE: 0.9465
[Train] Epoch :  5   Train Loss: 0.7989  Train MAE: 0.7054  Train MSE: 0.7989  Train RMSE: 0.8917
[Validation] Valid Loss: 0.8577  Valid MAE: 0.7206  Valid MSE: 0.8577  Valid RMSE: 0.9238
[Train] Epoch :  6   Train Loss: 0.776  Train MAE: 0.6948  Train MSE:

[Train] Epoch : 46   Train Loss: 0.05366  Train MAE: 0.1461  Train MSE: 0.05366  Train RMSE: 0.2299
[Validation] Valid Loss: 1.199  Valid MAE: 0.8408  Valid MSE: 1.199  Valid RMSE: 1.093
[Train] Epoch : 47   Train Loss: 0.0536  Train MAE: 0.1444  Train MSE: 0.0536  Train RMSE: 0.2299
[Validation] Valid Loss: 1.21  Valid MAE: 0.8345  Valid MSE: 1.21  Valid RMSE: 1.097
[Train] Epoch : 48   Train Loss: 0.05259  Train MAE: 0.1412  Train MSE: 0.05259  Train RMSE: 0.2275
[Validation] Valid Loss: 1.228  Valid MAE: 0.8426  Valid MSE: 1.228  Valid RMSE: 1.106
[Train] Epoch : 49   Train Loss: 0.05165  Train MAE: 0.1387  Train MSE: 0.05165  Train RMSE: 0.2254
[Validation] Valid Loss: 1.191  Valid MAE: 0.83  Valid MSE: 1.191  Valid RMSE: 1.089
[Train] Epoch : 50   Train Loss: 0.05135  Train MAE: 0.137  Train MSE: 0.05135  Train RMSE: 0.2247
[Validation] Valid Loss: 1.218  Valid MAE: 0.8394  Valid MSE: 1.218  Valid RMSE: 1.101
CPU times: user 37min 4s, sys: 4min 40s, total: 41min 45s
Wall time: 36m

## Inference

In [36]:
model.load_state_dict(best["state"])

<All keys matched successfully>

In [62]:
input = [i.to(device) for i in next(iter(validloader))]

with torch.no_grad():
    out, target_movie_rating = model(input)
    out = out.flatten()

for i in range(10):
    user_id = input[0][i].detach().cpu()
    target_movie_id = input[2][i].detach().cpu()
    target_movie_rating = input[4][i].detach().cpu()
    movie_detail = movies.query(f'movie_id=={target_movie_id}')[['title','genres']]
    print(f'User : {user_id}')
    print(f'Movie Title : {movie_detail["title"].values[0]}')
    print(f'Movie Genres : {movie_detail["genres"].values[0]}')
    print(f'Target rating : {target_movie_rating}, Predicted rating : {out[i].detach().cpu():.4}')
    print('-'*100)

User : 117
Movie Title : Predator 2 (1990)
Movie Genres : Action|Sci-Fi|Thriller
Target rating : 5, Predicted rating : 3.65
----------------------------------------------------------------------------------------------------
User : 3678
Movie Title : Pi (1998)
Movie Genres : Sci-Fi|Thriller
Target rating : 3, Predicted rating : 4.093
----------------------------------------------------------------------------------------------------
User : 3474
Movie Title : Creature Comforts (1990)
Movie Genres : Animation|Comedy
Target rating : 4, Predicted rating : 5.071
----------------------------------------------------------------------------------------------------
User : 5316
Movie Title : American Movie (1999)
Movie Genres : Documentary
Target rating : 5, Predicted rating : 5.098
----------------------------------------------------------------------------------------------------
User : 352
Movie Title : Illuminata (1998)
Movie Genres : Comedy
Target rating : 1, Predicted rating : 1.023
------

---