# Base TimeDataset and TimeDataModule

> TimeDataset

In [None]:
#| default_exp time.base

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import numpy as np, pandas as pd

import torch, torch.nn as nn, pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [None]:
#| export
from dataclasses import dataclass, field, KW_ONLY
from beartype.typing import Optional, Union, Iterable

from iza.static import TIME, SERIES
from iza.utils import Slice

In [None]:
#| export
from littyping.core import (Device)

from litds.abc.dfdm.base import set_dataset, BaseDataFrameDataModule
from litds.abc.dfds.base import BaseDataFrameDataset
from litds.types import (
    SequenceWithLength, SequencesWithLengths
)
from litds.mocks.time import MockTimeSeries
from litds.time.mixs import TimeDatasetMixin

In [None]:
#| eval: False
df = MockTimeSeries(set_index=True).df
df.head()

Unnamed: 0_level_0,time,feature_0,feature_1,feature_2
series,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0,4,8,0
0,1,0,0,0
0,2,0,2,2
0,3,6,0,4
0,4,7,4,3


## TimeDataset

In [None]:
#| export
@dataclass
class TimeDataset(TimeDatasetMixin, BaseDataFrameDataset): 
    pass

## TimeDataModule

In [None]:
#| export
@set_dataset(TimeDataset)
class TimeDataModule(BaseDataFrameDataModule):
    time_key: str = TIME
    _: KW_ONLY = field(default=None, init=False)
    batch_size: Optional[int] = 64
    include_time: Optional[bool] = False
    device: Optional[Device] = None

    def setup(self, stage:Optional[str]=None):
        pass

    def train_dataloader(self):
        ds = self.make_dataset()
        self.train_ds = ds
        return DataLoader(ds, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def collate_fn(self, batch):
        seqs, time = zip(*batch)
        # seqs = pad_sequence(seqs, batch_first=True)
        seqs = torch.stack(seqs)
        time = torch.stack(time)
        return seqs, time

    def getall(self, pad:Optional[bool]=True) -> Union[SequenceWithLength, SequencesWithLengths]:
        pad = getattr(self, 'pad', pad)
        return self.ds.getall(pad=pad)

### Example

In [None]:
#| eval: False
df = MockTimeSeries(set_index=True).df
df = df.head().reset_index().drop(columns='series')

In [None]:
#| eval: False
tdm = TimeDataModule(df)

In [None]:
#| eval: False
tdm.df

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()