In [None]:
import polars as rs
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as pl
import lightning.pytorch as torchpl
from tqdm import tqdm
import numpy as np

In [None]:
from sklearn.preprocessing import LabelEncoder

In [None]:
columns = ["user", "poi", "date", "TZ"]

In [None]:
data = rs.read_csv(
    "dataset_TIST2015/dataset_TIST2015_Checkins.txt",
    has_header=False,
    low_memory=True,
    separator="\t",
)

In [None]:
data.columns = columns

In [None]:
data

In [None]:
data_users = (
    data.lazy()
    .group_by("user")
    .agg(
        [
            rs.col("poi").n_unique().alias("n_pois"),
            rs.col("poi").count().alias("n_checkins"),
            # turn the rest into a list
            rs.col("poi").alias("pois"),
            rs.col("date").alias("dates"),
            rs.col("TZ").alias("TZs"),
        ]
    )
).collect()

In [None]:
data_users.describe()

In [None]:
data_culled = data_users.filter(
    (rs.col("n_checkins") > 20) & (rs.col("n_checkins") < 50)
).drop_nulls()

In [None]:
del data
del data_users

import gc

gc.collect()

In [None]:
# extract unique elements from each lists in data_culled["pois"]
out = data_culled.with_columns(
    [
        rs.col("pois").list.unique(),
        rs.col("pois").list.unique().list.len().alias("n_unique_pois"),
    ]
)

In [None]:
data_culled

In [None]:
out

In [None]:
l = out["pois"][0].to_list()

In [None]:
len(set(l))

In [None]:
l2 = data_culled["pois"][0].to_list()

In [None]:
len(l2)

In [None]:
len(set(l2))

In [None]:
out

In [None]:
unique_pois = out["pois"]

In [None]:
frequent_pois = unique_pois.list.explode().value_counts().filter(rs.col("count") >= 10)

In [None]:
frequent_pois

In [None]:
frequent_pois = frequent_pois["pois"]
frequent_pois = set(frequent_pois.to_list())

In [None]:
data_culled

In [None]:
data_culled = data_culled.with_columns(
    [
        rs.col("pois")
        .list.eval(
            rs.element().is_in(frequent_pois),
        )
        .alias("is_frequent")
    ]
)  # prep mask

In [None]:
final_data = (
    data_culled.lazy()
    .with_row_index()
    .explode(
        [
            "pois",
            "dates",
            "TZs",
            "is_frequent",
        ]
    )
    .group_by("user")
    .agg(
        [
            
            rs.col("pois").filter(rs.col("is_frequent")).alias("pois"),
            rs.col("dates").filter(rs.col("is_frequent")).alias("dates"),
            rs.col("TZs").filter(rs.col("is_frequent")).alias("TZs"),
            rs.col("pois").filter(rs.col("is_frequent")).n_unique().alias("n_pois"),
            rs.col("pois").filter(rs.col("is_frequent")).count().alias("n_checkins"),
        ]
    )
    .filter(rs.col("n_checkins") > 0)
    .filter(rs.col("n_pois") > 0)
    .collect()
)  # filter out infrequent pois and users with no pois

In [None]:
final_data.describe()

In [None]:
import geohash2 as gh

pois = rs.read_csv(
    "dataset_TIST2015/dataset_TIST2015_POIs.txt",
    has_header=False,
    low_memory=True,
    separator="\t",
)
pois.columns = ["poi", "lat", "long", "category", "country"]
pois = pois.drop("category").drop("country")

In [None]:
pois = (
    pois.lazy()
    .filter(rs.col("poi").is_in(frequent_pois))
    .select(
        [
            rs.col("poi"),
            rs.struct(
                [
                    rs.col("lat").cast(rs.Float32),
                    rs.col("long").cast(rs.Float32),
                ]
            )
            .alias("location")
            .map_elements(
                lambda s: gh.encode(s["lat"], s["long"], precision=6),
                return_dtype=rs.String,
            )
            .alias("geohash"),
        ]
    )
    .collect()
)

In [None]:
poi_geo_dict = dict(zip(pois["poi"], pois["geohash"]))

In [None]:
# for each row in final_data, add the geohash of the pois


final_data = final_data.with_columns(
    [
        rs.col("pois")
        .map_elements(
            lambda s: [poi_geo_dict[s] for s in s],
        )
        .alias("geohashes")
    ]
)

In [None]:
final_data["dates"][79].to_list()

In [None]:
final_data["TZs"][79].to_list()

In [None]:
import datetime


def UTC_to_local(utc, tz):

    date = datetime.datetime.strptime(utc, "%a %b %d %H:%M:%S %z %Y")
    date = date.replace(tzinfo=datetime.timezone.utc)

    # shift by tz offset

    date = date.astimezone(datetime.timezone(datetime.timedelta(minutes=tz)))

    date_s = datetime.datetime.strftime(date, "%Y-%m-%d %H:%M:%S")
    return date_s

In [None]:
UTC_to_local("Mon May 21 15:53:01 +0000 2012", -420)

In [None]:
final_data = final_data.with_columns(
    [
        rs.struct([rs.col("dates"), rs.col("TZs")])
        .alias("times")
        .map_elements(
            lambda struct: [
                UTC_to_local(date, tz)
                for date, tz in zip(struct["dates"], struct["TZs"])
            ],
            return_dtype=rs.List(rs.String),
        )
    ]
)   # This *should* perform timezone conversion

In [None]:
def to_UNIX_time(date):
    return datetime.datetime.strptime(date, "%Y-%m-%d %H:%M:%S").timestamp()

In [None]:
final_sorted = final_data.select( # sort the times
    [
        rs.col("user"),
        rs.struct(
            [
                rs.col("pois"),
                rs.col("times"),
            ]
        ).map_elements(
            lambda struct: [
                poi
                for poi, _ in sorted(
                    zip(
                        struct["pois"], [to_UNIX_time(date) for date in struct["times"]]
                    ),
                    key=lambda s: s[1],
                )
            ],
            return_dtype=rs.List(rs.String),
        ),
        rs.struct(
            [
                rs.col("geohashes"),
                rs.col("times"),
            ]
        ).map_elements(
            lambda struct: [
                geo
                for geo, _ in sorted(
                    zip(
                        struct["geohashes"],
                        [to_UNIX_time(date) for date in struct["times"]],
                    ),
                    key=lambda s: s[1],
                )
            ],
            return_dtype=rs.List(rs.String),
        ),
        rs.col("times")
        .map_elements(lambda dates: sorted(dates, key=to_UNIX_time), return_dtype=rs.List(rs.String))
        .alias("times_sorted"),
        rs.col("n_checkins"),
    ]
)

In [None]:
final_sorted

In [None]:
# we now need to obtain a dataframe containing: each POI, it's geohash, and a set of all the check-ins it appears in

pois_checkins = final_sorted.explode(["pois", "geohashes"]).drop("n_checkins")

pois_checkins = pois_checkins.with_columns(
    [
        rs.col("geohashes").map_elements(lambda s: s[:4], rs.String).alias("g4"),
    ]
).drop("geohashes").group_by(["pois", "g4"]).agg(
    [
        rs.col("times_sorted").flatten().alias("checkin_times")
    ]
)

In [None]:
pois_checkins # with this we can *efficiently* build our POI-POI spatial-temporal graphs

In [None]:
def UTC_to_weekslot(utc: str) -> int:
    # convert UTC into an integer (from 0 to 55), according to which three-hour slot
    # it occupies in a week
    
    date = datetime.datetime.strptime(utc, "%Y-%m-%d %H:%M:%S")
    week = date.weekday()
    hour = date.hour
    
    return week * 4 + hour // 3
    

In [None]:
encoder_dict = {
    "users": LabelEncoder(),
    "pois": LabelEncoder(),
    "g2": LabelEncoder(),
    "g3": LabelEncoder(),
    "g4": LabelEncoder(),
    "g5": LabelEncoder(),
    "g6": LabelEncoder(),
}

encoded_data = {
    "users" : [],
    "pois" : [],
    "g2" : [],
    "g3" : [],
    "g4" : [],
    "g5" : [],
    "g6" : [],
}

unique_data = {
    "users" : set(),
    "pois" : set(),
    "g2" : set(),
    "g3" : set(),
    "g4" : set(),
    "g5" : set(),
    "g6" : set(),
}

# quick and dirty encoding:
# 1. put every unique symbol in a list
# 2. fit the respective encoder
# 3. transform the lists

for i, row in enumerate(final_sorted.iter_rows()):

    user, pois, geohashes, times_sorted, n_checkins = row
    
    g2 = [geo[:2] for geo in geohashes]
    g3 = [geo[:3] for geo in geohashes]
    g4 = [geo[:4] for geo in geohashes]
    g5 = [geo[:5] for geo in geohashes]
    g6 = [geo[:6] for geo in geohashes] # redundant, but I like symmetry
    
    unique_data["users"].add(user)
    unique_data["pois"].update(pois)
    unique_data["g2"].update(g2)
    unique_data["g3"].update(g3)
    unique_data["g4"].update(g4)
    unique_data["g5"].update(g5)
    unique_data["g6"].update(g6)

for property, enc, data in zip(encoder_dict.keys(), encoder_dict.values(), unique_data.values()):
    enc.fit(list(data))
    encoder_dict[property] = enc

In [None]:
# this could be optimized, right now it takes a while

for i, row in tqdm(enumerate(final_sorted.iter_rows())):
    
    user, pois, geohashes, times_sorted, n_checkins = row
    
    g2 = [geo[:2] for geo in geohashes]
    g3 = [geo[:3] for geo in geohashes]
    g4 = [geo[:4] for geo in geohashes]
    g5 = [geo[:5] for geo in geohashes]
    g6 = [geo[:6] for geo in geohashes]
    
    encoded_data["users"].append(encoder_dict["users"].transform([user])[0])
    encoded_data["pois"].append(encoder_dict["pois"].transform(pois))
    encoded_data["g2"].append(encoder_dict["g2"].transform(g2))
    encoded_data["g3"].append(encoder_dict["g3"].transform(g3))
    encoded_data["g4"].append(encoder_dict["g4"].transform(g4))
    encoded_data["g5"].append(encoder_dict["g5"].transform(g5))
    encoded_data["g6"].append(encoder_dict["g6"].transform(g6))
    
    # sum 1 to all values to avoid 0s
    encoded_data["users"][-1] += 1
    encoded_data["pois"][-1] += 1
    encoded_data["g2"][-1] += 1
    encoded_data["g3"][-1] += 1
    encoded_data["g4"][-1] += 1
    encoded_data["g5"][-1] += 1
    encoded_data["g6"][-1] += 1
    


In [None]:
pois_checkins

In [None]:
# we also encode the graph dataframe so we can build the graphs

pois_checkins = pois_checkins.lazy().with_columns(
    [
        rs.col("pois").map_elements(lambda s: encoder_dict["pois"].transform([s])[0], rs.Int64),
        rs.col("g4").map_elements(lambda s: encoder_dict["g4"].transform([s])[0], rs.Int64), # apply utc_to_weekslot to each timestamp in the list
        rs.col("checkin_times").map_elements(lambda s: [UTC_to_weekslot(date) for date in s], rs.List(rs.Int64)),
    ]
).sort("pois").collect()

In [None]:
pois_checkins

In [None]:
spatial_row = np.array(pois_checkins["g4"].to_list())

In [None]:
spatial_graph = np.zeros((spatial_row.shape[0], spatial_row.shape[0]))

In [None]:
for i, a in enumerate(spatial_row):
    for j, b in enumerate(spatial_row):
        if a == b:
            spatial_graph[i, j] = 1

In [None]:
temporal_row = pois_checkins["checkin_times"].to_list()

In [None]:
temporal_graph = np.zeros((spatial_row.shape[0], spatial_row.shape[0]))

In [None]:
def iaccard_sim(A, B):
    I = A.intersection(B)
    U = A.union(B)
    
    IoU = len(I)/len(U)
    
    return IoU >= 0.9


In [None]:
for i, a in enumerate(temporal_row):
    for j, b in enumerate(temporal_row):
        if iaccard_sim(set(a), set(b)):
            temporal_graph[i, j] = 1

In [None]:
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence


def rnn_collation_fn(batch):
    users = []
    pois = []
    g2 = []
    g3 = []
    g4 = []
    g5 = []
    g6 = []
    
    
    for user, poi, geo2, geo3, geo4, geo5, geo6 in batch:
        users.append(user) # 0 is reserved for padding
        pois.append(poi)
        g2.append(geo2)
        g3.append(geo3)
        g4.append(geo4)
        g5.append(geo5)
        g6.append(geo6)
    seq = (
        users,
        pad_sequence(pois, batch_first=True, padding_value=0),
        pad_sequence(g2, batch_first=True, padding_value=0),
        pad_sequence(g3, batch_first=True, padding_value=0),
        pad_sequence(g4, batch_first=True, padding_value=0),
        pad_sequence(g5, batch_first=True, padding_value=0),
        pad_sequence(g6, batch_first=True, padding_value=0),
    )
    
    x = (
        seq[0],
        seq[1][:, :-1],
        seq[2][:, :-1],
        seq[3][:, :-1],
        seq[4][:, :-1],
        seq[5][:, :-1],
        seq[6][:, :-1],
    )
    
    y = (
        seq[0],
        seq[1][:, 1:],
        seq[2][:, 1:],
        seq[3][:, 1:],
        seq[4][:, 1:],
        seq[5][:, 1:],
        seq[6][:, 1:],
    
    )

    return x, y

class CheckinDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data["users"])

    def __getitem__(self, idx):
        
        x = (
            torch.tensor(encoded_data["users"][idx], dtype=torch.long),
            torch.tensor(encoded_data["pois"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g2"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g3"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g4"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g5"][idx], dtype=torch.long),
            torch.tensor(encoded_data["g6"][idx], dtype=torch.long),
        )
        
        return x

In [None]:
ds = CheckinDataset(encoded_data)

In [None]:
ds[0]

In [None]:
loader = torch.utils.data.DataLoader(
    ds, batch_size=8, shuffle=True, collate_fn=rnn_collation_fn
)

In [None]:
batch = next(iter(loader))

In [None]:
x, y = batch

In [None]:
x[1][7]

In [None]:
y[1][7]