# Dataset Base and Meta

> Base Classes for Datasets

In [None]:
#| default_exp abc.dfdm.base

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from dataclasses import dataclass, field
from beartype.typing import (Any, Optional,  Type)


import pytorch_lightning as pl
import numpy as np, pandas as pd

from iza.utils import (filter_kwargs_for_class)

In [None]:
#| export
from litds.types import BoolLike

## DataModules

In [None]:
#| export
from litds.abc.dfdm.meta import MetaDataFrameDataModule
from litds.abc.dfds.mixs import DataFrameArgsMixin, DataFrameKWArgsMixins
from litds.abc.dfds.base import BaseDataFrameDataset

In [None]:
#| export
@dataclass
class BaseDataFrameDataModule(
    pl.LightningDataModule, DataFrameArgsMixin, DataFrameKWArgsMixins, 
    metaclass=MetaDataFrameDataModule
): 
    
    DatasetClass: Type[BaseDataFrameDataset] = field(init=False, repr=False, default=BaseDataFrameDataset)
        
    def __post_init__(self):        
        super().__post_init__()        
        pass

    def prepare_data(self):
        pass
        return self
    
    def setup(self, stage: Optional[str]=None):
        pass
        return self
    
    def make_dataset(self, **kwargs):
        params = self.__kwargs__()
        params.update(kwargs)
        params = filter_kwargs_for_class(self.DatasetClass, **params)    
        return self.DatasetClass(**params)

In [None]:
#| export
def set_dataset(DatasetClass: Any = BaseDataFrameDataset):
    def inner(cls):        
        cls.DatasetClass = DatasetClass
        return cls
    return inner

### Example

In [None]:
#| eval: False
df = pd.DataFrame(
    np.random.randint(0, 10, (10, 3)), 
    index=np.random.choice('a b c'.split(), 10),
    columns='x y z'.split()
)
df.head()

In [None]:
#| eval: False
@set_dataset(BaseDataFrameDataset)
class DFModuleTest(BaseDataFrameDataModule):
    pass

In [None]:
#| eval: False
dm = DFModuleTest(df=df)

In [None]:
#| eval: False
dm.setup()

DataFrameDataModuleTest()

In [None]:
#| eval: False
ds = dm.make_dataset()

In [None]:
#| eval: False
ds.loc['a']

Unnamed: 0,x,y,z
a,0,9,0
a,1,3,9


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()