# Simple Usecase

We all copied this code at least once in our life:

```python
>>> X_train, X_test, y_train, y_test = train_test_split(
...     X, y, test_size=0.33, random_state=42)
```

and even if it's not that long, it's still a bit annoying to write/ copy. It must be a better way to do this. And there is! With the `Dataset` class, you can do this:



In [1]:
from skdataset.dataset import Dataset
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification

# make some data
X, y = make_classification(
    n_samples=200, 
    n_features=4, 
    n_classes=2, 
    n_clusters_per_class=1,
    random_state=42,
)

ds = Dataset(X=X, y=y)

In [2]:
train_ds, test_ds = train_test_split(ds, test_size=0.33, random_state=42)

In [3]:
len(train_ds) + len(test_ds) == len(X)

True

This by itself if not that impressive, but it's a good start. Let's see what else we can do with this class.
We also can manipulate the data in the dataset with `transform` method without changing the original data.

In [4]:
def multiply_by_2(dataset: Dataset):
    dataset.X = dataset.X * 2
    return dataset

ds.transform(multiply_by_2)[:3, 'X'] 

array([[-0.63315862, -0.74435214,  1.13876335, -0.46847537],
       [-0.92523065, -0.29979364,  7.23459721, -5.21427765],
       [ 2.60788321,  4.11773656,  2.7461706 , -4.11749094]])

In [5]:
ds[:3, 'X']

array([[-0.31657931, -0.37217607,  0.56938167, -0.23423769],
       [-0.46261533, -0.14989682,  3.61729861, -2.60713883],
       [ 1.3039416 ,  2.05886828,  1.3730853 , -2.05874547]])

or even filter some rows based on a condition and that filter will be applied to all the variables:


In [6]:
ds.filter(lambda x: x.y == 1)

{'X': array([[-3.16579308e-01, -3.72176068e-01,  5.69381673e-01,
         -2.34237685e-01],
        [ 1.30394160e+00,  2.05886828e+00,  1.37308530e+00,
         -2.05874547e+00],
        [ 2.34871791e+00,  3.66332221e+00,  2.15367981e+00,
         -3.44843438e+00],
        [-1.27793199e-01, -1.24782220e-01,  4.09796127e-01,
         -2.40885141e-01],
        [-5.65769548e-02,  7.15787270e-03,  6.22599748e-01,
         -4.65387410e-01],
        [ 1.30414855e+00,  1.85205111e+00, -9.11810620e-02,
         -8.68221180e-01],
        [ 1.22461511e+00,  1.88944759e+00,  9.77292933e-01,
         -1.67958481e+00],
        [ 8.78841517e-01,  1.55214461e+00,  2.08837827e+00,
         -2.33321540e+00],
        [ 6.13003086e-01,  9.44105156e-01,  4.77240904e-01,
         -8.31020512e-01],
        [-5.77861782e-01, -7.69655932e-01,  4.00816021e-01,
          9.16323690e-02],
        [ 3.33406079e-01,  5.05135335e-01,  2.00506290e-01,
         -4.03958542e-01],
        [ 6.41812359e-01,  1.10808220e

The abvious benefit of this is that you always have all of your data together, so you dont have to filter them once, and map the filter to the other parts(e.g. filtering some rows on X and then apply the same filter on y, or sample_weight)

## DatasetDict

Here we will interduce `DatasetDict` which is a dictionary-like object that holds multiple `Dataset` objects. It's very useful when you have multiple datasets(e.g. train, val and test) that you want to keep together.

You can create one by jast passing a dict with some keys and `Dataset` objects as values or call `split` with a spliter function on your dataset. Here is an example:

In [7]:
from skdataset import DatasetDict

ds_dict = DatasetDict({'train': train_ds, 'test': test_ds})

ds_dict.keys()

dict_keys(['train', 'test'])

In [8]:

ds_dict = ds.split(train_test_split, test_size=0.33, random_state=42)

ds_dict.keys()

dict_keys(['train', 'test'])

This new object has some handy attributes like `X_train` or `y_test` which all automatically generated from the keys in the dict:

In [9]:
ds_dict.X_train

array([[ 1.08384298e+00,  1.63367133e+00,  5.92193278e-01,
        -1.26471893e+00],
       [ 2.06290890e+00,  3.58425588e+00,  4.48422581e+00,
        -5.13700395e+00],
       [ 1.39912667e+00,  2.25751204e+00,  1.81515089e+00,
        -2.48699264e+00],
       [ 2.25697284e-01,  6.73925831e-01,  2.48277463e+00,
        -2.18196447e+00],
       [ 1.12705269e+00,  1.85756778e+00,  1.73826470e+00,
        -2.22787424e+00],
       [ 1.24766859e-01,  3.26881584e-01,  1.04961997e+00,
        -9.43657871e-01],
       [ 3.45250244e-01,  6.56967077e-01,  1.15419481e+00,
        -1.18801251e+00],
       [-1.36849939e+00, -1.96954030e+00, -8.88658098e-02,
         1.06112629e+00],
       [-7.13347532e-01, -7.41314420e-01,  1.97095155e+00,
        -1.08722866e+00],
       [-6.47747668e-01, -1.07817905e+00, -1.07385964e+00,
         1.34126944e+00],
       [ 6.41812359e-01,  1.10808220e+00,  1.34527769e+00,
        -1.55768528e+00],
       [ 1.86442194e+00,  2.74022947e+00,  5.23762269e-01,
      

which is equal to saying `ds_dict['train']['X']`.

Imagine you have a function that you want to apply to all of your splits now, instead of looping over them, you can do `ds_dict.transform(func)` and it will apply the function to all of the splits and return a new `DatasetDict` object:

In [10]:
ds_dict.transform(multiply_by_2)

{'train': {'X': array([[ 2.16768595e+00,  3.26734266e+00,  1.18438656e+00,
          -2.52943786e+00],
         [ 4.12581780e+00,  7.16851176e+00,  8.96845162e+00,
          -1.02740079e+01],
         [ 2.79825335e+00,  4.51502408e+00,  3.63030177e+00,
          -4.97398527e+00],
         [ 4.51394568e-01,  1.34785166e+00,  4.96554927e+00,
          -4.36392895e+00],
         [ 2.25410537e+00,  3.71513556e+00,  3.47652939e+00,
          -4.45574849e+00],
         [ 2.49533719e-01,  6.53763168e-01,  2.09923994e+00,
          -1.88731574e+00],
         [ 6.90500489e-01,  1.31393415e+00,  2.30838961e+00,
          -2.37602502e+00],
         [-2.73699877e+00, -3.93908061e+00, -1.77731620e-01,
           2.12225258e+00],
         [-1.42669506e+00, -1.48262884e+00,  3.94190311e+00,
          -2.17445733e+00],
         [-1.29549534e+00, -2.15635810e+00, -2.14771927e+00,
           2.68253887e+00],
         [ 1.28362472e+00,  2.21616441e+00,  2.69055539e+00,
          -3.11537055e+00],
       

## Put It All Together

Let's see how we can use all of these together. If we wanted to do it without `Dataset` It would be something like this:

```python
from skdataset.dataset import Dataset
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification

X, y = make_classification(
    n_samples=200, 
    n_features=4, 
    n_classes=2, 
    n_clusters_per_class=1,
    random_state=42,
)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42
)

model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)
print('Train score:', model.score(X_train, y_train))
print('Test score:', model.score(X_test, y_test))
```

In [11]:
from skdataset.dataset import Dataset
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification

X, y = make_classification(
    n_samples=200, 
    n_features=4, 
    n_classes=2, 
    n_clusters_per_class=1,
    random_state=42,
)

ds_dict = Dataset(X=X, y=y).split(train_test_split, test_size=0.33, random_state=42)

model = RandomForestClassifier(random_state=42)
model.fit(**ds_dict['train'])

print('Train score:', model.score(**ds_dict['train']))
print('Test score:', model.score(**ds_dict['test']))

Train score: 1.0
Test score: 0.8181818181818182
