# Foursquare dataset next-POI Recommendation System

First off we import all the necessary libraries:

In [4]:
 %pip install lightning geohash2 wandb polars==0.20.25



In [5]:
 from google.colab import drive

 drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
import polars as rs
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as pl
import lightning.pytorch as torchpl
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import LabelEncoder
from dataclasses import dataclass
import wandb

In [7]:
import os

# define WANDB_NOTEBOOK_NAME
os.environ["WANDB_NOTEBOOK_NAME"] = "train.ipynb"

In [8]:
import gc

gc.collect()
# clean CUDA memory
torch.cuda.empty_cache()

# sometimes jupyter notebook does not release memory, we leave this here so a run-all
# can *sometimes* fix leaks

Next, we load the data, we utilize `polars` since it is much more efficient than `pandas` and can handle large datasets with ease.

In [9]:
columns = ["user", "poi", "date", "TZ"]
data = rs.read_csv(
    "/content/drive/MyDrive/dataset_TIST2015/dataset_TIST2015_Checkins.txt",
    has_header=False,
    low_memory=True,
    separator="\t",
)
data.columns = columns

In [10]:
data

user,poi,date,TZ
i64,str,str,i64
50756,"""4f5e3a72e4b053fd6a4313f6""","""Tue Apr 03 18:00:06 +0000 2012""",240
190571,"""4b4b87b5f964a5204a9f26e3""","""Tue Apr 03 18:00:07 +0000 2012""",180
221021,"""4a85b1b3f964a520eefe1fe3""","""Tue Apr 03 18:00:08 +0000 2012""",-240
66981,"""4b4606f2f964a520751426e3""","""Tue Apr 03 18:00:08 +0000 2012""",-300
21010,"""4c2b4e8a9a559c74832f0de2""","""Tue Apr 03 18:00:09 +0000 2012""",240
…,…,…,…
16349,"""4c957755c8a1bfb7e89024f3""","""Mon Sep 16 23:24:11 +0000 2013""",-240
256757,"""4c8bbb6d9ef0224bd2d6667b""","""Mon Sep 16 23:24:13 +0000 2013""",-180
66425,"""513e82a5e4b0ed4f0f3bcf2d""","""Mon Sep 16 23:24:14 +0000 2013""",-180
1830,"""4b447865f964a5204cf525e3""","""Mon Sep 16 23:24:14 +0000 2013""",120


Differently from what suggested by the professor, we utilize the full TIST2015 dataset, which has a far greater scale compared to the reduced NY one. However, by following the pruning steps detailed in the paper (http://dx.doi.org/10.1145/3477495.3531989, section 5.1), we obtain sequences that are much smaller in size, resulting in a dataset that is usable on Google Colab's free tier (as required by the assignment).

In [11]:
data_users = (
    data.lazy()
    .group_by("user")
    .agg(
        [
            rs.col("poi").n_unique().alias("n_pois"),
            rs.col("poi").count().alias("n_checkins"),
            # turn the rest into a list
            rs.col("poi").alias("pois"),
            rs.col("date").alias("dates"),
            rs.col("TZ").alias("TZs"),
        ]
    )
).collect()

In [12]:
data_users.describe()

statistic,user,n_pois,n_checkins,pois,dates,TZs
str,f64,f64,f64,f64,f64,f64
"""count""",266909.0,266909.0,266909.0,266909.0,266909.0,266909.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",133455.0,56.477459,124.62537,,,
"""std""",77050.135837,45.968603,140.692138,,,
"""min""",1.0,1.0,1.0,,,
"""25%""",66728.0,30.0,61.0,,,
"""50%""",133455.0,49.0,93.0,,,
"""75%""",200182.0,71.0,148.0,,,
"""max""",266909.0,1246.0,5430.0,,,


## Data Preprocessing

In [13]:
data_culled = data_users.filter(
    (rs.col("n_checkins") > 20) & (rs.col("n_checkins") < 50)
).drop_nulls()

Since the original dataset is huge, we delete it and call the python garbage collector to free up memory. We then proceed with the second pruning step (frequency-based pruning) as detailed in the paper.

In [14]:
del data
del data_users

import gc

gc.collect()

0

In [15]:
# print lengths

print(data_culled["pois"].list.len().min(), data_culled["pois"].list.len().max())

21 49


In [16]:
# extract unique elements from each lists in data_culled["pois"]
out = data_culled.with_columns(
    [
        rs.col("pois").list.unique(),
        rs.col("pois").list.unique().list.len().alias("n_unique_pois"),
    ]
)

In [17]:
out

user,n_pois,n_checkins,pois,dates,TZs,n_unique_pois
i64,u32,u32,list[str],list[str],list[i64],u32
98328,21,26,"[""4f9167dbe4b08d590747fea6"", ""4f9362f9e4b0cd2d539f6228"", … ""5001e67de4b023056d98fa66""]","[""Wed Oct 31 13:11:14 +0000 2012"", ""Sat Nov 17 17:48:38 +0000 2012"", … ""Tue Aug 20 15:07:05 +0000 2013""]","[-240, -240, … -240]",21
264666,34,44,"[""4e411f0a62e17b948c23a757"", ""4bcadc5468f976b000196083"", … ""4b55670ff964a52071e327e3""]","[""Sun May 27 12:22:45 +0000 2012"", ""Tue May 29 06:05:50 +0000 2012"", … ""Mon May 27 11:48:18 +0000 2013""]","[540, 540, … 540]",34
221671,30,33,"[""4e50440445ddff0031cce4d3"", ""4aa7a6c1f964a520e74c20e3"", … ""4b5c9ad1f964a520623929e3""]","[""Tue Jun 12 23:05:54 +0000 2012"", ""Tue Jun 12 23:07:11 +0000 2012"", … ""Mon Aug 12 22:20:28 +0000 2013""]","[-300, -300, … -300]",30
176589,19,27,"[""4ff60179e4b055f896988a44"", ""4d6799c7052ea1cd44f49b49"", … ""4efa47720e013b2128976189""]","[""Mon Apr 23 16:01:01 +0000 2012"", ""Mon Apr 23 20:09:08 +0000 2012"", … ""Tue Aug 27 22:54:15 +0000 2013""]","[-180, -180, … -180]",19
239374,34,42,"[""4b74fe05f964a52066fa2de3"", ""50d02692e4b0a6c92577c9a4"", … ""4b50d7dbf964a520dd3427e3""]","[""Fri Aug 10 19:12:25 +0000 2012"", ""Sat Sep 15 13:23:31 +0000 2012"", … ""Tue Aug 06 10:56:02 +0000 2013""]","[180, 180, … 180]",34
…,…,…,…,…,…,…
97512,22,30,"[""4b83685cf964a520880531e3"", ""4d4e127d5fb0b1f7cd1c7791"", … ""50b99c90e4b0f99bbe09991b""]","[""Sun Jan 06 00:37:10 +0000 2013"", ""Sun Jan 06 03:55:06 +0000 2013"", … ""Sat Sep 14 11:59:04 +0000 2013""]","[540, 540, … 540]",22
73655,25,47,"[""4c668d5de75ac928a4c0f7da"", ""4d9203f15f33b1f722517a7e"", … ""4e83d05de5e877f1e1751aa9""]","[""Fri Apr 06 16:38:29 +0000 2012"", ""Sun Apr 08 20:58:30 +0000 2012"", … ""Sun Sep 15 00:00:07 +0000 2013""]","[-240, -240, … -240]",25
113244,15,23,"[""4baa65f2f964a5205a663ae3"", ""4b5b4195f964a520dcee28e3"", … ""4b525236f964a520127727e3""]","[""Sat Apr 21 23:47:09 +0000 2012"", ""Sat Jun 16 23:38:50 +0000 2012"", … ""Thu Jul 04 17:39:27 +0000 2013""]","[-240, -240, … -240]",15
144178,34,44,"[""509b658ce4b0397ad7075fa7"", ""4f5da931e4b0695cbc2b9d8c"", … ""4cb1a6346c2695210887b8d9""]","[""Mon Jul 02 14:19:22 +0000 2012"", ""Mon Jul 02 19:13:05 +0000 2012"", … ""Fri Jun 07 21:53:50 +0000 2013""]","[180, 180, … 180]",34


In [18]:
l = out["pois"][0].to_list()
len(set(l))  # print number of unique POIs in first sequence

21

In [19]:
l2 = data_culled["pois"][0].to_list()
len(l2)  # print sequence length of first user

26

In [20]:
len(set(l2))  # confirm that the two match

21

In [21]:
# run a Polars query to obtain all the frequent POIs, the ones expected to survive the filtering
unique_pois = out["pois"]
frequent_pois = unique_pois.list.explode().value_counts().filter(rs.col("count") >= 10)

In [22]:
frequent_pois

pois,count
str,u32
"""4a367102f964a520869d1fe3""",16
"""4b505f21f964a520902127e3""",23
"""4b8bceedf964a52065ac32e3""",26
"""4df36e6d18a88611c6bc496e""",10
"""4c42030caf052d7f1d447e79""",182
…,…
"""4b0ce6b2f964a5207e4223e3""",25
"""4d0a931c6aa05481f324e5aa""",10
"""4d34236c2e56236a180a2cb4""",11
"""4c5e12552815c928b1dbb667""",10


In [23]:
frequent_pois = frequent_pois["pois"]
frequent_pois = set(frequent_pois.to_list())

In [24]:
data_culled

user,n_pois,n_checkins,pois,dates,TZs
i64,u32,u32,list[str],list[str],list[i64]
98328,21,26,"[""4fac332ce4b0346fb571b34a"", ""4e5ff6fccc3f621d78b87c46"", … ""4f9167dbe4b08d590747fea6""]","[""Wed Oct 31 13:11:14 +0000 2012"", ""Sat Nov 17 17:48:38 +0000 2012"", … ""Tue Aug 20 15:07:05 +0000 2013""]","[-240, -240, … -240]"
264666,34,44,"[""4bcadc5468f976b000196083"", ""4b586711f964a520255628e3"", … ""513059d8e4b054cc5dbbd425""]","[""Sun May 27 12:22:45 +0000 2012"", ""Tue May 29 06:05:50 +0000 2012"", … ""Mon May 27 11:48:18 +0000 2013""]","[540, 540, … 540]"
221671,30,33,"[""4f246345e4b0476579acdc9e"", ""4f246345e4b0476579acdc9e"", … ""4bd88ddc1671b7135afb765f""]","[""Tue Jun 12 23:05:54 +0000 2012"", ""Tue Jun 12 23:07:11 +0000 2012"", … ""Mon Aug 12 22:20:28 +0000 2013""]","[-300, -300, … -300]"
176589,19,27,"[""4b50f6f7f964a520e13a27e3"", ""4bd099931f89ce7272b867ea"", … ""4b50f6f7f964a520e13a27e3""]","[""Mon Apr 23 16:01:01 +0000 2012"", ""Mon Apr 23 20:09:08 +0000 2012"", … ""Tue Aug 27 22:54:15 +0000 2013""]","[-180, -180, … -180]"
239374,34,42,"[""4fcfae92e4b0ef77527d5432"", ""4c45dd45429a0f47a8ba4a1e"", … ""4d32fd806c7c721ee446b556""]","[""Fri Aug 10 19:12:25 +0000 2012"", ""Sat Sep 15 13:23:31 +0000 2012"", … ""Tue Aug 06 10:56:02 +0000 2013""]","[180, 180, … 180]"
…,…,…,…,…,…
97512,22,30,"[""50b99c90e4b0f99bbe09991b"", ""4b80f2c7f964a520fd9230e3"", … ""4b6fc43cf964a520f3fb2ce3""]","[""Sun Jan 06 00:37:10 +0000 2013"", ""Sun Jan 06 03:55:06 +0000 2013"", … ""Sat Sep 14 11:59:04 +0000 2013""]","[540, 540, … 540]"
73655,25,47,"[""4ca675e6ae1eef3b36013147"", ""4ca675e6ae1eef3b36013147"", … ""4c54645f728920a1be38af82""]","[""Fri Apr 06 16:38:29 +0000 2012"", ""Sun Apr 08 20:58:30 +0000 2012"", … ""Sun Sep 15 00:00:07 +0000 2013""]","[-240, -240, … -240]"
113244,15,23,"[""40e0b100f964a5201b051fe3"", ""40e0b100f964a5201b051fe3"", … ""4d69034c2acd6ea8888d34c0""]","[""Sat Apr 21 23:47:09 +0000 2012"", ""Sat Jun 16 23:38:50 +0000 2012"", … ""Thu Jul 04 17:39:27 +0000 2013""]","[-240, -240, … -240]"
144178,34,44,"[""4f5da931e4b0695cbc2b9d8c"", ""4e44f25f2271bdbcf67a1ee7"", … ""4e44f25f2271bdbcf67a1ee7""]","[""Mon Jul 02 14:19:22 +0000 2012"", ""Mon Jul 02 19:13:05 +0000 2012"", … ""Fri Jun 07 21:53:50 +0000 2013""]","[180, 180, … 180]"


In [25]:
data_culled = data_culled.with_columns(
    [
        rs.col("pois")
        .list.eval(
            rs.element().is_in(frequent_pois),
        )
        .alias("is_frequent")
    ]
)  # prep mask

In [26]:
final_data = (
    data_culled.lazy()
    .explode(
        [
            "pois",
            "dates",
            "TZs",
            "is_frequent",
        ]
    )
    .group_by("user")
    .agg(
        [
            rs.col("pois").filter(rs.col("is_frequent")).alias("pois"),
            rs.col("dates").filter(rs.col("is_frequent")).alias("dates"),
            rs.col("TZs").filter(rs.col("is_frequent")).alias("TZs"),
            rs.col("pois").filter(rs.col("is_frequent")).n_unique().alias("n_pois"),
            rs.col("pois").filter(rs.col("is_frequent")).count().alias("n_checkins"),
        ]
    )
    .filter(rs.col("n_checkins") > 0)
    .filter(rs.col("n_pois") > 0)
    .collect()
)  # filter out infrequent pois and users with no pois

In [27]:
final_data.describe()

statistic,user,pois,dates,TZs,n_pois,n_checkins
str,f64,f64,f64,f64,f64,f64
"""count""",19862.0,19862.0,19862.0,19862.0,19862.0,19862.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",156852.822274,,,,6.123452,8.831437
"""std""",76314.892884,,,,4.609024,6.877662
"""min""",49.0,,,,1.0,1.0
"""25%""",95613.0,,,,3.0,4.0
"""50%""",167846.0,,,,5.0,7.0
"""75%""",224576.0,,,,8.0,12.0
"""max""",266909.0,,,,32.0,46.0


At this stage, culling is done, we can appreciate that `polars`'s SQL/functional-style API is different from Pandas, but it is very powerful and efficient.

The next step is geohashing the POIs, that is, we want to convert the latitude-longitude positions of the POIs into a grid-based geohash representation, which will form the basis for our network's embeddings.

In [28]:
import geohash2 as gh

pois = rs.read_csv(
    "/content/drive/MyDrive/dataset_TIST2015/dataset_TIST2015_POIs.txt",
    has_header=False,
    low_memory=True,
    separator="\t",
)
pois.columns = ["poi", "lat", "long", "category", "country"]
pois = pois.drop("category").drop("country")

In [29]:
pois = (
    pois.lazy()
    .filter(rs.col("poi").is_in(frequent_pois))
    .select(
        [
            rs.col("poi"),
            rs.struct(
                [
                    rs.col("lat").cast(rs.Float32),
                    rs.col("long").cast(rs.Float32),
                ]
            )
            .alias("location")
            .map_elements(
                lambda s: gh.encode(s["lat"], s["long"], precision=6),
                return_dtype=rs.String,
            )
            .alias("geohash"),
        ]
    )
    .collect()
)
poi_geo_dict = dict(zip(pois["poi"], pois["geohash"]))

In [30]:
# for each row in final_data, add the geohash of the pois by hitting the poi_geo_dict

final_data = final_data.with_columns(
    [
        rs.col("pois")
        .map_elements(
            lambda s: [poi_geo_dict[s] for s in s],
            return_dtype=rs.List(rs.String),
        )
        .alias("geohashes")
    ]
)

In [31]:
final_data["dates"][79].to_list()  # check out a temporal sequence

['Sat Apr 14 19:36:51 +0000 2012',
 'Mon Apr 16 17:13:05 +0000 2012',
 'Wed May 09 15:58:48 +0000 2012',
 'Thu Nov 01 19:59:24 +0000 2012']

In [32]:
final_data["TZs"][79].to_list()  # ... and the corresponding timezones

[-180, -180, -240, -180]

The work *might* seem over, however, we still have timezones to account for, we want to normalize everything according to GMT, so we convert the timestamps accordingly.

In [33]:
import datetime


def UTC_to_local(utc, tz):

    date = datetime.datetime.strptime(utc, "%a %b %d %H:%M:%S %z %Y")
    date = date.replace(tzinfo=datetime.timezone.utc)

    # shift by tz offset
    date = date.astimezone(datetime.timezone(datetime.timedelta(minutes=tz)))

    date_s = datetime.datetime.strftime(date, "%Y-%m-%d %H:%M:%S")
    return date_s


def to_UNIX_time(date):
    return datetime.datetime.strptime(
        date, "%Y-%m-%d %H:%M:%S"
    ).timestamp()  # we use UNIX time as a key for sorting the POIs in our polars query

In [34]:
UTC_to_local("Mon May 21 15:53:01 +0000 2012", -420)  # example of usage

'2012-05-21 08:53:01'

In [35]:
final_data = final_data.with_columns(
    [
        rs.struct([rs.col("dates"), rs.col("TZs")])
        .alias("times")
        .map_elements(
            lambda struct: [
                UTC_to_local(date, tz)
                for date, tz in zip(struct["dates"], struct["TZs"])
            ],
            return_dtype=rs.List(rs.String),
        )
    ]
)  # This performs timezone conversion

In [36]:
final_sorted = final_data.select(  # sort the times
    [
        rs.col("user"),
        rs.struct(
            [
                rs.col("pois"),
                rs.col("times"),
            ]
        ).map_elements(
            lambda struct: [
                poi
                for poi, _ in sorted(
                    zip(  # here we sort the POIs struct by UNIX timestamps of the GMT times
                        struct["pois"], [to_UNIX_time(date) for date in struct["times"]]
                    ),
                    key=lambda s: s[1],
                )
            ],
            return_dtype=rs.List(rs.String),
        ),
        rs.struct(
            [
                rs.col("geohashes"),
                rs.col("times"),
            ]
        ).map_elements(
            lambda struct: [
                geo
                for geo, _ in sorted(
                    zip(
                        struct["geohashes"],  # same thing goes on for geohashes
                        [to_UNIX_time(date) for date in struct["times"]],
                    ),
                    key=lambda s: s[1],
                )
            ],
            return_dtype=rs.List(rs.String),
        ),
        rs.col("times")
        .map_elements(
            lambda dates: sorted(dates, key=to_UNIX_time),
            return_dtype=rs.List(rs.String),
        )
        .alias("times_sorted"),
        rs.col("n_checkins"),
    ]
)

# P.S, admittedly, it would have been more efficient to encode the geohashes *after* sorting the POIs,
# so that we could save on the sorting of the geohashes. Tough luck, you can't win 'em all.

In [37]:
final_sorted

user,pois,geohashes,times_sorted,n_checkins
i64,list[str],list[str],list[str],u32
224937,"[""4bb8b56e98c7ef3be94e3102"", ""4bb8b56e98c7ef3be94e3102"", … ""4b4dbb3ff964a52074d626e3""]","[""7pkdtn"", ""7pkdtn"", … ""6vjvtw""]","[""2012-04-16 12:23:44"", ""2012-04-19 16:05:53"", … ""2013-04-16 04:23:09""]",7
153503,"[""4bec1e73f90e9c74999ae3ed"", ""4a1c8506f964a520457b1fe3"", … ""4c9c6a059c48236a1cb14dee""]","[""u1huxw"", ""u33db1"", … ""gcpvhc""]","[""2012-10-16 15:45:05"", ""2013-04-07 12:10:32"", … ""2013-05-25 10:38:18""]",5
102118,"[""4541114bf964a520433c1fe3"", ""4aac1b4ef964a520475c20e3"", … ""4aa18a65f964a520ea4020e3""]","[""9mudjn"", ""9v1zzp"", … ""87z9q7""]","[""2012-04-19 08:17:27"", ""2012-04-19 14:24:03"", … ""2012-09-05 14:12:13""]",12
55062,"[""45b88de0f964a520ce411fe3"", ""41269080f964a520520c1fe3"", … ""45b88de0f964a520ce411fe3""]","[""9q5ctr"", ""9q5cfj"", … ""9q5ctr""]","[""2012-04-20 12:52:37"", ""2012-04-29 00:26:17"", … ""2012-11-24 21:21:23""]",11
160202,"[""4b572472f964a5200b2828e3"", ""4f1b29eae4b0d1a65569e663""]","[""xn77jj"", ""r3gwbn""]","[""2013-01-05 07:32:58"", ""2013-08-07 07:13:51""]",2
…,…,…,…,…
97652,"[""4bb5288a2ea19521a816aa2f"", ""4b4dbb3ff964a52074d626e3"", … ""4d8c9978ca75b60c0e4bd5a8""]","[""75cn8v"", ""6vjvtw"", … ""75cnnj""]","[""2012-04-10 05:00:21"", ""2012-04-24 15:07:27"", … ""2012-12-16 14:30:44""]",10
128440,"[""4d678903cba6a35df698cdc1"", ""4bf269b79f93952137b03098"", … ""4d678903cba6a35df698cdc1""]","[""qqguf8"", ""qqgvp2"", … ""qqguf8""]","[""2012-06-08 19:46:06"", ""2012-07-07 17:31:11"", … ""2013-04-21 14:53:42""]",7
3522,"[""4b155081f964a520b4b023e3"", ""4ad4c061f964a520adf720e3""]","[""dpz839"", ""dpz832""]","[""2013-03-14 07:52:42"", ""2013-04-19 19:27:48""]",2
122027,"[""4c34ecc53ffc9521e60791f5"", ""4be82d4488ed2d7fde85cb1d"", … ""4bb8469a1261d13a5586e898""]","[""ucftpy"", ""ucftpy"", … ""ucfut5""]","[""2012-12-28 13:30:23"", ""2012-12-28 14:08:34"", … ""2013-09-02 18:37:35""]",9


we now need to obtain a dataframe containing: each POI, it's geohash, and a set of all the check-ins it appears in
this is just one `polars` query away!

In [38]:
pois_checkins = final_sorted.explode(["pois", "geohashes"]).drop("n_checkins")

pois_checkins = (
    pois_checkins.with_columns(
        [
            rs.col("geohashes").map_elements(lambda s: s[:4], rs.String).alias("g4"),
        ]
    )
    .drop("geohashes")
    .group_by(["pois", "g4"])
    .agg([rs.col("times_sorted").flatten().alias("checkin_times")])
)

In [39]:
pois_checkins  # with this we can *efficiently* build our POI-POI spatial-temporal graphs

pois,g4,checkin_times
str,str,list[str]
"""4c65af8c7abde21eff6e6168""","""sxk9""","[""2012-04-23 00:31:38"", ""2012-04-24 16:23:51"", … ""2013-07-08 20:03:25""]"
"""4d160cd685fc6dcb5649a44e""","""w0zf""","[""2012-11-18 16:50:13"", ""2012-11-25 18:29:55"", … ""2013-08-22 14:23:43""]"
"""4a64ac42f964a52072c61fe3""","""9vg4""","[""2012-04-13 19:05:48"", ""2012-04-14 11:56:58"", … ""2012-12-02 15:14:04""]"
"""4c62ba497c9def3ba775d41c""","""66j8""","[""2012-11-17 06:20:19"", ""2013-01-26 11:11:16"", … ""2012-07-03 19:11:18""]"
"""4db435b31e7248d1359c7e13""","""sxk8""","[""2012-12-21 14:35:34"", ""2012-12-21 17:04:09"", … ""2013-07-28 08:49:15""]"
…,…,…
"""4b0a269df964a5200a2223e3""","""r1r0""","[""2012-10-13 14:09:50"", ""2012-10-13 16:49:43"", … ""2013-03-14 19:39:52""]"
"""4b4305d1f964a520b8db25e3""","""qqgu""","[""2012-04-15 11:13:43"", ""2012-05-13 12:18:22"", … ""2013-02-07 17:36:47""]"
"""4ba337d7f964a5205c3038e3""","""qqy8""","[""2012-12-27 18:15:19"", ""2012-12-28 14:36:50"", … ""2012-06-21 22:20:44""]"
"""4e318eeba809ef7b4ece9dcb""","""sxhs""","[""2012-04-29 20:48:06"", ""2012-05-19 21:41:18"", … ""2013-09-14 23:08:17""]"


In [40]:
def UTC_to_weekslot(utc: str) -> int:
    """UTC_to_weekslot converts a UTC timestamp to a weekslot.

    Parameters
    ----------
    utc : str
        A string representing a UTC timestamp.

    Returns
    -------
    int
        A weekslot in the range [0, 56).
    """

    date = datetime.datetime.strptime(utc, "%Y-%m-%d %H:%M:%S")
    week = date.weekday()
    hour = date.hour

    return week * 8 + hour // 3

Next, we want to encode all of our inputs for our neural networks, this could *probably* be done
with polars magic, but it's too delicate and we prefer classic for-looping.

In [41]:
encoder_dict = {
    "users": LabelEncoder(),
    "pois": LabelEncoder(),
    "g2": LabelEncoder(),
    "g3": LabelEncoder(),
    "g4": LabelEncoder(),
    "g5": LabelEncoder(),
    "g6": LabelEncoder(),
}

encoded_data = {
    "users": [],
    "pois": [],
    "g2": [],
    "g3": [],
    "g4": [],
    "g5": [],
    "g6": [],
}

unique_data = {
    "users": set(),
    "pois": set(),
    "g2": set(),
    "g3": set(),
    "g4": set(),
    "g5": set(),
    "g6": set(),
}

# quick and dirty encoding:
# 1. put every unique symbol in a list
# 2. fit the respective encoder
# 3. transform the lists

for i, row in enumerate(final_sorted.iter_rows()):

    user, pois, geohashes, times_sorted, n_checkins = row

    g2 = [geo[:2] for geo in geohashes]
    g3 = [geo[:3] for geo in geohashes]
    g4 = [geo[:4] for geo in geohashes]
    g5 = [geo[:5] for geo in geohashes]
    g6 = [geo[:6] for geo in geohashes]  # redundant, but I like symmetry

    unique_data["users"].add(user)
    unique_data["pois"].update(pois)
    unique_data["g2"].update(g2)
    unique_data["g3"].update(g3)
    unique_data["g4"].update(g4)
    unique_data["g5"].update(g5)
    unique_data["g6"].update(g6)

for property, enc, data in zip(
    encoder_dict.keys(), encoder_dict.values(), unique_data.values()
):
    enc.fit(list(data))
    encoder_dict[property] = enc

In [42]:
# this could be optimized, right now it takes a while, at least we have a nice progress bar to look at

ds_size = len(final_sorted)

for i, row in tqdm(enumerate(final_sorted.iter_rows()), total=ds_size):

    user, pois, geohashes, times_sorted, n_checkins = row

    g2 = [geo[:2] for geo in geohashes]
    g3 = [geo[:3] for geo in geohashes]
    g4 = [geo[:4] for geo in geohashes]
    g5 = [geo[:5] for geo in geohashes]
    g6 = [geo[:6] for geo in geohashes]

    encoded_data["users"].append(encoder_dict["users"].transform([user])[0])
    encoded_data["pois"].append(encoder_dict["pois"].transform(pois))
    encoded_data["g2"].append(encoder_dict["g2"].transform(g2))
    encoded_data["g3"].append(encoder_dict["g3"].transform(g3))
    encoded_data["g4"].append(encoder_dict["g4"].transform(g4))
    encoded_data["g5"].append(encoder_dict["g5"].transform(g5))
    encoded_data["g6"].append(encoder_dict["g6"].transform(g6))

    # sum 1 to all values to avoid 0s
    encoded_data["users"][-1] += 1
    encoded_data["pois"][-1] += 1
    encoded_data["g2"][-1] += 1
    encoded_data["g3"][-1] += 1
    encoded_data["g4"][-1] += 1
    encoded_data["g5"][-1] += 1
    encoded_data["g6"][-1] += 1

100%|██████████| 19862/19862 [02:44<00:00, 121.08it/s]


In [43]:
# check that we left space for the padding token
min((arr.min() for arr in encoded_data["pois"]))

1

In [44]:
pois_checkins

pois,g4,checkin_times
str,str,list[str]
"""4c65af8c7abde21eff6e6168""","""sxk9""","[""2012-04-23 00:31:38"", ""2012-04-24 16:23:51"", … ""2013-07-08 20:03:25""]"
"""4d160cd685fc6dcb5649a44e""","""w0zf""","[""2012-11-18 16:50:13"", ""2012-11-25 18:29:55"", … ""2013-08-22 14:23:43""]"
"""4a64ac42f964a52072c61fe3""","""9vg4""","[""2012-04-13 19:05:48"", ""2012-04-14 11:56:58"", … ""2012-12-02 15:14:04""]"
"""4c62ba497c9def3ba775d41c""","""66j8""","[""2012-11-17 06:20:19"", ""2013-01-26 11:11:16"", … ""2012-07-03 19:11:18""]"
"""4db435b31e7248d1359c7e13""","""sxk8""","[""2012-12-21 14:35:34"", ""2012-12-21 17:04:09"", … ""2013-07-28 08:49:15""]"
…,…,…
"""4b0a269df964a5200a2223e3""","""r1r0""","[""2012-10-13 14:09:50"", ""2012-10-13 16:49:43"", … ""2013-03-14 19:39:52""]"
"""4b4305d1f964a520b8db25e3""","""qqgu""","[""2012-04-15 11:13:43"", ""2012-05-13 12:18:22"", … ""2013-02-07 17:36:47""]"
"""4ba337d7f964a5205c3038e3""","""qqy8""","[""2012-12-27 18:15:19"", ""2012-12-28 14:36:50"", … ""2012-06-21 22:20:44""]"
"""4e318eeba809ef7b4ece9dcb""","""sxhs""","[""2012-04-29 20:48:06"", ""2012-05-19 21:41:18"", … ""2013-09-14 23:08:17""]"


In [45]:
# we also encode the graph dataframe so we can build the graphs

pois_checkins = (
    pois_checkins.lazy()
    .with_columns(
        [
            rs.col("pois").map_elements(
                lambda s: encoder_dict["pois"].transform([s])[0] + 1, rs.Int64
            ),
            rs.col("g4").map_elements(
                lambda s: encoder_dict["g4"].transform([s])[0] + 1, rs.Int64
            ),  # apply utc_to_weekslot to each timestamp in the list
            rs.col("checkin_times").map_elements(
                lambda s: [UTC_to_weekslot(date) for date in s], rs.List(rs.Int64)
            ),
        ]
    )
    .sort("pois")
    .collect()
)

In [46]:
# add fictitious POI 0 to the graph, with nonexistent geohash and no timeslot, so we get a 0 row and column for the padding token
fake_datapoint = rs.DataFrame(
    {
        "pois": [0],
        "g4": [pois_checkins["g4"].max() + 42],
        "checkin_times": [[43]],
    }
)
# this is a lot of work since polars dataframes are immutable by default, we have to run a query to change the 43 into an empty list
# we NEED the 43 otherwise polars won't infer the datatype of the list

fake_datapoint = fake_datapoint.with_columns(
    [rs.col("checkin_times").map_elements(lambda s: [], rs.List(rs.Int64))]
)

pois_checkins = fake_datapoint.vstack(pois_checkins)

In [47]:
spatial_row = np.array(pois_checkins["g4"].to_list()).reshape(-1, 1)

In [48]:
# outer product using equality
spatial_graph = (spatial_row == spatial_row.T).astype(np.int32)
spatial_graph[0, 0] = (
    0  # the fake g4 is still equal to itself, we suppress this equality
)
spatial_graph = torch.tensor(spatial_graph)

In [49]:
temporal_row = pois_checkins["checkin_times"].to_list()

In [50]:
temporal_graph = np.zeros((spatial_row.shape[0], spatial_row.shape[0]))

In [51]:
temporal_sets = [np.array(list(set(row))) for row in temporal_row]

In [52]:
time_sets = torch.zeros((len(temporal_sets), 56), dtype=torch.int8)

for i, r in enumerate(temporal_row):
    indices = torch.tensor(r, dtype=torch.long)
    time_sets[i, indices] = 1

In [53]:
time_sets.shape

torch.Size([4456, 56])

In [54]:
# AND outer product

intersection = time_sets @ time_sets.T
union = time_sets.unsqueeze(1) | time_sets.unsqueeze(0)
union = union.sum(dim=2)
iou = intersection / union

In [55]:
temporal_graph = iou >= 0.9
# cast to int
temporal_graph = temporal_graph.int()

In [56]:
temporal_graph[0, :].sum()

tensor(0)

We print information about the sparsity of the graphs, we note that
the sparsity of the graphs is similar to that of the paper.

In [57]:
temporal_density = (
    temporal_graph.sum() / (temporal_graph.shape[0] * temporal_graph.shape[1])
).item()
spatial_density = (
    spatial_graph.sum() / (spatial_graph.shape[0] * spatial_graph.shape[1])
).item()

print(f"Temporal sparsity: {(1 - temporal_density) * 100:.2f}%")

print(f"Spatial sparsity: {(1 - spatial_density) * 100:.2f}%")

Temporal sparsity: 97.05%
Spatial sparsity: 96.99%


## Train Test Split

We now generate two dataframes from the `encoded_data` dataframe, one for training and one for testing.

First, we have to drop every sequence that has less than 4 timestamps, as we wouldn't be able to get the minimum of two samples for each of the sets,
we then calculate the 80% of the sequences and split the data accordingly.

In [58]:
len(encoded_data["pois"])

19862

In [59]:
total_data = rs.DataFrame(encoded_data)

In [60]:
total_data = total_data.with_columns(
    [
        rs.col("pois").list.len().alias("length"),
    ]
)

In [61]:
total_data = total_data.with_columns(
    rs.col("length")
    .map_elements(lambda s: int(0.8 * s) - 1, rs.Int64)
    .alias("train_end")
)

In [62]:
# drop sequences that are too short
total_data = total_data.filter(
    (
        rs.col("train_end") >= 1
    )  # at least 2 elements in the training set (1 is the index)
    & (
        rs.col("length") - (rs.col("train_end") + 1) >= 2
    )  # at least 2 elements in the validation set
)
print(total_data["length"].mean())
print(total_data.count())

12.629353233830846
shape: (1, 9)
┌───────┬───────┬───────┬───────┬───┬───────┬───────┬────────┬───────────┐
│ users ┆ pois  ┆ g2    ┆ g3    ┆ … ┆ g5    ┆ g6    ┆ length ┆ train_end │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆   ┆ ---   ┆ ---   ┆ ---    ┆ ---       │
│ u32   ┆ u32   ┆ u32   ┆ u32   ┆   ┆ u32   ┆ u32   ┆ u32    ┆ u32       │
╞═══════╪═══════╪═══════╪═══════╪═══╪═══════╪═══════╪════════╪═══════════╡
│ 12060 ┆ 12060 ┆ 12060 ┆ 12060 ┆ … ┆ 12060 ┆ 12060 ┆ 12060  ┆ 12060     │
└───────┴───────┴───────┴───────┴───┴───────┴───────┴────────┴───────────┘


In [63]:
total_data.sort("length")  # check out the distribution of sequence lengths

users,pois,g2,g3,g4,g5,g6,length,train_end
i64,list[i64],list[i64],list[i64],list[i64],list[i64],list[i64],u32,i64
3518,"[3143, 2681, … 3165]","[101, 101, … 101]","[261, 261, … 261]","[502, 503, … 502]","[1218, 1221, … 1214]","[2475, 2487, … 2448]",6,3
13014,"[618, 745, … 745]","[77, 77, … 77]","[205, 205, … 205]","[368, 368, … 368]","[906, 905, … 905]","[1919, 1910, … 1910]",6,3
9035,"[1395, 1395, … 1395]","[86, 86, … 86]","[232, 232, … 232]","[439, 439, … 439]","[1043, 1043, … 1043]","[2130, 2130, … 2130]",6,3
5174,"[3249, 2072, … 2398]","[92, 92, … 90]","[245, 245, … 239]","[469, 469, … 452]","[1120, 1124, … 1064]","[2255, 2266, … 2156]",6,3
11509,"[1849, 179, … 119]","[31, 27, … 26]","[63, 51, … 49]","[111, 93, … 88]","[284, 244, … 237]","[524, 450, … 441]",6,3
…,…,…,…,…,…,…,…,…
18788,"[1561, 1561, … 1255]","[106, 106, … 109]","[266, 266, … 277]","[509, 509, … 541]","[1236, 1236, … 1344]","[2517, 2517, … 2746]",44,34
9814,"[1542, 1542, … 793]","[109, 109, … 109]","[277, 277, … 277]","[538, 538, … 538]","[1319, 1319, … 1322]","[2671, 2671, … 2685]",44,34
10822,"[1280, 3441, … 3915]","[68, 68, … 68]","[181, 181, … 181]","[334, 332, … 334]","[807, 795, … 802]","[1637, 1536, … 1585]",45,35
16188,"[599, 3122, … 1773]","[68, 68, … 68]","[181, 181, … 181]","[332, 334, … 334]","[795, 806, … 812]","[1537, 1614, … 1672]",46,35


In [64]:
# Check if the shortest sequence is long enough
total_data.sort("length")["pois"][0]

3143
2681
2681
4054
2681
3165


In [65]:
# slice the two dataframes
train_data = total_data.select(
    [
        rs.col("users"),
        rs.struct(
            [
                rs.col("pois"),
                rs.col("g2"),
                rs.col("g3"),
                rs.col("g4"),
                rs.col("g5"),
                rs.col("g6"),
                rs.col("train_end"),
            ]
        )
        .map_elements(
            lambda struct: [
                struct["pois"][: struct["train_end"]],
                struct["g2"][: struct["train_end"]],
                struct["g3"][: struct["train_end"]],
                struct["g4"][: struct["train_end"]],
                struct["g5"][: struct["train_end"]],
                struct["g6"][: struct["train_end"]],
            ],
            return_dtype=rs.List(rs.List(rs.Int64)),
        )
        .alias("sequences"),
    ]
)


test_data = total_data.select(
    [
        rs.col("users"),
        rs.struct(
            [
                rs.col("pois"),
                rs.col("g2"),
                rs.col("g3"),
                rs.col("g4"),
                rs.col("g5"),
                rs.col("g6"),
                rs.col("train_end"),
            ]
        )
        .map_elements(
            lambda struct: [
                struct["pois"][struct["train_end"] :],
                struct["g2"][struct["train_end"] :],
                struct["g3"][struct["train_end"] :],
                struct["g4"][struct["train_end"] :],
                struct["g5"][struct["train_end"] :],
                struct["g6"][struct["train_end"] :],
            ],
            return_dtype=rs.List(rs.List(rs.Int64)),
        )
        .alias("sequences"),
    ]
)

In [66]:
def explode_dict(d):
    """explode_dict Convert packed polars dataframe into a neat python dict

    Parameters
    ----------
    d : Polars.DataFrame
        A polars dataframe with a struct column

    Returns
    -------
    dict
        A python dict with the same structure as the struct column
    """
    ret = {
        "users": d["users"].to_list(),
        "pois": [],
        "g2": [],
        "g3": [],
        "g4": [],
        "g5": [],
        "g6": [],
    }

    columns = ["pois", "g2", "g3", "g4", "g5", "g6"]

    for sample in d["sequences"]:
        pois, g2, g3, g4, g5, g6 = sample
        ret["pois"].append(pois.to_list())
        ret["g2"].append(g2.to_list())
        ret["g3"].append(g3.to_list())
        ret["g4"].append(g4.to_list())
        ret["g5"].append(g5.to_list())
        ret["g6"].append(g6.to_list())

    return ret

In [67]:
encoded_data_train = explode_dict(train_data.to_dict())
encoded_data_test = explode_dict(test_data.to_dict())

## Metrics


The paper utilizes metrics that check if the target is in the top-k recommendations, we implement them here.

In [76]:
class AccuracyAtK(nn.Module):
    def __init__(self, k: int):
        """__init__ initializes the AccuracyAtK module.

        Accuracy@k is the proportion of correct predictions in the top-k elements.

        Parameters
        ----------
        k : int
            The number of top-k elements to consider.

        """
        super().__init__()
        self.k = k

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """forward computes the accuracy at k between logits and targets.

        Parameters
        ----------
        logits : torch.Tensor
            Class probability, either (B, C) or (B, T, C)
        targets : torch.Tensor
            Ground truth class indices, either (B,) or (B, T)

        Returns
        -------
        torch.Tensor
            The accuracy at k, a scalar-tensor.
        """

        # Gotta have at least one nasty python one-liner, in memory of the old
        # programming lab 1 bachelor course
        return (
            (logits.topk(self.k, dim=-1)[1] == targets.unsqueeze(-1))
            .any(dim=-1)
            .float()
            .mean()
        )
class AccuracyAt1(nn.Module):
      def __init__(self):
          super().__init__()

      def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
            predicted = logits.argmax(dim=-1)
            correct = (predicted == targets).float()
            accuracy = correct.mean()
            return accuracy

class MeanReciprocalRank(nn.Module):

    def __init__(self):
        """__init__ initializes the MeanReciprocalRank module.

        Mean reciprocal rank is the average of the reciprocal ranks of the top-k elements.

        """
        super().__init__()

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """forward computes the mean reciprocal rank between logits and targets.

        Parameters
        ----------
        logits : torch.Tensor
            Class probability
        targets : torch.Tensor
            Ground truth class indices

        Returns
        -------
        torch.Tensor
            The mean reciprocal rank, a scalar-tensor.
        """

        _, indices = logits.topk(logits.shape[-1], dim=-1)
        ranks = (indices == targets.unsqueeze(-1)).nonzero()[:, -1].float() + 1
        return (1.0 / ranks).mean()

## Dataset and Datamodule

We then define a pytorch dataset and a custom collation function that allows us to dynamically
pad sequences to the longest one in the batch (as opposed to the longest one in the dataset)
as they are loaded during training, this gives us an edge in performance by dramatically reducing the
sparsity of our inputs.

In [77]:
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence


def rnn_collation_fn(batch):

    users = []
    pois = []
    g2 = []
    g3 = []
    g4 = []
    g5 = []
    g6 = []

    for user, poi, geo2, geo3, geo4, geo5, geo6 in batch:
        users.append(user)
        pois.append(poi)
        g2.append(geo2)
        g3.append(geo3)
        g4.append(geo4)
        g5.append(geo5)
        g6.append(geo6)
    seq = (
        torch.tensor(users, dtype=torch.long),
        pad_sequence(pois, batch_first=True, padding_value=0),
        pad_sequence(g2, batch_first=True, padding_value=0),
        pad_sequence(g3, batch_first=True, padding_value=0),
        pad_sequence(g4, batch_first=True, padding_value=0),
        pad_sequence(g5, batch_first=True, padding_value=0),
        pad_sequence(g6, batch_first=True, padding_value=0),
    )  # build a sequence

    x = (
        seq[0],
        seq[1][:, :-1],
        seq[2][:, :-1],
        seq[3][:, :-1],
        seq[4][:, :-1],
        seq[5][:, :-1],
        seq[6][:, :-1],
    )  # omit the last one for sample

    y = (
        seq[0],
        seq[1][:, 1:],
        seq[2][:, 1:],
        seq[3][:, 1:],
        seq[4][:, 1:],
        seq[5][:, 1:],
        seq[6][:, 1:],
    )  # omit the first one for ground truth

    return x, y


class CheckinDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data["users"])

    def __getitem__(self, idx):

        x = (
            torch.tensor(encoded_data["users"][idx], dtype=torch.long),
            torch.tensor(encoded_data["pois"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g2"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g3"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g4"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g5"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g6"][idx], dtype=torch.long),
        )

        return x

In [78]:
class CheckinModule(pl.LightningDataModule):
    def __init__(self, encoded_data_train, encoded_data_test, batch_size=32, workers=4):
        """__init__ initializes the CheckinModule.

        Parameters
        ----------
        encoded_data_train : Union[dict, rs.DataFrame]
            The training data.
        encoded_data_test : Union[dict, rs.DataFrame]
            The testing data.
        batch_size : int, optional
            Size of the batches, by default 32
        workers : int, optional
            Number of worker processes, by default 4
        """
        super().__init__()
        self.encoded_data_train = encoded_data_train
        self.encoded_data_test = encoded_data_test
        self.batch_size = batch_size
        self.workers = workers

        assert isinstance(self.encoded_data_train, dict) or isinstance(
            self.encoded_data_train, rs.DataFrame
        ), "encoded_data_train must be a dict or a polars DataFrame"
        assert isinstance(self.encoded_data_test, dict) or isinstance(
            self.encoded_data_test, rs.DataFrame
        ), "encoded_data_test must be a dict or a polars DataFrame"

        assert batch_size > 0, "batch_size must be a positive integer"
        assert workers >= 0, "workers must be a non-negative integer"

    def setup(self, stage=None):

        if (
            isinstance(self.encoded_data_train, dict)
            or isinstance(self.encoded_data_train, rs.DataFrame)
        ) and (
            isinstance(self.encoded_data_test, dict)
            or isinstance(self.encoded_data_test, rs.DataFrame)
        ):
            print("Loading data from dict/dataframe")
            self.train_dataset = CheckinDataset(self.encoded_data_train)
            self.test_dataset = CheckinDataset(self.encoded_data_test)
            print(len(self.encoded_data_train["users"]))
        elif isinstance(self.encoded_data_train, CheckinDataset) and isinstance(
            self.encoded_data_test, CheckinDataset
        ):
            print("Loading data from pre-instantiated datasets")
            self.train_dataset = self.encoded_data_train
            self.test_dataset = self.encoded_data_test
        else:
            raise ValueError("Invalid data type")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.workers,
            collate_fn=rnn_collation_fn,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.workers,
            collate_fn=rnn_collation_fn,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.workers,
            collate_fn=rnn_collation_fn,
        )

    def save(self, whole_path, train_path, test_path):

        torch.save(self.train_dataset, train_path)
        torch.save(self.test_dataset, test_path)

    @staticmethod  # load without instantiating
    def load(train_path, test_path):

        train_dataset = torch.load(train_path)
        test_dataset = torch.load(test_path)

        return CheckinModule(train_dataset, test_dataset)

## Baseline model: LSTM

In [84]:
@dataclass
class BaselineDimensions:
    nuser: int
    npoi: int
    g2len: int
    g3len: int
    g4len: int
    g5len: int
    g6len: int


# HMT_RN (Hierarchical Multi-Task Recurrent Network)
class HMT_RN(pl.LightningModule):
    def __init__(
        self,
        dimensions: BaselineDimensions,
        embedding_dim,
        lstm_hidden_dim,
        dropout_rate=0.9,  # 0.9 is a lot, but the paper says so.
    ):
        super(HMT_RN, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = lstm_hidden_dim
        self.dims = dimensions

        # Embedding layers one for user, one for poi and one for each G@P
        self.user_embedding = nn.Embedding(
            dimensions.nuser, embedding_dim, padding_idx=0
        )
        self.poi_embedding = nn.Embedding(dimensions.npoi, embedding_dim, padding_idx=0)
        self.g2_embed = nn.Embedding(dimensions.g2len, embedding_dim, padding_idx=0)
        self.g3_embed = nn.Embedding(dimensions.g3len, embedding_dim, padding_idx=0)
        self.g4_embed = nn.Embedding(dimensions.g4len, embedding_dim, padding_idx=0)
        self.g5_embed = nn.Embedding(dimensions.g5len, embedding_dim, padding_idx=0)
        self.g6_embed = nn.Embedding(dimensions.g6len, embedding_dim, padding_idx=0)

        # Dropout layer for embeddings
        self.e_drop = nn.Dropout(p=dropout_rate)

        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=embedding_dim, hidden_size=lstm_hidden_dim, batch_first=True
        )

        # Linear layers for prediction tasks
        self.linear_poi = nn.Linear(lstm_hidden_dim + embedding_dim, dimensions.npoi)
        self.linear_g2 = nn.Linear(lstm_hidden_dim + embedding_dim, dimensions.g2len)
        self.linear_g3 = nn.Linear(lstm_hidden_dim + embedding_dim, dimensions.g3len)
        self.linear_g4 = nn.Linear(lstm_hidden_dim + embedding_dim, dimensions.g4len)
        self.linear_g5 = nn.Linear(lstm_hidden_dim + embedding_dim, dimensions.g5len)
        self.linear_g6 = nn.Linear(lstm_hidden_dim + embedding_dim, dimensions.g6len)

        self.criterion = nn.CrossEntropyLoss()

        self.top1_argmax = AccuracyAt1()
        self.top1 = AccuracyAtK(1)
        self.top5 = AccuracyAtK(5)
        self.top10 = AccuracyAtK(10)
        self.top20 = AccuracyAtK(20)
        self.mrr = MeanReciprocalRank()

        self.apply(self.init_weights)

    def init_weights(self, w):

        if type(w) == nn.Linear:
            nn.init.kaiming_normal_(w.weight)
            nn.init.constant_(w.bias, 0)
        elif type(w) == nn.LSTM:
            for name, param in w.named_parameters():
                if "bias" in name:
                    nn.init.constant_(param, 0)
                elif "weight" in name:
                    nn.init.kaiming_normal_(param)
        elif type(w) == nn.Embedding:
            nn.init.kaiming_normal_(w.weight)
            nn.init.constant_(w.weight[0], 0)

    def forward(self, batch):
        """forward passes the batch through the model.

        Parameters
        ----------
        batch : `tuple[torch.Tensor]`
            A tuple of tensors ordered as follows:
            (users, poi, x_geoHash2, x_geoHash3, x_geoHash4, x_geoHash5, x_geoHash6)
        """

        users, poi, x_geoHash2, x_geoHash3, x_geoHash4, x_geoHash5, x_geoHash6 = batch

        B, T = poi.shape

        # make it so  that users are tiled T times
        users = users.repeat(T, 1).T

        e_user = self.e_drop(self.user_embedding(users))
        e_poi = self.e_drop(self.poi_embedding(poi))
        e_gap2 = self.e_drop(self.g2_embed(x_geoHash2))
        e_gap3 = self.e_drop(self.g3_embed(x_geoHash3))
        e_gap4 = self.e_drop(self.g4_embed(x_geoHash4))
        e_gap5 = self.e_drop(self.g5_embed(x_geoHash5))
        e_gap6 = self.e_drop(self.g6_embed(x_geoHash6))

        h_t, c_t = self.lstm(e_poi)

        # dense layers
        next_poi = self.linear_poi(torch.cat((h_t, e_user), dim=2))
        next_g2 = self.linear_g2(torch.cat((h_t, e_gap2), dim=2))
        next_g3 = self.linear_g3(torch.cat((h_t, e_gap3), dim=2))
        next_g4 = self.linear_g4(torch.cat((h_t, e_gap4), dim=2))
        next_g5 = self.linear_g5(torch.cat((h_t, e_gap5), dim=2))
        next_g6 = self.linear_g6(torch.cat((h_t, e_gap6), dim=2))

        return next_poi, next_g2, next_g3, next_g4, next_g5, next_g6

    def training_step(self, batch, batch_idx):
        x, y = batch
        (
            poi_pred,
            gap2_pred,
            gap3_pred,
            gap4_pred,
            gap5_pred,
            gap6_pred,
        ) = self(x)

        loss_poi = self.criterion(
            poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1)
        )
        loss_gap2 = self.criterion(
            gap2_pred.reshape(-1, self.dims.g2len), y[2].reshape(-1)
        )
        loss_gap3 = self.criterion(
            gap3_pred.reshape(-1, self.dims.g3len), y[3].reshape(-1)
        )
        loss_gap4 = self.criterion(
            gap4_pred.reshape(-1, self.dims.g4len), y[4].reshape(-1)
        )
        loss_gap5 = self.criterion(
            gap5_pred.reshape(-1, self.dims.g5len), y[5].reshape(-1)
        )
        loss_gap6 = self.criterion(
            gap6_pred.reshape(-1, self.dims.g6len), y[6].reshape(-1)
        )

        loss = (
            loss_poi + loss_gap2 + loss_gap3 + loss_gap4 + loss_gap5 + loss_gap6
        ) / 6
        self.log("train/loss", loss)
        self.log("train/loss_gap2", loss_gap2)
        self.log("train/loss_gap3", loss_gap3)
        self.log("train/loss_gap4", loss_gap4)
        self.log("train/loss_gap5", loss_gap5)
        self.log("train/loss_gap6", loss_gap6)
        self.log("train/loss_poi", loss_poi)

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        (
            poi_pred,
            gap2_pred,
            gap3_pred,
            gap4_pred,
            gap5_pred,
            gap6_pred,
        ) = self(x)

        loss_poi = self.criterion(
            poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1)
        )
        loss_gap2 = self.criterion(
            gap2_pred.reshape(-1, self.dims.g2len), y[2].reshape(-1)
        )
        loss_gap3 = self.criterion(
            gap3_pred.reshape(-1, self.dims.g3len), y[3].reshape(-1)
        )
        loss_gap4 = self.criterion(
            gap4_pred.reshape(-1, self.dims.g4len), y[4].reshape(-1)
        )
        loss_gap5 = self.criterion(
            gap5_pred.reshape(-1, self.dims.g5len), y[5].reshape(-1)
        )
        loss_gap6 = self.criterion(
            gap6_pred.reshape(-1, self.dims.g6len), y[6].reshape(-1)
        )

        loss = (
            loss_poi + loss_gap2 + loss_gap3 + loss_gap4 + loss_gap5 + loss_gap6
        ) / 6

        top1_argmax=self.top1_argmax(poi_pred,y[1])
        top1_acc = self.top1(poi_pred,y[1])
        top5_acc = self.top5(poi_pred, y[1])
        top10_acc = self.top10(poi_pred, y[1])
        top20_acc = self.top20(poi_pred, y[1])
        mrr = self.mrr(poi_pred, y[1])

        self.log("val/loss", loss)
        self.log("val/loss_gap2", loss_gap2)
        self.log("val/loss_gap3", loss_gap3)
        self.log("val/loss_gap4", loss_gap4)
        self.log("val/loss_gap5", loss_gap5)
        self.log("val/loss_gap6", loss_gap6)
        self.log("val/loss_poi", loss_poi)

        # log "leaderboard" metrics
        self.log("val/top1_argmax", top1_argmax)
        self.log("val/top1", top1_acc)
        self.log("val/top5", top5_acc)
        self.log("val/top10", top10_acc)
        self.log("val/top20", top20_acc)
        self.log("val/mrr", mrr)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        (
            poi_pred,
            gap2_pred,
            gap3_pred,
            gap4_pred,
            gap5_pred,
            gap6_pred,
        ) = self(x)

        loss_poi = self.criterion(
            poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1)
        )
        loss_gap2 = self.criterion(
            gap2_pred.reshape(-1, self.dims.g2len), y[2].reshape(-1)
        )
        loss_gap3 = self.criterion(
            gap3_pred.reshape(-1, self.dims.g3len), y[3].reshape(-1)
        )
        loss_gap4 = self.criterion(
            gap4_pred.reshape(-1, self.dims.g4len), y[4].reshape(-1)
        )
        loss_gap5 = self.criterion(
            gap5_pred.reshape(-1, self.dims.g5len), y[5].reshape(-1)
        )
        loss_gap6 = self.criterion(
            gap6_pred.reshape(-1, self.dims.g6len), y[6].reshape(-1)
        )

        loss = (
            loss_poi + loss_gap2 + loss_gap3 + loss_gap4 + loss_gap5 + loss_gap6
        ) / 6

        self.log("test/loss", loss)
        self.log("test/loss_gap2", loss_gap2)
        self.log("test/loss_gap3", loss_gap3)
        self.log("test/loss_gap4", loss_gap4)
        self.log("test/loss_gap5", loss_gap5)
        self.log("test/loss_gap6", loss_gap6)
        self.log("test/loss_poi", loss_poi)

        top1_acc = self.top1(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        top5_acc = self.top5(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        top10_acc = self.top10(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        top20_acc = self.top20(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        mrr = self.mrr(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))

        # log "leaderboard" metrics
        self.log("test/top1", top1_acc)
        self.log("test/top5", top5_acc)
        self.log("test/top10", top10_acc)
        self.log("test/top20", top20_acc)
        self.log("test/mrr", mrr)

        return {"loss": loss}

    def configure_optimizers(self):
        # Define optimizer and scheduler
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, amsgrad=True)
        return optimizer

## Graph Neural Network

In [85]:
# GNN Components


class attn_LSTM(pl.LightningModule):

    def __init__(self, embedding_dim, hidden_dim):
        super(attn_LSTM, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.W = nn.Linear(embedding_dim, 4 * hidden_dim)
        self.U = nn.Linear(hidden_dim, 4 * hidden_dim)

        self.s_W = nn.Linear(embedding_dim, 4 * hidden_dim)
        self.t_W = nn.Linear(embedding_dim, 4 * hidden_dim)

    def forward(self, x, hidden, spatial, temporal, numTimeSteps):
        h_t, c_t = hidden

        previous_h_t = h_t
        previous_c_t = c_t

        allGates_preact = (
            self.W(x) + self.U(previous_h_t) + self.s_W(spatial) + self.t_W(temporal)
        )

        input_g = allGates_preact[:, :, : self.hidden_dim].sigmoid()
        forget_g = allGates_preact[
            :, :, self.hidden_dim : 2 * self.hidden_dim
        ].sigmoid()
        output_g = allGates_preact[
            :, :, 2 * self.hidden_dim : 3 * self.hidden_dim
        ].sigmoid()
        c_t_g = allGates_preact[:, :, 3 * self.hidden_dim :].tanh()

        c_t = forget_g * previous_c_t + input_g * c_t_g
        h_t = output_g * c_t.tanh()

        batchSize = x.shape[0]
        h_t = h_t.view(batchSize, numTimeSteps, self.hidden_dim)
        c_t = c_t.view(batchSize, numTimeSteps, self.hidden_dim)

        return h_t, c_t


def get_neighbours(adj_matrix, poi):
    neigh_indices_list = []
    max_length = 0

    for batch_poi in poi:
        batch_indices = []
        for single_poi in batch_poi:
            poi_row = adj_matrix[single_poi]
            neigh_indices = torch.where(poi_row == 1)[0]
            batch_indices.append(neigh_indices)
            max_length = max(max_length, len(neigh_indices))

        neigh_indices_list.append(batch_indices)

    padded_neigh_indices_list = []
    for batch_indices in neigh_indices_list:
        padded_batch_indices = pad_sequence(
            batch_indices, batch_first=True, padding_value=0
        )
        padded_neigh_indices_list.append(padded_batch_indices)

    padded_tensor = torch.stack(padded_neigh_indices_list)

    return padded_tensor


class GRNSelfAttention(torch.nn.Module):

    def __init__(self, hidden_dim, n_heads):

        super(GRNSelfAttention, self).__init__()

        self.hidden_dim = hidden_dim
        self.n_heads = n_heads

        self.Wp = nn.Linear(hidden_dim, hidden_dim)  # embeddings to pre-concat
        self.Wa = nn.Linear(2 * hidden_dim, hidden_dim)  # concatenation to pre-softmax

        # total size = 3 * (hidden_dim) ** 2, quadratic in embedding size

    def forward(self, poi, neighbors):
        """forward

        Parameters
        ----------
        poi: torch.Tensor
            A batched tensor of embedded POI vectors, (B x H) where H is the
            embedding dimension
        neighbors: torch.Tensor
            A batched tensor of sequences of embedded POI vectors that are extracted
            from an adjacency matrix (temporal or spatial neighbors of POI),
            (B x N x H), where N is the number of neighbours of POI, B is the
            batch size, H is the embedding dimension, and must be the same as POI

        Returns
        -------
        tuple[torch.Tensor, torch.Tensor]
          A tuple containing the self-attention weighted hadamard product of neighbour activations
          in the first index, the attention weights in the second index.
        """
        # assert len(poi.shape) == 2, f"POI tensor must be 2D, got {poi.shape} instead"
        assert (
            len(neighbors.shape) == 3
        ), f"Neighbour tensor must be 3D, got {neighbors.shape} instead"

        B, N, H = neighbors.shape

        h_poi = self.Wp(poi)
        h_n = self.Wp(neighbors)
        h_cat = torch.cat([h_poi.expand(B, N, -1), h_n], dim=2)
        h_att = F.leaky_relu(self.Wa(h_cat))

        alpha = torch.nn.functional.softmax(h_att, dim=1)

        p = torch.sum(alpha * h_n, dim=1)
        return p, alpha

In [94]:
# GRN (Graph Recurrent Network)
class GRN(pl.LightningModule):

    def __init__(
        self,
        dims: BaselineDimensions,
        spatial_graph,
        temporal_graph,
        hidden_dim,
        n_heads,
        dropout_rate=0.9,
        device="cpu",
    ):
        super(GRN, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.dims = dims

        self.spatial_graph = spatial_graph.to(device)
        self.temporal_graph = temporal_graph.to(device)

        self.spatial_attn = GRNSelfAttention(hidden_dim, n_heads)
        self.temporal_attn = GRNSelfAttention(hidden_dim, n_heads)

        self.lstm = attn_LSTM(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout_rate)

        self.user_embedding = nn.Embedding(dims.nuser, hidden_dim, padding_idx=0)
        self.poi_embedding = nn.Embedding(dims.npoi, hidden_dim, padding_idx=0)
        self.g2_embed = nn.Embedding(dims.g2len, hidden_dim, padding_idx=0)
        self.g3_embed = nn.Embedding(dims.g3len, hidden_dim, padding_idx=0)
        self.g4_embed = nn.Embedding(dims.g4len, hidden_dim, padding_idx=0)
        self.g5_embed = nn.Embedding(dims.g5len, hidden_dim, padding_idx=0)
        self.g6_embed = nn.Embedding(dims.g6len, hidden_dim, padding_idx=0)

        self.linear_poi = nn.Linear(2 * hidden_dim, dims.npoi)
        self.linear_g2 = nn.Linear(2 * hidden_dim, dims.g2len)
        self.linear_g3 = nn.Linear(2 * hidden_dim, dims.g3len)
        self.linear_g4 = nn.Linear(2 * hidden_dim, dims.g4len)
        self.linear_g5 = nn.Linear(2 * hidden_dim, dims.g5len)
        self.linear_g6 = nn.Linear(2 * hidden_dim, dims.g6len)

        self.top1_argmax=AccuracyAt1()
        self.top1 = AccuracyAtK(1)
        self.top5 = AccuracyAtK(5)
        self.top10 = AccuracyAtK(10)
        self.top20 = AccuracyAtK(20)
        self.mrr = MeanReciprocalRank()

        # extract indices from one-hot neighbor list
        self.iota = torch.arange(self.dims.npoi, requires_grad=False, device=device)

        self.criterion = nn.CrossEntropyLoss()

        self.apply(self.init_weights)

    def init_weights(self, w):
        if type(w) == nn.Linear:
            nn.init.kaiming_normal_(w.weight)
            nn.init.constant_(w.bias, 0)
        elif type(w) == nn.LSTM:
            for name, param in w.named_parameters():
                if "bias" in name:
                    nn.init.constant_(param, 0)
                elif "weight" in name:
                    nn.init.kaiming_normal_(param)
        elif type(w) == nn.Embedding:
            nn.init.kaiming_normal_(w.weight)
            nn.init.constant_(w.weight[0], 0)

    def forward(self, x):

        users, poi, x_geoHash2, x_geoHash3, x_geoHash4, x_geoHash5, x_geoHash6 = x

        B, T = poi.shape

        users = users.repeat(T, 1).T

        neighbors_spatial = self.spatial_graph[poi]
        neighbors_temporal = self.temporal_graph[poi]

        e_user = self.dropout(self.user_embedding(users))
        e_poi = self.dropout(self.poi_embedding(poi))
        e_gap2 = self.dropout(self.g2_embed(x_geoHash2))
        e_gap3 = self.dropout(self.g3_embed(x_geoHash3))
        e_gap4 = self.dropout(self.g4_embed(x_geoHash4))
        e_gap5 = self.dropout(self.g5_embed(x_geoHash5))
        e_gap6 = self.dropout(self.g6_embed(x_geoHash6))

        spatial_atts = torch.empty((B, T, self.hidden_dim), device=self.device)
        temporal_atts = torch.empty((B, T, self.hidden_dim), device=self.device)

        for b in range(B):
            for t in range(T):

                spatial_neigh = neighbors_spatial[b, t] * self.iota
                temporal_neigh = neighbors_temporal[b, t] * self.iota

                spatial_neigh = spatial_neigh[spatial_neigh != 0]
                temporal_neigh = temporal_neigh[temporal_neigh != 0]

                spatial_neigh = spatial_neigh.unsqueeze(0)
                temporal_neigh = temporal_neigh.unsqueeze(0)

                e_spatial = self.dropout(self.poi_embedding(spatial_neigh))
                e_temporal = self.dropout(self.poi_embedding(temporal_neigh))

                curr_poi = e_poi[b, t].unsqueeze(0)

                spatial_p, _ = self.spatial_attn(curr_poi, e_spatial)
                temporal_p, _ = self.temporal_attn(curr_poi, e_temporal)

                # we are not using the batch dimension, so we squeeze it
                spatial_atts[b, t] = spatial_p.squeeze()
                temporal_atts[b, t] = temporal_p.squeeze()

        # zero-init LSTM states
        h_t = torch.zeros(B, T, self.hidden_dim, device=self.device)
        c_t = torch.zeros(B, T, self.hidden_dim, device=self.device)

        h_t, c_t = self.lstm(e_poi, (h_t, c_t), spatial_atts, temporal_atts, T)

        # Note:the prediction of the poi depends on the embedding o the user
        next_poi = self.linear_poi(torch.cat((h_t, e_user), dim=2))
        next_g2 = self.linear_g2(torch.cat((h_t, e_gap2), dim=2))
        next_g3 = self.linear_g3(torch.cat((h_t, e_gap3), dim=2))
        next_g4 = self.linear_g4(torch.cat((h_t, e_gap4), dim=2))
        next_g5 = self.linear_g5(torch.cat((h_t, e_gap5), dim=2))
        next_g6 = self.linear_g6(torch.cat((h_t, e_gap6), dim=2))

        return next_poi, next_g2, next_g3, next_g4, next_g5, next_g6

    def training_step(self, batch, batch_idx):
        x, y = batch
        (
            poi_pred,
            gap2_pred,
            gap3_pred,
            gap4_pred,
            gap5_pred,
            gap6_pred,
        ) = self(x)

        loss_poi = self.criterion(
            poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1)
        )
        loss_gap2 = self.criterion(
            gap2_pred.reshape(-1, self.dims.g2len), y[2].reshape(-1)
        )
        loss_gap3 = self.criterion(
            gap3_pred.reshape(-1, self.dims.g3len), y[3].reshape(-1)
        )
        loss_gap4 = self.criterion(
            gap4_pred.reshape(-1, self.dims.g4len), y[4].reshape(-1)
        )
        loss_gap5 = self.criterion(
            gap5_pred.reshape(-1, self.dims.g5len), y[5].reshape(-1)
        )
        loss_gap6 = self.criterion(
            gap6_pred.reshape(-1, self.dims.g6len), y[6].reshape(-1)
        )

        loss = (
            loss_poi + loss_gap2 + loss_gap3 + loss_gap4 + loss_gap5 + loss_gap6
        ) / 6
        self.log("train/loss", loss)
        self.log("train/loss_gap2", loss_gap2)
        self.log("train/loss_gap3", loss_gap3)
        self.log("train/loss_gap4", loss_gap4)
        self.log("train/loss_gap5", loss_gap5)
        self.log("train/loss_gap6", loss_gap6)
        self.log("train/loss_poi", loss_poi)

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        (
            poi_pred,
            gap2_pred,
            gap3_pred,
            gap4_pred,
            gap5_pred,
            gap6_pred,
        ) = self(x)

        loss_poi = self.criterion(
            poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1)
        )
        loss_gap2 = self.criterion(
            gap2_pred.reshape(-1, self.dims.g2len), y[2].reshape(-1)
        )
        loss_gap3 = self.criterion(
            gap3_pred.reshape(-1, self.dims.g3len), y[3].reshape(-1)
        )
        loss_gap4 = self.criterion(
            gap4_pred.reshape(-1, self.dims.g4len), y[4].reshape(-1)
        )
        loss_gap5 = self.criterion(
            gap5_pred.reshape(-1, self.dims.g5len), y[5].reshape(-1)
        )
        loss_gap6 = self.criterion(
            gap6_pred.reshape(-1, self.dims.g6len), y[6].reshape(-1)
        )

        loss = (
            loss_poi + loss_gap2 + loss_gap3 + loss_gap4 + loss_gap5 + loss_gap6
        ) / 6

        self.log("val/loss", loss)
        self.log("val/loss_gap2", loss_gap2)
        self.log("val/loss_gap3", loss_gap3)
        self.log("val/loss_gap4", loss_gap4)
        self.log("val/loss_gap5", loss_gap5)
        self.log("val/loss_gap6", loss_gap6)
        self.log("val/loss_poi", loss_poi)

        top1_argmax=self.top1_argmax(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        top1_acc =self.top1(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        top5_acc = self.top5(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        top10_acc = self.top10(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        top20_acc = self.top20(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))
        mrr = self.mrr(poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1))

        # log "leaderboard" metrics
        self.log("val/top1_argmax", top1_argmax)
        self.log("val/top1", top1_acc)
        self.log("val/top5", top5_acc)
        self.log("val/top10", top10_acc)
        self.log("val/top20", top20_acc)
        self.log("val/mrr", mrr)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        (
            poi_pred,
            gap2_pred,
            gap3_pred,
            gap4_pred,
            gap5_pred,
            gap6_pred,
        ) = self(x)

        loss_poi = self.criterion(
            poi_pred.reshape(-1, self.dims.npoi), y[1].reshape(-1)
        )
        loss_gap2 = self.criterion(
            gap2_pred.reshape(-1, self.dims.g2len), y[2].reshape(-1)
        )
        loss_gap3 = self.criterion(
            gap3_pred.reshape(-1, self.dims.g3len), y[3].reshape(-1)
        )
        loss_gap4 = self.criterion(
            gap4_pred.reshape(-1, self.dims.g4len), y[4].reshape(-1)
        )
        loss_gap5 = self.criterion(
            gap5_pred.reshape(-1, self.dims.g5len), y[5].reshape(-1)
        )
        loss_gap6 = self.criterion(
            gap6_pred.reshape(-1, self.dims.g6len), y[6].reshape(-1)
        )

        loss = (
            loss_poi + loss_gap2 + loss_gap3 + loss_gap4 + loss_gap5 + loss_gap6
        ) / 6

        self.log("test/loss", loss)
        self.log("test/loss_gap2", loss_gap2)
        self.log("test/loss_gap3", loss_gap3)
        self.log("test/loss_gap4", loss_gap4)
        self.log("test/loss_gap5", loss_gap5)
        self.log("test/loss_gap6", loss_gap6)
        self.log("test/loss_poi", loss_poi)

        top1_acc = self.top1(poi_pred, y[1])
        top5_acc = self.top5(poi_pred, y[1])
        top10_acc = self.top10(poi_pred, y[1])
        top20_acc = self.top20(poi_pred, y[1])
        mrr = self.mrr(poi_pred, y[1])

        # log "leaderboard" metrics
        self.log("test/top1", top1_acc)
        self.log("test/top5", top5_acc)
        self.log("test/top10", top10_acc)
        self.log("test/top20", top20_acc)
        self.log("test/mrr", mrr)

        return {"loss": loss}

    def configure_optimizers(self):
        # Define optimizer and scheduler
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, amsgrad=True)
        return optimizer

## Training Loops

In [87]:
n_users = encoder_dict["users"].classes_.shape[0]
n_pois = encoder_dict["pois"].classes_.shape[0]
n_g2 = encoder_dict["g2"].classes_.shape[0]
n_g3 = encoder_dict["g3"].classes_.shape[0]
n_g4 = encoder_dict["g4"].classes_.shape[0]
n_g5 = encoder_dict["g5"].classes_.shape[0]
n_g6 = encoder_dict["g6"].classes_.shape[0]


# account for the padding token
dims = BaselineDimensions(
    n_users + 1, n_pois + 1, n_g2 + 1, n_g3 + 1, n_g4 + 1, n_g5 + 1, n_g6 + 1
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [88]:
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

TRAIN_BASELINE = True

wandb.finish()
torch.cuda.empty_cache()
# cargo-cult like stuff that is supposed to make you faster
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.benchmark = True

ds = CheckinModule(encoded_data_train, encoded_data_test, batch_size=32, workers=4)

wandb.init(project="trovailpoi")

classifier_baseline = HMT_RN(dims, embedding_dim=1024, lstm_hidden_dim=1024)
wandb_logger = WandbLogger(project="trovailpoi")
trainer = Trainer(
    max_epochs=10,
    accelerator="auto",
    devices=[0],
    log_every_n_steps=10,
    logger=wandb_logger,
    strategy="auto",
    callbacks=[
        torchpl.callbacks.LearningRateMonitor(logging_interval="step"),
        torchpl.callbacks.ModelCheckpoint(
            monitor="val/loss",
            mode="min",
            save_top_k=1,
            save_last=True,
            filename="best_model",
        ),
        torchpl.callbacks.EarlyStopping(
            monitor="val/loss", patience=3, min_delta=0.0005, mode="min"
        ),
    ],
)

if TRAIN_BASELINE:
    trainer.fit(model=classifier_baseline, datamodule=ds)
wandb.finish()

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

[34m[1mwandb[0m: Currently logged in as: [33mmartinadoku[0m ([33mpoi-dl-airo[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
   | Name           | Type               | Params
-------------------------------------------------------
0  | user_embedding | Embedding          | 20.3 M
1  | poi_embedding  | Embedding          | 4.6 M 
2  | g2_embed       | Embedding          | 114 K 
3  | g3_embed       | Embedding          | 287 K 
4  | g4_embed       | Embedding          | 561 K 
5  | g

Loading data from dict/dataframe
12060


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
lr-AdamW,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▅▄▂▃▂▃▃▂▂▂▂▃▂▂▂▂▃▂▂▁▁▃▃▂▁▃▃▂▁▁▂▂▂▂▁▂▂▂▂
train/loss_gap2,█▅▄▃▃▂▂▃▂▂▂▂▂▁▂▂▂▂▁▂▁▁▂▂▁▁▂▂▁▁▁▂▂▁▂▂▁▁▁▂
train/loss_gap3,█▆▄▃▃▂▃▃▃▂▃▂▃▂▂▂▂▃▂▂▁▁▃▃▂▁▂▃▁▁▁▂▂▂▂▂▂▂▁▂
train/loss_gap4,█▅▄▃▃▂▃▃▃▂▃▂▃▂▂▂▂▄▂▂▁▁▃▃▂▁▂▃▂▁▁▂▂▂▂▁▂▂▂▂
train/loss_gap5,█▅▄▂▃▂▃▃▂▂▃▂▃▂▂▂▂▄▃▂▁▁▃▃▂▁▃▃▂▁▁▂▂▂▃▂▃▂▂▂
train/loss_gap6,█▅▄▂▃▂▃▃▂▂▂▂▃▂▂▃▂▄▃▂▁▁▃▃▂▁▄▃▃▁▁▂▂▂▃▂▃▂▂▂
train/loss_poi,█▅▄▂▃▂▃▃▂▂▂▂▃▂▂▃▂▄▃▂▁▁▃▃▂▁▄▃▃▁▂▂▂▂▃▂▃▂▂▂
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,9.0
lr-AdamW,0.0001
train/loss,1.91603
train/loss_gap2,0.54152
train/loss_gap3,0.95938
train/loss_gap4,1.35265
train/loss_gap5,2.34223
train/loss_gap6,2.97279
train/loss_poi,3.32761
trainer/global_step,3769.0


In [None]:
TRAIN_GNN = True

batch_size = 60
wandb.finish()
torch.cuda.empty_cache()
# cargo-cult like stuff that is supposed to make you faster
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.benchmark = True

wandb.init(project="trovailpoi")

classifier_gnn = GRN(
    dims,
    spatial_graph,
    temporal_graph,
    hidden_dim=1024,
    n_heads=1,
    dropout_rate=0.9,
    device=device,
)
wandb_logger = WandbLogger(project="trovailpoi")
trainer = Trainer(
    max_epochs=40,
    accelerator="auto",
    devices=[0],
    log_every_n_steps=10,
    logger=wandb_logger,
    strategy="auto",
    callbacks=[
        torchpl.callbacks.LearningRateMonitor(logging_interval="step"),
        torchpl.callbacks.ModelCheckpoint(
            monitor="val/loss",
            mode="min",
            save_top_k=1,
            save_last=True,
            filename="best_model",
        ),
        torchpl.callbacks.EarlyStopping(
            monitor="val/loss", patience=3, min_delta=0.0005, mode="min"
        ),
    ],
)

if TRAIN_GNN:
    trainer.fit(model=classifier_gnn, datamodule=ds)
wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111259042222375, max=1.0)…

In [90]:
checkpoint_path = "/content/trovailpoi/dnpotgxc/checkpoints/best_model.ckpt"
# Load the trained model from the checkpoint
trained_model = HMT_RN.load_from_checkpoint(
    checkpoint_path,
    dimensions=dims,
    embedding_dim=1024,  # Example embedding dimension
    lstm_hidden_dim=1024,  # Example LSTM hidden dimension
    dropout_rate=0.9,  # Example dropout rate
)

# Create a test dataloader
# Assuming you have a method `test_dataloader` in your data module
test_loader = ds.test_dataloader()  # Replace `ds` with your actual data module instance

# Instantiate the trainer
trainer = Trainer(accelerator="auto", devices=[0])

# Test the model
results = trainer.test(trained_model, test_loader)

# Print the test results
print(results)

FileNotFoundError: [Errno 2] No such file or directory: '/content/trovailpoi/dnpotgxc/checkpoints/best_model.ckpt'

## Scrapbook for Experimentation

Ignore all code below, it's just for quick prototyping