In [5]:
import os, sys

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    print("Running in Colab!")
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    from google.colab import userdata
else:
    print("Not running in Colab.")

def resolve_path_gdrive(relativePath):
    if os.path.exists('/content/drive'):
        return '/content/drive/MyDrive/work/gdrive-workspaces/git/nn_catalyst/' + relativePath
    else:
        from utils import get_project_root
        return get_project_root() + "/" + relativePath


Not running in Colab.


In [3]:
import pandas as pd
def get_nan_columns(df, threshold):
    """
    Gets columns with a certain percentage of NaN values.

    Args:
        df: The DataFrame to check.
        threshold: The percentage threshold (e.g., 0.2 for 20%).

    Returns:
        A list of column names that meet the threshold.
    """

    nan_percentage = df.isnull().sum() * 100 / len(df)
    return nan_percentage[nan_percentage > threshold * 100].index.tolist()

In [4]:
# Load the data
descriptors_path = 'descriptors.csv'
targets_path = 'compiled_data.csv'
descriptors_df = pd.read_csv(resolve_path_gdrive(descriptors_path), dtype=str)
targets_df = pd.read_csv(resolve_path_gdrive(targets_path), dtype=float)
# Show sample rows
print("\nSample Rows from Descriptors DataFrame:")
print(descriptors_df.head())
print("\nSample Rows from Targets DataFrame:")
print(targets_df.head())


Sample Rows from Descriptors DataFrame:
   Label                 ABC              ABCGG nAcid nBase  \
0   9268   4.719396554912958  5.004087722255558     0     0   
1  10488  10.334062109951281  9.836417242300065     0     0   
2  25579   5.875633848974738  5.566041006395633     0     0   
3   8952  6.6112502008627025  6.890735261322646     1     0   
4  23681   7.249407296827953  6.976305832589716     0     0   

              SpAbs_A             SpMax_A           SpDiam_A  \
0   6.720566232730447  2.1010029896154583  4.202005979230917   
1  16.752497538971177  2.3623398328574394  4.724679665714879   
2    9.43114762028933  2.1753277471610764  4.350655494322151   
3   10.68725972618713    2.28774942353935  4.425414875225794   
4  11.945821561028193  2.2671838628844996     4.534367725769   

               SpAD_A             SpMAD_A  ...              SRW10  \
0   6.720566232730447  0.9600808903900638  ...   8.12355783506165   
1  16.752497538971177  1.1966069670693698  ...   9.472627

In [43]:
descriptors_df = descriptors_df.apply(pd.to_numeric, errors='coerce') 
columns_with_over_X_percent_nan = get_nan_columns(descriptors_df, 0.25)
len(columns_with_over_X_percent_nan)
descriptors_df.drop(columns_with_over_X_percent_nan, axis=1, inplace=True)
descriptors_df = descriptors_df.fillna(descriptors_df.mean())

In [47]:
descriptors_df

Unnamed: 0,Label,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,...,SRW10,TSRW10,MW,AMW,WPath,WPol,Zagreb1,Zagreb2,mZagreb1,mZagreb2
0,9268,4.719397,5.004088,0,0,6.720566,2.101003,4.202006,6.720566,0.960081,...,8.123558,33.343946,136.047505,6.802375,46,4,28.0,26.0,4.562500,1.625000
1,10488,10.334062,9.836417,0,0,16.752498,2.362340,4.724680,16.752498,1.196607,...,9.472628,45.579501,215.042359,8.960098,296,20,68.0,77.0,6.645833,3.152778
2,25579,5.875634,5.566041,0,0,9.431148,2.175328,4.350655,9.431148,1.178893,...,8.479907,35.755147,111.048427,7.932031,61,7,36.0,38.0,3.222222,1.833333
3,8952,6.611250,6.890735,1,0,10.687260,2.287749,4.425415,10.687260,1.187473,...,8.735364,51.247665,143.004099,10.214579,84,9,42.0,47.0,4.083333,2.055556
4,23681,7.249407,6.976306,0,0,11.945822,2.267184,4.534368,11.945822,1.194582,...,8.912069,39.310842,219.935762,13.745985,116,12,46.0,51.0,4.333333,2.361111
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28726,16555,18.930123,15.134694,0,1,20.727298,2.363418,4.675169,20.727298,1.232732,...,10.155219,60.792100,379.155036,7.291443,2500001534,42,128.0,152.0,6.389452,5.861111
28727,6241,11.838245,10.433687,0,0,19.148384,2.324965,4.649929,19.148384,1.196774,...,9.415971,47.809076,262.006658,9.703950,503,20,76.0,83.0,7.145833,3.583333
28728,34083,12.825381,11.178801,1,0,22.014979,2.385270,4.770540,22.014979,1.294999,...,9.622251,49.493130,237.100108,7.409378,504,25,84.0,97.0,5.805556,3.916667
28729,1475,13.044161,11.771267,1,0,21.465002,2.367989,4.667357,21.465002,1.262647,...,9.580316,63.381945,229.110279,7.159696,531,23,86.0,99.0,6.416667,3.750000


In [62]:
# Merge the numeric dataframes on the common label column
numeric_data = pd.merge(descriptors_df, targets_df, left_on='Label', right_on='mol_num')
numeric_data = numeric_data.drop(columns=['Label', 'mol_num'])

In [63]:
numeric_data

Unnamed: 0,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,LogEE_A,...,homo_spin_down_r,lumo_spin_down_r,max_charge_pos_r,max_charge_neg_r,max_spin_r,dipole_r,gibbs_r,elec_en_r,ddg_ox,ddg_red
0,4.719397,5.004088,0,0,6.720566,2.101003,4.202006,6.720566,0.960081,2.779033,...,-0.020575,0.084241,0.776662,-0.786469,0.819695,7.82223,-947.774409,-947.901675,8.772307,0.872504
1,10.334062,9.836417,0,0,16.752498,2.362340,4.724680,16.752498,1.196607,3.534071,...,-0.057754,0.092397,0.781351,-0.707508,0.393427,4.22675,-1051.047766,-1051.177809,8.096243,-0.229233
2,5.875634,5.566041,0,0,9.431148,2.175328,4.350655,9.431148,1.178893,2.979741,...,-0.016276,0.150350,0.177492,-0.609541,0.390859,1.28048,-386.580976,-386.648066,8.820537,0.584135
3,6.611250,6.890735,1,0,10.687260,2.287749,4.425415,10.687260,1.187473,3.103710,...,-0.030881,0.105986,0.391843,-0.557651,0.513850,3.95845,-796.662432,-796.720716,8.867437,-0.211576
4,7.249407,6.976306,0,0,11.945822,2.267184,4.534368,11.945822,1.194582,3.197666,...,-0.032088,0.062253,0.488503,-0.677500,0.687528,8.14008,-3343.196598,-3343.263585,7.578382,-0.510448
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
26228,7.985024,7.826624,0,0,13.404811,2.333527,4.667054,13.404811,1.218619,3.295415,...,0.056183,0.096731,0.383872,-0.732787,0.351100,8.40102,-515.072039,-515.187291,6.923090,-0.779938
26229,14.567764,13.061164,0,0,23.187437,2.369725,4.720068,23.187437,1.220391,3.859935,...,-0.038718,0.104386,0.326480,-0.593591,0.675688,9.12559,-866.368550,-866.702386,8.019006,0.474847
26230,11.838245,10.433687,0,0,19.148384,2.324965,4.649929,19.148384,1.196774,3.658465,...,-0.053150,0.006134,0.681842,-0.652126,0.326542,18.39088,-1546.460302,-1546.609496,8.658371,-2.021268
26231,12.825381,11.178801,1,0,22.014979,2.385270,4.770540,22.014979,1.294999,3.747192,...,-0.038013,0.100376,0.384700,-0.588087,0.272390,8.78065,-821.360312,-821.572805,6.655491,-0.071545


In [73]:
numeric_data.to_csv('merged_data_last29.csv', index=False)
print(numeric_data.shape)

In [5]:
import numpy as np
xy = np.loadtxt(resolve_path_gdrive('src/pl/merged_data_last29.csv'), delimiter=',', skiprows=1)

In [6]:
xy.shape

(26233, 1508)

In [76]:
xy

array([[ 4.71939655e+00,  5.00408772e+00,  0.00000000e+00, ...,
        -9.47901675e+02,  8.77230721e+00,  8.72504388e-01],
       [ 1.03340621e+01,  9.83641724e+00,  0.00000000e+00, ...,
        -1.05117781e+03,  8.09624339e+00, -2.29233083e-01],
       [ 5.87563385e+00,  5.56604101e+00,  0.00000000e+00, ...,
        -3.86648066e+02,  8.82053653e+00,  5.84135000e-01],
       ...,
       [ 1.18382446e+01,  1.04336871e+01,  0.00000000e+00, ...,
        -1.54660950e+03,  8.65837095e+00, -2.02126764e+00],
       [ 1.28253813e+01,  1.11788009e+01,  1.00000000e+00, ...,
        -8.21572805e+02,  6.65549141e+00, -7.15453380e-02],
       [ 1.30441609e+01,  1.17712666e+01,  1.00000000e+00, ...,
        -7.47409158e+02,  6.94039385e+00, -3.70383343e-01]])

In [78]:
xy[:,:-29]

array([[ 4.71939655,  5.00408772,  0.        , ..., 26.        ,
         4.5625    ,  1.625     ],
       [10.33406211,  9.83641724,  0.        , ..., 77.        ,
         6.64583333,  3.15277778],
       [ 5.87563385,  5.56604101,  0.        , ..., 38.        ,
         3.22222222,  1.83333333],
       ...,
       [11.83824461, 10.43368714,  0.        , ..., 83.        ,
         7.14583333,  3.58333333],
       [12.82538132, 11.1788009 ,  1.        , ..., 97.        ,
         5.80555556,  3.91666667],
       [13.04416092, 11.77126656,  1.        , ..., 99.        ,
         6.41666667,  3.75      ]])

In [79]:
xy[:,-29:]

array([[-2.48949000e-01,  1.44670000e-02,  9.84272000e-01, ...,
        -9.47901675e+02,  8.77230721e+00,  8.72504388e-01],
       [-2.24233000e-01, -7.42270000e-02,  8.87076000e-01, ...,
        -1.05117781e+03,  8.09624339e+00, -2.29233083e-01],
       [-2.30413000e-01, -6.20650000e-02,  2.42509000e-01, ...,
        -3.86648066e+02,  8.82053653e+00,  5.84135000e-01],
       ...,
       [-2.45138000e-01, -1.16174000e-01,  7.25381000e-01, ...,
        -1.54660950e+03,  8.65837095e+00, -2.02126764e+00],
       [-1.82995000e-01, -7.28580000e-02,  4.35501000e-01, ...,
        -8.21572805e+02,  6.65549141e+00, -7.15453380e-02],
       [-1.70878000e-01, -8.44480000e-02,  4.37395000e-01, ...,
        -7.47409158e+02,  6.94039385e+00, -3.70383343e-01]])

## Make a new merged data with reordered target columns

In [8]:
# Import the dataset from the csv file.
df_orig = pd.read_csv(resolve_path_gdrive('src/pl/merged_data_last29.csv'))

In [39]:
df = df_orig
print(df_orig.shape)
X = df[df.columns[:-29]]
y = df[df.columns[-29:]]
print(X.shape)
print(y.shape)

(26233, 1508)
(26233, 1479)
(26233, 29)


In [30]:
print(y.columns)

Index(['homo_n', 'lumo_n', 'max_charge_pos_n', 'max_charge_neg_n', 'dipole_n',
       'gibbs_n', 'elec_en_n', 'homo_spin_up_o', 'lumo_spin_up_o',
       'homo_spin_down_o', 'lumo_spin_down_o', 'max_charge_pos_o',
       'max_charge_neg_o', 'max_spin_o', 'dipole_o', 'gibbs_o', 'elec_en_o',
       'homo_spin_up_r', 'lumo_spin_up_r', 'homo_spin_down_r',
       'lumo_spin_down_r', 'max_charge_pos_r', 'max_charge_neg_r',
       'max_spin_r', 'dipole_r', 'gibbs_r', 'elec_en_r', 'ddg_ox', 'ddg_red'],
      dtype='object')


In [43]:
new_order = [27, 7, 26, 17, 6, 16, 10, 8, 28, 2, 9, 3, 12, 29, 11, 1, 19, 22, 21, 4, 13, 18, 20, 23, 24, 14, 25, 15, 5]
new_order = [x - 1 for x in new_order]
#new_order
y_new = y.iloc[:, new_order]

In [44]:
print(y_new.shape)
y_new.columns

(26233, 29)


Index(['elec_en_r', 'elec_en_n', 'gibbs_r', 'elec_en_o', 'gibbs_n', 'gibbs_o',
       'homo_spin_down_o', 'homo_spin_up_o', 'ddg_ox', 'lumo_n',
       'lumo_spin_up_o', 'max_charge_pos_n', 'max_charge_pos_o', 'ddg_red',
       'lumo_spin_down_o', 'homo_n', 'lumo_spin_up_r', 'max_charge_pos_r',
       'lumo_spin_down_r', 'max_charge_neg_n', 'max_charge_neg_o',
       'homo_spin_up_r', 'homo_spin_down_r', 'max_charge_neg_r', 'max_spin_r',
       'max_spin_o', 'dipole_r', 'dipole_o', 'dipole_n'],
      dtype='object')

In [53]:
df_merged = pd.concat([X, y_new], axis=1)
print(df_merged.shape)
df_merged

(26233, 1508)


Unnamed: 0,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,LogEE_A,...,max_charge_neg_n,max_charge_neg_o,homo_spin_up_r,homo_spin_down_r,max_charge_neg_r,max_spin_r,max_spin_o,dipole_r,dipole_o,dipole_n
0,4.719397,5.004088,0,0,6.720566,2.101003,4.202006,6.720566,0.960081,2.779033,...,-0.795555,-0.757062,0.041076,-0.020575,-0.786469,0.819695,0.448475,7.82223,3.77466,2.38501
1,10.334062,9.836417,0,0,16.752498,2.362340,4.724680,16.752498,1.196607,3.534071,...,-0.677296,-0.665821,0.068416,-0.057754,-0.707508,0.393427,0.523409,4.22675,5.61772,5.46422
2,5.875634,5.566041,0,0,9.431148,2.175328,4.350655,9.431148,1.178893,2.979741,...,-0.595579,-0.551924,0.121335,-0.016276,-0.609541,0.390859,0.367176,1.28048,4.87947,3.23992
3,6.611250,6.890735,1,0,10.687260,2.287749,4.425415,10.687260,1.187473,3.103710,...,-0.503715,-0.439583,0.079421,-0.030881,-0.557651,0.513850,0.309175,3.95845,4.16712,4.78338
4,7.249407,6.976306,0,0,11.945822,2.267184,4.534368,11.945822,1.194582,3.197666,...,-0.557242,-0.537023,0.011237,-0.032088,-0.677500,0.687528,0.435499,8.14008,6.22044,1.46339
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
26228,7.985024,7.826624,0,0,13.404811,2.333527,4.667054,13.404811,1.218619,3.295415,...,-0.720828,-0.695303,0.064812,0.056183,-0.732787,0.351100,0.522121,8.40102,2.10622,5.30719
26229,14.567764,13.061164,0,0,23.187437,2.369725,4.720068,23.187437,1.220391,3.859935,...,-0.602053,-0.590933,0.075004,-0.038718,-0.593591,0.675688,0.371838,9.12559,2.28224,6.13081
26230,11.838245,10.433687,0,0,19.148384,2.324965,4.649929,19.148384,1.196774,3.658465,...,-0.657408,-0.667147,-0.024233,-0.053150,-0.652126,0.326542,0.280797,18.39088,9.55022,5.79727
26231,12.825381,11.178801,1,0,22.014979,2.385270,4.770540,22.014979,1.294999,3.747192,...,-0.523937,-0.520914,0.078543,-0.038013,-0.588087,0.272390,0.379390,8.78065,4.76712,2.50315


In [54]:
df_merged.to_csv('merged_data_last29_reordered_byR2.csv', index=False)

In [6]:
import pandas as pd
df_last29 = pd.read_csv(resolve_path_gdrive('src/pl/merged_data_last29.csv'), nrows=10)
df_last29_reordered = pd.read_csv(resolve_path_gdrive('src/pl/merged_data_last29_reordered_byR2.csv'), nrows=10)

In [16]:
df_last29.iloc[:, -13:]

Unnamed: 0,elec_en_o,homo_spin_up_r,lumo_spin_up_r,homo_spin_down_r,lumo_spin_down_r,max_charge_pos_r,max_charge_neg_r,max_spin_r,dipole_r,gibbs_r,elec_en_r,ddg_ox,ddg_red
0,-947.612948,0.041076,0.135704,-0.020575,0.084241,0.776662,-0.786469,0.819695,7.82223,-947.774409,-947.901675,8.772307,0.872504
1,-1050.874821,0.068416,0.097778,-0.057754,0.092397,0.781351,-0.707508,0.393427,4.22675,-1051.047766,-1051.177809,8.096243,-0.229233
2,-386.35072,0.121335,0.157118,-0.016276,0.15035,0.177492,-0.609541,0.390859,1.28048,-386.580976,-386.648066,8.820537,0.584135
3,-796.389167,0.079421,0.116174,-0.030881,0.105986,0.391843,-0.557651,0.51385,3.95845,-796.662432,-796.720716,8.867437,-0.211576
4,-3342.971231,0.011237,0.076315,-0.032088,0.062253,0.488503,-0.6775,0.687528,8.14008,-3343.196598,-3343.263585,7.578382,-0.510448
5,-2027.99644,0.035542,0.063488,-0.035806,0.056888,0.334323,-0.446168,0.420106,4.65223,-2028.139831,-2028.282185,7.10473,-0.784362
6,-3511.861904,-0.004721,0.071423,-0.045213,0.049832,0.432329,-0.658101,0.739521,5.92762,-3512.09548,-3512.189311,8.215709,-0.812342
7,-747.171555,0.075703,0.086869,-0.034518,0.086775,0.528934,-0.698429,0.289324,1.24896,-747.304552,-747.461104,7.886484,-0.15435
8,-543.818415,0.124699,0.126322,-0.04452,0.094732,0.115681,-0.412319,0.210536,4.90508,-543.666264,-543.976337,6.36543,1.84119
9,-1177.527132,0.062698,0.074428,-0.079391,0.072709,0.55753,-0.803741,0.241492,1.15805,-1177.482948,-1177.792814,7.311857,-0.055645


In [11]:
df_last29_reordered.iloc[:, -29:]

Unnamed: 0,elec_en_r,elec_en_n,gibbs_r,elec_en_o,gibbs_n,gibbs_o,homo_spin_down_o,homo_spin_up_o,ddg_ox,lumo_n,...,max_charge_neg_n,max_charge_neg_o,homo_spin_up_r,homo_spin_down_r,max_charge_neg_r,max_spin_r,max_spin_o,dipole_r,dipole_o,dipole_n
0,-947.901675,-947.939146,-947.774409,-947.612948,-947.806473,-947.484092,-0.44008,-0.427002,8.772307,0.014467,...,-0.795555,-0.757062,0.041076,-0.020575,-0.786469,0.819695,0.448475,7.82223,3.77466,2.38501
1,-1051.177809,-1051.175838,-1051.047766,-1050.874821,-1051.039341,-1050.741806,-0.398813,-0.401269,8.096243,-0.074227,...,-0.677296,-0.665821,0.068416,-0.057754,-0.707508,0.393427,0.523409,4.22675,5.61772,5.46422
2,-386.648066,-386.678452,-386.580976,-386.35072,-386.602443,-386.27829,-0.446113,-0.447334,8.820537,-0.062065,...,-0.595579,-0.551924,0.121335,-0.016276,-0.609541,0.390859,0.367176,1.28048,4.87947,3.23992
3,-796.720716,-796.717606,-796.662432,-796.389167,-796.654657,-796.32878,-0.438997,-0.446121,8.867437,-0.086854,...,-0.503715,-0.439583,0.079421,-0.030881,-0.557651,0.51385,0.309175,3.95845,4.16712,4.78338
4,-3343.263585,-3343.250411,-3343.196598,-3342.971231,-3343.177839,-3342.899335,-0.412475,-0.380143,7.578382,-0.064958,...,-0.557242,-0.537023,0.011237,-0.032088,-0.6775,0.687528,0.435499,8.14008,6.22044,1.46339
5,-2028.282185,-2028.257802,-2028.139831,-2027.99644,-2028.111006,-2027.849909,-0.353538,-0.346968,7.10473,-0.081833,...,-0.466706,-0.40674,0.035542,-0.035806,-0.446168,0.420106,0.603036,4.65223,6.01536,4.66541
6,-3512.189311,-3512.165141,-3512.09548,-3511.861904,-3512.065626,-3511.7637,-0.38813,-0.391367,8.215709,-0.069075,...,-0.594036,-0.604785,-0.004721,-0.045213,-0.658101,0.739521,0.215973,5.92762,3.85286,2.71244
7,-747.461104,-747.46344,-747.304552,-747.171555,-747.29888,-747.009053,-0.396167,-0.392643,7.886484,-0.068555,...,-0.680876,-0.665433,0.075703,-0.034518,-0.698429,0.289324,0.313012,1.24896,1.62555,0.65319
8,-543.976337,-544.053511,-543.666264,-543.818415,-543.733928,-543.499999,-0.314721,-0.330876,6.36543,0.019323,...,-0.641368,-0.568822,0.124699,-0.04452,-0.412319,0.210536,0.566077,4.90508,2.19214,1.47716
9,-1177.792814,-1177.798239,-1177.482948,-1177.527132,-1177.480903,-1177.212194,-0.338987,-0.344523,7.311857,-0.053451,...,-0.821138,-0.806507,0.062698,-0.079391,-0.803741,0.241492,0.147,1.15805,1.28063,1.60491
