In [35]:
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

split_type = 'sample_split'

In [53]:
def split_by_pairs(df, col1, col2, train_pairs, test_pairs):
    df['pair'] = list(zip(df[col1], df[col2]))
    train_df = df.loc[df['pair'].isin(train_pairs)]
    test_df = df.loc[df['pair'].isin(test_pairs)]
    assert set(train_df['pair']).isdisjoint(set(test_df['pair']))
    print('train_df: ')
    print(train_df)
    print('test_df: ')
    print(test_df)
    assert len(train_df) + len(test_df) == len(df)
    return train_df, test_df

def split_by_samples(df, train_samples, test_samples):
    train_df = df.loc[df['sample_id'].isin(train_samples)]
    test_df = df.loc[df['sample_id'].isin(test_samples)]
    assert set(train_df['sample_id']).isdisjoint(set(test_df['sample_id']))
    assert len(train_df) + len(test_df) == len(df)
    return train_df, test_df

def split_dataframe(df, col1, col2, X, split_type, train_index, test_index):
    train_vals = X[train_index]
    test_vals = X[test_index]
    if split_type == 'random_split':
        train_vals = list(map(tuple, train_vals))
        test_vals = list(map(tuple, test_vals))
        train_df, test_df = split_by_pairs(df, col1, col2, train_vals, test_vals)
    elif split_type == 'sample_split':
        train_df, test_df = split_by_samples(df, train_vals, test_vals)
    else:
        print('Error! Need valid split_type.')
        return
    return train_df, test_df
        
def get_items_to_split(df, split_type):
    if split_type == 'random_split':
        # split by pairs
        X = df[['sample_id', 'drug_id']].to_numpy()
    elif split_type == 'sample_split':
        X = np.array(list(df['sample_id'].unique()))
    else:
        print('Error! split_type must be in random_split or sample_split.')
        return
    return X

In [50]:
data = {'sample_id': [1, 1, 2, 2, 3, 3, 4, 4, 5, 5], 'drug_id': [1, 2, 3, 1, 2, 3, 1, 2, 3, 1], 
       'value': [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]}
df = pd.DataFrame(data=data)
df.head()

Unnamed: 0,sample_id,drug_id,value
0,1,1,10
1,1,2,20
2,2,3,30
3,2,1,40
4,3,2,50


In [51]:
X = get_items_to_split(df, split_type)
X

array([1, 2, 3, 4, 5])

In [56]:
kf = KFold(n_splits=5, random_state=0, shuffle=True)
kf.get_n_splits(X)
for i, (train_index, val_index) in enumerate(kf.split(X)):
    print(f"Fold {i}:")
    print(f" Train: index={train_index}")
    print(f" Test: index={val_index}")
    train_df, val_df = split_dataframe(df, 'sample_id', 'drug_id', X, split_type, train_index, val_index)
    print('train_df')
    print(train_df)
    print('test_df')
    print(val_df)

Fold 0:
 Train: index=[0 1 3 4]
 Test: index=[2]
train_df
   sample_id  drug_id  value
0          1        1     10
1          1        2     20
2          2        3     30
3          2        1     40
6          4        1     70
7          4        2     80
8          5        3     90
9          5        1    100
test_df
   sample_id  drug_id  value
4          3        2     50
5          3        3     60
Fold 1:
 Train: index=[1 2 3 4]
 Test: index=[0]
train_df
   sample_id  drug_id  value
2          2        3     30
3          2        1     40
4          3        2     50
5          3        3     60
6          4        1     70
7          4        2     80
8          5        3     90
9          5        1    100
test_df
   sample_id  drug_id  value
0          1        1     10
1          1        2     20
Fold 2:
 Train: index=[0 2 3 4]
 Test: index=[1]
train_df
   sample_id  drug_id  value
0          1        1     10
1          1        2     20
4          3        2     5