In [1]:
import pandas as pd
from src.data.hex_utils import * 
import pickle
import numpy as np

In [2]:
with open("../data/processed/geolife_hex_100.pkl", 'rb') as f: 
    hdf = pickle.load(f)

In [5]:
for col in [f"cell{i}" for i in range(4)] + ["user"]:
    hdf[col] += 1

hdf.to_pickle("../data/processed/geolife_hex_100.pkl")

In [45]:
import torch 
from torch import nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence, unpack_sequence
import torch.nn.functional as F

In [70]:
def f1(seq, lens):
    t = pad_sequence(seq, batch_first=True)
    batch_size, max_len = t.shape
    t = torch.concat([t, torch.zeros((batch_size, 1))], -1)
    t[torch.arange(batch_size), lens] = -1

In [71]:
%timeit f1()

832 μs ± 31.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [20]:
f


class TULVAE(nn.Model):
    def __init__(
        self,
        n_hidden,
        embedding_dim,
        embedding_dim_time,
        latent_dim,
        n_locs,
        n_times,
        n_users,
        dropout,
        n_layers,
        timesteps_split,
        device,
    ):
        super().__init__()
        self.device = device
        self.n_hidden = n_hidden
        self.latent_dim = latent_dim
        self.n_locs = n_locs
        self.n_times = n_times
        self.n_layers = n_layers
        self.n_users = n_users
        self.embedding_dim = embedding_dim
        self.timesteps_split = timesteps_split
        self.sos_token = torch.tensor(n_locs + 1)
        self.eos_token = torch.tensor(n_locs + 2)

        self.short_encoder = nn.LSTM(
            self.embedding_dim,
            n_hidden,
            bidirectional=True,
            batch_first=True,
            dropout=dropout,
        )
        self.long_encoder = nn.LSTM(
            self.embedding_dim,
            n_hidden,
            bidirectional=True,
            batch_first=True,
            dropout=dropout,
        )
        self.decoder = nn.LSTM(
            self.embedding_dim+self.n_users,
            n_hidden,
            bidirectional=True,
            batch_first=True,
            dropout=dropout,
        )
        self.fc_clf = nn.Linear(2 * n_hidden, n_users)
        self.fc_mu_short = nn.Linear(2 * n_hidden, latent_dim)
        self.fc_logvar_short = nn.Linear(2 * n_hidden, latent_dim)
        self.fc_mu_long = nn.Linear(2 * n_hidden, latent_dim)
        self.fc_logvar_long = nn.Linear(2 * n_hidden, latent_dim)
        self.fc_h_decoder_in = nn.Linear(
            2 * latent_dim, 2 * n_hidden
        )  # softplus unit receiving hidden state of encoder

        self.fc_decoder_out = (
            nn.Linear()
        )  # for projecting the hidden state to number of classes

        self.loc_embed = nn.Embedding(n_locs + 3, embedding_dim, padding_idx=0)
        self.time_embed = nn.Embedding(n_times + 1, embedding_dim_time, padding_idx=0)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, t, u):
        input_long, h_short = self.short_forward(x, t)
        h_long = self.long_forward(input_long)

        mean_short, logvar_short = self.fc_mu_short(h_short), self.fc_mu_short(h_short)
        mean_long, logvar_long = self.fc_mu_long(h_long), self.fc_mu_long(h_long)
        z_short = self.reparameterize(mean_short, logvar_short)
        z_long = self.reparameterize(mean_long, logvar_long)
        z = torch.cat([z_short, z_long], -1)
        h_dec = self.fc_h_decoder_in(z)

    def decode(self, h_enc, c_enc, x_embed, u):
        batch_size, max_seq_len, emb_dim = x_embed.shape
        u_one_hot = F.one_hot(u, self.n_users)[:, None].expand(-1, max_seq_len, -1)
        xu_embed = torch.cat([x_embed, u_one_hot], -1)
        h_dec = self.
        out, (h, c)


    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def long_forward(self, input_packed):
        out_packed, (h, c) = self.long_encoder(input_packed)
        out, lens_long = pad_packed_sequence(out_packed)
        return out.gather(1, lens_long)

    def short_forward(self, x, t):
        # TODO: Move this to higher scope
        lens_short = torch.tensor([len(xi) for xi in x])
        x_pad = pad_sequence(x, batch_first=True)
        t_pad = pad_sequence(t, batch_first=True)
        x_embed = self.loc_embed(x_pad)

        # Encode trajectory
        x_packed = pack_padded_sequence(
            x_embed, lens_short, batch_first=True, enforce_sorted=False
        )
        out_packed, (h, c) = self.short_encoder(x_packed)
        out, lens_unpacked = pad_packed_sequence(out_packed)
        # out.shape = (batch_size, max_len, 2*n_hidden)
        idcs_long, lens_long = self.get_idcs_long(t_pad)
        input_long_pad = out.gather(1, idcs_long)
        input_long_packed = pack_padded_sequence(
            input_long_pad, lens_long, batch_first=True, enforce_sorted=False
        )
        h_short = out.gather(1, lens_unpacked)
        return input_long_packed, h_short, c

    def get_idcs_long(self, t_pad):
        delta_t = t_pad - t_pad[:, 0, None]
        idcs = torch.searchsorted(
            delta_t, torch.arange(0, 24, 3)[None, :].expand(2, -1)
        )
        _, max_len = t_pad.shape
        # Get place to cut off idcs
        lens = (idcs < max_len).sum(-1) + 1
        return idcs[:, : lens.max()], lens

tensor([[0, 3, 8, 8],
        [0, 3, 6, 8]])