## Dealing with unbalanced data
## Creating random 5-fold splits
Use the NEK2 Binding data set as an example

In [1]:
import math
import torch
import numpy as np
import gpytorch
import pandas as pd
from matplotlib import pyplot as plt

from sklearn.model_selection import KFold

%matplotlib inline
%load_ext autoreload
%autoreload 2


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
# Get data
data_path = "/global/scratch/users/fan4/datasets/scaled_descriptors/"

binding_df = pd.read_csv(data_path+"NEK2_1_uM_min_50_pct_binding_with_moe_descriptors.csv") 
binding_df


Unnamed: 0,compound_id,base_rdkit_smiles,active,ASA+_per_atom,ASA-,ASA_H_per_atom,ASA_P,ASA_per_atom,BCUT_PEOE_0,BCUT_PEOE_1,...,vsurf_Wp2_per_atom,vsurf_Wp3,vsurf_Wp4,vsurf_Wp5,vsurf_Wp6,vsurf_Wp7,vsurf_Wp8,weinerPath,weinerPol_per_atom,zagreb_per_atom
0,kdb_2562,Cn1cnc2c(N)ncnc21,0,14.897949,53.326256,8.903966,161.21796,17.860519,-2.222484,-0.418766,...,18.020833,95.250,30.125,12.500,3.750,0.250,0.0,137,0.833333,3.222222
1,kdb_2536,FC(F)(F)c1ccc(/C=C/c2cncnc2Nc2ccc3[nH]c(Cc4ccc...,0,6.511466,299.143010,8.385662,196.06219,11.950429,-2.311277,-0.649332,...,10.279545,141.125,30.125,12.250,2.875,0.000,0.0,4495,0.945455,3.418182
2,kdb_3056,CNCc1ccc(-c2cc(-c3nc(-c4ccc(S(=O)(=O)C(C)C)cc4...,0,8.496225,275.100830,8.517174,278.04828,13.229856,-2.568112,-0.618748,...,12.677966,231.750,54.000,20.250,4.625,0.125,0.0,3645,0.932203,2.983051
3,kdb_3510,CNC(=O)Nc1ccc2c(c1)CC[C@@]21OC(=O)N(CC(=O)N(Cc...,0,6.791036,344.713500,8.136228,261.31158,12.350931,-2.586196,-0.562002,...,7.979839,113.875,27.125,10.625,2.500,0.000,0.0,4847,1.064516,3.322581
4,kdb_3024,CCCNc1ccc2ncc(-c3ccnc(C)c3)n2n1,0,9.173611,133.466250,9.677048,114.33565,12.685881,-2.513905,-0.540779,...,9.536184,96.250,22.875,7.500,0.750,0.000,0.0,828,0.763158,2.736842
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1403,kdb_2305,C[C@@H]1CCN(C(=O)CO)C[C@@H]1N(C)c1ncnc2[nH]ccc12,0,8.819834,119.169170,6.812851,206.03860,11.604447,-2.808345,-0.568651,...,8.514535,100.000,24.000,8.250,2.125,0.000,0.0,1016,0.837209,2.697674
1404,kdb_2496,O=C(Nc1n[nH]c2nc(-c3nccs3)c(Br)cc12)C1CCCC1,0,8.136458,252.025010,11.508347,131.47646,15.061764,-2.494025,-0.587177,...,12.851351,144.875,36.125,14.125,4.250,0.250,0.0,1234,0.891892,3.405405
1405,kdb_2226,CC(C)COC(=O)c1ccc2c(c1)/C(=C/Nc1ccc(S(N)(=O)=O...,0,7.304606,311.213350,7.477071,302.59012,13.528873,-2.591978,-0.620893,...,12.157500,208.750,58.500,23.125,7.250,0.750,0.0,2549,0.880000,3.040000
1406,kdb_2219,COc1cc2ncnc(Nc3ccc4c(cnn4Cc4ccccc4)c3)c2cc1OC,0,9.033107,139.988530,9.740264,106.82767,11.794643,-2.335035,-0.642214,...,9.331731,101.125,16.125,6.750,1.625,0.000,0.0,2957,0.980769,3.230769


In [3]:
print("Dataset shape:",binding_df.shape)

print(binding_df.active.value_counts())

num_gap = (binding_df.loc[binding_df['active']==0].shape[0]) - (binding_df.loc[binding_df['active']==1].shape[0])
print("\nDifference in class sample sizes: ",num_gap)

num_minority = binding_df.loc[binding_df['active']==1].shape[0]
print("Number of minority samples: ",num_minority)

Dataset shape: (1408, 309)
active
0    1351
1      57
Name: count, dtype: int64

Difference in class sample sizes:  1294
Number of minority samples:  57


## Obtain splits

In [4]:
# Separate majority and minority classes
df_majority = binding_df[binding_df['active']==0]
df_minority = binding_df[binding_df['active']==1]

print("Shape of the majority: ",df_majority.shape)
df_minority

Shape of the majority:  (1351, 309)


Unnamed: 0,compound_id,base_rdkit_smiles,active,ASA+_per_atom,ASA-,ASA_H_per_atom,ASA_P,ASA_per_atom,BCUT_PEOE_0,BCUT_PEOE_1,...,vsurf_Wp2_per_atom,vsurf_Wp3,vsurf_Wp4,vsurf_Wp5,vsurf_Wp6,vsurf_Wp7,vsurf_Wp8,weinerPath,weinerPol_per_atom,zagreb_per_atom
8,kdb_2785,Nc1ncc(-c2cccnc2)c2scc(-c3ccc4c(c3)CCN4C(=O)Cc...,1,8.325928,262.59103,11.143368,111.40368,13.13272,-2.446876,-0.625692,...,11.229911,147.625,22.5,7.375,2.125,0.375,0.0,3699,1.035714,3.392857
20,kdb_2014,CN[C@@H]1C[C@H]2O[C@@](C)([C@@H]1OC)n1c3ccccc3...,1,6.903796,177.77444,7.091759,170.00287,9.833741,-2.766122,-0.529788,...,8.270161,145.0,34.75,10.75,1.5,0.0,0.0,2737,1.370968,3.548387
162,kdb_980,Cc1cnc(-c2cnc(NCCNc3ccc(C#N)cn3)nc2-c2ccc(Cl)c...,1,8.781019,264.83151,10.362926,187.3172,14.10927,-2.510593,-0.653889,...,12.1825,176.875,44.25,18.875,9.125,2.875,0.5,3332,0.96,3.32
171,kdb_2654,Cc1ccc(C)c(Oc2nccc(-c3c(-c4ccc(F)cc4)ncn3C3CCN...,1,7.216254,191.20229,8.520195,112.96584,10.402959,-2.754877,-0.641516,...,8.4875,121.75,25.25,9.375,2.0,0.0,0.0,3148,0.883333,2.966667
202,kdb_2692,C[C@@H](Oc1cc(-n2cnc3cc(-c4ccncc4)ccc32)sc1C(N...,1,6.934025,297.44644,9.521226,168.0475,12.752907,-2.441788,-0.622655,...,11.141827,154.125,29.5,11.125,4.0,0.5,0.0,3402,1.038462,3.461538
280,kdb_51,CCOc1cc2ncc(C#N)c(Nc3ccc(F)c(Cl)c3)c2cc1NC(=O)...,1,8.191867,208.23019,7.842065,231.18535,11.828019,-2.620959,-0.57441,...,10.795259,233.25,58.375,23.25,6.5,0.25,0.0,3320,0.896552,2.862069
298,kdb_2696,COc1cccc(C2=C(Nc3cc(Cl)c(O)c(Cl)c3)C(=O)NC2=O)c1,1,7.519988,293.28033,11.027122,163.5164,15.446484,-2.276777,-0.658796,...,11.709459,135.125,29.5,10.375,2.875,0.0,0.0,1471,1.108108,3.567568
307,kdb_2448,Nc1ncnc2occ(-c3ccc4c(c3)CCN4C(=O)Cc3cc(F)ccc3F...,1,7.673137,268.34943,9.745125,177.60167,13.606031,-2.443113,-0.720864,...,10.913043,120.875,24.25,9.625,2.5,0.125,0.0,2615,1.108696,3.652174
343,kdb_2155,O=c1c2cnc(Nc3ccc(N4CCN(CC5CCCCC5)CC4)cc3)nc2n2...,1,6.629744,242.89304,8.084622,133.77724,9.868318,-2.852835,-0.609622,...,7.68,133.0,25.625,13.375,5.375,1.125,0.0,6863,0.946667,3.04
362,kdb_2334,CN1CCC(NC(=O)c2ccc(Nc3ncc4c(n3)N(C3CCCC3)CC(F)...,1,8.011797,218.21959,7.862388,228.82764,11.085313,-2.801621,-0.584651,...,8.890845,203.125,49.125,18.0,4.875,0.0,0.0,4822,0.915493,2.84507


In [5]:
# Create 5-fold splits
# Use random_state=0
kf = KFold(n_splits=5, shuffle=True, random_state=0)

# majority
for i, (_, v_ind) in enumerate(kf.split(df_majority)):
    df_majority.loc[df_majority.index[v_ind], 'fold'] = f"fold{i+1}"
    
df_majority['fold'].value_counts()



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_majority.loc[df_majority.index[v_ind], 'fold'] = f"fold{i+1}"


fold
fold1    271
fold4    270
fold2    270
fold3    270
fold5    270
Name: count, dtype: int64

In [6]:
df_majority

Unnamed: 0,compound_id,base_rdkit_smiles,active,ASA+_per_atom,ASA-,ASA_H_per_atom,ASA_P,ASA_per_atom,BCUT_PEOE_0,BCUT_PEOE_1,...,vsurf_Wp3,vsurf_Wp4,vsurf_Wp5,vsurf_Wp6,vsurf_Wp7,vsurf_Wp8,weinerPath,weinerPol_per_atom,zagreb_per_atom,fold
0,kdb_2562,Cn1cnc2c(N)ncnc21,0,14.897949,53.326256,8.903966,161.21796,17.860519,-2.222484,-0.418766,...,95.250,30.125,12.500,3.750,0.250,0.0,137,0.833333,3.222222,fold4
1,kdb_2536,FC(F)(F)c1ccc(/C=C/c2cncnc2Nc2ccc3[nH]c(Cc4ccc...,0,6.511466,299.143010,8.385662,196.06219,11.950429,-2.311277,-0.649332,...,141.125,30.125,12.250,2.875,0.000,0.0,4495,0.945455,3.418182,fold1
2,kdb_3056,CNCc1ccc(-c2cc(-c3nc(-c4ccc(S(=O)(=O)C(C)C)cc4...,0,8.496225,275.100830,8.517174,278.04828,13.229856,-2.568112,-0.618748,...,231.750,54.000,20.250,4.625,0.125,0.0,3645,0.932203,2.983051,fold2
3,kdb_3510,CNC(=O)Nc1ccc2c(c1)CC[C@@]21OC(=O)N(CC(=O)N(Cc...,0,6.791036,344.713500,8.136228,261.31158,12.350931,-2.586196,-0.562002,...,113.875,27.125,10.625,2.500,0.000,0.0,4847,1.064516,3.322581,fold4
4,kdb_3024,CCCNc1ccc2ncc(-c3ccnc(C)c3)n2n1,0,9.173611,133.466250,9.677048,114.33565,12.685881,-2.513905,-0.540779,...,96.250,22.875,7.500,0.750,0.000,0.0,828,0.763158,2.736842,fold1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1403,kdb_2305,C[C@@H]1CCN(C(=O)CO)C[C@@H]1N(C)c1ncnc2[nH]ccc12,0,8.819834,119.169170,6.812851,206.03860,11.604447,-2.808345,-0.568651,...,100.000,24.000,8.250,2.125,0.000,0.0,1016,0.837209,2.697674,fold5
1404,kdb_2496,O=C(Nc1n[nH]c2nc(-c3nccs3)c(Br)cc12)C1CCCC1,0,8.136458,252.025010,11.508347,131.47646,15.061764,-2.494025,-0.587177,...,144.875,36.125,14.125,4.250,0.250,0.0,1234,0.891892,3.405405,fold1
1405,kdb_2226,CC(C)COC(=O)c1ccc2c(c1)/C(=C/Nc1ccc(S(N)(=O)=O...,0,7.304606,311.213350,7.477071,302.59012,13.528873,-2.591978,-0.620893,...,208.750,58.500,23.125,7.250,0.750,0.0,2549,0.880000,3.040000,fold4
1406,kdb_2219,COc1cc2ncnc(Nc3ccc4c(cnn4Cc4ccccc4)c3)c2cc1OC,0,9.033107,139.988530,9.740264,106.82767,11.794643,-2.335035,-0.642214,...,101.125,16.125,6.750,1.625,0.000,0.0,2957,0.980769,3.230769,fold3


In [7]:
# minority
for i, (_, v_ind) in enumerate(kf.split(df_minority)):
    df_minority.loc[df_minority.index[v_ind], 'fold'] = f"fold{i+1}"
    
df_minority['fold'].value_counts()




A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_minority.loc[df_minority.index[v_ind], 'fold'] = f"fold{i+1}"


fold
fold1    12
fold2    12
fold5    11
fold4    11
fold3    11
Name: count, dtype: int64

In [8]:
df_minority['fold']

8       fold5
20      fold4
162     fold1
171     fold5
202     fold1
280     fold3
298     fold4
307     fold2
343     fold3
362     fold5
366     fold1
369     fold1
381     fold4
397     fold3
414     fold2
416     fold3
434     fold3
456     fold4
536     fold2
554     fold5
609     fold3
612     fold5
735     fold1
748     fold4
750     fold4
778     fold4
781     fold3
782     fold1
788     fold2
891     fold2
906     fold3
907     fold2
919     fold2
944     fold1
981     fold1
989     fold1
1002    fold5
1077    fold2
1086    fold4
1096    fold5
1111    fold2
1121    fold1
1147    fold4
1157    fold4
1158    fold5
1169    fold4
1259    fold2
1264    fold5
1275    fold3
1289    fold3
1293    fold2
1295    fold2
1302    fold5
1322    fold5
1342    fold3
1352    fold1
1389    fold1
Name: fold, dtype: object

In [9]:
all_fold_df = pd.concat([df_majority,df_minority])
print(all_fold_df.shape)
print(all_fold_df.active.value_counts())


(1408, 310)
active
0    1351
1      57
Name: count, dtype: int64


In [10]:
# Save to file
split_path = "/global/scratch/users/fan4/NEK_data_4Berkeley/NEK2"
#all_fold_df.to_csv(split_path+"/NEK2_1_uM_min_50_pct_binding_5fold_random_imbalanced.csv", index=False)


In [None]:
# Using one fold as a test set, and the other folds as a training set.