In [None]:
!pip install tensordict -q
!pip install torchode -q
!pip install missingno -q
!pip install pytorch_lightning -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m414.3/414.3 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m80.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m52.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import pandas as pd
import numpy as np
from IPython.display import display
import missingno as msno
import torch
from torch import nn

### 1. Код утилит

In [None]:
"""Batch-element-wise masking utilities."""

import torch
from torch import Tensor


def maskmax(t: Tensor, m: Tensor, dim: int, **kwargs):
    """Get the max element along dim, ignoring elements not in mask.

    Examples
    --------
    >>> t = torch.tensor([[1, 2, 3], [4, 5, 6]])
    >>> m = torch.tensor([[1, 1, 0], [1, 1, 0]]).bool()
    >>> maskmax(t, m, 1) # tensor([2., 5.])

    """
    return torch.amax(torch.where(m, t, -torch.inf), dim=dim, **kwargs)


def maskmin(t: Tensor, m: Tensor, dim: int, **kwargs):
    """Get the min element along dim, ignoring elements not in mask.

    Examples
    --------
    >>> t = torch.tensor([[3, 2, 1], [6, 5, 4]])
    >>> m = torch.tensor([[1, 1, 0], [1, 1, 0]]).bool()
    >>> maskmin(t, m, 1) # tensor([2., 5.])

    """
    return torch.amin(torch.where(m, t, +torch.inf), dim=dim, **kwargs)


def maskmean(t: Tensor, m: Tensor, dim: int, **kwargs):
    """Get the mean along dim, ignoring elements not in mask.

    Examples
    --------
    >>> t = torch.tensor([[1, 1, 9], [2, 2, 9]])
    >>> m = torch.tensor([[1, 1, 0], [1, 1, 0]]).bool()
    >>> maskmean(t, m, 1) # tensor([1., 2.])

    """
    return torch.nanmean(torch.where(m, t, torch.nan), dim=dim, **kwargs)


def masklast(t: Tensor, m: Tensor, dim: int, *, keepdim: bool = False):
    """Get the last element along dim, ignoring elements not in mask.

    Examples
    --------
    >>> t = torch.tensor([[2, 1, 3],
                          [4, 1, 6]])
    >>> m = torch.tensor([[1, 1, 1],
                          [1, 1, 0]]).bool()
    >>> masklast(t, m, 1) # tensor([3., 1.])

    """
    indices = (torch.sum(m, dim) - 1).int()
    idx = [torch.arange(0, size) for size in indices.shape]
    idx.insert(dim, indices)
    t = t[idx]
    if keepdim:
        t = t.unsqueeze(dim)

    return t


def complex_log(float_input, eps=1e-6):
    """Compute the complex logarithm.

    Used in associative_scan.
    """
    eps = float_input.new_tensor(eps)
    real = float_input.abs().maximum(eps).log()
    imag = (float_input < 0).to(float_input.dtype) * torch.pi
    return torch.complex(real, imag)


def associative_scan(values: torch.Tensor, coeffs: torch.Tensor, dim: int):
    """Calculate cumsum with resets.

    Source: https://github.com/pytorch/pytorch/issues/53095#issuecomment-2102409471.

    Examples
    --------
    >>> input = torch.tensor([1, 2, 3, 4, 5])
    >>> inverted_reset_mask = torch.tensor([0, 1, 1, 0, 1])
    >>> output = associative_scan(input, inverted_reset_mask, dim=0)
    >>> print(output)
    tensor([1.0000, 3.0000, 6.0000, 4.0000, 9.0000])

    """
    log_values = complex_log(values.float())
    log_coeffs = complex_log(coeffs.float())
    a_star = torch.cumsum(log_coeffs, dim=dim)
    log_x0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=dim)
    log_x = a_star + log_x0_plus_b_star
    return torch.exp(log_x).real


def roll_padding(t: Tensor, m: Tensor):
    """Roll padding specified by `m` to end over the last dimension of mask `m`.

    Returns
    -------
        tuple of new tensor and new padding mask

    """
    nkeep = m.sum(-1, keepdim=True)
    firstel_mask = nkeep > torch.arange(m.size(-1), device=m.device)
    tnew = torch.zeros_like(t)
    tnew[firstel_mask] = t[m]
    return tnew, firstel_mask


def sum_simultaneous(x: Tensor, t: Tensor, m):
    """Sum simultaneous events."""
    simultaneous_mask = torch.zeros_like(t, dtype=torch.bool)
    simultaneous_mask[:, 1:] = t[:, 1:] == t[:, :-1]
    x = associative_scan(x, simultaneous_mask.unsqueeze(-1), 1)
    keep_mask = torch.ones_like(simultaneous_mask)
    keep_mask[:, :-1] = ~simultaneous_mask[:, 1:]
    keep_mask = keep_mask & m
    x, mask = roll_padding(x, keep_mask)
    t, mask = roll_padding(t, keep_mask)
    return x, t, mask

### 2. Код TPP ODE

In [None]:
# Torch ODE example

import torch
import torch.nn as nn
import torchode as to

torch.random.manual_seed(180819023)

class Model(nn.Module):
    def __init__(self, n_features, n_hidden):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_features, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_features)
        )

    def forward(self, t, y):
        return self.layers(y)

n_features = 5
model = Model(n_features=n_features, n_hidden=32)


dev = torch.device("cpu")
term = to.ODETerm(model)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)

adjoint = to.AutoDiffAdjoint(step_method, step_size_controller).to(dev)
adjoint_jit = torch.jit.script(adjoint)

batch_size = 3
t_eval = torch.tile(torch.linspace(0.0, 3.0, 10), (batch_size, 1))
problem = to.InitialValueProblem(y0=torch.zeros((batch_size, n_features)).to(dev), t_eval=t_eval.to(dev))


sol = adjoint.solve(problem)
sol_jit = adjoint_jit.solve(problem)

print(sol.stats)
print(sol_jit.stats)
print("Max absolute difference", float((sol.ys - sol_jit.ys).abs().max()))

{'n_f_evals': tensor([38, 38, 38]), 'n_steps': tensor([6, 6, 6]), 'n_accepted': tensor([6, 6, 6]), 'n_initialized': tensor([10, 10, 10])}
{'n_f_evals': tensor([38, 38, 38]), 'n_steps': tensor([6, 6, 6]), 'n_accepted': tensor([6, 6, 6]), 'n_initialized': tensor([10, 10, 10])}
Max absolute difference 2.205371856689453e-06


In [None]:
sol.ts.shape # [batch_size, T]

torch.Size([3, 10])

In [None]:
sol.ys.shape # [batch_size, T, hidden_size]

torch.Size([3, 10, 5])

In [None]:
class TransformerBackbone(nn.Module):
    def __init__(self):
        pass

class CatEmbedding(nn.Module):
    def __init__(self, n_classes: int, embedding_dim: int = 32):
        super().__init__()
        self.n_classes = n_classes
        self.embedding_dim = embedding_dim

        self.embedding = nn.Embedding(n_classes, embedding_dim)

    def forward(self, batch):
        # batch: [batch_size, T] <--- int values of classes ids

        assert len(batch.shape) == 2

        return self.embedding(batch) # [batch_size, T, embedding_dim]

n_classes = 3
T = 100
emb_dim = 8
batch_size = 2
cat_emb = CatEmbedding(n_classes, emb_dim)
res = cat_emb(torch.randint(low=0, high=n_classes, size=(batch_size, T)))
print(res)
assert res.shape == (batch_size, T, emb_dim)

tensor([[[ 1.1834,  0.5987, -0.0172,  ...,  0.9349,  0.3069,  0.3437],
         [ 1.2050, -2.5556,  0.7688,  ..., -0.1799,  0.1548,  1.4784],
         [ 0.2382,  0.4838, -0.4177,  ..., -0.6549,  0.2083,  0.4805],
         ...,
         [ 0.2382,  0.4838, -0.4177,  ..., -0.6549,  0.2083,  0.4805],
         [ 1.2050, -2.5556,  0.7688,  ..., -0.1799,  0.1548,  1.4784],
         [ 0.2382,  0.4838, -0.4177,  ..., -0.6549,  0.2083,  0.4805]],

        [[ 1.2050, -2.5556,  0.7688,  ..., -0.1799,  0.1548,  1.4784],
         [ 1.1834,  0.5987, -0.0172,  ...,  0.9349,  0.3069,  0.3437],
         [ 1.2050, -2.5556,  0.7688,  ..., -0.1799,  0.1548,  1.4784],
         ...,
         [ 1.1834,  0.5987, -0.0172,  ...,  0.9349,  0.3069,  0.3437],
         [ 0.2382,  0.4838, -0.4177,  ..., -0.6549,  0.2083,  0.4805],
         [ 1.2050, -2.5556,  0.7688,  ..., -0.1799,  0.1548,  1.4784]]],
       grad_fn=<EmbeddingBackward0>)


In [None]:
import torch

# Оригинальный тензор формы (1, 3)
tensor = torch.tensor([[1, 2, 3]])
print("Оригинальная форма тензора:", tensor.shape)

# Расширение тензора до формы (4, 3)
expanded_tensor = tensor.expand(4, 3)
print("Расширенная форма тензора:", expanded_tensor.shape)


Оригинальная форма тензора: torch.Size([1, 3])
Расширенная форма тензора: torch.Size([4, 3])


In [None]:
expanded_tensor

tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

In [None]:
batch_size = 3
t_eval = torch.tile(torch.linspace(0.0, 3.0, 10), (batch_size, 1))
t_eval

tensor([[0.0000, 0.3333, 0.6667, 1.0000, 1.3333, 1.6667, 2.0000, 2.3333, 2.6667,
         3.0000],
        [0.0000, 0.3333, 0.6667, 1.0000, 1.3333, 1.6667, 2.0000, 2.3333, 2.6667,
         3.0000],
        [0.0000, 0.3333, 0.6667, 1.0000, 1.3333, 1.6667, 2.0000, 2.3333, 2.6667,
         3.0000]])

In [None]:
t_eval.shape

torch.Size([3, 10])

In [None]:
from tensordict import TensorDict
td = TensorDict({'a': torch.zeros(3, 4, 5)}, batch_size=[3, 4])
# returns a TensorDict of batch size [3, 4, 1]:
td_unsqueeze = td.unsqueeze(-1)
# returns a TensorDict of batch size [12]
td_view = td.view(-1)
# returns a tensor of batch size [12, 4]
a_view = td.view(-1).get("a")

In [None]:
td_view.shape

torch.Size([12])

In [None]:
class VectorField_time(nn.Module):
    def __init__(self, input_output_dim: int, hidden_dim: int = 32):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_output_dim = input_output_dim
        self.layers = nn.Sequential(
            nn.Linear(input_output_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(), # instead of ReLU
            nn.Linear(hidden_dim, input_output_dim)
        )

    def forward(self, times, y_s):
        # times: [batch_size]
        # y_s: [batch_size, input_dim]
        # assert len(times.shape) == 1 and len(y_s.shape) == 2
        # assert times.shape[0] == y_s.shape[0]
        # assert y_s.shape[1] == self.input_output_dim

        res = self.layers(torch.cat([y_s, times[..., None]], axis=-1)) # [batch_size, input_output_dim]
        # print(f"VF Result norm: {torch.sqrt((res ** 2).sum().sum()):.3f}")
        # print(f"VF res: {res}")
        assert res.shape == y_s.shape

        return res

class VectorField(nn.Module):
    def __init__(self, input_output_dim: int, hidden_dim: int = 32):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_output_dim = input_output_dim
        self.layers = nn.Sequential(
            nn.Linear(input_output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(), # instead of ReLU
            nn.Linear(hidden_dim, input_output_dim)
        )

    def forward(self, times, y_s):
        # times: [batch_size]
        # y_s: [batch_size, input_dim]
        # assert len(times.shape) == 1 and len(y_s.shape) == 2
        # assert times.shape[0] == y_s.shape[0]
        # assert y_s.shape[1] == self.input_output_dim

        res = self.layers(y_s) # [batch_size, input_output_dim]
        assert res.shape == y_s.shape
        print(f"Res: {res}")
        # print(f"VF res: {torch.norm(res)}")
        # print(f"VF Result norm: {torch.sqrt((res ** 2).sum().sum()):.3f}")
        return res


input_output_dim = 5
batch_size = 8

vf = VectorField(input_output_dim)

times = torch.randn(size=(batch_size, ))
y_s = torch.randn(size=(batch_size, input_output_dim))

res = vf(times, y_s)
print(res.shape)

Res: tensor([[ 0.0967, -0.1183,  0.1953,  0.2361, -0.0527],
        [-0.0622, -0.0880,  0.0976,  0.2218, -0.0546],
        [ 0.1182,  0.0478,  0.3133,  0.2378, -0.0678],
        [ 0.1800,  0.1763,  0.0573,  0.2568, -0.0333],
        [ 0.2010,  0.0228,  0.1759,  0.2285,  0.1710],
        [ 0.2070,  0.0522,  0.2412,  0.3855, -0.0352],
        [ 0.1309, -0.0157,  0.1974,  0.2547, -0.0070],
        [ 0.1089,  0.0159,  0.2136,  0.3113, -0.1271]],
       grad_fn=<AddmmBackward0>)
torch.Size([8, 5])


In [None]:
n_features = 5
model = VectorField(n_features)
# model = Model(n_features=n_features, n_hidden=32)


dev = torch.device("cpu")
term = to.ODETerm(model)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)

adjoint = to.AutoDiffAdjoint(step_method, step_size_controller).to(dev)
# adjoint_jit = torch.jit.script(adjoint)

batch_size = 3
t_eval = torch.tile(torch.linspace(0.0, 3.0, 10), (batch_size, 1))
problem = to.InitialValueProblem(y0=torch.zeros((batch_size, n_features)).to(dev), t_eval=t_eval.to(dev))


sol = adjoint.solve(problem)
# sol_jit = adjoint_jit.solve(problem)

print(sol.stats)
# print(sol_jit.stats)
print("Max absolute difference", float((sol.ys - sol_jit.ys).abs().max()))

VF Result norm: 0.340
VF res: tensor([[ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],
        [ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],
        [ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588]],
       grad_fn=<AddmmBackward0>)
VF Result norm: 0.340
VF res: tensor([[ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],
        [ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],
        [ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588]],
       grad_fn=<AddmmBackward0>)
VF Result norm: 0.340
VF res: tensor([[ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],
        [ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],
        [ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588]],
       grad_fn=<AddmmBackward0>)
VF Result norm: 0.340
VF res: tensor([[ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],
        [ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],
        [ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588]],
       grad_fn=<AddmmBackward0>)
VF Result norm: 0.340
VF res: tensor([[ 0.1325, -0.0948, -0.0794,  0.0468,  0.0588],

In [None]:
class Backbone_mean(nn.Module):
    def __init__(self,
                 hidden_dim: int = 32):
        super().__init__()


    def forward(self, batch):
        embedding = batch['embedding'] # [batch_size, time, hidden_dim]
        mask = batch['mask'] # [batch_size, time] 1 for good

        aggregated = (embedding * mask[:, :, None]).sum(axis=1) / mask.sum(axis=1, keepdims=True)
        # output: [batch_size, hidden_dim]

        return {
            'mask': mask,
            'embedding': aggregated
        }



batch_size = 8
T = 20
hidden_dim = 4

bb = Backbone_mean()

mask = (torch.randn(size=(batch_size, T)) > 0)
embedding = torch.randn(size=(batch_size, T, hidden_dim))

res = bb({'mask': mask, 'embedding': embedding})
print(res['embedding'].shape)

torch.Size([8, 4])


### 3. Код TPP ODE

https://torchode.readthedocs.io/en/latest/

Какие классы необходимо дописать:

0. `batch_collator` --- который бы выдвавал батчи в виде TensorDict:  `['time', 'mask', 'type']` из исходных данных

1. `self.encoder` --- берет TensorDict с ключами `['time', 'mask', 'type']` и возвращает, батч с добавленным полем `'embedding'`, посчитанным по `type`

2. `self.backbone` --- аггрегатор (а что принимает и выдает)?


Обучаемые параметры модуля ниже:
1. `encoder` --- кодирует события в векторы
2. `vf` --- векторное поле в ODE (нужно для вычисления финальных эмбеддингов)
3. `backbone` ---
4. `intensity_layer`

### 2. Код TPP ODE (старый, nn.Module)

In [None]:
import pytorch_lightning as pl

In [None]:
from collections.abc import Callable
from dataclasses import dataclass
from typing import Literal

import torch
from tensordict import TensorDict
from torch import Tensor, nn
from torch.nn import functional as F

import torchode as to
from __future__ import annotations

# from ..mask_utils import masklast
# from ..tensordict_utils import decollate_unpad_batch
# from .base import BaseModule

# class ODETPPModule(BaseModule):
class ODETPPModule(nn.Module): # использовать lightinig_module (чтобы training_step не прописывать) -> training_step, validation_step: в зависимости от того, что проще
    vf: nn.Module
    neg_count: int

    def __init__(
        self,
        neg_count: int,
        n_classes: int,
        hidden_dim: int = 128,
        device: str = 'cpu',
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.device = device
        self.hidden_dim = hidden_dim
        self.n_classes = n_classes

        self.backbone = Backbone_mean()
        # self.encoder: cat features -> vector = embedding
        # self.backbone: transformer (encoder-only) / aggregate with softmax <--- output is n_classes


        ### 1. Define vector field and embedding
        self.vf = VectorField(hidden_dim) # vf_factory(self.backbone.output_dim) # vector field what is it???
        # self.backbone.output_dim = n_classes???

        # Add encoder params to make them learnable.
        self.encoder = CatEmbedding(n_classes, hidden_dim)
        # self.vf._prev = self.encoder # происходит ДО того, как начинается процесс интегрирования

        ### 2. Set up black box ODE solver (from vector factory)
        term = to.ODETerm(self.vf)
        step_method = to.Dopri5(term=term)
        step_size_controller = to.IntegralController(atol=1e-3, rtol=1e-3, term=term)
        self.solver = torch.compile(
            to.BacksolveAdjoint(term, step_method, step_size_controller) # to.AutoDiffAdjoint may help (faster, but more memory)
        )
        # self.joint_solver = torch.compile(
        #     to.JointBacksolveAdjoint(term, step_method, step_size_controller) # only if all have the same evaluation points
        # )


        ### 3. Define layers to calculate decoupled conditional intensity
        self.intensity_layer = nn.Sequential(
            nn.Linear(hidden_dim, 1), nn.Sigmoid(), nn.Flatten()
        ) # self.encoder.output_dim = hidden_size [in the article authors used softplus]

        self.type_layer = nn.Sequential(
            nn.Linear(hidden_dim, n_classes),
            nn.Softmax(-1),
        ) # outputs probabilities of types


        self.bb_td = self.backbone # self.backbone.to_tensordictmodule()

        self.neg_count = neg_count ###

        self.monitor_name = "val_loss"
        self.monitor_mode = "min"

    def shared_step(
        self,
        stage: None | Literal["train"] | Literal["val"] | Literal["test"],
        batch: TensorDict,
        *args,
        **kwargs,
    ):
        # batch: dict with keys
        '''
            batch[0]: should contain:
                - timestamps (float)
                - corresponding events (int)
                - padding mask (used right padding)
                - types??? (слудующее событие = сдвиг на 1 вперед???)

            encoder(batch):
        '''
        embedding = self.encoder(batch['types']) # just encodes the events into embeddings



        time: Tensor = batch["time"]
        nonpadding_mask: Tensor = batch["mask"]
        # embedding: Tensor = batch["embedding"]
        batch['embedding'] = embedding
        types: Tensor = batch['types']

        # print(f"time:\n{time}\nmask:\n{nonpadding_mask}\nembedding:\n{embedding}\ntypes:\n{types}\n")

        pos_int, pos_type, neg_int = self.forward(
            time, nonpadding_mask, embedding, types
        )
        loss = pos_int + pos_type + neg_int

        # self.log(f"{stage}_loss", loss.item())
        # self.log(f"{stage}_posint_loss", pos_int.item())
        # self.log(f"{stage}_postype_loss", pos_type.item())
        # self.log(f"{stage}_negint_loss", neg_int.item())

        # if stage == "train":
        return loss
        # else:
        #     return self.unnest_merge_detach(
        #         pos_intensity_loss_reduced=pos_int,
        #         pos_type_loss_reduced=pos_type,
        #         neg_intensity_loss_reduced=neg_int,
        #     )

    def forward(
        self, time: Tensor, nonpadding_mask: Tensor, embedding: Tensor, types: Tensor
    ):
        '''
            time: [batch_size, T] <--- timestamps of events (ascending)
            nonpadding_mask: [batch_size, T] <--- 1 if value, 0 if mask

            types: [batch_size, T] <--- int tensor of event types (NEXT events??? What for the last: no subsequent event)
            embedding: [batch_size, T, hidden_dim] <--- embedding of event types, corresponding to according timestemp
        '''
        # types --- типы след. события
        B, T = time.shape # batch_size, T
        N = self.neg_count # what is it?

        ### 1. Fill padding in time with last times in corresponding timeseries
        t_start_int = time # [batch_size, T]
        t_end_int = masklast(time, nonpadding_mask, dim=1, keepdim=True).expand(B, T) # [batch_size] -> (copy) [batch_size, T]. Find last timestamp for every timeserie
        time[~nonpadding_mask] = t_end_int[~nonpadding_mask] # Fill all padding times with last time in corresponding rimeseries (row)

        ### 2. Get total times, at which to calculate the hidden state
        t_neg_01 = torch.rand(B, N, device=time.device, dtype=time.dtype) # [batch_size, N] ~ U[0, 1] --- times
        t_neg = t_neg_01 * (time[:, -1:] - time[:, :1]) + time[:, :1] # [batch_size, N]: random times in a matrix (each random time within range of t_min, t_max) of corresponding timeseries

        t_eval = torch.cat([time, t_neg], dim=1).unsqueeze(1).expand(B, T, -1) # [batch_size, T + N] -> [batch_size, 1, T + N] -> ??? [batch_size, T, ???]


        ### 3. Compute hidden states in aforementioned period od times
        solution = self.solver.solve(
            to.InitialValueProblem(
                embedding.flatten(0, -2), # of which shape embedding is???
                t_start_int.flatten(), # [batch_size * T]
                t_end_int.flatten(), # [batch_size * T] <--- integrate to the end (of the corresp. timeseries)
                t_eval.flatten(0, -2), # [a network of points in which to evaluate???]
            )
        )
        # solution.ts:
        ys = solution.ys.reshape(B, T, T + N, -1) # [batch_size, T (start time of event), T + N (eval time of event), hidden_dim]
        ts = solution.ts.reshape(B, T, T + N) # [batch_size, T (start time of event), T + N (eval time of event)]

        ### 4. Filter computed hidden states (some should not be considered). TO NOT TOUCH
        # B x T x (T + N)
        mask = (t_start_int.unsqueeze(2) <= ts) & (ts < t_end_int.unsqueeze(2)) # eval time should be: older than corresponding start time and younger than corresp end time

        # Remove diagonal elements, to prevent data leak
        mask[:, range(T), range(T)] = False # we do not evaluate in points, where no integration were held

        mask[..., :T] = (
            mask[..., :T] & nonpadding_mask.unsqueeze(1) & nonpadding_mask.unsqueeze(2)
        )  # B x T x P
        mask[..., T:] = mask[..., T:] & nonpadding_mask.unsqueeze(2)  # B x T x N
        ys = torch.where(
            mask.unsqueeze(-1),
            ys,
            embedding.unsqueeze(2),
        )

        # Flattened td B * (P + N) x T x C;
        # to be fed into backbone to aggr 1st dim.
        td4backbone = TensorDict(
            {
                "embedding": ys.transpose(1, 2).reshape(B * (T + N), T, -1), # [batch_size * (T + N), T, embedding_dim]]
                "mask": mask.transpose(1, 2).reshape(B * (T + N), T),
            },
            batch_size=(B * (T + N),),
        )

        # print(f"td4backbone.embedding shape: {td4backbone['embedding'].shape}")

        # Aggregate 1st (nut first, but #1 (T) dimension) dimension => B x (T + N) x C
        # Pass only non-masked batch elements
        batch_mask = td4backbone["mask"].any(1)
        aggr_td = self.bb_td(td4backbone[batch_mask]) # [batch_size * (T + N)??, hidden_dim]???]

        aggr_embeddings = torch.zeros(
            B * (T + N),
            self.hidden_dim, # self.backbone.output_dim
            device=self.device,
            dtype=ys.dtype,
        )
        aggr_embeddings[batch_mask] = aggr_td["embedding"]


        aggr_embeddings = aggr_embeddings.reshape(B, (T + N), -1) # [batch_size, T + N, hidden_dim] ???


        ### 5. Calculate loss using computed embeddings
        pos, neg = torch.split(aggr_embeddings, [T, N], dim=1) # [batch_size, T, hidden_dim], [batch_size, N, hidden_dim]

        # intensity loss
        pos_intensity_loss = -torch.log(self.intensity_layer(pos))

        # event type loss
        pos_type_loss = F.cross_entropy(
            self.type_layer(pos).flatten(0, -2),
            types.flatten().long(),
            reduction="none",
            ignore_index=0,
        ).reshape(B, T)
        neg_intensity_loss: Tensor = self.intensity_layer(neg)

        pos_intensity_loss_reduced = pos_intensity_loss[nonpadding_mask].sum(-1).mean()
        pos_type_loss_reduced = pos_type_loss[nonpadding_mask].sum(-1).mean()

        neg_intensity_loss_reduced = (neg_intensity_loss.mean(-1) * time[:, -1]).mean()

        return (
            pos_intensity_loss_reduced,
            pos_type_loss_reduced,
            neg_intensity_loss_reduced,
        )

#     def predict_step(
#         self, batch: TensorDict, *args: torch.Any, **kwargs: torch.Any
#     ) -> torch.Any:
#         batch = self.encoder(batch)
#         time: Tensor = batch["time"]
#         mask: Tensor = batch["mask"]
#         embedding: Tensor = batch["embedding"]

#         t_end = masklast(time, mask, dim=1, keepdim=True).expand_as(time)
#         time[~mask] = t_end[~mask]
#         trajectory_event = self.joint_solver.solve(
#             to.InitialValueProblem(
#                 y0=embedding.flatten(0, -2),
#                 t_start=time.flatten(),
#                 t_end=t_end.flatten(),
#             )
#         )

#         embedding = torch.where(
#             (time == t_end).unsqueeze(-1),
#             embedding,
#             trajectory_event.ys[:, -1].view_as(embedding),
#         )

#         batch = self.bb_td(batch)
#         return decollate_unpad_batch(batch.detach().cpu())

# # decollate_unpad_batch --- удаляет паддинг

In [None]:
neg_count = 20
n_classes = 2
hidden_dim = 128
device = 'cpu'

tpp = ODETPPModule(
    neg_count=neg_count,
    n_classes=n_classes,
    hidden_dim=hidden_dim,
    device=device,
)

In [None]:
batch_size = 2
T = 5

time = torch.randn(size=(batch_size, T))
nonpadding_mask = (torch.ones_like(time) > 0)
types = torch.randint(0, n_classes, size=(batch_size, T))

embedding = tpp.encoder(types)

forw_res = tpp.forward(time, nonpadding_mask, embedding, types)
forw_res

(tensor(inf, grad_fn=<MeanBackward0>),
 tensor(3.2976, grad_fn=<MeanBackward0>),
 tensor(0.4410, grad_fn=<MeanBackward0>))

In [None]:
nonpadding_mask

tensor([[True, True, True, True, True],
        [True, True, True, True, True]])

In [None]:
tpp.shared_step("train", {'time': time, 'mask': nonpadding_mask, 'types': types})

tensor(10.6109, grad_fn=<AddBackward0>)

In [None]:
forw_res

(tensor(7.0654, grad_fn=<MeanBackward0>),
 tensor(4.3057, grad_fn=<MeanBackward0>),
 tensor(-0.8223, grad_fn=<MeanBackward0>))

In [None]:
(forw_res[0] + forw_res[1] + forw_res[2]).backward()

### 2. Код TPP ODE (новый, Lightning)

In [None]:
from collections.abc import Callable
from dataclasses import dataclass
from typing import Literal

import torch
from tensordict import TensorDict
from torch import Tensor, nn
from torch.nn import functional as F

import torchode as to
from __future__ import annotations
from torchmetrics import MetricCollection

from torch.optim import AdamW
import pytorch_lightning as pl
from functools import partial

EPS = 1e-6

# from ..mask_utils import masklast
# from ..tensordict_utils import decollate_unpad_batch
# from .base import BaseModule
from pytorch_lightning.callbacks import ModelCheckpoint
# class ODETPPModule(BaseModule):
class ODETPPModule(pl.LightningModule): # использовать lightinig_module (чтобы training_step не прописывать) -> training_step, validation_step: в зависимости от того, что проще
    vf: nn.Module
    neg_count: int
    # device: str

    def __init__(
        self,
        neg_count: int,
        n_classes: int,
        hidden_dim: int = 128,
        # device: str = 'cpu',
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        # self.device = device
        self.hidden_dim = hidden_dim
        self.n_classes = n_classes

        self.optimizer_partial = partial(AdamW, lr=1e-4)
        self.scheduler_partial = False

        self.backbone = Backbone_mean()
        # self.encoder: cat features -> vector = embedding
        # self.backbone: transformer (encoder-only) / aggregate with softmax <--- output is n_classes


        ### 1. Define vector field and embedding
        self.vf = VectorField(hidden_dim) # vf_factory(self.backbone.output_dim) # vector field what is it???
        # self.backbone.output_dim = n_classes???

        # Add encoder params to make them learnable.
        self.encoder = CatEmbedding(n_classes, hidden_dim)
        # self.vf._prev = self.encoder # происходит ДО того, как начинается процесс интегрирования

        ### 2. Set up black box ODE solver (from vector factory)
        term = to.ODETerm(self.vf)
        step_method = to.Dopri5(term=term)
        step_size_controller = to.IntegralController(atol=1e-3, rtol=1e-3, term=term)
        self.solver = torch.compile(
            to.BacksolveAdjoint(term, step_method, step_size_controller) # to.AutoDiffAdjoint may help (faster, but more memory)
        )
        # self.joint_solver = torch.compile(
        #     to.JointBacksolveAdjoint(term, step_method, step_size_controller) # only if all have the same evaluation points
        # )


        ### 3. Define layers to calculate decoupled conditional intensity
        self.intensity_layer = nn.Sequential(
            nn.Linear(hidden_dim, 1), nn.Sigmoid(), nn.Flatten()
        ) # self.encoder.output_dim = hidden_size [in the article authors used softplus]

        self.type_layer = nn.Sequential(
            nn.Linear(hidden_dim, n_classes),
            nn.Softmax(-1),
        ) # outputs probabilities of types


        self.bb_td = self.backbone # self.backbone.to_tensordictmodule()

        self.neg_count = neg_count ###

        self.monitor_name = "val_loss"
        self.monitor_mode = "min"

    def shared_step(
        self,
        stage: None | Literal["train"] | Literal["val"] | Literal["test"],
        batch: TensorDict,
        *args,
        **kwargs,
    ):
        # batch: dict with keys
        '''
            batch[0]: should contain:
                - timestamps (float)
                - corresponding events (int)
                - padding mask (used right padding)
                - types??? (слудующее событие = сдвиг на 1 вперед???)

            encoder(batch):
        '''
        print(f"Types: {batch['types']}")
        embedding = self.encoder(batch['types']) # just encodes the events into embeddings
        print(f"embedding: {embedding}")

        if embedding.flatten().isnan().any():
            return 0


        time: Tensor = batch["time"]
        nonpadding_mask: Tensor = batch["mask"]
        # embedding: Tensor = batch["embedding"]
        batch['embedding'] = embedding
        types: Tensor = batch['types']

        # print(f"time:\n{time}\nmask:\n{nonpadding_mask}\nembedding:\n{embedding}\ntypes:\n{types}\n")

        pos_int, pos_type, neg_int = self.forward(
            time, nonpadding_mask, embedding, types
        )
        loss = pos_int + pos_type + neg_int

        self.log(f"{stage}_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log(f"{stage}_posint_loss", pos_int.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log(f"{stage}_postype_loss", pos_type.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log(f"{stage}_negint_loss", neg_int.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)

        # if stage == "train":
        print(f"------------------Loss: {loss}")
        return loss
        # else:
        #     return self.unnest_merge_detach(
        #         pos_intensity_loss_reduced=pos_int,
        #         pos_type_loss_reduced=pos_type,
        #         neg_intensity_loss_reduced=neg_int,
        #     )

    def forward(
        self, time: Tensor, nonpadding_mask: Tensor, embedding: Tensor, types: Tensor
    ):
        '''
            time: [batch_size, T] <--- timestamps of events (ascending)
            nonpadding_mask: [batch_size, T] <--- 1 if value, 0 if mask

            types: [batch_size, T] <--- int tensor of event types (NEXT events??? What for the last: no subsequent event)
            embedding: [batch_size, T, hidden_dim] <--- embedding of event types, corresponding to according timestemp
        '''
        # types --- типы след. события
        B, T = time.shape # batch_size, T
        N = self.neg_count # what is it?

        ### 1. Fill padding in time with last times in corresponding timeseries
        t_start_int = time # [batch_size, T]
        t_end_int = masklast(time, nonpadding_mask, dim=1, keepdim=True).expand(B, T) # [batch_size] -> (copy) [batch_size, T]. Find last timestamp for every timeserie
        time[~nonpadding_mask] = t_end_int[~nonpadding_mask] # Fill all padding times with last time in corresponding rimeseries (row)

        ### 2. Get total times, at which to calculate the hidden state
        t_neg_01 = torch.rand(B, N, device=time.device, dtype=time.dtype) # [batch_size, N] ~ U[0, 1] --- times
        t_neg = t_neg_01 * (time[:, -1:] - time[:, :1]) + time[:, :1] # [batch_size, N]: random times in a matrix (each random time within range of t_min, t_max) of corresponding timeseries

        t_eval = torch.cat([time, t_neg], dim=1).unsqueeze(1).expand(B, T, -1) # [batch_size, T + N] -> [batch_size, 1, T + N] -> ??? [batch_size, T, ???]

        # print(f"t_start_int: {get_stats(t_start_int)}")
        # print(f"t_end_int: {get_stats(t_end_int)}")
        # print(f"t_eval: {get_stats(t_eval)}")
        print(f"t_start_int: {(t_start_int)}")
        print(f"t_end_int: {(t_end_int)}")
        print(f"t_eval: {(t_eval)}")
        print(embedding.flatten(0, -2))
        print(t_start_int.flatten())
        print(t_end_int.flatten())
        print(t_eval.flatten(0, -2))

        assert (t_start_int.flatten() <= t_end_int.flatten()).all()

        # t_end_int[t_end_int == t_start_int] = t_end_int[t_end_int == t_start_int] + EPS

        ### 3. Compute hidden states in aforementioned period od times
        solution = self.solver.solve(
            to.InitialValueProblem(
                embedding.flatten(0, -2), # of which shape embedding is???
                t_start_int.flatten(), # [batch_size * T]
                t_end_int.flatten(), # [batch_size * T] <--- integrate to the end (of the corresp. timeseries)
                t_eval.flatten(0, -2), # [a network of points in which to evaluate???]
            )
        )
        # solution.ts:
        ys = solution.ys.reshape(B, T, T + N, -1) # [batch_size, T (start time of event), T + N (eval time of event), hidden_dim]
        ts = solution.ts.reshape(B, T, T + N) # [batch_size, T (start time of event), T + N (eval time of event)]

        ### 4. Filter computed hidden states (some should not be considered). TO NOT TOUCH
        # B x T x (T + N)
        mask = (t_start_int.unsqueeze(2) <= ts) & (ts < t_end_int.unsqueeze(2)) # eval time should be: older than corresponding start time and younger than corresp end time

        # Remove diagonal elements, to prevent data leak
        mask[:, range(T), range(T)] = False # we do not evaluate in points, where no integration were held

        mask[..., :T] = (
            mask[..., :T] & nonpadding_mask.unsqueeze(1) & nonpadding_mask.unsqueeze(2)
        )  # B x T x P
        mask[..., T:] = mask[..., T:] & nonpadding_mask.unsqueeze(2)  # B x T x N
        ys = torch.where(
            mask.unsqueeze(-1),
            ys,
            embedding.unsqueeze(2),
        )

        print(f"ys: {get_stats(ys)}")
        print(f"ts: {get_stats(ts)}")

        # Flattened td B * (P + N) x T x C;
        # to be fed into backbone to aggr 1st dim.
        td4backbone = TensorDict(
            {
                "embedding": ys.transpose(1, 2).reshape(B * (T + N), T, -1), # [batch_size * (T + N), T, embedding_dim]]
                "mask": mask.transpose(1, 2).reshape(B * (T + N), T),
            },
            batch_size=(B * (T + N),),
        )

        # print(f"td4backbone.embedding shape: {td4backbone['embedding'].shape}")

        # Aggregate 1st (nut first, but #1 (T) dimension) dimension => B x (T + N) x C
        # Pass only non-masked batch elements
        batch_mask = td4backbone["mask"].any(1)
        aggr_td = self.bb_td(td4backbone[batch_mask]) # [batch_size * (T + N)??, hidden_dim]???]

        aggr_embeddings = torch.zeros(
            B * (T + N),
            self.hidden_dim, # self.backbone.output_dim
            device=self.device,
            dtype=ys.dtype,
        )
        aggr_embeddings[batch_mask] = aggr_td["embedding"]


        aggr_embeddings = aggr_embeddings.reshape(B, (T + N), -1) # [batch_size, T + N, hidden_dim] ???
        print(f"aggr emb: {get_stats(aggr_embeddings)}")

        ### 5. Calculate loss using computed embeddings
        pos, neg = torch.split(aggr_embeddings, [T, N], dim=1) # [batch_size, T, hidden_dim], [batch_size, N, hidden_dim]

        # intensity loss
        print(f"POS: {get_stats(pos)}")
        pos_intensity_loss = -torch.log(self.intensity_layer(pos))
        print(f"pos_intensity_loss: {get_stats(pos_intensity_loss)}")

        # event type loss
        pos_type_loss = F.cross_entropy(
            self.type_layer(pos).flatten(0, -2),
            types.flatten().long(),
            reduction="none",
            ignore_index=0,
        ).reshape(B, T)

        print(f"pos_type_loss: {pos_type_loss}")

        neg_intensity_loss: Tensor = self.intensity_layer(neg)

        print(f"neg_intensity_loss: {neg_intensity_loss}")

        pos_intensity_loss_reduced = pos_intensity_loss[nonpadding_mask].sum(-1).mean()
        pos_type_loss_reduced = pos_type_loss[nonpadding_mask].sum(-1).mean()

        neg_intensity_loss_reduced = (neg_intensity_loss.mean(-1) * time[:, -1]).mean()

        return (
            pos_intensity_loss_reduced,
            pos_type_loss_reduced,
            neg_intensity_loss_reduced,
        )


    def checkpoint_callback(self):
        """Construct a matching ModelCheckpoint callback."""
        if not hasattr(self, "_checkpoint_callback"):
            self._checkpoint_callback = ModelCheckpoint(
                monitor=self.monitor_name, mode=self.monitor_mode
            )

        return self._checkpoint_callback

    def load_best_checkpoint(self):
        """Load best state dict into model."""
        self.load_state_dict(
            torch.load(self._checkpoint_callback.best_model_path)["state_dict"]
        )

    def init_metrics(self, metric_collection: MetricCollection, *stages: str):
        """Initialize metrics for each stage."""
        self.metric_collection = metric_collection
        for stage in stages:
            setattr(self, f"{stage}_metrics", self.metric_collection.clone(stage))

    def metrics(self, stage: Literal["train", "val", "test"] | None):
        """Get the relevant metrics."""
        if stage is None:
            if self.metric_collection is None:
                raise ValueError(
                    "Set metric_factory before calling metrics with stage=None."
                )
            return self.metric_collection.clone("")

        return getattr(self, f"{stage}_metrics")


    def training_step(self, batch, batch_idx, dataloader_idx=0):
        """Run the training step of this model."""
        return self.shared_step("train", batch, batch_idx, dataloader_idx)

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """Enable artifact logging."""
        batch = self.shared_step("val", batch)

        return batch

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """Run the test step of this model."""
        return self.shared_step("test", batch, batch_idx, dataloader_idx)

    def configure_optimizers(self):
        """Configure optimizers."""
        optimizer = self.optimizer_partial(self.parameters())
        if self.scheduler_partial:
            scheduler = self.scheduler_partial(optimizer)
            scheduler_config = dict(scheduler=scheduler, **self.scheduler_config)
            return [optimizer], [scheduler_config]

        return optimizer

    def predict_step(
        self, batch: TensorDict, *args: torch.Any, **kwargs: torch.Any
    ) -> torch.Any:
        batch = self.encoder(batch)
        time: Tensor = batch["time"]
        mask: Tensor = batch["mask"]
        embedding: Tensor = batch["embedding"]

        t_end = masklast(time, mask, dim=1, keepdim=True).expand_as(time)
        time[~mask] = t_end[~mask]
        trajectory_event = self.joint_solver.solve(
            to.InitialValueProblem(
                y0=embedding.flatten(0, -2),
                t_start=time.flatten(),
                t_end=t_end.flatten(),
            )
        )

        embedding = torch.where(
            (time == t_end).unsqueeze(-1),
            embedding,
            trajectory_event.ys[:, -1].view_as(embedding),
        ) # then aggregate and get loss/logits

        batch = self.bb_td(batch)
        return decollate_unpad_batch(batch.detach().cpu())

# # decollate_unpad_batch --- удаляет паддинг

In [None]:
def print_stats(t):
    print(f"{t.flatten().min()}---{t.flatten().max()}; {t.flatten().isnan().any()}")

def get_stats(t):
    return f"{t.flatten().min()}---{t.flatten().max()}; {t.flatten().isnan().any()}"

In [None]:
torch.randn(size=(2, 3, 4)).flatten().isnan()

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False])

In [None]:
tpp.encoder(torch.tensor([[1, 0, 0],
        [0, 0, 0]]).cpu())

tensor([[[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]]], grad_fn=<EmbeddingBackward0>)

In [None]:
for par_name, par in tpp.encoder.named_parameters():
    print(par_name, par)

embedding.weight Parameter containing:
tensor([[    nan,     nan,     nan,     nan],
        [    nan,     nan,     nan,     nan],
        [    nan,     nan,     nan,     nan],
        [    nan,     nan,     nan,     nan],
        [    nan,     nan,     nan,     nan],
        [    nan,     nan,     nan,     nan],
        [-1.9999,  0.0287, -0.3504,  0.4536],
        [    nan,     nan,     nan,     nan],
        [    nan,     nan,     nan,     nan],
        [-1.0465, -2.5754,  0.8889, -0.2293]], requires_grad=True)


In [None]:
neg_count = 2
n_classes = 10
hidden_dim = 4
max_length = 3
batch_size = 2
lr = 1e-4

torch_dataset = EventSequenceDataset(dataset, max_length)

dataloader = DataLoader(
        torch_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
)

tpp = ODETPPModule(
    neg_count=neg_count,
    n_classes=n_classes,
    hidden_dim=hidden_dim,
    # lr=lr,
    # device=device,
)

opt = tpp.configure_optimizers()

In [None]:
batch_size = 2
T = 5

times = torch.randn(size=(batch_size, T))
time = torch.zeros_like(times)

for i in range(batch_size):
    time[i] = torch.tensor(sorted(times[i]))
    time[i] = time[i] - time[i].min()
    time[i] = time[i] / time[i].max()

nonpadding_mask = (torch.ones_like(time) > 0)
types = torch.randint(0, n_classes, size=(batch_size, T))

embedding = tpp.encoder(types)

forw_res = tpp.forward(time, nonpadding_mask, embedding, types)
forw_res

t_start_int: 0.0---1.0; False
t_end_int: 1.0---1.0; False
t_eval: 0.0---1.0; False
tensor([[ 1.3201,  0.5140, -0.6002,  0.3601],
        [ 0.6993,  0.9179,  0.5372,  1.9960],
        [ 0.6993,  0.9179,  0.5372,  1.9960],
        [-1.1232,  0.3883, -0.1905, -1.4667],
        [-0.7341,  0.7822,  0.7711,  0.0690],
        [ 0.3828, -0.4969, -0.6063,  0.4326],
        [ 0.6993,  0.9179,  0.5372,  1.9960],
        [ 0.8078,  0.9735,  1.9105,  0.2795],
        [ 1.3201,  0.5140, -0.6002,  0.3601],
        [ 0.3828, -0.4969, -0.6063,  0.4326]], grad_fn=<ViewBackward0>)
tensor([0.0000, 0.0243, 0.1964, 0.8294, 1.0000, 0.0000, 0.0615, 0.1327, 0.1671,
        1.0000])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([[0.0000, 0.0243, 0.1964, 0.8294, 1.0000, 0.0830, 0.4389],
        [0.0000, 0.0243, 0.1964, 0.8294, 1.0000, 0.0830, 0.4389],
        [0.0000, 0.0243, 0.1964, 0.8294, 1.0000, 0.0830, 0.4389],
        [0.0000, 0.0243, 0.1964, 0.8294, 1.0000, 0.0830, 0.4389],
        [0.0000, 0.02

  t_end_int[t_end_int == t_start_int] = t_end_int[t_end_int == t_start_int] + EPS


(tensor(8.8851, grad_fn=<MeanBackward0>),
 tensor(22.9223, grad_fn=<MeanBackward0>),
 tensor(0.3970, grad_fn=<MeanBackward0>))

In [None]:
loss = forw_res[0] + forw_res[1] + forw_res[2]
print(loss)
loss.backward()
opt.step()

tensor(32.2044, grad_fn=<AddBackward0>)


In [None]:
opt.zero_grad()

In [None]:
opt.step()

In [None]:
loss.backward()

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
tpp.optimizer_partial.step()

AttributeError: 'functools.partial' object has no attribute 'step'

In [None]:
for par_name, par in tpp.encoder.named_parameters():
    print(par_name, par)

embedding.weight Parameter containing:
tensor([[    nan,     nan,     nan,     nan],
        [ 0.8081, -0.2189,  1.6855,  1.2377],
        [    nan,     nan,     nan,     nan],
        [-1.3312,  1.4097, -1.7261, -0.3673],
        [ 1.2208,  2.1358, -0.7320,  0.7653],
        [ 0.0409,  2.2860,  1.1739,  0.4008],
        [ 0.4618,  1.2737, -0.2647,  1.0148],
        [ 1.1140, -0.7972,  0.5095,  0.0975],
        [ 0.4329,  0.5912,  0.5195,  1.8779],
        [-0.1359, -0.5337,  1.0371, -2.5003]], requires_grad=True)


In [None]:
neg_count = 2
n_classes = 10
hidden_dim = 4
max_length = 3
batch_size = 2
lr = 1e-4

torch_dataset = EventSequenceDataset(dataset, max_length)

dataloader = DataLoader(
        torch_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
)

tpp = ODETPPModule(
    neg_count=neg_count,
    n_classes=n_classes,
    hidden_dim=hidden_dim,
    # lr=lr,
    # device=device,
)

trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=tpp.cuda(), train_dataloaders=dataloader)

NameError: name 'Backbone_mean' is not defined

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from tensordict import TensorDict
import pandas as pd
import ast

class EventSequenceDataset(Dataset):
    def __init__(self, dict_dataset, max_length: int = 100):
        """
        Args:
            csv_path (str): Path to CSV file containing:
                - 'timestamps': List of event timestamps
                - 'types': List of event type indices
                - 'target': Single integer target value
        """
        self.df = dict_dataset
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df[idx]

        return {
            'time': torch.tensor(row['time'], dtype=torch.float32)[:self.max_length],
            'types': torch.tensor(row['type'], dtype=torch.long)[:self.max_length],
            'target': torch.tensor(row['target'], dtype=torch.long)
        }

def collate_fn(batch):
    """Collate function to pad sequences and create masks"""
    # Separate components
    times = [item['time'] for item in batch]
    types = [item['types'] for item in batch]
    targets = [item['target'] for item in batch]

    # Get sequence lengths and determine padding length
    lengths = [len(seq) for seq in times]
    max_len = max(lengths)

    # Initialize padded tensors
    batch_size = len(batch)
    padded_times = torch.zeros((batch_size, max_len), dtype=torch.float32)
    padded_types = torch.zeros((batch_size, max_len), dtype=torch.long)
    masks = torch.zeros((batch_size, max_len), dtype=torch.bool)

    # Fill tensors
    for i, (t, ty) in enumerate(zip(times, types)):
        seq_len = t.shape[0]
        padded_times[i, :seq_len] = t
        padded_types[i, :seq_len] = ty
        masks[i, :seq_len] = True  # Mask where actual data exists

    return TensorDict({
        'time': padded_times,
        'types': padded_types,
        'target': torch.stack(targets),
        'mask': masks
    }, batch_size=batch_size)

# # Example usage
# if __name__ == "__main__":
#     # Initialize dataset and dataloader
#     dataset = EventSequenceDataset('your_data.csv')
#     dataloader = DataLoader(
#         dataset,
#         batch_size=32,
#         shuffle=True,
#         collate_fn=collate_fn
#     )

#     # Test one batch
#     batch = next(iter(dataloader))
#     print("Batch structure:")
#     print(f"- Time tensor shape: {batch['time'].shape}")
#     print(f"- Type tensor shape: {batch['type'].shape}")
#     print(f"- Target tensor shape: {batch['target'].shape}")
#     print(f"- Mask tensor shape: {batch['non_padding_mask'].shape}")



In [None]:
torch_dataset = EventSequenceDataset(dataset)

In [None]:
max_length = 10
torch_dataset = EventSequenceDataset(dataset, max_length)

dataloader = DataLoader(
        torch_dataset,
        batch_size=5,
        shuffle=True,
        collate_fn=collate_fn
)

In [None]:
batch = next(iter(dataloader))
print("Batch structure:")
print(f"- Time tensor shape: {batch['time'].shape}")
print(f"- Type tensor shape: {batch['types'].shape}")
print(f"- Target tensor shape: {batch['target'].shape}")
print(f"- Mask tensor shape: {batch['mask'].shape}")

Batch structure:
- Time tensor shape: torch.Size([5, 10])
- Type tensor shape: torch.Size([5, 10])
- Target tensor shape: torch.Size([5])
- Mask tensor shape: torch.Size([5, 10])


In [None]:
batch['target']

tensor([1, 1, 0, 1, 0])

In [None]:
neg_count = 20
n_classes = 10
hidden_dim = 16
max_length = 10
batch_size = 32

torch_dataset = EventSequenceDataset(dataset, max_length)

dataloader = DataLoader(
        torch_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
)

tpp = ODETPPModule(
    neg_count=neg_count,
    n_classes=n_classes,
    hidden_dim=hidden_dim,
    # device=device,
)

In [None]:
len(trx_category_to_id)

10

In [None]:
trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=tpp.cuda(), train_dataloaders=dataloader)

INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name            | Type            | Params | Mode 
------------------------------------------------------------
0 | backbone        | Backbone_mean   | 0      | train
1 | vf              | VectorField     | 2.2 K  | train
2 | encoder         | CatEmbedding    | 160    | train
3 | solver          | OptimizedModule | 2.2 K  | train
4 | intensity_layer | Sequential      | 17     | train
5 | type_layer      | Sequential

Training: |          | 0/? [00:00<?, ?it/s]

  t_end_int[t_end_int == t_start_int] = t_end_int[t_end_int == t_start_int] + EPS


Types: tensor([[2, 2, 2, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 6, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 6, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
        [1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 2, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 0, 0, 0, 0, 0, 0],
        [6, 0, 0, 0, 0, 6, 2, 2, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 2, 0, 0, 0, 0, 0, 6, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 0, 2, 4, 0, 0, 0, 2],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 

RuntimeError: vmap: It looks like you're either (1) calling .item() on a Tensor or (2) attempting to use a Tensor in some data-dependent control flow or (3) encountering this error in PyTorch internals. For (1): we don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. For (2): If you're doing some control flow instead, we don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . For (3): please file an issue.

In [None]:
trainer.fit(model=tpp.cuda(), train_dataloaders=dataloader)

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name            | Type            | Params | Mode 
------------------------------------------------------------
0 | backbone        | Backbone_mean   | 0      | train
1 | vf              | VectorField     | 2.2 K  | train
2 | encoder         | CatEmbedding    | 160    | train
3 | solver          | OptimizedModule | 2.2 K  | train
4 | intensity_layer | Sequential      | 17     | train
5 | type_layer      | Sequential      | 170    | train
---------------------------------------------------------

Training: |          | 0/? [00:00<?, ?it/s]

TypeError: unsupported format string passed to Tensor.__format__

In [None]:
model._metrics

AttributeError: 'VectorField' object has no attribute '_metrics'

In [None]:
trainer.logged_metrics

{'train_loss': tensor(nan),
 'train_posint_loss': tensor(nan),
 'train_postype_loss': tensor(nan),
 'train_negint_loss': tensor(nan)}

### 3. Загрузка датасета Churn

In [None]:
import requests
from urllib.parse import urlencode
base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
public_key = 'https://disk.yandex.ru/i/qpmYnFRGE7mQPw'  # Сюда вписываете вашу ссылку https://disk.yandex.ru/i/qpmYnFRGE7mQPw
# Получаем загрузочную ссылку
final_url = base_url + urlencode(dict(public_key=public_key))
response = requests.get(final_url)
download_url = response.json()['href']
# Загружаем файл и сохраняем его
download_response = requests.get(download_url)
with open('train_churn.csv', 'wb') as f:   # Здесь укажите нужный путь к файлу
    f.write(download_response.content)

In [None]:
churn_dataset = pd.read_csv("/content/train_churn.csv")
churn_dataset['PERIOD'] = pd.to_datetime(churn_dataset['PERIOD'])
churn_dataset['TRDATETIME'] = pd.to_datetime(churn_dataset['TRDATETIME'].str.replace(r'^(\d{2}[A-Z]{3}\d{2}):', r'\1 ', regex=True), format='%d%b%y %H:%M:%S')

churn_dataset['trx_category'] = churn_dataset['trx_category'].astype(str)
churn_dataset = churn_dataset.sort_values(by=['PERIOD', 'TRDATETIME'])
churn_dataset = churn_dataset.dropna(axis=1)

# enough transactions
vc = churn_dataset['cl_id'].value_counts()
valid_cl_ids = vc.index[vc > 100]

churn_dataset = churn_dataset[churn_dataset['cl_id'].isin(valid_cl_ids)]
churn_dataset = churn_dataset.drop(['PERIOD', 'currency', 'amount'], axis=1)


churn_dataset = churn_dataset.reset_index(drop=True)
display(churn_dataset.head(2))
display(churn_dataset.tail(2))
display(churn_dataset.shape)

Unnamed: 0,cl_id,MCC,TRDATETIME,trx_category,target_flag,target_sum
0,1290,5411,2016-10-07 00:00:00,POS,1,321242.09
1,1290,6011,2016-10-07 18:57:17,WD_ATM_ROS,1,321242.09


Unnamed: 0,cl_id,MCC,TRDATETIME,trx_category,target_flag,target_sum
363714,791,6011,2018-04-02 15:06:41,WD_ATM_ROS,1,32714.44
363715,851,6011,2018-04-02 19:07:05,WD_ATM_ROS,0,0.0


(363716, 6)

In [None]:
churn_dataset['trx_category'].value_counts().index

Index(['POS', 'DEPOSIT', 'WD_ATM_ROS', 'WD_ATM_PARTNER', 'C2C_IN',
       'WD_ATM_OTHER', 'C2C_OUT', 'BACK_TRX', 'CAT', 'CASH_ADV'],
      dtype='object', name='trx_category')

In [None]:
id_to_trx_category = ['POS', 'DEPOSIT', 'WD_ATM_ROS', 'WD_ATM_PARTNER', 'C2C_IN',
       'WD_ATM_OTHER', 'C2C_OUT', 'BACK_TRX', 'CAT', 'CASH_ADV']

trx_category_to_id = dict()

for i, val in enumerate(id_to_trx_category):
    trx_category_to_id[val] = i

target_flag --- ушел/не ушел

нормировать время на 0-1

In [None]:
valid_cl_ids[:10]

Index([2143, 5373, 5630, 4564, 1261, 5398, 10, 5847, 757, 1839], dtype='int64', name='cl_id')

для каждого клиента формируем последовательность транзакций типа (время, категория)

In [None]:
groupped = list(churn_dataset.groupby(by=['cl_id'])[['TRDATETIME', 'trx_category', 'target_flag']])

In [None]:
groupped[0]

((1,),
        TRDATETIME trx_category  target_flag
 265860 2017-07-19          POS            0
 267344 2017-07-20          POS            0
 270488 2017-07-22          POS            0
 273085 2017-07-24          POS            0
 274408 2017-07-25          POS            0
 ...           ...          ...          ...
 347440 2017-10-16          POS            0
 347684 2017-10-17          POS            0
 347685 2017-10-17          POS            0
 347916 2017-10-18          POS            0
 348154 2017-10-19          POS            0
 
 [104 rows x 3 columns])

In [None]:
dataset = {}

for i in range(len(groupped)):
    client_id = groupped[i][0][0]
    transactions_ordered_in_time = groupped[i][1]


    target = transactions_ordered_in_time['target_flag'].values
    assert len(set(target)) == 1

    target = target[0]
    times = transactions_ordered_in_time['TRDATETIME'].values
    times = (times - times.min()) / np.timedelta64(1, 'D')
    times = times / times.max()

    dataset[i] = {
        'time': times,
        'type': np.array(list([trx_category_to_id[x] for x in transactions_ordered_in_time['trx_category'].values])),
        'target': target
    }

In [None]:
dataset[0]['time']

array([0.        , 0.01086957, 0.0326087 , 0.05434783, 0.06521739,
       0.07432116, 0.07608696, 0.07608696, 0.08695652, 0.09782609,
       0.13043478, 0.15217391, 0.29347826, 0.35869565, 0.36956522,
       0.36956522, 0.36956522, 0.45652174, 0.47826087, 0.47826087,
       0.51086957, 0.51086957, 0.51086957, 0.51086957, 0.51086957,
       0.5326087 , 0.54347826, 0.58695652, 0.59782609, 0.60869565,
       0.60869565, 0.60869565, 0.61956522, 0.63043478, 0.64130435,
       0.64130435, 0.64130435, 0.64130435, 0.64130435, 0.65217391,
       0.65217391, 0.66304348, 0.66304348, 0.67391304, 0.68478261,
       0.68478261, 0.69565217, 0.70652174, 0.70652174, 0.7173913 ,
       0.73703905, 0.73704748, 0.73913043, 0.73913043, 0.75      ,
       0.75      , 0.75      , 0.75      , 0.76086957, 0.76086957,
       0.76086957, 0.77173913, 0.77173913, 0.77173913, 0.77173913,
       0.77173913, 0.77173913, 0.77173913, 0.77173913, 0.7826087 ,
       0.7826087 , 0.7826087 , 0.7826087 , 0.79347826, 0.79347

### 4. MDB dataset
https://huggingface.co/datasets/ai-lab/MBD