In [1]:
import pandas as pd
from src.data.hex_utils import * 
import pickle
import numpy as np

In [2]:
with open("../data/processed/geolife_hex_100.pkl", 'rb') as f: 
    hdf = pickle.load(f)

In [4]:
hdf.columns

Index(['lat', 'lon', 'datetime', 'trajectory', 'user', 't_idx', 'timediff',
       'x', 'y', 'dist', 'speed', 'q0', 'r0', 'cell0', 'q1', 'r1', 'cell1',
       'q2', 'r2', 'cell2', 'q3', 'r3', 'cell3', 'is_workday', 'is_in_time_0',
       'is_in_time_1', 'is_in_time_2', 'is_in_time_3', 'time_label'],
      dtype='object')

In [5]:
for col in [f"cell{i}" for i in range(4)] + ["user"]:
    hdf[col] += 1

hdf.to_pickle("../data/processed/geolife_hex_100.pkl")

In [45]:
import torch 
from torch import nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence, unpack_sequence

In [20]:
# Create example data
lens = [5, 8]
x = [torch.randint(1, 10, (lens[0],)), torch.randint(1, 10, (lens[1],))]
t = [torch.arange(6, 6 + lens[0]), torch.arange(6, 6 + lens[1])]


class HierarchicEncoder(nn.Model):
    def __init__(
        self,
        n_hidden,
        embedding_dim_loc,
        embedding_dim_time,
        n_locs,
        n_times,
        dropout,
        n_layers,
        timesteps_split,
        device,
    ):
        super().__init__()
        self.device = device
        self.n_hidden = n_hidden
        self.n_locs = n_locs
        self.n_times = n_times
        self.n_layers = n_layers
        self.embedding_dim_loc = embedding_dim_loc
        self.embedding_dim_time = embedding_dim_time
        self.timesteps_split = timesteps_split
        self.embedding_dim = embedding_dim_time + embedding_dim_loc

        self.short_encoder = nn.LSTM(
            self.embedding_dim,
            n_hidden,
            bidirectional=True,
            batch_first=True,
            dropout=dropout,
        )
        self.long_encoder = nn.LSTM(
            self.embedding_dim,
            n_hidden,
            bidirectional=True,
            batch_first=True,
            dropout=dropout,
        )
        self
        self.loc_embed = nn.Embedding(n_locs + 1, embedding_dim_loc, padding_idx=0)
        self.time_embed = nn.Embedding(n_times + 1, embedding_dim_time, padding_idx=0)
        self.dropout = nn.Dropout(dropout)
    
    def decode(self, z):



    def long_forward(self, input_packed):
        out_packed, (h, c) = self.long_encoder(input_packed)
        out, lens_long = pad_packed_sequence(out_packed)
        return out.gather(1, lens_long)

    def short_forward(self, x, t):
        lens_short = torch.tensor([len(xi) for xi in x])
        x_pad = pad_sequence(x, batch_first=True)
        t_pad = pad_sequence(t, batch_first=True)

        t_embed = self.time_embed(t_pad)
        x_embed = self.loc_embed(x_pad)
        xt_embed = torch.cat([x_embed, t_embed], dim=-1)

        # Encode trajectory
        x_packed = pack_padded_sequence(
            xt_embed, lens_short, batch_first=True, enforce_sorted=False
        )
        out_packed, (h, c) = self.short_encoder(x_packed)
        out, lens_unpacked = pad_packed_sequence(out_packed)
        # out.shape = (batch_size, max_len, 2*n_hidden)
        idcs_long, lens_long = self.get_idcs_long(t_pad)
        input_long_pad = out.gather(1, idcs_long)
        input_long_packed = pack_padded_sequence(
            input_long_pad, lens_long, batch_first=True, enforce_sorted=False
        )
        input_decoder = out.gather(1, lens_unpacked)
        return input_long_packed, input_decoder

    def get_idcs_long(self, t_pad):
        delta_t = t_pad - t_pad[:, 0, None]
        idcs = torch.searchsorted(
            delta_t, torch.arange(0, 24, 3)[None, :].expand(2, -1)
        )
        _, max_len = t_pad.shape
        # Get place to cut off idcs
        lens = (idcs < max_len).sum(-1) + 1
        return idcs[:, : lens.max()], lens

tensor([[0, 3, 8, 8],
        [0, 3, 6, 8]])

In [37]:
t = torch.cat([torch.arange(10)[None, :], torch.arange(10)[None, :]])
t[1, :] = t[1] * 2
idx = torch.tensor([[0, 3, 8, 8], [0, 3, 6, 8]])
t.gather(dim=1, index=idx)

tensor([[ 0,  3,  8,  8],
        [ 0,  6, 12, 16]])

In [52]:
t = torch.randn(2, 8, 4)
t2 = t.clone()
t[0, 5:] = 0
t[1, 6:] = 0
tp = pack_padded_sequence(t, lengths=torch.tensor([5, 6]), batch_first=True, enforce_sorted=False)
tp2 = pack_padded_sequence(t2, lengths=torch.tensor([5, 6]), batch_first=True, enforce_sorted=False)
torch.allclose(tp.data, tp2.data)

True