In [1]:
import wandb, os, sys

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    print("Running in Colab!")
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    from google.colab import userdata
else:
    print("Not running in Colab.")

def resolve_path_gdrive(relativePath):
    if os.path.exists('/content/drive'):
        return '/content/drive/MyDrive/work/gdrive-workspaces/git/nn_catalyst/' + relativePath
    else:
        from utils import get_project_root
        return get_project_root() + "/" + relativePath


Not running in Colab.


In [None]:
import pandas as pd
def get_nan_columns(df, threshold):
    """
    Gets columns with a certain percentage of NaN values.

    Args:
        df: The DataFrame to check.
        threshold: The percentage threshold (e.g., 0.2 for 20%).

    Returns:
        A list of column names that meet the threshold.
    """

    nan_percentage = df.isnull().sum() * 100 / len(df)
    return nan_percentage[nan_percentage > threshold * 100].index.tolist()

In [4]:
# Load the data
descriptors_path = 'descriptors.csv'
targets_path = 'compiled_data.csv'
descriptors_df = pd.read_csv(resolve_path_gdrive(descriptors_path), dtype=str)
targets_df = pd.read_csv(resolve_path_gdrive(targets_path), dtype=float)
# Show sample rows
print("\nSample Rows from Descriptors DataFrame:")
print(descriptors_df.head())
print("\nSample Rows from Targets DataFrame:")
print(targets_df.head())


Sample Rows from Descriptors DataFrame:
   Label                 ABC              ABCGG nAcid nBase  \
0   9268   4.719396554912958  5.004087722255558     0     0   
1  10488  10.334062109951281  9.836417242300065     0     0   
2  25579   5.875633848974738  5.566041006395633     0     0   
3   8952  6.6112502008627025  6.890735261322646     1     0   
4  23681   7.249407296827953  6.976305832589716     0     0   

              SpAbs_A             SpMax_A           SpDiam_A  \
0   6.720566232730447  2.1010029896154583  4.202005979230917   
1  16.752497538971177  2.3623398328574394  4.724679665714879   
2    9.43114762028933  2.1753277471610764  4.350655494322151   
3   10.68725972618713    2.28774942353935  4.425414875225794   
4  11.945821561028193  2.2671838628844996     4.534367725769   

               SpAD_A             SpMAD_A  ...              SRW10  \
0   6.720566232730447  0.9600808903900638  ...   8.12355783506165   
1  16.752497538971177  1.1966069670693698  ...   9.472627

In [43]:
descriptors_df = descriptors_df.apply(pd.to_numeric, errors='coerce') 
columns_with_over_X_percent_nan = get_nan_columns(descriptors_df, 0.25)
len(columns_with_over_X_percent_nan)
descriptors_df.drop(columns_with_over_X_percent_nan, axis=1, inplace=True)
descriptors_df = descriptors_df.fillna(descriptors_df.mean())

In [47]:
descriptors_df

Unnamed: 0,Label,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,...,SRW10,TSRW10,MW,AMW,WPath,WPol,Zagreb1,Zagreb2,mZagreb1,mZagreb2
0,9268,4.719397,5.004088,0,0,6.720566,2.101003,4.202006,6.720566,0.960081,...,8.123558,33.343946,136.047505,6.802375,46,4,28.0,26.0,4.562500,1.625000
1,10488,10.334062,9.836417,0,0,16.752498,2.362340,4.724680,16.752498,1.196607,...,9.472628,45.579501,215.042359,8.960098,296,20,68.0,77.0,6.645833,3.152778
2,25579,5.875634,5.566041,0,0,9.431148,2.175328,4.350655,9.431148,1.178893,...,8.479907,35.755147,111.048427,7.932031,61,7,36.0,38.0,3.222222,1.833333
3,8952,6.611250,6.890735,1,0,10.687260,2.287749,4.425415,10.687260,1.187473,...,8.735364,51.247665,143.004099,10.214579,84,9,42.0,47.0,4.083333,2.055556
4,23681,7.249407,6.976306,0,0,11.945822,2.267184,4.534368,11.945822,1.194582,...,8.912069,39.310842,219.935762,13.745985,116,12,46.0,51.0,4.333333,2.361111
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28726,16555,18.930123,15.134694,0,1,20.727298,2.363418,4.675169,20.727298,1.232732,...,10.155219,60.792100,379.155036,7.291443,2500001534,42,128.0,152.0,6.389452,5.861111
28727,6241,11.838245,10.433687,0,0,19.148384,2.324965,4.649929,19.148384,1.196774,...,9.415971,47.809076,262.006658,9.703950,503,20,76.0,83.0,7.145833,3.583333
28728,34083,12.825381,11.178801,1,0,22.014979,2.385270,4.770540,22.014979,1.294999,...,9.622251,49.493130,237.100108,7.409378,504,25,84.0,97.0,5.805556,3.916667
28729,1475,13.044161,11.771267,1,0,21.465002,2.367989,4.667357,21.465002,1.262647,...,9.580316,63.381945,229.110279,7.159696,531,23,86.0,99.0,6.416667,3.750000


In [62]:
# Merge the numeric dataframes on the common label column
numeric_data = pd.merge(descriptors_df, targets_df, left_on='Label', right_on='mol_num')
numeric_data = numeric_data.drop(columns=['Label', 'mol_num'])

In [63]:
numeric_data

Unnamed: 0,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,LogEE_A,...,homo_spin_down_r,lumo_spin_down_r,max_charge_pos_r,max_charge_neg_r,max_spin_r,dipole_r,gibbs_r,elec_en_r,ddg_ox,ddg_red
0,4.719397,5.004088,0,0,6.720566,2.101003,4.202006,6.720566,0.960081,2.779033,...,-0.020575,0.084241,0.776662,-0.786469,0.819695,7.82223,-947.774409,-947.901675,8.772307,0.872504
1,10.334062,9.836417,0,0,16.752498,2.362340,4.724680,16.752498,1.196607,3.534071,...,-0.057754,0.092397,0.781351,-0.707508,0.393427,4.22675,-1051.047766,-1051.177809,8.096243,-0.229233
2,5.875634,5.566041,0,0,9.431148,2.175328,4.350655,9.431148,1.178893,2.979741,...,-0.016276,0.150350,0.177492,-0.609541,0.390859,1.28048,-386.580976,-386.648066,8.820537,0.584135
3,6.611250,6.890735,1,0,10.687260,2.287749,4.425415,10.687260,1.187473,3.103710,...,-0.030881,0.105986,0.391843,-0.557651,0.513850,3.95845,-796.662432,-796.720716,8.867437,-0.211576
4,7.249407,6.976306,0,0,11.945822,2.267184,4.534368,11.945822,1.194582,3.197666,...,-0.032088,0.062253,0.488503,-0.677500,0.687528,8.14008,-3343.196598,-3343.263585,7.578382,-0.510448
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
26228,7.985024,7.826624,0,0,13.404811,2.333527,4.667054,13.404811,1.218619,3.295415,...,0.056183,0.096731,0.383872,-0.732787,0.351100,8.40102,-515.072039,-515.187291,6.923090,-0.779938
26229,14.567764,13.061164,0,0,23.187437,2.369725,4.720068,23.187437,1.220391,3.859935,...,-0.038718,0.104386,0.326480,-0.593591,0.675688,9.12559,-866.368550,-866.702386,8.019006,0.474847
26230,11.838245,10.433687,0,0,19.148384,2.324965,4.649929,19.148384,1.196774,3.658465,...,-0.053150,0.006134,0.681842,-0.652126,0.326542,18.39088,-1546.460302,-1546.609496,8.658371,-2.021268
26231,12.825381,11.178801,1,0,22.014979,2.385270,4.770540,22.014979,1.294999,3.747192,...,-0.038013,0.100376,0.384700,-0.588087,0.272390,8.78065,-821.360312,-821.572805,6.655491,-0.071545


In [65]:
# Last 29 columns are the targets
numeric_data

Unnamed: 0,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,LogEE_A,...,homo_spin_down_r,lumo_spin_down_r,max_charge_pos_r,max_charge_neg_r,max_spin_r,dipole_r,gibbs_r,elec_en_r,ddg_ox,ddg_red
0,4.719397,5.004088,0,0,6.720566,2.101003,4.202006,6.720566,0.960081,2.779033,...,-0.020575,0.084241,0.776662,-0.786469,0.819695,7.82223,-947.774409,-947.901675,8.772307,0.872504
1,10.334062,9.836417,0,0,16.752498,2.362340,4.724680,16.752498,1.196607,3.534071,...,-0.057754,0.092397,0.781351,-0.707508,0.393427,4.22675,-1051.047766,-1051.177809,8.096243,-0.229233
2,5.875634,5.566041,0,0,9.431148,2.175328,4.350655,9.431148,1.178893,2.979741,...,-0.016276,0.150350,0.177492,-0.609541,0.390859,1.28048,-386.580976,-386.648066,8.820537,0.584135
3,6.611250,6.890735,1,0,10.687260,2.287749,4.425415,10.687260,1.187473,3.103710,...,-0.030881,0.105986,0.391843,-0.557651,0.513850,3.95845,-796.662432,-796.720716,8.867437,-0.211576
4,7.249407,6.976306,0,0,11.945822,2.267184,4.534368,11.945822,1.194582,3.197666,...,-0.032088,0.062253,0.488503,-0.677500,0.687528,8.14008,-3343.196598,-3343.263585,7.578382,-0.510448
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
26228,7.985024,7.826624,0,0,13.404811,2.333527,4.667054,13.404811,1.218619,3.295415,...,0.056183,0.096731,0.383872,-0.732787,0.351100,8.40102,-515.072039,-515.187291,6.923090,-0.779938
26229,14.567764,13.061164,0,0,23.187437,2.369725,4.720068,23.187437,1.220391,3.859935,...,-0.038718,0.104386,0.326480,-0.593591,0.675688,9.12559,-866.368550,-866.702386,8.019006,0.474847
26230,11.838245,10.433687,0,0,19.148384,2.324965,4.649929,19.148384,1.196774,3.658465,...,-0.053150,0.006134,0.681842,-0.652126,0.326542,18.39088,-1546.460302,-1546.609496,8.658371,-2.021268
26231,12.825381,11.178801,1,0,22.014979,2.385270,4.770540,22.014979,1.294999,3.747192,...,-0.038013,0.100376,0.384700,-0.588087,0.272390,8.78065,-821.360312,-821.572805,6.655491,-0.071545


In [73]:
numeric_data.to_csv('merged_data_last29.csv', index=False) 

In [74]:
print(numeric_data.shape)
import numpy as np
xy = np.loadtxt(resolve_path_gdrive('src/pl/merged_data_last29.csv'), delimiter=',', skiprows=1)

(26233, 1508)


In [75]:
xy.shape

(26233, 1508)

In [76]:
xy

array([[ 4.71939655e+00,  5.00408772e+00,  0.00000000e+00, ...,
        -9.47901675e+02,  8.77230721e+00,  8.72504388e-01],
       [ 1.03340621e+01,  9.83641724e+00,  0.00000000e+00, ...,
        -1.05117781e+03,  8.09624339e+00, -2.29233083e-01],
       [ 5.87563385e+00,  5.56604101e+00,  0.00000000e+00, ...,
        -3.86648066e+02,  8.82053653e+00,  5.84135000e-01],
       ...,
       [ 1.18382446e+01,  1.04336871e+01,  0.00000000e+00, ...,
        -1.54660950e+03,  8.65837095e+00, -2.02126764e+00],
       [ 1.28253813e+01,  1.11788009e+01,  1.00000000e+00, ...,
        -8.21572805e+02,  6.65549141e+00, -7.15453380e-02],
       [ 1.30441609e+01,  1.17712666e+01,  1.00000000e+00, ...,
        -7.47409158e+02,  6.94039385e+00, -3.70383343e-01]])

In [78]:
xy[:,:-29]

array([[ 4.71939655,  5.00408772,  0.        , ..., 26.        ,
         4.5625    ,  1.625     ],
       [10.33406211,  9.83641724,  0.        , ..., 77.        ,
         6.64583333,  3.15277778],
       [ 5.87563385,  5.56604101,  0.        , ..., 38.        ,
         3.22222222,  1.83333333],
       ...,
       [11.83824461, 10.43368714,  0.        , ..., 83.        ,
         7.14583333,  3.58333333],
       [12.82538132, 11.1788009 ,  1.        , ..., 97.        ,
         5.80555556,  3.91666667],
       [13.04416092, 11.77126656,  1.        , ..., 99.        ,
         6.41666667,  3.75      ]])

In [79]:
xy[:,-29:]

array([[-2.48949000e-01,  1.44670000e-02,  9.84272000e-01, ...,
        -9.47901675e+02,  8.77230721e+00,  8.72504388e-01],
       [-2.24233000e-01, -7.42270000e-02,  8.87076000e-01, ...,
        -1.05117781e+03,  8.09624339e+00, -2.29233083e-01],
       [-2.30413000e-01, -6.20650000e-02,  2.42509000e-01, ...,
        -3.86648066e+02,  8.82053653e+00,  5.84135000e-01],
       ...,
       [-2.45138000e-01, -1.16174000e-01,  7.25381000e-01, ...,
        -1.54660950e+03,  8.65837095e+00, -2.02126764e+00],
       [-1.82995000e-01, -7.28580000e-02,  4.35501000e-01, ...,
        -8.21572805e+02,  6.65549141e+00, -7.15453380e-02],
       [-1.70878000e-01, -8.44480000e-02,  4.37395000e-01, ...,
        -7.47409158e+02,  6.94039385e+00, -3.70383343e-01]])