6.1) In the "split.py" module of the "model_selection" subpackage add the "stratified_train_test_split" function (Consider
the structure of the function presented in the next slide)

In [11]:
import sys
import numpy as np
sys.path.append("/Users/utilizador/Documents/GitHub/si/src")
from sklearn.utils import shuffle
from typing import Tuple
import numpy as np
from si.data.dataset import Dataset
from si.io.csv_file import read_csv

def stratified_train_test_split(dataset:Dataset, test_size=0.2, random_state:int =None) ->Tuple[Dataset, Dataset]:
    """
    split the dataset into training and testing sets while maintaining the class distribution
    
    Parameters
    ----------
    dataset: Dataset
        The dataset to split
    test_size: float
        The proportion of the dataset to include in the test split
    random_state: int
        The seed of the random number generator
        
    Returns
    train: Dataset
        The training dataset
    test: Dataset
        The testing dataset
    
    """
    
    X= dataset.X
    y= dataset.y
    
    labels = y
    unique_classes, class_counts = np.unique(labels, return_counts=True)
    train= []
    test=[]
    if random_state is not None:
        np.random.seed(random_state)
        
    for label, count in zip(unique_classes, class_counts):
        
        idxs = np.where(labels == label)[0]
        
        num_test= int(np.floor(test_size * count))
        
        idxs= shuffle(idxs, random_state= random_state)
        
        lables_test_idxs= idxs[:num_test]
        test.extend(lables_test_idxs) #use the extendo because we add multiple elements
        
        lables_train_idxs= idxs[num_test:]
        train.extend(lables_train_idxs)
    
    train= np.array(train, dtype=int)
    test= np.array(test, dtype=int)
    
    
    X_train, X_test = X[train], X[test]
    y_train, y_test = y[train], y[test]
    
    train_dataset = {'data': X_train, 'target': y_train}
    test_dataset = {'data': X_test, 'target': y_test}
    
    train_dataset = Dataset(X_train, y_train, features=dataset.features, label=dataset.label)
    test_dataset = Dataset(X_test, y_test, features=dataset.features, label=dataset.label)
    
    
    return train_dataset, test_dataset

In [7]:
Path= "/Users/utilizador/Documents/GitHub/si/datasets/iris/"
data = read_csv(Path + "iris.csv", sep=",", label=True)
data.summary()

Unnamed: 0,feat_0,feat_1,feat_2,feat_3
mean,5.843333,3.054,3.758667,1.198667
median,5.8,3.0,4.35,1.3
min,4.3,2.0,1.0,0.1
max,7.9,4.4,6.9,2.5
var,0.681122,0.186751,3.092425,0.578532


In [17]:
train_dataset, test_dataset = stratified_train_test_split(data, test_size=0.2)

In [19]:
train_dataset.X

array([[4.9, 3. , 1.4, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [5.8, 4. , 1.2, 0.2],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [4.4, 3.2, 1.3, 0.2],
       [5.1, 3.8, 1.5, 0.3],
       [4.8, 3. , 1.4, 0.1],
       [5. , 3.4, 1.5, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.5, 4.2, 1.4, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [5.4, 3.9, 1.3, 0.4],
       [5. , 3.3, 1.4, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [4.6, 3.4, 1.4, 0.3],
       [4.3, 3. , 1.1, 0.1],
       [4.8, 3.4, 1.9, 0.2],
       [5.2, 4.1, 1.5, 0.1],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [5.1, 3.5, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.7, 3.8, 1.7, 0.3],
       [4.6, 3.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.4, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [5.1, 3

In [20]:
test_dataset.X

array([[4.8, 3.1, 1.6, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5. , 3. , 1.6, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.9, 1.7, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [4.8, 3. , 1.4, 0.3],
       [4.7, 3.2, 1.3, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [4.5, 2.3, 1.3, 0.3],
       [6.1, 2.8, 4.7, 1.2],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.9, 4.2, 1.3],
       [5.5, 2.4, 3.8, 1.1],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.2, 2.2, 4.5, 1.5],
       [7. , 3.2, 4.7, 1.4],
       [6. , 2.7, 5.1, 1.6],
       [6.9, 3.1, 4.9, 1.5],
       [5.6, 2.8, 4.9, 2. ],
       [6.3, 3.4, 5.6, 2.4],
       [6.8, 3.2, 5.9, 2.3],
       [6. , 3. , 4.8, 1.8],
       [7.6, 3. , 6.6, 2.1],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [6.4, 3.2, 5.3, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6. , 2.2, 5. , 1.5]])