# DataFrameDataModule

> DataFrameDataModule.

In [None]:
#| default_exp core.dfdm

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os, math, inspect
import numpy as np, pandas as pd
from dataclasses import dataclass, field, KW_ONLY

import torch, torch.nn as nn, pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader

from beartype.typing import Optional

from iza.static import (LABEL, )
from littyping.core import Device

In [None]:
#| export
from litds.abc.dfdm.base import BaseDataFrameDataModule, set_dataset
from litds.core.dfds import DataFrameDataset
from litds.mocks.time import MockTimeSeries

## DataFrameDataModule

In [None]:
#| export
@set_dataset(DataFrameDataset)
@dataclass
class DataFrameDataModule(BaseDataFrameDataModule):
    label_key: str = LABEL

    _: KW_ONLY = field(default=None, init=False)
    batch_size: Optional[int] = 64
    include_time: Optional[bool] = False
    device: Optional[Device] = None

    def setup(self, stage: Optional[str] = None):
        pass

    def train_dataloader(self):
        ds = self.make_dataset(df=self.df)
        self.train_ds = ds
        return DataLoader(ds, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def collate_fn(self, batch):
        samples, targets = zip(*batch)        
        samples = torch.stack(samples)
        targets = torch.stack(targets)
        return samples, targets

    def getall(self, pad:Optional[bool]=True):
        return self.ds.getall(pad=pad)

In [None]:
#| eval: False
df = MockTimeSeries(set_index=True).df
df = df.reset_index().drop(columns='series')
df.head()

In [None]:
#| eval: False
dfm = DataFrameDataModule(df=df, label_key='time')

In [None]:
#| eval: False
dfm.df.head()

Unnamed: 0,time,feature_0,feature_1,feature_2
0,0,4,8,0
1,1,0,0,0
2,2,0,2,2
3,3,6,0,4
4,4,7,4,3


In [None]:
#| eval: False
for b in dfm.train_dataloader():
    break
b[0].shape, b[1].shape

(torch.Size([3, 9, 3]), torch.Size([3, 9]))

In [None]:
#| eval: False
dfm.train_ds.df.shape

(23, 4)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()