## Dealing with imbalanced data
Use the NEK2 Binding data set as an example

In [1]:
import math
import torch
import numpy as np
import gpytorch
import pandas as pd
import os
from matplotlib import pyplot as plt

from sklearn.model_selection import KFold

%matplotlib inline
%load_ext autoreload
%autoreload 2


In [2]:
# Binding
nekAll = ["2","3","5","9"]
#nekAll = ["3","5","9"]

for nek in nekAll:
    # Get training data
    data_path = "/p/lustre2/fan4/NEK_data/NEK"+nek+"/scaled_descriptors/"

    binding_df = pd.read_csv(data_path+"NEK"+nek+"_1_uM_min_50_pct_binding_with_moe_descriptors.csv") 
    print(binding_df.shape)

    print(binding_df.active.value_counts())
    num_gap = (binding_df.loc[binding_df['active']==0].shape[0]) - (binding_df.loc[binding_df['active']==1].shape[0])
    print(num_gap)
    num_minority = binding_df.loc[binding_df['active']==1].shape[0]
    print(num_minority)

    # Separate majority and minority classes
    df_majority = binding_df[binding_df['active']==0]
    df_minority = binding_df[binding_df['active']==1]

    #=======================
    # Create 5-fold splits
    #=======================
    kf = KFold(n_splits=5, shuffle=True, random_state=0)

    # majority
    for i, (_, v_ind) in enumerate(kf.split(df_majority)):
        df_majority.loc[df_majority.index[v_ind], 'fold'] = f"fold{i+1}"

    # minority
    for i, (_, v_ind) in enumerate(kf.split(df_minority)):
        df_minority.loc[df_minority.index[v_ind], 'fold'] = f"fold{i+1}"


    print(df_majority['fold'].value_counts())
    print(df_minority['fold'].value_counts())


    # Concat
    all_fold_df = pd.concat([df_majority,df_minority])
    print(all_fold_df.shape)
    print(all_fold_df.active.value_counts())


    # Save to file
    split_path = "/p/lustre2/fan4/NEK_data/NEK_data_4Berkeley/NEK"+nek
    
    if not os.path.exists(split_path):
        os.makedirs(split_path)

    all_fold_df.to_csv(split_path+"/NEK"+nek+"_1_uM_min_50_pct_binding_5fold_random_imbalanced.csv", index=False)


FileNotFoundError: [Errno 2] No such file or directory: '/p/lustre2/fan4/NEK_data/NEK2/scaled_descriptors/NEK2_1_uM_min_50_pct_binding_with_moe_descriptors.csv'

In [3]:
# inhibition
source = "inhibition"
nekAll = ["2","9"]

for nek in nekAll:
    # Get training data
    data_path = "/p/lustre2/fan4/NEK_data/NEK"+nek+"/scaled_descriptors/"

    binding_df = pd.read_csv(data_path+"NEK"+nek+"_1_uM_min_50_pct_"+source+"_with_moe_descriptors.csv") 
    print(binding_df.shape)

    print(binding_df.active.value_counts())
    num_gap = (binding_df.loc[binding_df['active']==0].shape[0]) - (binding_df.loc[binding_df['active']==1].shape[0])
    print(num_gap)
    num_minority = binding_df.loc[binding_df['active']==1].shape[0]
    print(num_minority)

    # Separate majority and minority classes
    df_majority = binding_df[binding_df['active']==0]
    df_minority = binding_df[binding_df['active']==1]

    #=======================
    # Create 5-fold splits
    #=======================
    kf = KFold(n_splits=5, shuffle=True, random_state=0)

    # majority
    for i, (_, v_ind) in enumerate(kf.split(df_majority)):
        df_majority.loc[df_majority.index[v_ind], 'fold'] = f"fold{i+1}"

    # minority
    for i, (_, v_ind) in enumerate(kf.split(df_minority)):
        df_minority.loc[df_minority.index[v_ind], 'fold'] = f"fold{i+1}"


    print(df_majority['fold'].value_counts())
    print(df_minority['fold'].value_counts())


    # Concat
    all_fold_df = pd.concat([df_majority,df_minority])
    print(all_fold_df.shape)
    print(all_fold_df.active.value_counts())


    # Save to file
    split_path = "/p/lustre2/fan4/NEK_data/NEK_data_4Berkeley/NEK"+nek
    all_fold_df.to_csv(split_path+"/NEK"+nek+"_1_uM_min_50_pct_"+source+"_5fold_random_imbalanced.csv", index=False)


(2044, 309)
active
0    1904
1     140
Name: count, dtype: int64
1764
140
fold
fold4    381
fold1    381
fold2    381
fold3    381
fold5    380
Name: count, dtype: int64
fold
fold4    28
fold3    28
fold1    28
fold5    28
fold2    28
Name: count, dtype: int64
(2044, 310)
active
0    1904
1     140
Name: count, dtype: int64


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_majority.loc[df_majority.index[v_ind], 'fold'] = f"fold{i+1}"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_minority.loc[df_minority.index[v_ind], 'fold'] = f"fold{i+1}"


(393, 309)
active
0    351
1     42
Name: count, dtype: int64
309
42
fold
fold1    71
fold4    70
fold2    70
fold3    70
fold5    70
Name: count, dtype: int64
fold
fold2    9
fold1    9
fold5    8
fold4    8
fold3    8
Name: count, dtype: int64
(393, 310)
active
0    351
1     42
Name: count, dtype: int64


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_majority.loc[df_majority.index[v_ind], 'fold'] = f"fold{i+1}"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_minority.loc[df_minority.index[v_ind], 'fold'] = f"fold{i+1}"


## 5-fold Split

In [4]:
# Binding
method = "binding"
nekAll = ["2","3","5","9"]
foldAll = ["fold1","fold2","fold3","fold4","fold5"]

for nek in nekAll:
    # Get training data
    split_path = "/p/lustre2/fan4/NEK_data/NEK_data_4Berkeley/NEK"+nek
    random_df = pd.read_csv(split_path+"/NEK"+nek+"_1_uM_min_50_pct_binding_5fold_random_imbalanced.csv")
    random_df.head

    moe_columns = random_df.columns[3:]
    moe_columns = moe_columns[:-1]
    moe_columns

    # Use the fold 0 as the test set, fold 1,2,3,4 as the train set
    for fold in foldAll:
        test_moe_df = random_df.loc[random_df['fold'] == fold]
        train_moe_df = random_df.loc[random_df['fold'] != fold]
        print(test_moe_df.shape)
        print(train_moe_df.shape)

        test_x_df = test_moe_df[moe_columns]
        test_y_df = test_moe_df['active']
        print(test_x_df)
        print(test_y_df.value_counts())

        train_x_df = train_moe_df[moe_columns]
        train_y_df = train_moe_df['active']
        print(train_x_df.shape)
        print(train_y_df.value_counts())


        # Save to file
        uq_path = "/p/lustre2/fan4/NEK_data/NEK_data_4Berkeley/NEK"+nek
        #uq_path = "/p/lustre2/fan4/myGPyTorch/classification_NEK/data/"

        train_x_df.to_csv(uq_path+"/NEK"+nek+"_binding_random_"+fold+"_trainX.csv", index=False)
        train_y_df.to_csv(uq_path+"/NEK"+nek+"_binding_random_"+fold+"_trainY.csv", index=False)
        test_x_df.to_csv(uq_path+"/NEK"+nek+"_binding_random_"+fold+"_testX.csv", index=False)
        test_y_df.to_csv(uq_path+"/NEK"+nek+"_binding_random_"+fold+"_testY.csv", index=False)





(283, 310)
(1125, 310)
      ASA+_per_atom       ASA-  ASA_H_per_atom       ASA_P  ASA_per_atom  \
1          6.511466  299.14301        8.385662  196.062190     11.950429   
4          9.173611  133.46625        9.677048  114.335650     12.685881   
5          8.296552  237.38503        9.749090  192.898510     14.825367   
8          9.887776  210.51236       13.607906   81.717201     15.877828   
9          7.022085  194.07881        7.802077  156.386930     10.452703   
...             ...        ...             ...         ...           ...   
1385       6.990230  193.16020        8.167764  135.191250     10.348269   
1386       9.424202  155.31209       13.011127   73.114189     15.935694   
1392       9.204168  170.90587       10.350260  128.671800     12.923696   
1406       7.843355  171.71243        8.605889  130.535580     11.023215   
1407       8.132017  268.25793        9.818810  181.742920     13.184420   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
1   

(281, 310)
(1127, 310)
      ASA+_per_atom        ASA-  ASA_H_per_atom      ASA_P  ASA_per_atom  \
0         14.897949   53.326256        8.903966  161.21796     17.860519   
3          6.791036  344.713500        8.136228  261.31158     12.350931   
7          9.629118  190.869540       11.334554  127.85452     14.530917   
22         8.697605  166.925110        8.602409  171.78012     11.970647   
26         8.426145  166.821290        9.244890  121.79031     11.459260   
...             ...         ...             ...        ...           ...   
1376       8.806886  163.197460        9.159060  146.61690     11.824821   
1389       8.246741  241.523150       10.113879  152.88306     13.171541   
1393       8.426281  172.753750        7.767350  212.94861     11.258310   
1394       8.104833  161.416690        7.435220  190.21004     11.858710   
1396       9.440269  132.677580        8.752351  178.08022     11.450536   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
0   

(281, 310)
(1123, 310)
      ASA+_per_atom       ASA-  ASA_H_per_atom      ASA_P  ASA_per_atom  \
2          9.101535  110.52470        7.935420  185.15604     10.828484   
10         6.007288  174.51620        9.371731  160.40347     12.398212   
15         7.162030  314.78915        9.065344  210.13132     12.567533   
16         8.974845  158.77829        9.987769  119.27426     13.046083   
19         8.747375  266.50034       10.316172  189.91547     14.114481   
...             ...        ...             ...        ...           ...   
1389       8.018614  208.74669       10.170798  119.24756     12.763136   
1392       6.055027  197.22391        8.698582  180.47252     11.276761   
1395       9.515752  166.51373       10.709387  111.35814     12.771575   
1399       6.836198  264.85251       10.641759  143.07455     15.112839   
1403       9.397039  184.02939        8.185000  257.96371     12.413913   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
2       -2.76337

(280, 310)
(1124, 310)
      ASA+_per_atom       ASA-  ASA_H_per_atom       ASA_P  ASA_per_atom  \
11         8.917764  153.35754        8.920795  157.043400     12.572967   
23         8.739346  194.10822       10.344430  106.722490     12.250189   
24         8.615438  159.62811        6.797009  240.598010     12.392311   
25         7.796604  186.91588        8.385761  149.798980     10.763523   
28         8.697266  216.59297       11.393547  100.652900     13.734311   
...             ...        ...             ...         ...           ...   
1387       7.199203  186.93825        7.093256  193.189190     10.367649   
1388       8.156430  237.23343       10.997156  123.878350     14.094115   
1390       7.852167  172.38380        8.631871  130.279750     11.044460   
1393       9.718043  193.89278       12.875704   88.634361     15.408114   
1400       6.507985  213.30025        9.257843  134.728740     11.396395   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
11  

(247, 310)
(990, 310)
      ASA+_per_atom       ASA-  ASA_H_per_atom       ASA_P  ASA_per_atom  \
0          7.389352  231.03418        9.362611  144.210750     12.640128   
10         9.876893  155.97939        8.081039  219.703520     14.946774   
12         8.459283  232.54727       11.844980   93.733719     14.131168   
17         8.738267  246.53464       11.181508  153.691470     15.226021   
20         7.977551  217.33803       10.854941   90.579193     12.782158   
...             ...        ...             ...         ...           ...   
1182       8.969400  188.22676        8.778990  200.032230     12.005316   
1190       6.463359  216.07686        7.572915  133.969790      9.383317   
1199       8.268882  212.95699        9.480789  139.585820     11.696436   
1201       7.494627  183.99098        7.828725  166.952000     11.102294   
1228      10.239554  175.10945       11.455413  118.781010     13.831033   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
0    

(283, 310)
(1126, 310)
      ASA+_per_atom       ASA-  ASA_H_per_atom      ASA_P  ASA_per_atom  \
1          8.863331  200.55731        8.744996  211.27695     13.146599   
5          5.710754  308.49966        8.336615  103.68244      9.665877   
8          7.641164  234.83014       10.412645  114.83887     13.022620   
9          7.728748  250.53065        8.216459  234.36145     13.424491   
14         7.777926  360.22958        9.405464  269.08743     14.210598   
...             ...        ...             ...        ...           ...   
1382       8.375013  138.32080        8.325641  141.62868     10.439502   
1383       7.279829  393.08694        9.612618  271.78192     14.839193   
1388       7.053889  166.09090        6.883618  178.69096      9.298361   
1407       8.538722  185.80338        8.776142  171.08333     11.535550   
1408       9.692602  179.73247        9.938689  168.41252     13.599830   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
1       -2.57943

(281, 310)
(1128, 310)
      ASA+_per_atom       ASA-  ASA_H_per_atom       ASA_P  ASA_per_atom  \
0          7.883086  153.95410        8.725966  113.495890     11.090464   
3          8.411578  222.33913       12.737668   66.599869     14.587665   
4          8.287487  183.06439        7.755779  213.420750     13.228106   
7          6.692531  237.76653        8.087123  133.172120      9.862751   
22         6.452667  275.56476        7.350988  238.733610     13.173760   
...             ...        ...             ...         ...           ...   
1386       5.955156  299.83179        7.854027  193.495000     11.309295   
1393       7.732639  211.90147        9.358958  119.201280     11.450209   
1394       6.764417  205.41473        7.854536  124.745920      9.540291   
1396       8.313942  281.27042        9.581525  228.031910     15.010857   
1402       5.512654  330.02377        7.531693  188.691060     10.227279   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
0   

In [5]:
# inhibition
method = "inhibition"
nekAll = ["2","9"]
foldAll = ["fold1","fold2","fold3","fold4","fold5"]

for nek in nekAll:
    # Get training data
    split_path = "/p/lustre2/fan4/NEK_data/NEK_data_4Berkeley/NEK"+nek
    random_df = pd.read_csv(split_path+"/NEK"+nek+"_1_uM_min_50_pct_"+method+"_5fold_random_imbalanced.csv")
    random_df.head

    moe_columns = random_df.columns[3:]
    moe_columns = moe_columns[:-1]
    moe_columns

    # Use the fold 0 as the test set, fold 1,2,3,4 as the train set
    for fold in foldAll:
        test_moe_df = random_df.loc[random_df['fold'] == fold]
        train_moe_df = random_df.loc[random_df['fold'] != fold]
        print(test_moe_df.shape)
        print(train_moe_df.shape)

        test_x_df = test_moe_df[moe_columns]
        test_y_df = test_moe_df['active']
        print(test_x_df)
        print(test_y_df.value_counts())

        train_x_df = train_moe_df[moe_columns]
        train_y_df = train_moe_df['active']
        print(train_x_df.shape)
        print(train_y_df.value_counts())


        # Save to file
        uq_path = "/p/lustre2/fan4/NEK_data/NEK_data_4Berkeley/NEK"+nek
        #uq_path = "/p/lustre2/fan4/myGPyTorch/classification_NEK/data/"

        train_x_df.to_csv(uq_path+"/NEK"+nek+"_"+method+"_random_"+fold+"_trainX.csv", index=False)
        train_y_df.to_csv(uq_path+"/NEK"+nek+"_"+method+"_random_"+fold+"_trainY.csv", index=False)
        test_x_df.to_csv(uq_path+"/NEK"+nek+"_"+method+"_random_"+fold+"_testX.csv", index=False)
        test_y_df.to_csv(uq_path+"/NEK"+nek+"_"+method+"_random_"+fold+"_testY.csv", index=False)





(409, 310)
(1635, 310)
      ASA+_per_atom       ASA-  ASA_H_per_atom      ASA_P  ASA_per_atom  \
1          7.934616  158.87631        7.420585  189.45982     10.865310   
4          7.463830  237.68829       10.271510  108.22626     12.574197   
5          8.182241  281.36035       11.570551  134.44797     14.558284   
6          9.252768  169.27931       10.768620  118.69076     13.735889   
9          8.561838  226.01562       10.099210  161.08264     13.601007   
...             ...        ...             ...        ...           ...   
2015       7.525463  138.65553        7.565813  140.20396      9.340547   
2026       8.021498  323.03540        8.850160  277.65530     13.477749   
2027       7.554388  255.61058        7.652859  252.85341     16.683338   
2035       7.552055  179.63304        8.890672  104.67045     10.759787   
2041       9.273098  136.72115        9.210555  140.78642     11.376500   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
1       -2.74318

(409, 310)
(1635, 310)
      ASA+_per_atom       ASA-  ASA_H_per_atom       ASA_P  ASA_per_atom  \
0         10.828989  121.34690       10.714393  124.440990     15.323319   
3          6.847276  246.11134        8.917394  130.184710     11.242121   
7          9.007080  211.36525       12.433657  123.350210     17.002184   
12         9.210639  220.20757       14.834701   62.733829     17.075195   
20         9.840159  204.23480       12.099922  131.922410     16.222498   
...             ...        ...             ...         ...           ...   
2012       7.195989  211.14612        8.421276  144.613800     11.050618   
2023       7.600487  132.85730        7.883537  116.157370      9.852305   
2029       7.663664  272.46390       10.050291  189.469210     14.671492   
2032      10.058188  120.01466        9.615865  134.611300     13.694995   
2037       8.319601  180.82919        8.728214  159.029740     11.101793   

      BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
0   

     ASA+_per_atom       ASA-  ASA_H_per_atom       ASA_P  ASA_per_atom  \
1         8.584988  255.29848        9.869930  183.341670     13.143889   
5         8.867048  181.31946        9.370875  167.430450     13.010667   
8         9.482415  214.35144       13.958587   71.113968     16.180899   
12        7.607816  304.18521       10.844918  143.639310     13.717705   
17        7.102113  410.88959        8.947090  286.857850     13.045060   
..             ...        ...             ...         ...           ...   
373       9.660049  156.93013       10.041747  140.898820     13.396480   
380       7.282692  275.65182        7.766187  242.290650     11.277645   
382       8.039699  155.56841        7.579986  185.193070     10.566971   
390       9.148384  172.83737        9.838884  130.716900     11.981783   
392       6.766839  194.49434        6.565108  211.939940      9.825722   

     BCUT_PEOE_0  BCUT_PEOE_1  BCUT_PEOE_2  BCUT_PEOE_3  \
1      -2.429073    -0.628973     0.6599