In [None]:
#| default_exp data

In [None]:
#| hide
from nbdev.showdoc import *

# data
> Data reading and wrangling functionality

In [None]:
#| export
import types

from fastcore.utils import patch

import torch
import torch.distributions
from torch.distributions import uniform

# from latent_ode.generate_timeseries import Periodic_1d
from uc3m_ml_healthcare.generate_timeseries import Periodic_1d

These `import`s are not actually required by this module but only used in tests.

In [None]:
# import

## Synthetic data

Except for the first and last lines, everything else comes from [Rubanova's implementation](https://github.com/YuliaRubanova/latent_ode/blob/c0682d4f52b806fb88d965755892eadd9783f936/lib/parse_datasets.py) (comments mine)

In [None]:
#| export
def make_periodic_dataset(
    timepoints: int, # Number of time instants
    extrap: bool, # Whether extrapolation is peformed
    max_t: float, # Maximum value of time instants
    n: int, # Number of examples
    noise_weight: float # Standard deviation of the noise to be added
): # Time and observations
# ) -> tuple[torch.Tensor, torch.Tensor]: # Time and observations # <-------------- Python 3.10
    
    # so that we can use the original code "verbatim" (plus some comments)
    args = types.SimpleNamespace(timepoints=timepoints, extrap=extrap, max_t=max_t, n=n, noise_weight=noise_weight)
    
    # --------- Rubanova (see above)

    n_total_tp = args.timepoints + args.extrap

    # better understood as max_t_extrap = n_total_tp / args.timepoints * args.max_t (you adjust `max_t` if extrapolation is requested)
    # if `args.extrap` is `False`, then this is exactly equal to `n_total_tp` since `n_total_tp = args.timepoints`
    max_t_extrap = args.max_t / args.timepoints * n_total_tp

    distribution = uniform.Uniform(torch.Tensor([0.0]),torch.Tensor([max_t_extrap]))
    time_steps_extrap =  distribution.sample(torch.Size([n_total_tp-1]))[:,0] # last part is just "squeezing"
    time_steps_extrap = torch.cat((torch.Tensor([0.0]), time_steps_extrap)) # 0 is always there
    time_steps_extrap = torch.sort(time_steps_extrap)[0]

    # frequencies are not passed (and henced sampled internally)
    dataset_obj = Periodic_1d(
        init_freq = None, init_amplitude = 1.,
        final_amplitude = 1., final_freq = None, 
        z0 = 1.)

    dataset = dataset_obj.sample_traj(time_steps_extrap, n_samples = args.n, noise_weight = args.noise_weight)
    
    # ---------
    
    return time_steps_extrap, dataset

In [None]:
time, observations = make_periodic_dataset(timepoints=100, extrap=True, max_t=5.0, n=200, noise_weight=0.01)
time.shape, observations.shape

(torch.Size([101]), torch.Size([200, 101, 1]))

## PyTorch

A class defining a (somehow complex) *collate function* for a PyTorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)

In [None]:
#| export
class CollateFunction:
    
    def __init__(self,
                 time: torch.Tensor, # Time axis [time]
                 # n_points_to_subsample: int | None = None # Number of points to be "subsampled" # <-------------- Python 3.10
                 n_points_to_subsample = None # Number of points to be "subsampled"
                ):
        
        self.time = time
        
        self._n_time_instants = len(time)
        self._half_n_time_instants: int = self._n_time_instants // 2
        
        if n_points_to_subsample is None:
            self.n_points_to_subsample = self._half_n_time_instants
        else:
            self.n_points_to_subsample = n_points_to_subsample
            
    
    # TODO: drop time instants with no data (`get_next_batch` in Rubanova's code)
        
    def __call__(self,
                 batch: list # Observations [batch]
                ) -> dict:
        
        # [batch, time, feature]
        batch = torch.stack(batch)
        
        # ----------- splitting on training and to-predict
        
        # for observations
        observed_data = batch[:, :self._half_n_time_instants, :].clone()
        to_predict_data = batch[:, self._half_n_time_instants:, :].clone()
        
        # for time
        observed_time = self.time[:self._half_n_time_instants].clone()
        to_predict_at_time = self.time[self._half_n_time_instants:].clone()
        
        # ----------- mask
        
        # CAVEAT: only on observed data
        observed_mask = torch.ones_like(observed_data, device=observed_data.device)
        
        # if we are to sample ALL the points in the observed data...
        if self._half_n_time_instants == self.n_points_to_subsample:
            
            # ...there is nothing to do here
            pass
        
        else:
            
            raise Exception('not implemented')
            
        # ----------- observation-less time instants
        
        # # summing across "batch" and "feature" dimensions
        # non_missing = (observed_data.sum(dim=(0, 2)) != 0.)
        
        
        return dict(
            observed_time=observed_time, observed_data=observed_data,
            to_predict_at_time=to_predict_at_time, to_predict_data=to_predict_data,
            observed_mask=observed_mask)
        
    def __str__(self):
        
        return f'Collate function expecting {self._n_time_instants} time instants, subsampling {self.n_points_to_subsample}.'
    
    # a object is represented by its string
    __repr__ = __str__
    
    def to(self, device):
        
        self.time = self.time.to(device=device)

Let us build an object for testing

In [None]:
collate_fn = CollateFunction(time, n_points_to_subsample=50)
collate_fn

Collate function expecting 101 time instants, subsampling 50.

We also need a PyTorch `DataLoader`

In [None]:
dataloader = torch.utils.data.DataLoader(observations, batch_size = 10, shuffle=False, collate_fn=collate_fn)
dataloader

<torch.utils.data.dataloader.DataLoader>

How many batches is this `DataLoader` providing?

In [None]:
n_batches = len(dataloader)
n_batches

20

Let us get the first batch

In [None]:
batch_bundle = next(iter(dataloader))
type(batch_bundle)

dict

Notice that, as seen from `CollateFunction.__call__` function's prototype, the type is returned is a dictionary. It contains the following fields

In [None]:
print(batch_bundle.keys())

dict_keys(['observed_time', 'observed_data', 'to_predict_at_time', 'to_predict_data', 'observed_mask'])


- `observed_time` and `observed_data` is the **first part** of a time series we want to learn, whereas
- `to_predict_at_time`, `to_predict_data` is the **second part** of the *same* time series we aim at predicting; on the other hand
- `observed_mask` is `True` for every observation that is available (it only applies to the *observed* data)

If one must think of this in terms of an input, $x$, that is given, and a related output, $y$, that is to be predicted, the latter would be `to_predict_data` and the former would encompass the rest of the fields.

We can check the size of every component

In [None]:
for k, v in batch_bundle.items():
    print(f'Dimensions of {k}: {tuple(v.shape)}')

Dimensions of observed_time: (50,)
Dimensions of observed_data: (10, 50, 1)
Dimensions of to_predict_at_time: (51,)
Dimensions of to_predict_data: (10, 51, 1)
Dimensions of observed_mask: (10, 50, 1)


In this simple example, every observatios is available

In [None]:
(batch_bundle['observed_mask'] == 1.).all()

tensor(True)

### GPU support

If one wants to *move* this object to another device, this function will do that for all the relevant internal state.

In [None]:
#| export
@patch
def to(self: CollateFunction, device):
    
    self.time = self.time.to(device=device)

In [None]:
#| hide
from nbdev.doclinks import nbdev_export

In [None]:
#| hide
nbdev_export('10_data.ipynb')