# Foursquare dataset next-POI Recommendation System

First off we import all the necessary libraries:

In [58]:
import polars as rs
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as pl
import lightning.pytorch as torchpl
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import LabelEncoder
from dataclasses import dataclass

Next, we load the data, we utilize `polars` since it is much more efficient than `pandas` and can handle large datasets with ease.

In [4]:
columns = ["user", "poi", "date", "TZ"]
data = rs.read_csv(
    "dataset_TIST2015/dataset_TIST2015_Checkins.txt",
    has_header=False,
    low_memory=True,
    separator="\t",
)
data.columns = columns

In [5]:
data

user,poi,date,TZ
i64,str,str,i64
50756,"""4f5e3a72e4b053fd6a4313f6""","""Tue Apr 03 18:00:06 +0000 2012""",240
190571,"""4b4b87b5f964a5204a9f26e3""","""Tue Apr 03 18:00:07 +0000 2012""",180
221021,"""4a85b1b3f964a520eefe1fe3""","""Tue Apr 03 18:00:08 +0000 2012""",-240
66981,"""4b4606f2f964a520751426e3""","""Tue Apr 03 18:00:08 +0000 2012""",-300
21010,"""4c2b4e8a9a559c74832f0de2""","""Tue Apr 03 18:00:09 +0000 2012""",240
…,…,…,…
16349,"""4c957755c8a1bfb7e89024f3""","""Mon Sep 16 23:24:11 +0000 2013""",-240
256757,"""4c8bbb6d9ef0224bd2d6667b""","""Mon Sep 16 23:24:13 +0000 2013""",-180
66425,"""513e82a5e4b0ed4f0f3bcf2d""","""Mon Sep 16 23:24:14 +0000 2013""",-180
1830,"""4b447865f964a5204cf525e3""","""Mon Sep 16 23:24:14 +0000 2013""",120


Differently from what suggested by the professor, we utilize the full TIST2015 dataset, which has a far greater scale compared to the reduced NY one. However, by following the pruning steps detailed in the paper (http://dx.doi.org/10.1145/3477495.3531989, section 5.1), we obtain sequences that are much smaller in size, resulting in a dataset that is usable on Google Colab's free tier (as required by the assignment).

In [6]:
data_users = (
    data.lazy()
    .group_by("user")
    .agg(
        [
            rs.col("poi").n_unique().alias("n_pois"),
            rs.col("poi").count().alias("n_checkins"),
            # turn the rest into a list
            rs.col("poi").alias("pois"),
            rs.col("date").alias("dates"),
            rs.col("TZ").alias("TZs"),
        ]
    )
).collect()

In [7]:
data_users.describe()

statistic,user,n_pois,n_checkins,pois,dates,TZs
str,f64,f64,f64,f64,f64,f64
"""count""",266909.0,266909.0,266909.0,266909.0,266909.0,266909.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",133455.0,56.477459,124.62537,,,
"""std""",77050.135837,45.968603,140.692138,,,
"""min""",1.0,1.0,1.0,,,
"""25%""",66728.0,30.0,61.0,,,
"""50%""",133455.0,49.0,93.0,,,
"""75%""",200182.0,71.0,148.0,,,
"""max""",266909.0,1246.0,5430.0,,,


## Data Preprocessing

In [8]:
data_culled = data_users.filter(
    (rs.col("n_checkins") > 20) & (rs.col("n_checkins") < 50)
).drop_nulls()

Since the original dataset is huge, we delete it and call the python garbage collector to free up memory. We then proceed with the second pruning step (frequency-based pruning) as detailed in the paper.

In [9]:
del data
del data_users

import gc

gc.collect()

0

In [10]:
# extract unique elements from each lists in data_culled["pois"]
out = data_culled.with_columns(
    [
        rs.col("pois").list.unique(),
        rs.col("pois").list.unique().list.len().alias("n_unique_pois"),
    ]
)

In [11]:
out

user,n_pois,n_checkins,pois,dates,TZs,n_unique_pois
i64,u32,u32,list[str],list[str],list[i64],u32
228150,30,46,"[""4ac0fa3df964a520449520e3"", ""4a80aab7f964a520def51fe3"", … ""4c08e0a1bbc676b09a2a47d5""]","[""Sat Apr 14 20:54:33 +0000 2012"", ""Wed Apr 25 22:25:08 +0000 2012"", … ""Sat Aug 31 19:53:21 +0000 2013""]","[-300, -300, … -300]",30
255812,25,38,"[""4bae2c85f964a520a98d3be3"", ""4e4fc1bf8877402b06c775fb"", … ""4b058707f964a520307c22e3""]","[""Sat Apr 07 19:45:54 +0000 2012"", ""Sat Apr 07 19:52:31 +0000 2012"", … ""Sun Jun 16 16:47:52 +0000 2013""]","[-300, -300, … -300]",25
232026,30,44,"[""4b694d6ff964a5203c9e2be3"", ""4f68714ce4b01977701c8758"", … ""4f63da9fe4b0dc771f7af760""]","[""Sat Apr 28 23:40:46 +0000 2012"", ""Sat May 19 18:37:49 +0000 2012"", … ""Sun Nov 18 14:58:41 +0000 2012""]","[-180, -180, … -120]",30
5678,21,23,"[""4fdb0fdee4b0fb17898c1073"", ""4ba286c6f964a5201f0138e3"", … ""4cdd4782df986ea8d785d416""]","[""Tue Jul 24 15:38:54 +0000 2012"", ""Thu Aug 09 08:52:53 +0000 2012"", … ""Fri Aug 30 22:58:27 +0000 2013""]","[-240, -240, … -240]",21
91802,24,26,"[""4bba08c898c7ef3be1973202"", ""4b788451f964a52097d32ee3"", … ""500721e9e4b06d6fcf6c6e95""]","[""Sat Mar 02 14:04:14 +0000 2013"", ""Sat Mar 02 14:04:48 +0000 2013"", … ""Thu Mar 14 16:22:26 +0000 2013""]","[-180, -180, … -180]",24
…,…,…,…,…,…,…
30582,24,28,"[""50076321e4b0d5e2f527a27a"", ""4ea93b3f61af654198a8b9d3"", … ""4bc1fffcf8219c74b811b410""]","[""Sat Aug 18 19:48:06 +0000 2012"", ""Sat Aug 18 22:33:27 +0000 2012"", … ""Sat Sep 14 17:13:05 +0000 2013""]","[-240, -240, … -180]",24
78314,13,29,"[""40e0b100f964a5206c051fe3"", ""4bbf7417006dc9b669a0fc3f"", … ""4ba553b1f964a5201dfb38e3""]","[""Tue Dec 18 01:19:53 +0000 2012"", ""Fri Jan 04 20:01:45 +0000 2013"", … ""Sun Sep 15 03:43:14 +0000 2013""]","[-300, -300, … -240]",13
227893,37,45,"[""4bfc815eda7120a1c7354afd"", ""4a63be12f964a520aec51fe3"", … ""4da0d31f63b5a35d99e2da19""]","[""Wed Apr 04 18:06:13 +0000 2012"", ""Tue Apr 17 13:25:25 +0000 2012"", … ""Mon Aug 12 22:39:53 +0000 2013""]","[-240, -240, … -240]",37
138405,26,28,"[""444999c5f964a52076321fe3"", ""49bc22ebf964a5201a541fe3"", … ""4aab0694f964a520985820e3""]","[""Thu Apr 04 00:02:34 +0000 2013"", ""Thu Apr 04 00:18:04 +0000 2013"", … ""Thu Aug 08 14:06:55 +0000 2013""]","[-180, -180, … -180]",26


In [12]:
l = out["pois"][0].to_list()
len(set(l))  # print number of unique POIs in first sequence

30

In [13]:
l2 = data_culled["pois"][0].to_list()
len(l2)  # print sequence length of first user

46

In [14]:
len(set(l2))  # confirm that the two match

30

In [15]:
# run a Polars query to obtain all the frequent POIs, the ones expected to survive the filtering
unique_pois = out["pois"]
frequent_pois = unique_pois.list.explode().value_counts().filter(rs.col("count") >= 10)

In [16]:
frequent_pois

pois,count
str,u32
"""4b92005ef964a520ebe233e3""",14
"""4ccec1f6ee23a14391d92da8""",10
"""4b5dc322f964a5205f6a29e3""",16
"""4b0587fcf964a520eeaa22e3""",11
"""4d9c7fd7b0fccbffd8d834cf""",25
…,…
"""4ca1c143542b224b0bcd10a0""",131
"""44014894f964a5201d301fe3""",19
"""4b15507bf964a520a7b023e3""",73
"""5000582fe4b07988e18344b5""",21


In [17]:
frequent_pois = frequent_pois["pois"]
frequent_pois = set(frequent_pois.to_list())

In [18]:
data_culled

user,n_pois,n_checkins,pois,dates,TZs
i64,u32,u32,list[str],list[str],list[i64]
228150,30,46,"[""49cbc9f5f964a52018591fe3"", ""4b41143af964a520acc025e3"", … ""4b67290df964a520953e2be3""]","[""Sat Apr 14 20:54:33 +0000 2012"", ""Wed Apr 25 22:25:08 +0000 2012"", … ""Sat Aug 31 19:53:21 +0000 2013""]","[-300, -300, … -300]"
255812,25,38,"[""4ca95b6dae1eef3bd0f33247"", ""4de4fda7b3ad5eaa5aaeb7f0"", … ""51619b8be4b07ee81b5f0d10""]","[""Sat Apr 07 19:45:54 +0000 2012"", ""Sat Apr 07 19:52:31 +0000 2012"", … ""Sun Jun 16 16:47:52 +0000 2013""]","[-300, -300, … -300]"
232026,30,44,"[""4d5f0a9c9f67f04d0ace67fb"", ""4bd6230e6f649521526870ec"", … ""4c3ddcfb7d00d13aaccd3a50""]","[""Sat Apr 28 23:40:46 +0000 2012"", ""Sat May 19 18:37:49 +0000 2012"", … ""Sun Nov 18 14:58:41 +0000 2012""]","[-180, -180, … -120]"
5678,21,23,"[""500ad0f8e4b0ca8f7be9d0ff"", ""4bf81a008d30d13a4f1c0018"", … ""4c6bf59399b9236af82ce1c9""]","[""Tue Jul 24 15:38:54 +0000 2012"", ""Thu Aug 09 08:52:53 +0000 2012"", … ""Fri Aug 30 22:58:27 +0000 2013""]","[-240, -240, … -240]"
91802,24,26,"[""50abca92e4b0146313c60ffc"", ""4b870115f964a520ebaa31e3"", … ""4bba08c898c7ef3be1973202""]","[""Sat Mar 02 14:04:14 +0000 2013"", ""Sat Mar 02 14:04:48 +0000 2013"", … ""Thu Mar 14 16:22:26 +0000 2013""]","[-180, -180, … -180]"
…,…,…,…,…,…
30582,24,28,"[""4b6f27adf964a520c5e02ce3"", ""501c4a5de4b03fc9220d980b"", … ""4bd5c1b94e32d13a8c60c180""]","[""Sat Aug 18 19:48:06 +0000 2012"", ""Sat Aug 18 22:33:27 +0000 2012"", … ""Sat Sep 14 17:13:05 +0000 2013""]","[-240, -240, … -180]"
78314,13,29,"[""4ba553b1f964a5201dfb38e3"", ""4b08658df964a520e30a23e3"", … ""4db53a3cf7b121c29f59ab7d""]","[""Tue Dec 18 01:19:53 +0000 2012"", ""Fri Jan 04 20:01:45 +0000 2013"", … ""Sun Sep 15 03:43:14 +0000 2013""]","[-300, -300, … -240]"
227893,37,45,"[""4b98f18ef964a520b05735e3"", ""4b98f18ef964a520b05735e3"", … ""4b52459ff964a5200d7427e3""]","[""Wed Apr 04 18:06:13 +0000 2012"", ""Tue Apr 17 13:25:25 +0000 2012"", … ""Mon Aug 12 22:39:53 +0000 2013""]","[-240, -240, … -240]"
138405,26,28,"[""4e42b1e4b0fb09b608717ae8"", ""4ba7b75cf964a5207bac39e3"", … ""4c730ff5376da0937935a8c6""]","[""Thu Apr 04 00:02:34 +0000 2013"", ""Thu Apr 04 00:18:04 +0000 2013"", … ""Thu Aug 08 14:06:55 +0000 2013""]","[-180, -180, … -180]"


In [19]:
data_culled = data_culled.with_columns(
    [
        rs.col("pois")
        .list.eval(
            rs.element().is_in(frequent_pois),
        )
        .alias("is_frequent")
    ]
)  # prep mask

In [20]:
final_data = (
    data_culled.lazy()
    .with_row_index()
    .explode(
        [
            "pois",
            "dates",
            "TZs",
            "is_frequent",
        ]
    )
    .group_by("user")
    .agg(
        [
            rs.col("pois").filter(rs.col("is_frequent")).alias("pois"),
            rs.col("dates").filter(rs.col("is_frequent")).alias("dates"),
            rs.col("TZs").filter(rs.col("is_frequent")).alias("TZs"),
            rs.col("pois").filter(rs.col("is_frequent")).n_unique().alias("n_pois"),
            rs.col("pois").filter(rs.col("is_frequent")).count().alias("n_checkins"),
        ]
    )
    .filter(rs.col("n_checkins") > 0)
    .filter(rs.col("n_pois") > 0)
    .collect()
)  # filter out infrequent pois and users with no pois

In [21]:
final_data.describe()

statistic,user,pois,dates,TZs,n_pois,n_checkins
str,f64,f64,f64,f64,f64,f64
"""count""",19862.0,19862.0,19862.0,19862.0,19862.0,19862.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",156852.822274,,,,6.123452,8.831437
"""std""",76314.892884,,,,4.609024,6.877662
"""min""",49.0,,,,1.0,1.0
"""25%""",95613.0,,,,3.0,4.0
"""50%""",167846.0,,,,5.0,7.0
"""75%""",224576.0,,,,8.0,12.0
"""max""",266909.0,,,,32.0,46.0


At this stage, culling is done, we can appreciate that `polars`'s SQL/functional-style API is different from Pandas, but it is very powerful and efficient.

The next step is geohashing the POIs, that is, we want to convert the latitude-longitude positions of the POIs into a grid-based geohash representation, which will form the basis for our network's embeddings.

In [22]:
import geohash2 as gh

pois = rs.read_csv(
    "dataset_TIST2015/dataset_TIST2015_POIs.txt",
    has_header=False,
    low_memory=True,
    separator="\t",
)
pois.columns = ["poi", "lat", "long", "category", "country"]
pois = pois.drop("category").drop("country")

In [23]:
pois = (
    pois.lazy()
    .filter(rs.col("poi").is_in(frequent_pois))
    .select(
        [
            rs.col("poi"),
            rs.struct(
                [
                    rs.col("lat").cast(rs.Float32),
                    rs.col("long").cast(rs.Float32),
                ]
            )
            .alias("location")
            .map_elements(
                lambda s: gh.encode(s["lat"], s["long"], precision=6),
                return_dtype=rs.String,
            )
            .alias("geohash"),
        ]
    )
    .collect()
)
poi_geo_dict = dict(zip(pois["poi"], pois["geohash"]))

In [24]:
# for each row in final_data, add the geohash of the pois by hitting the poi_geo_dict

final_data = final_data.with_columns(
    [
        rs.col("pois")
        .map_elements(
            lambda s: [poi_geo_dict[s] for s in s],
            return_dtype=rs.List(rs.String),
        )
        .alias("geohashes")
    ]
)

In [25]:
final_data["dates"][79].to_list()  # check out a temporal sequence

['Sat Dec 08 09:23:13 +0000 2012',
 'Thu May 02 23:03:58 +0000 2013',
 'Sun Jun 16 22:42:00 +0000 2013']

In [26]:
final_data["TZs"][79].to_list()  # ... and the corresponding timezones

[-180, -240, -240]

The work *might* seem over, however, we still have timezones to account for, we want to normalize everything according to GMT, so we convert the timestamps accordingly.

In [27]:
import datetime


def UTC_to_local(utc, tz):

    date = datetime.datetime.strptime(utc, "%a %b %d %H:%M:%S %z %Y")
    date = date.replace(tzinfo=datetime.timezone.utc)

    # shift by tz offset
    date = date.astimezone(datetime.timezone(datetime.timedelta(minutes=tz)))

    date_s = datetime.datetime.strftime(date, "%Y-%m-%d %H:%M:%S")
    return date_s


def to_UNIX_time(date):
    return datetime.datetime.strptime(
        date, "%Y-%m-%d %H:%M:%S"
    ).timestamp()  # we use UNIX time as a key for sorting the POIs in our polars query

In [28]:
UTC_to_local("Mon May 21 15:53:01 +0000 2012", -420)  # example of usage

'2012-05-21 08:53:01'

In [29]:
final_data = final_data.with_columns(
    [
        rs.struct([rs.col("dates"), rs.col("TZs")])
        .alias("times")
        .map_elements(
            lambda struct: [
                UTC_to_local(date, tz)
                for date, tz in zip(struct["dates"], struct["TZs"])
            ],
            return_dtype=rs.List(rs.String),
        )
    ]
)  # This performs timezone conversion

In [30]:
final_sorted = final_data.select(  # sort the times
    [
        rs.col("user"),
        rs.struct(
            [
                rs.col("pois"),
                rs.col("times"),
            ]
        ).map_elements(
            lambda struct: [
                poi
                for poi, _ in sorted(
                    zip(  # here we sort the POIs struct by UNIX timestamps of the GMT times
                        struct["pois"], [to_UNIX_time(date) for date in struct["times"]]
                    ),
                    key=lambda s: s[1],
                )
            ],
            return_dtype=rs.List(rs.String),
        ),
        rs.struct(
            [
                rs.col("geohashes"),
                rs.col("times"),
            ]
        ).map_elements(
            lambda struct: [
                geo
                for geo, _ in sorted(
                    zip(
                        struct["geohashes"],  # same thing goes on for geohashes
                        [to_UNIX_time(date) for date in struct["times"]],
                    ),
                    key=lambda s: s[1],
                )
            ],
            return_dtype=rs.List(rs.String),
        ),
        rs.col("times")
        .map_elements(
            lambda dates: sorted(dates, key=to_UNIX_time),
            return_dtype=rs.List(rs.String),
        )
        .alias("times_sorted"),
        rs.col("n_checkins"),
    ]
)

# P.S, admittedly, it would have been more efficient to encode the geohashes *after* sorting the POIs,
# so that we could save on the sorting of the geohashes. Tough luck, you can't win 'em all.

In [31]:
final_sorted

user,pois,geohashes,times_sorted,n_checkins
i64,list[str],list[str],list[str],u32
247249,"[""4b6fec73f964a5202b002de3"", ""4be139c2c1732d7f797c5b9a"", … ""4bd171a4caff95213ca8d0f0""]","[""sxk9st"", ""sxk9kp"", … ""sxk97k""]","[""2012-12-22 14:10:04"", ""2013-01-06 16:43:04"", … ""2013-07-29 10:27:00""]",22
130772,"[""49e36665f964a52078621fe3"", ""40b28c80f964a5204eff1ee3"", … ""3fd66200f964a520bbf11ee3""]","[""dqcjq6"", ""dp3wt7"", … ""dqcjrk""]","[""2012-05-05 13:25:31"", ""2012-05-25 17:39:22"", … ""2012-06-10 00:25:10""]",7
243075,"[""4d15d1a285fc6dcb88b7a24e"", ""4c45dd45429a0f47a8ba4a1e"", … ""4c77eeb293ef236a3978aa0f""]","[""sxk91r"", ""sxk9hq"", … ""sxk9kp""]","[""2012-11-11 17:03:51"", ""2013-02-14 08:40:20"", … ""2013-05-30 16:10:53""]",7
34281,"[""4c0e115bc700c9b6ce42a3dd"", ""4c45ed693f0276b051c051e7"", … ""5121dd01e4b09c4652a6f9fc""]","[""sxp75h"", ""swxp7w"", … ""sz3fg4""]","[""2012-05-19 11:27:26"", ""2012-07-22 13:58:29"", … ""2013-05-10 22:45:25""]",11
27966,"[""4bcc3184b6c49c74c04b9391"", ""4b5fa00af964a52048c529e3"", … ""4bea97446295c9b6138c8608""]","[""w2q2xr"", ""w2q2xp"", … ""w949bf""]","[""2012-04-28 14:45:19"", ""2012-04-28 19:49:13"", … ""2012-12-15 16:04:57""]",5
…,…,…,…,…
144050,"[""4ea0717bdab4b59126c1a08c"", ""4ca95b6dae1eef3bd0f33247"", … ""4b15503df964a5202eb023e3""]","[""9g3w3r"", ""9g6ht2"", … ""9mudn4""]","[""2012-07-24 06:15:09"", ""2012-11-09 15:04:50"", … ""2013-07-14 14:23:43""]",8
260908,"[""4c6e4cddd5c3a1cdcfe9c72b"", ""4b7d1431f964a52007ae2fe3"", … ""4c6fe7a6fa49a1cde13ba4e3""]","[""w283d2"", ""w2834w"", … ""w29mxn""]","[""2012-05-27 11:03:18"", ""2012-05-27 11:55:49"", … ""2012-11-06 22:47:58""]",11
205837,"[""4bd5ddf3637ba5939a2ef770"", ""4b77523af964a520cd912ee3"", … ""4b12abfcf964a520118c23e3""]","[""7nxpke"", ""7jswx8"", … ""6gycfp""]","[""2012-12-18 02:02:24"", ""2012-12-18 07:54:46"", … ""2013-01-31 05:19:40""]",9
36772,"[""4cace5f3ae1eef3bc1fc3447"", ""4b5ec87ff964a520909829e3"", … ""4b723eeef964a5206d752de3""]","[""7h2ww2"", ""7h2y8j"", … ""6gycc7""]","[""2012-04-27 15:26:52"", ""2012-04-28 14:08:35"", … ""2013-03-30 12:56:18""]",4


we now need to obtain a dataframe containing: each POI, it's geohash, and a set of all the check-ins it appears in
this is just one `polars` query away!

In [32]:
pois_checkins = final_sorted.explode(["pois", "geohashes"]).drop("n_checkins")

pois_checkins = (
    pois_checkins.with_columns(
        [
            rs.col("geohashes").map_elements(lambda s: s[:4], rs.String).alias("g4"),
        ]
    )
    .drop("geohashes")
    .group_by(["pois", "g4"])
    .agg([rs.col("times_sorted").flatten().alias("checkin_times")])
)

In [33]:
pois_checkins  # with this we can *efficiently* build our POI-POI spatial-temporal graphs

pois,g4,checkin_times
str,str,list[str]
"""4e5e2d8e1838f7255271dd55""","""75cn""","[""2012-04-21 10:17:24"", ""2012-04-21 16:34:07"", … ""2013-08-16 17:16:13""]"
"""49fe6d4ef964a520a16f1fe3""","""9q5c""","[""2012-05-04 09:42:37"", ""2012-05-26 16:19:53"", … ""2013-09-16 20:35:13""]"
"""44cf0ff8f964a5201c361fe3""","""drt2""","[""2012-04-09 10:53:27"", ""2012-04-10 19:14:49"", … ""2013-04-24 05:01:44""]"
"""4d8394e899b78cfaab11a41f""","""swgs""","[""2012-09-14 18:33:06"", ""2012-11-20 23:34:32"", … ""2013-09-15 11:59:42""]"
"""4baa8948f964a52095723ae3""","""swxp""","[""2012-07-08 11:54:13"", ""2012-07-30 05:42:59"", … ""2013-09-02 07:46:58""]"
…,…,…
"""4ef058520e61196dc8dd2abe""","""sxk3""","[""2012-11-24 09:12:16"", ""2012-11-24 11:52:37"", … ""2012-11-27 19:02:13""]"
"""447bf8f1f964a520ec331fe3""","""dr5r""","[""2012-04-24 12:47:41"", ""2012-06-17 15:53:36"", … ""2013-05-17 07:04:50""]"
"""4b46b71ff964a520452726e3""","""de30""","[""2012-05-01 11:32:48"", ""2012-05-06 12:23:37"", … ""2013-06-02 17:41:26""]"
"""4a88af3ef964a520600720e3""","""dr5r""","[""2012-11-09 21:44:43"", ""2012-11-10 15:08:48"", … ""2013-05-08 09:23:10""]"


In [34]:
def UTC_to_weekslot(utc: str) -> int:
    """UTC_to_weekslot converts a UTC timestamp to a weekslot.

    Parameters
    ----------
    utc : str
        A string representing a UTC timestamp.

    Returns
    -------
    int
        A weekslot in the range [0, 56).
    """

    date = datetime.datetime.strptime(utc, "%Y-%m-%d %H:%M:%S")
    week = date.weekday()
    hour = date.hour

    return week * 8 + hour // 3

Next, we want to encode all of our inputs for our neural networks, this could *probably* be done 
with polars magic, but it's too delicate and we prefer classic for-looping.

In [35]:
encoder_dict = {
    "users": LabelEncoder(),
    "pois": LabelEncoder(),
    "g2": LabelEncoder(),
    "g3": LabelEncoder(),
    "g4": LabelEncoder(),
    "g5": LabelEncoder(),
    "g6": LabelEncoder(),
}

encoded_data = {
    "users": [],
    "pois": [],
    "g2": [],
    "g3": [],
    "g4": [],
    "g5": [],
    "g6": [],
}

unique_data = {
    "users": set(),
    "pois": set(),
    "g2": set(),
    "g3": set(),
    "g4": set(),
    "g5": set(),
    "g6": set(),
}

# quick and dirty encoding:
# 1. put every unique symbol in a list
# 2. fit the respective encoder
# 3. transform the lists

for i, row in enumerate(final_sorted.iter_rows()):

    user, pois, geohashes, times_sorted, n_checkins = row

    g2 = [geo[:2] for geo in geohashes]
    g3 = [geo[:3] for geo in geohashes]
    g4 = [geo[:4] for geo in geohashes]
    g5 = [geo[:5] for geo in geohashes]
    g6 = [geo[:6] for geo in geohashes]  # redundant, but I like symmetry

    unique_data["users"].add(user)
    unique_data["pois"].update(pois)
    unique_data["g2"].update(g2)
    unique_data["g3"].update(g3)
    unique_data["g4"].update(g4)
    unique_data["g5"].update(g5)
    unique_data["g6"].update(g6)

for property, enc, data in zip(
    encoder_dict.keys(), encoder_dict.values(), unique_data.values()
):
    enc.fit(list(data))
    encoder_dict[property] = enc

In [36]:
# this could be optimized, right now it takes a while, at least we have a nice progress bar to look at

ds_size = len(final_sorted)

for i, row in tqdm(enumerate(final_sorted.iter_rows()), total=ds_size):

    user, pois, geohashes, times_sorted, n_checkins = row

    g2 = [geo[:2] for geo in geohashes]
    g3 = [geo[:3] for geo in geohashes]
    g4 = [geo[:4] for geo in geohashes]
    g5 = [geo[:5] for geo in geohashes]
    g6 = [geo[:6] for geo in geohashes]

    encoded_data["users"].append(encoder_dict["users"].transform([user])[0])
    encoded_data["pois"].append(encoder_dict["pois"].transform(pois))
    encoded_data["g2"].append(encoder_dict["g2"].transform(g2))
    encoded_data["g3"].append(encoder_dict["g3"].transform(g3))
    encoded_data["g4"].append(encoder_dict["g4"].transform(g4))
    encoded_data["g5"].append(encoder_dict["g5"].transform(g5))
    encoded_data["g6"].append(encoder_dict["g6"].transform(g6))

    # sum 1 to all values to avoid 0s
    encoded_data["users"][-1] += 1
    encoded_data["pois"][-1] += 1
    encoded_data["g2"][-1] += 1
    encoded_data["g3"][-1] += 1
    encoded_data["g4"][-1] += 1
    encoded_data["g5"][-1] += 1
    encoded_data["g6"][-1] += 1

100%|██████████| 19862/19862 [03:02<00:00, 108.59it/s]


In [37]:
# check that we left space for the padding token
min((arr.min() for arr in encoded_data["pois"]))

1

In [38]:
pois_checkins

pois,g4,checkin_times
str,str,list[str]
"""4e5e2d8e1838f7255271dd55""","""75cn""","[""2012-04-21 10:17:24"", ""2012-04-21 16:34:07"", … ""2013-08-16 17:16:13""]"
"""49fe6d4ef964a520a16f1fe3""","""9q5c""","[""2012-05-04 09:42:37"", ""2012-05-26 16:19:53"", … ""2013-09-16 20:35:13""]"
"""44cf0ff8f964a5201c361fe3""","""drt2""","[""2012-04-09 10:53:27"", ""2012-04-10 19:14:49"", … ""2013-04-24 05:01:44""]"
"""4d8394e899b78cfaab11a41f""","""swgs""","[""2012-09-14 18:33:06"", ""2012-11-20 23:34:32"", … ""2013-09-15 11:59:42""]"
"""4baa8948f964a52095723ae3""","""swxp""","[""2012-07-08 11:54:13"", ""2012-07-30 05:42:59"", … ""2013-09-02 07:46:58""]"
…,…,…
"""4ef058520e61196dc8dd2abe""","""sxk3""","[""2012-11-24 09:12:16"", ""2012-11-24 11:52:37"", … ""2012-11-27 19:02:13""]"
"""447bf8f1f964a520ec331fe3""","""dr5r""","[""2012-04-24 12:47:41"", ""2012-06-17 15:53:36"", … ""2013-05-17 07:04:50""]"
"""4b46b71ff964a520452726e3""","""de30""","[""2012-05-01 11:32:48"", ""2012-05-06 12:23:37"", … ""2013-06-02 17:41:26""]"
"""4a88af3ef964a520600720e3""","""dr5r""","[""2012-11-09 21:44:43"", ""2012-11-10 15:08:48"", … ""2013-05-08 09:23:10""]"


In [39]:
# we also encode the graph dataframe so we can build the graphs

pois_checkins = (
    pois_checkins.lazy()
    .with_columns(
        [
            rs.col("pois").map_elements(
                lambda s: encoder_dict["pois"].transform([s])[0] + 1, rs.Int64
            ),
            rs.col("g4").map_elements(
                lambda s: encoder_dict["g4"].transform([s])[0] + 1, rs.Int64
            ),  # apply utc_to_weekslot to each timestamp in the list
            rs.col("checkin_times").map_elements(
                lambda s: [UTC_to_weekslot(date) for date in s], rs.List(rs.Int64)
            ),
        ]
    )
    .sort("pois")
    .collect()
)

In [40]:
# add fictitious POI 0 to the graph, with nonexistent geohash and no timeslot, so we get a 0 row and column for the padding token
fake_datapoint = rs.DataFrame(
    {
        "pois": [0],
        "g4": [pois_checkins["g4"].max() + 42],
        "checkin_times": [[43]],
    }
)
# this is a lot of work since polars dataframes are immutable by default, we have to run a query to change the 43 into an empty list
# we NEED the 43 otherwise polars won't infer the datatype of the list

fake_datapoint = fake_datapoint.with_columns(
    [rs.col("checkin_times").map_elements(lambda s: [], rs.List(rs.Int64))]
)

pois_checkins = fake_datapoint.vstack(pois_checkins)

In [41]:
spatial_row = np.array(pois_checkins["g4"].to_list()).reshape(-1, 1)

In [42]:
# outer product using equality
spatial_graph = (spatial_row == spatial_row.T).astype(np.int32)
spatial_graph[0, 0] = (
    0  # the fake g4 is still equal to itself, we suppress this equality
)

In [43]:
temporal_row = pois_checkins["checkin_times"].to_list()

In [44]:
temporal_graph = np.zeros((spatial_row.shape[0], spatial_row.shape[0]))

In [45]:
temporal_sets = [np.array(list(set(row))) for row in temporal_row]

In [46]:
time_sets = torch.zeros((len(temporal_sets), 56), dtype=torch.int8)

for i, r in enumerate(temporal_row):
    indices = torch.tensor(r, dtype=torch.long)
    time_sets[i, indices] = 1

In [47]:
time_sets.shape

torch.Size([4456, 56])

In [48]:
# AND outer product

intersection = time_sets @ time_sets.T
union = time_sets.unsqueeze(1) | time_sets.unsqueeze(0)
union = union.sum(dim=2)
iou = intersection / union

In [49]:
temporal_graph = iou >= 0.9
# cast to int
temporal_graph = temporal_graph.int()

In [50]:
temporal_graph[0, :].sum()

tensor(0)

We print information about the sparsity of the graphs, we note that 
the sparsity of the graphs is similar to that of the paper.

In [51]:
temporal_density = (
    temporal_graph.sum() / (temporal_graph.shape[0] * temporal_graph.shape[1])
).item()
spatial_density = (
    spatial_graph.sum() / (spatial_graph.shape[0] * spatial_graph.shape[1])
).item()

print(f"Temporal sparsity: {(1 - temporal_density) * 100:.2f}%")

print(f"Spatial sparsity: {(1 - spatial_density) * 100:.2f}%")

Temporal sparsity: 97.05%
Spatial sparsity: 96.99%


## Dataset and Datamodule

We then define a pytorch dataset and a custom collation function that allows us to dynamically
pad sequences to the longest one in the batch (as opposed to the longest one in the dataset)
as they are loaded during training, this gives us an edge in performance by dramatically reducing the 
sparsity of our inputs.

In [52]:
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence


def rnn_collation_fn(batch):

    users = []
    pois = []
    g2 = []
    g3 = []
    g4 = []
    g5 = []
    g6 = []

    for user, poi, geo2, geo3, geo4, geo5, geo6 in batch:
        users.append(user)
        pois.append(poi)
        g2.append(geo2)
        g3.append(geo3)
        g4.append(geo4)
        g5.append(geo5)
        g6.append(geo6)
    seq = (
        users,
        pad_sequence(pois, batch_first=True, padding_value=0),
        pad_sequence(g2, batch_first=True, padding_value=0),
        pad_sequence(g3, batch_first=True, padding_value=0),
        pad_sequence(g4, batch_first=True, padding_value=0),
        pad_sequence(g5, batch_first=True, padding_value=0),
        pad_sequence(g6, batch_first=True, padding_value=0),
    )  # build a sequence

    x = (
        seq[0],
        seq[1][:, :-1],
        seq[2][:, :-1],
        seq[3][:, :-1],
        seq[4][:, :-1],
        seq[5][:, :-1],
        seq[6][:, :-1],
    )  # omit the last one for sample

    y = (
        seq[0],
        seq[1][:, 1:],
        seq[2][:, 1:],
        seq[3][:, 1:],
        seq[4][:, 1:],
        seq[5][:, 1:],
        seq[6][:, 1:],
    )  # omit the first one for ground truth

    return x, y


class CheckinDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data["users"])

    def __getitem__(self, idx):

        x = (
            torch.tensor(encoded_data["users"][idx], dtype=torch.long),
            torch.tensor(encoded_data["pois"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g2"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g3"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g4"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g5"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g6"][idx], dtype=torch.long),
        )

        return x

In [53]:
class CheckinModule(pl.LightningDataModule):
    def __init__(self, encoded_data, batch_size=32, workers=4):
        super().__init__()
        self.encoded_data = encoded_data
        self.batch_size = batch_size
        self.workers = workers

    def setup(self, stage=None):
        self.whole_dataset = CheckinDataset(self.encoded_data)

        l = len(self.whole_dataset)

        train_size = int(0.8 * l)
        val_size = int(0.1 * l)
        test_size = l - train_size - val_size

        # generate train, val, test datasets by random split
        self.train_dataset, self.val_dataset, self.test_dataset = (
            torch.utils.data.random_split(
                self.whole_dataset, [train_size, val_size, test_size]
            )
        )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.workers,
            collate_fn=rnn_collation_fn,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.workers,
            collate_fn=rnn_collation_fn,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.workers,
            collate_fn=rnn_collation_fn,
        )

    def save(self, whole_path, train_path, val_path, test_path):
        torch.save(self.whole_dataset, whole_path)
        torch.save(self.train_dataset, train_path)
        torch.save(self.val_dataset, val_path)
        torch.save(self.test_dataset, test_path)

    @staticmethod  # load without instantiating
    def load(whole_path, train_path, val_path, test_path):
        whole_dataset = torch.load(whole_path)
        train_dataset = torch.load(train_path)
        val_dataset = torch.load(val_path)
        test_dataset = torch.load(test_path)
        return whole_dataset, train_dataset, val_dataset, test_dataset

## Baseline model: LSTM

## Scrapbook for Experimentation

Ignore all code below, it's just for quick prototyping