In [144]:
import pandas as pd
import pickle
from itertools import combinations
import random

In [10]:
goldstd_file = '../ppi_ml/data/gold_stds/all.gold.cmplx.noRibos.merged.txt'
annot_file = '../ppi_ml/annotations/leca_eunog_annots_complete.030721.csv'

## Functions

In [53]:
def map_ids(x, id_dict):
    nog_id1 = x.split(' ')[0]
    nog_id2 = x.split(' ')[1]
    gene_id1 = id_dict.get(nog_id1, nog_id1)
    gene_id2 = id_dict.get(nog_id2, nog_id2)
    return(gene_id1+' '+gene_id2)

def make_fset(x, drop=True):
    if len(set(x.split(' '))) < 2:
        print(f"WARNING: Features for '{x}-{x}' (self-self PPI) detected ...")
        if drop == False:
            x1 = x.split(' ')[0]
            fset = frozenset({x1,x1})
            return(fset)
        else:
            return(None)
    else:
        x1 = x.split(' ')[0]
        x2 = x.split(' ')[1]
        fset = frozenset({x1,x2})
        return(fset)

## Pick test complexes

In [11]:
annots = pd.read_csv(annot_file)
annots

Unnamed: 0,ID,amor_11,exca_2,tsar_5,viri_13,old_status,old_cmplx_assignment,go_gene_name,go_cmplx_name,corum_cmplx_name,...,human_protein_names,human_length,human_function_cc,human_annotscore_1to5,human_subcellular_location_cc,arath_entry,arath_gene_names_primary,arath_protein_names,arath_function_cc,arath_subcellular_location_cc
0,ENOG502QPHT,0.000,0.5,0.0,0.000,,,,,,...,,,,,,,,,,
1,ENOG502QPHW,0.000,0.0,0.0,0.923,,,,,,...,,,,,,"A0A1P8B9Z1, A0A1P8BDX3, F4KCX6","NA, MQM1.23","Transmembrane protein, Uncharacterized protein",,
2,ENOG502QPHZ,0.000,0.0,0.0,0.308,,,,,,...,,,,,,,,,,
3,ENOG502QPIA,0.455,0.5,0.0,0.000,,,,,,...,"Secernin-3, Secernin-2","424, 425",,"2, 4",,,,,,
4,ENOG502QPIC,0.091,0.0,0.2,1.000,,,,,,...,,,,,,"O82811, Q1G385, Q9FJH7, Q9FJH8, Q9LMZ9, Q9LPV5...","NRT2.1, NA, NRT2.3, NRT2.4, NRT2.2, NRT2.5, NR...",High-affinity nitrate transporter 2.1 (AtNRT2:...,"FUNCTION: Involved in nitrate transport, but d...",SUBCELLULAR LOCATION: Cell membrane {ECO:00002...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5983,KOG4845,0.818,0.0,0.0,0.846,Known,Mitochondrial respiratory chain complex I,MT-ND4,mitochondrial respiratory chain complex I,Respiratory chain complex I (beta subunit) mit...,...,NADH-ubiquinone oxidoreductase chain 4 (EC 7.1...,459,FUNCTION: Core subunit of the mitochondrial me...,5,SUBCELLULAR LOCATION: Mitochondrion inner memb...,"A0A1P8B2B4, A0A2P2CLG4, P26288, P93313","NA, nad4, ndhD, ND4","NADH-plastoquinone oxidoreductase subunit, NAD...","NA, FUNCTION: Core subunit of the mitochondria...","NA, SUBCELLULAR LOCATION: Membrane {ECO:000025..."
5984,KOG4846,0.727,0.0,0.0,0.000,,,"NR1D1, NR1D2",chromatin,,...,Nuclear receptor subfamily 1 group D member 1 ...,"614, 579",FUNCTION: Transcriptional repressor which coor...,5,SUBCELLULAR LOCATION: Nucleus {ECO:0000250|Uni...,,,,,
5985,KOG4847,0.727,0.0,0.0,0.000,,,,,,...,"IQCJ-SCHIP1 readthrough transcript protein, Sc...","563, 487",FUNCTION: May play a role in action potential ...,5,"SUBCELLULAR LOCATION: Cell projection, axon {E...",,,,,
5986,KOG4849,0.818,0.0,0.0,0.000,,,"CPSF6, CPSF7","ribonucleoprotein complex, mRNA cleavage and p...","Spliceosome, CF IIAm complex (Cleavage factor ...",...,Cleavage and polyadenylation specificity facto...,"551, 471",FUNCTION: Component of the cleavage factor Im ...,5,SUBCELLULAR LOCATION: Nucleus {ECO:0000269|Pub...,,,,,


In [27]:
merged = merged[['ID', 'group', 'super_group', 'label']]
merged = merged.sort_values(by=['super_group'])
merged

Unnamed: 0,ID,group,super_group,label
26515,KOG1894 KOG3400,1444,1,1
656,KOG0052 KOG0468,1418,1,1
625,KOG2146 KOG0050,1320,1,1
626,KOG0050 KOG2297,1320,1,1
627,KOG0050 KOG2330,1320,1,1
...,...,...,...,...
21049,KOG3048 KOG2778,1463,79,-1
15178,KOG4270 KOG1957,1463,79,-1
24178,KOG0971 KOG0171,1463,79,-1
8170,KOG0771 KOG2910,1463,79,-1


In [14]:
id_dict = dict()
for i in range(len(annots)):
    nog_id = annots['ID'][i]
    gene_id = annots['human_gene_names_primary'][i]
    if not gene_id:
        gene_id = annots['arath_gene_names_primary'][i]
    gene_id_fmt = str(gene_id).replace(', ', '/')
    id_dict[nog_id] = gene_id_fmt

{'ENOG502QPHT': 'nan',
 'ENOG502QPHW': 'nan',
 'ENOG502QPHZ': 'nan',
 'ENOG502QPIA': 'SCRN3/SCRN2',
 'ENOG502QPIC': 'nan',
 'ENOG502QPIP': 'nan',
 'ENOG502QPIQ': 'nan',
 'ENOG502QPJ0': 'nan',
 'ENOG502QPJ2': 'nan',
 'ENOG502QPJ7': 'nan',
 'ENOG502QPJC': 'nan',
 'ENOG502QPJI': 'nan',
 'ENOG502QPJV': 'MANEA/MANEAL',
 'ENOG502QPJZ': 'nan',
 'ENOG502QPKB': 'nan',
 'ENOG502QPKF': 'nan',
 'ENOG502QPKK': 'DNMT1',
 'ENOG502QPKQ': 'nan',
 'ENOG502QPKZ': 'nan',
 'ENOG502QPM4': 'CCDC146',
 'ENOG502QPMF': 'nan',
 'ENOG502QPMK': 'PCDH11X/PCDH11Y/PCDH9',
 'ENOG502QPN1': 'nan',
 'ENOG502QPNA': 'IFT46',
 'ENOG502QPNJ': 'nan',
 'ENOG502QPNY': 'DNASE1L1/DNASE1L3',
 'ENOG502QPP0': 'MVP',
 'ENOG502QPP4': 'EIF4E3',
 'ENOG502QPP7': 'nan',
 'ENOG502QPPH': 'nan',
 'ENOG502QPPT': 'ADCY10',
 'ENOG502QPPV': 'nan',
 'ENOG502QPPW': 'nan',
 'ENOG502QPPX': 'VASH1/VASH2',
 'ENOG502QPPY': 'nan',
 'ENOG502QPQ4': 'nan',
 'ENOG502QPQ6': 'nan',
 'ENOG502QPQC': 'nan',
 'ENOG502QPQG': 'nan',
 'ENOG502QPQR': 'nan',
 'ENOG502

In [28]:
merged['gene_ID'] = [map_ids(i, id_dict) for i in merged['ID']]

In [29]:
merged_pos = merged[merged.label == 1]
merged_pos

Unnamed: 0,ID,group,super_group,label,gene_ID
26515,KOG1894 KOG3400,1444,1,1,RPAP1 POLR2H
656,KOG0052 KOG0468,1418,1,1,EEF1A1/EEF1A2/EEF1A1P5 EFTUD2
625,KOG2146 KOG0050,1320,1,1,SRRM1 CDC5L
626,KOG0050 KOG2297,1320,1,1,CDC5L BZW1/BZW2
627,KOG0050 KOG2330,1320,1,1,CDC5L SF3B2
...,...,...,...,...,...
26119,KOG3595 KOG3929,521,77,1,DNAH14/DYNC1H1/DNAH12/DNAH10/DYNC2H1/DNAH3/DNA...
26125,KOG3599 KOG4203,"[1, 1065]",78,1,PKD1/PKD2/PKD1L2/PKD1L3/LOXHD1/PKD1L1/PKDREJ/P...
13681,KOG4203 KOG1703,1065,78,1,JUP/CTNNB1/UCK2/UCK1/UCKL1 PDLIM1/TGFB1I1/LPXN...
24321,KOG0266 KOG4838,24,79,1,WDR27/WDR5/NWD1/WDR38/WDR88/WDR5B/WDR49/AHI1/C...


In [30]:
merged[merged.super_group == 75]

Unnamed: 0,ID,group,super_group,label,gene_ID
7639,KOG1833 KOG0688,122,75,-1,NUP210L/NUP210 ETF1
13522,KOG1670 KOG2438,122,75,-1,EIF4E1B/EIF4E GATB
19246,KOG4604 KOG0354,122,75,-1,MICOS10 DDX58/FANCM/DHX58/IFIH1
15083,KOG2799 KOG1942,860,75,-1,SUCLA2 RUVBL1
13219,KOG1861 KOG1625,860,75,-1,LENG8 POLA2
4782,KOG4270 KOG0365,122,75,-1,FAM13A/INPP5B/ARHGAP25/ARHGAP44/ARHGAP17/BARGI...
6211,KOG0526 KOG2397,122,75,-1,SSRP1 PRKCSH/GNPTG
1693,KOG2830 KOG0122,860,75,-1,NA/IGBP1 EIF3G
24589,KOG0466 KOG3887,122,75,-1,EIF2S3/EIF2S3B RRAGC/RRAGD
20356,KOG1761 KOG2165,860,75,-1,SRP14 ANAPC2


In [26]:
out_file = '../ppi_ml/data/featmats/test_group_merge/positive_group_labels_annoted.csv'
merged_pos.to_csv(out_file)

# Generate test feature matrix

In [None]:
fmat_file = '../ppi_ml/data/featmats/test_group_merge/featmat_labeled_all_groups_traintest.pkl'
with open(fmat_file, 'rb') as handle:
    fmat = pickle.load(handle)

In [155]:
## Generate positive pairs
cmplx_list = []
fs_list = []
id_list = []
with open('../ppi_ml/data/featmats/test_group_merge/syntest/gold_cmplx_test.txt') as f:
    lines = f.readlines()
    for line in lines:
        line = line.strip('\n')
        cmplx_list.append(line)
        ids = line.split(' ')
        id_list.extend([ids])
        fsets = [frozenset({i, j}) for i,j in list(combinations(ids, 2))]
        fs_list.extend(fsets)
flat_id_list = [p for cmplx in id_list for p in cmplx]
print('Complex list:')
print(cmplx_list)
print()

## Generate negative pairs
all_pairs = [frozenset({i, j}) for i,j in list(combinations(flat_id_list, 2))]
neg_pairs = set(all_pairs).difference(set(fs_list))
print('Negative pairs:')
print(neg_pairs)

## Subset feature matrix for positive pairs
fmat['frozen_pairs'] = [make_fset(i) for i in fmat['ID']]
fmat_test = fmat[fmat['frozen_pairs'].isin(fs_list)]
test_cols = ['neg_ln_pval_bioplex2_Z4','leca_euks_elut.filtdollo.filt150p.raw.pearsonR.feat','plants_concat.raw.150p.pearsonR.feat', 'animals_concat.raw.150p.braycurtis.feat']
fmat_test = fmat_test[['ID']+test_cols]

## Subset feature matrix for observed negative pairs
fmat_test_neg = fmat[fmat['frozen_pairs'].isin(neg_pairs)]
fmat_test_neg = fmat_test_neg[['ID']+test_cols]
fmat_test_neg

## Ope, there are no observed negatives from this set
## Need to generate fake data
neg_pair_list = []
for p in neg_pairs:
    if len(p) == 2:
        pair_fmt = ' '.join(p)
        neg_pair_list.append(pair_fmt)
print(neg_pair_list)
neg_pair_sample = random.sample(neg_pair_list, 20)
fmat_test_neg['ID'] = neg_pair_sample
fmat_test_neg.fillna(0, inplace=True)

## Join pos + neg and write out
fmat_test.reset_index(inplace=True, drop=True)
fmat_test_neg.reset_index(inplace=True, drop=True)
fmat_test_all = pd.concat([fmat_test, fmat_test_neg], ignore_index=True)
fmat_test_all.to_csv('../ppi_ml/data/featmats/test_group_merge/syntest/test_featmat', index=False)
fmat_test_all

Complex list:
['KOG0371 KOG2830 KOG2867', 'KOG2867 KOG0373 KOG2830', 'KOG1104 KOG0096 KOG0121', 'KOG1104 KOG0121 KOG2314 KOG0328', 'KOG2363 KOG4639 KOG3322 KOG3387 KOG2567']

Negative pairs:
{frozenset({'KOG0371', 'KOG2363'}), frozenset({'KOG3322', 'KOG0096'}), frozenset({'KOG1104', 'KOG2867'}), frozenset({'KOG2867'}), frozenset({'KOG4639', 'KOG0121'}), frozenset({'KOG0371', 'KOG0096'}), frozenset({'KOG2830', 'KOG0121'}), frozenset({'KOG2830', 'KOG2567'}), frozenset({'KOG4639', 'KOG0096'}), frozenset({'KOG0373', 'KOG2363'}), frozenset({'KOG0371', 'KOG0328'}), frozenset({'KOG2314', 'KOG3322'}), frozenset({'KOG2567', 'KOG0096'}), frozenset({'KOG2830', 'KOG4639'}), frozenset({'KOG1104'}), frozenset({'KOG2314', 'KOG2867'}), frozenset({'KOG0371', 'KOG0121'}), frozenset({'KOG2567', 'KOG0371'}), frozenset({'KOG1104', 'KOG2567'}), frozenset({'KOG2830', 'KOG0328'}), frozenset({'KOG0121'}), frozenset({'KOG4639', 'KOG0328'}), frozenset({'KOG2314', 'KOG2363'}), frozenset({'KOG3387', 'KOG0121'}), f

Unnamed: 0,ID,neg_ln_pval_bioplex2_Z4,leca_euks_elut.filtdollo.filt150p.raw.pearsonR.feat,plants_concat.raw.150p.pearsonR.feat,animals_concat.raw.150p.braycurtis.feat
0,KOG0096 KOG0121,0.0,0.0086,-0.0025,0.0999
1,KOG0328 KOG0121,21.3288,0.0346,0.1496,0.1079
2,KOG0121 KOG1104,28.4474,0.2943,0.4169,0.3157
3,KOG2314 KOG0121,0.0,0.0553,0.0307,0.0758
4,KOG0328 KOG1104,69.2498,0.0171,0.1268,0.1955
5,KOG0328 KOG2314,2.6493,0.0465,0.0735,0.122
6,KOG2830 KOG0371,3.1537,0.1101,0.115,0.2274
7,KOG2867 KOG0371,0.0,0.0479,0.0157,0.0437
8,KOG2830 KOG0373,2.2092,0.1107,0.119,0.2282
9,KOG2867 KOG0373,0.0,0.0405,0.021,0.0352


## Generate encoded test feature matrix

In [131]:
alph_list = [chr(i) for i in range(ord('A'),ord('Z')+1)]
encode_dict = dict()
dupes = []
skip = 0
for i in range(len(flat_id_list)):
    nog_id = flat_id_list[i]
    alph_id = alph_list[i-skip]
    if nog_id not in encode_dict.keys():
        encode_dict[nog_id] = alph_id
    else:
        dupes.append(nog_id)
        skip += 1

In [132]:
encode_dict

{'KOG0371': 'A',
 'KOG2830': 'B',
 'KOG2867': 'C',
 'KOG0373': 'D',
 'KOG1104': 'E',
 'KOG0096': 'F',
 'KOG0121': 'G',
 'KOG2314': 'H',
 'KOG0328': 'I',
 'KOG2363': 'J',
 'KOG4639': 'K',
 'KOG3322': 'L',
 'KOG3387': 'M',
 'KOG2567': 'N'}

In [111]:
dupes

['KOG2867', 'KOG2830', 'KOG1104', 'KOG0121']

In [156]:
fmat_test_all['coded_ID'] = [map_ids(i, encode_dict) for i in fmat_test_all['ID']]
fmat_test_all

Unnamed: 0,ID,neg_ln_pval_bioplex2_Z4,leca_euks_elut.filtdollo.filt150p.raw.pearsonR.feat,plants_concat.raw.150p.pearsonR.feat,animals_concat.raw.150p.braycurtis.feat,coded_ID
0,KOG0096 KOG0121,0.0,0.0086,-0.0025,0.0999,F G
1,KOG0328 KOG0121,21.3288,0.0346,0.1496,0.1079,I G
2,KOG0121 KOG1104,28.4474,0.2943,0.4169,0.3157,G E
3,KOG2314 KOG0121,0.0,0.0553,0.0307,0.0758,H G
4,KOG0328 KOG1104,69.2498,0.0171,0.1268,0.1955,I E
5,KOG0328 KOG2314,2.6493,0.0465,0.0735,0.122,I H
6,KOG2830 KOG0371,3.1537,0.1101,0.115,0.2274,B A
7,KOG2867 KOG0371,0.0,0.0479,0.0157,0.0437,C A
8,KOG2830 KOG0373,2.2092,0.1107,0.119,0.2282,B D
9,KOG2867 KOG0373,0.0,0.0405,0.021,0.0352,C D


In [159]:
fmat_test_coded = fmat_test_all.drop('ID', axis=1)
fmat_test_coded = fmat_test_coded.rename(columns={'coded_ID':'ID'})
fmat_test_coded = fmat_test_coded[['ID']+test_cols]
fmat_test_coded.to_csv('../ppi_ml/data/featmats/test_group_merge/syntest_encoded/test_featmat_coded', index=False)
fmat_test_coded.to_pickle('../ppi_ml/data/featmats/test_group_merge/syntest_encoded/test_featmat_coded.pkl')

In [160]:
fmat_test_coded

Unnamed: 0,ID,neg_ln_pval_bioplex2_Z4,leca_euks_elut.filtdollo.filt150p.raw.pearsonR.feat,plants_concat.raw.150p.pearsonR.feat,animals_concat.raw.150p.braycurtis.feat
0,F G,0.0,0.0086,-0.0025,0.0999
1,I G,21.3288,0.0346,0.1496,0.1079
2,G E,28.4474,0.2943,0.4169,0.3157
3,H G,0.0,0.0553,0.0307,0.0758
4,I E,69.2498,0.0171,0.1268,0.1955
5,I H,2.6493,0.0465,0.0735,0.122
6,B A,3.1537,0.1101,0.115,0.2274
7,C A,0.0,0.0479,0.0157,0.0437
8,B D,2.2092,0.1107,0.119,0.2282
9,C D,0.0,0.0405,0.021,0.0352


In [113]:
test_pairs_coded = [map_ids(i, encode_dict) for i in fmat_test['ID']]
print(test_pairs_coded)
with open('../ppi_ml/data/featmats/test_group_merge/syntest_encoded/test_pairs_coded.txt', 'w') as f:
    for i in test_pairs_coded:
        f.write(i+'\n')

['F G', 'I G', 'G E', 'H G', 'I E', 'I H', 'B A', 'C A', 'B D', 'C D', 'H E', 'J N', 'J L', 'J M', 'N L', 'N M', 'C B', 'M L', 'J K', 'K N', 'K M']


## Generate encoded gold standard file

In [118]:
coded_cmplx_lst = []
with open('../ppi_ml/data/featmats/test_group_merge/syntest/gold_cmplx_test.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        cmplx = line.strip('\n').split(' ')
        coded_cmplx = []
        for p in cmplx:
            new_id = encode_dict[p]
            coded_cmplx.append(new_id)
        coded_cmplx_lst.append(coded_cmplx)
        print(cmplx,'\t',coded_cmplx)
        
with open('../ppi_ml/data/featmats/test_group_merge/syntest_encoded/gold_cmplx_test_coded.txt', 'w') as f:
    for cmplx in coded_cmplx_lst:
        cmplx_str = ' '.join(cmplx)
        f.write(cmplx_str+'\n')

['KOG0371', 'KOG2830', 'KOG2867'] 	 ['A', 'B', 'C']
['KOG2867', 'KOG0373', 'KOG2830'] 	 ['C', 'D', 'B']
['KOG1104', 'KOG0096', 'KOG0121'] 	 ['E', 'F', 'G']
['KOG1104', 'KOG0121', 'KOG2314', 'KOG0328'] 	 ['E', 'G', 'H', 'I']
['KOG2363', 'KOG4639', 'KOG3322', 'KOG3387', 'KOG2567'] 	 ['J', 'K', 'L', 'M', 'N']


## Generate grouped train/test split example

In [161]:
import numpy as np
from sklearn.model_selection import GroupShuffleSplit
from sklearn.model_selection import GroupKFold
from sklearn.model_selection import StratifiedGroupKFold

In [166]:
labeled_file = '../ppi_ml/data/featmats/test_group_merge/syntest_encoded/featmat_labeled_traintest'
labeled = pd.read_csv(labeled_file)

In [167]:
X = labeled[test_cols].to_numpy()
y = labeled['label'].to_numpy()
groups = labeled['super_group'].to_numpy()

In [173]:
gs = GroupShuffleSplit(n_splits = 3, train_size=0.7, random_state=13)
for i, (train_idx, test_idx) in enumerate(gs.split(X, y, groups)):
    print(f"Fold {i+1}:")
    print(f"  Train: index={train_idx}, group={groups[train_idx]}")
    print(f"  Test:  index={test_idx}, group={groups[test_idx]}")
    print()

Fold 1:
  Train: index=[ 0  1  2  3  4  5 10 11 12 13 14 15 17 18 19 20 21 22 23 24 25 26 28 29
 30], group=[1 1 1 1 1 1 1 5 5 5 5 5 5 5 5 5 5 1 1 1 5 5 1 1 1]
  Test:  index=[ 6  7  8  9 16 27], group=[2 2 2 2 2 2]

Fold 2:
  Train: index=[ 0  1  2  3  4  5 10 11 12 13 14 15 17 18 19 20 21 22 23 24 25 26 28 29
 30], group=[1 1 1 1 1 1 1 5 5 5 5 5 5 5 5 5 5 1 1 1 5 5 1 1 1]
  Test:  index=[ 6  7  8  9 16 27], group=[2 2 2 2 2 2]

Fold 3:
  Train: index=[ 0  1  2  3  4  5 10 11 12 13 14 15 17 18 19 20 21 22 23 24 25 26 28 29
 30], group=[1 1 1 1 1 1 1 5 5 5 5 5 5 5 5 5 5 1 1 1 5 5 1 1 1]
  Test:  index=[ 6  7  8  9 16 27], group=[2 2 2 2 2 2]



In [172]:
gs = GroupKFold(n_splits = 3)
for i, (train_idx, test_idx) in enumerate(gs.split(X, y, groups)):
    print(f"Fold {i+1}:")
    print(f"  Train: index={train_idx}, group={groups[train_idx]}")
    print(f"  Test:  index={test_idx}, group={groups[test_idx]}")
    print()

Fold 1:
  Train: index=[ 6  7  8  9 11 12 13 14 15 16 17 18 19 20 21 25 26 27], group=[2 2 2 2 5 5 5 5 5 2 5 5 5 5 5 5 5 2]
  Test:  index=[ 0  1  2  3  4  5 10 22 23 24 28 29 30], group=[1 1 1 1 1 1 1 1 1 1 1 1 1]

Fold 2:
  Train: index=[ 0  1  2  3  4  5  6  7  8  9 10 16 22 23 24 27 28 29 30], group=[1 1 1 1 1 1 2 2 2 2 1 2 1 1 1 2 1 1 1]
  Test:  index=[11 12 13 14 15 17 18 19 20 21 25 26], group=[5 5 5 5 5 5 5 5 5 5 5 5]

Fold 3:
  Train: index=[ 0  1  2  3  4  5 10 11 12 13 14 15 17 18 19 20 21 22 23 24 25 26 28 29
 30], group=[1 1 1 1 1 1 1 5 5 5 5 5 5 5 5 5 5 1 1 1 5 5 1 1 1]
  Test:  index=[ 6  7  8  9 16 27], group=[2 2 2 2 2 2]



In [174]:
gs = StratifiedGroupKFold(n_splits = 3)
for i, (train_idx, test_idx) in enumerate(gs.split(X, y, groups)):
    print(f"Fold {i+1}:")
    print(f"  Train: index={train_idx}, group={groups[train_idx]}")
    print(f"  Test:  index={test_idx}, group={groups[test_idx]}")
    print()

Fold 1:
  Train: index=[ 0  1  2  3  4  5  6  7  8  9 10 16 22 23 24 27 28 29 30], group=[1 1 1 1 1 1 2 2 2 2 1 2 1 1 1 2 1 1 1]
  Test:  index=[11 12 13 14 15 17 18 19 20 21 25 26], group=[5 5 5 5 5 5 5 5 5 5 5 5]

Fold 2:
  Train: index=[ 0  1  2  3  4  5 10 11 12 13 14 15 17 18 19 20 21 22 23 24 25 26 28 29
 30], group=[1 1 1 1 1 1 1 5 5 5 5 5 5 5 5 5 5 1 1 1 5 5 1 1 1]
  Test:  index=[ 6  7  8  9 16 27], group=[2 2 2 2 2 2]

Fold 3:
  Train: index=[ 6  7  8  9 11 12 13 14 15 16 17 18 19 20 21 25 26 27], group=[2 2 2 2 5 5 5 5 5 2 5 5 5 5 5 5 5 2]
  Test:  index=[ 0  1  2  3  4  5 10 22 23 24 28 29 30], group=[1 1 1 1 1 1 1 1 1 1 1 1 1]

