In [2]:
import pandas as pd
import numpy as np
from openff.toolkit.topology import Molecule
from openff.units import unit
import seaborn as sb
from scipy import stats
import matplotlib.pyplot as plt
from rdkit.Chem import Draw
from sklearn.metrics import r2_score
from matplotlib.colors import Normalize
import matplotlib.cm as cm
import pandas as pd
import seaborn as sns
from matplotlib.colors import LogNorm
from matplotlib.collections import QuadMesh
import matplotlib as mpl
import polars as pl

def annotate_metrics(x, y, ax=None, **kwargs):
    """
    Annotate MAE, RMSE, and R² directly on the graph as red text.
    """
    ax = ax or plt.gca()
    mae = np.mean(np.abs(x - y))
    rmse = np.sqrt(np.mean((x - y) ** 2))
    r2 = r2_score(x, y)
    ax.text(0.05, 0.95, f'MAE: {mae:.2f}\nRMSE: {rmse:.2f}\nR²: {r2:.2f}',
            transform=ax.transAxes, fontsize=15, color='red', ha='left', va='top')

def density_hist2d(x, y, ax=None, bins=1000, range=None, **kwargs):
    """
    Create a 2D histogram colored by log(count) to show density.
    """
    ax = ax or plt.gca()
    
    h = ax.hist2d(x, y, bins=bins, range=range, norm=LogNorm(), cmap='viridis')
    
    # h is a tuple: (counts, xedges, yedges, image)
    # We return the image so we can potentially create a global colorbar outside.
    return h


def density_scatter(x, y, ax=None, bins=100, range=None, **kwargs):
    ax = ax or plt.gca()
    # Remove any existing 'color' or 'c' from kwargs to avoid conflicts
    kwargs.pop('color', None)
    kwargs.pop('c', None)
    # Compute 2D histogram
    counts, xedges, yedges = np.histogram2d(x, y, bins=bins, range=range)
    
    # Find bin indices for each point
    x_bin = np.searchsorted(xedges, x, side='right') - 1
    y_bin = np.searchsorted(yedges, y, side='right') - 1
    in_range = (x_bin >= 0) & (x_bin < len(xedges)-1) & (y_bin >= 0) & (y_bin < len(yedges)-1)
    x_in = x[in_range]
    y_in = y[in_range]
    x_bin = x_bin[in_range]
    y_bin = y_bin[in_range]
    
    point_density = counts[x_bin, y_bin]
    log_density = np.log10(point_density + 1)

    norm = Normalize(vmin=log_density.min(), vmax=log_density.max())
    cmap = cm.get_cmap('viridis')
    colors = cmap(norm(log_density))

    scatter = ax.scatter(x_in, y_in, c=colors, **kwargs)
    
    # Store the norm and cmap as attributes on the scatter for later retrieval
    scatter.norm = norm
    scatter.cmap = cmap
    
    return scatter

def equalityline(x, y, ax = None, **kwargs):
    x0, x1 = min(x), max(x)
    y0, y1 = min(y), max(y)
    lims = [min(x0, y0), max(x1, y1)]
    ax = ax or plt.gca()
    ax.plot(lims, lims, **kwargs)

def m(x, y, ax = None, **kwargs):
    
    x = np.array(x)
    y = np.array(y)
    ax = ax or plt.gca()
    m =(np.sum((x - y)**2, axis=0)/(y.shape[0]))**0.5
    ax.annotate(f'$RMSE= {float(m):.2f}$',
                xy=(.01, .99), xycoords=ax.transAxes, fontsize=8,
                color='darkred', backgroundcolor='#FFFFFF99', ha='left', va='top')   
    
def kde(x, y, ax = None, **kwargs):
    if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
        x = np.array(x)
        y = np.array(y)
    ax = ax or plt.gca()
    kernel = stats.gaussian_kde(x)(y)
    ax.scatter(x, y, c=kernel, s=5)

In [3]:
# df = pd.read_parquet('charge_models_test_withgeoms.parquet')
pldf = pl.scan_parquet('./charge_models_test_withgeoms.parquet').collect(engine='streaming')
df = pldf.to_pandas()


In [4]:
df

Unnamed: 0,mbis_charges,am1bcc_charges,espaloma_charges,riniker_monopoles,resp_charges,qm_dipoles,mbis_dipoles,am1bcc_dipole,espaloma_dipole,riniker_dipoles,resp_dipole,am1bcc_esp_rms,espaloma_esp_rms,resp_esp_rms,mbis_esp_rms,molecule,geometry,grid,qm_esp,riniker_esp_rms
0,"[-0.307478181466797, -0.32475236546718506, 0.0...","[-0.1269, -0.1269, -0.111, -0.0814, 0.1528, -0...","[-0.17321058894906724, -0.17321058894906724, -...","[-0.3678119480609894, -0.37190431356430054, 0....","[-0.3026, -0.3026, 0.1671, -0.0841, 0.0157, -0...",0.463489,0.328383,0.411711,1.121593,0.320236,0.474776,1.711427,4.318221,1.336640,1.854834,[C:1]1([H:9])([H:10])[C:2]([H:11])([H:12])[C:3...,"[-0.7528675459701196, -1.1051199921456485, -1....","[[-4.618226513262415, 0.6687491157891656, -0.0...","[0.003807534095750853, 0.005536206114819642, 0...",0.928624
1,"[-0.27024134000385275, -0.2514930826279807, 0....","[-0.1094, -0.0984, 0.1401, -0.4256, 0.1264, 0....","[-0.22874532377018647, -0.18081164097084718, 0...","[-0.31097567081451416, -0.3029680848121643, 0....","[-0.1146, -0.0616, 0.4422, -0.553, 0.1441, 0.0...",0.424558,0.453174,0.424885,0.729356,0.384860,0.437629,1.308036,3.308412,0.787228,1.472415,[C:1]1([H:8])([H:9])[C:2]([H:10])([H:11])[C:3]...,"[0.9644572333804696, -0.4156491196082709, 0.32...","[[-3.135893501910544, -4.689287925870755, 1.74...","[-0.012030632055308743, -0.012702403534801121,...",0.980107
2,"[-0.22269026300546416, 0.0986548857271199, -0....","[-0.0894, 0.1548, -0.7962, 0.1548, -0.0894, -0...","[-0.07576356260549455, 0.0697487351440248, -0....","[-0.30706286430358887, 0.01613098382949829, -0...","[-0.055, 0.0874, -0.8195, 0.0874, -0.055, 0.10...",0.482231,0.612847,0.657241,0.245546,0.616496,0.497217,1.766681,4.192248,1.061804,1.769024,[C:1]1([H:11])([H:12])[C:2]([H:13])([H:14])[N:...,"[-0.04154411000466645, 1.3246999855602708, -0....","[[-4.6840821751273385, -1.1985072844843319, 1....","[0.0002998244321243959, -0.0002435208943376565...",0.983287
3,"[0.18318003045342993, -0.2598634431943842, -0....","[0.1214, -0.0273, -0.3572, 0.1952, -0.3976, 0....","[0.10370072921117147, -0.2081667164961497, -0....","[0.07321783900260925, -0.32436901330947876, -0...","[-0.074, 0.0126, -0.2653, 0.0956, -0.3139, 0.0...",0.743610,0.811633,1.006174,1.447500,0.817276,0.717458,2.712812,4.618196,1.487021,1.810721,[C:1]1([H:8])([H:9])[C:2]([H:10])([H:11])[S:3]...,"[0.9020521480808357, -0.6100668318963072, -0.2...","[[-3.4603634180881917, -4.794984987767819, -3....","[-0.01648513102157345, -0.016546033230945945, ...",1.183654
4,"[-0.20271783777574814, 0.11143350543999422, -0...","[-0.0914, 0.1428, -0.7922, 0.1578, -0.138, 0.1...","[-0.12884797396040276, -0.012868405858937063, ...","[-0.27458471059799194, 0.034048937261104584, -...","[-0.153, 0.2201, -0.809, 0.1195, 0.2035, 0.096...",1.166569,1.194966,1.317575,0.915113,1.009524,1.239203,2.152583,6.542101,1.646385,1.397241,[C:1]1([H:9])([H:10])[C:2]([H:11])([H:12])[N:3...,"[0.640770333977381, 0.6951131028378826, 1.1031...","[[-5.101983923803556, -1.0107996737098657, -0....","[-0.026226996881657527, -0.024682465380372776,...",1.584835
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
34923,"[-0.4307936125045173, 0.35266299719448524, -0....","[-0.10315252631578949, 0.13044747368421053, -0...","[-0.16938045071928123, 0.23442702544362923, -0...","[-0.5688874125480652, 0.35258910059928894, -0....","[-0.3234, 0.422, -0.3234, 0.0835, -0.1508, -0....",0.869328,0.872096,1.217937,1.813061,0.668468,1.099153,2.691785,7.475896,1.678970,1.273094,[C:1]([C:2]([C:3]([H:13])([H:14])[H:15])([C:4]...,"[-0.10727328016408692, -0.8953934001088166, 0....","[[-3.7874826050688455, -0.902746520613953, 2.4...","[-0.0055729953891621165, -0.002821495455125777...",1.554715
34924,"[-0.316128375134428, -0.30531925078170524, 0.5...","[-0.08515876470588235, -0.2840587647058823, 0....","[-0.11128084328682984, -0.2857638296178159, 0....","[-0.41569340229034424, -0.336565226316452, 0.4...","[-0.2141, 0.1522, 0.4158, -0.4675, 0.4107, -0....",1.686318,1.713325,1.565514,2.143319,1.563467,1.668456,2.146896,6.335759,0.789686,1.393538,[C:1]([C:2]([C:3](=[O:4])[H:14])([C:5](=[N:6][...,"[1.5695909800648318, -0.6741383385672546, -1.1...","[[-4.805235936047538, 0.958310833137348, -0.95...","[-0.004068417656111656, -0.0004770630350874327...",1.319067
34925,"[-0.446016524842485, 0.4470529110725375, -0.40...","[-0.10454736842105264, 0.23015263157894736, -0...","[-0.16086909488627785, 0.28039979620983724, -0...","[-0.5875690579414368, 0.4356086552143097, -0.5...","[-0.3489, 0.6182, -0.3489, -0.2262, 0.0329, -0...",0.875990,0.837351,0.886616,0.316271,0.765047,1.067602,1.592437,5.912196,1.543108,1.421052,[C:1]([C:2]([C:3]([H:13])([H:14])[H:15])([C:4]...,"[-0.12929613281863148, -0.9008950227326125, 0....","[[-3.731207475715633, -0.8517897910100047, 2.0...","[-0.00211648004266074, -0.004617378798517535, ...",1.000175
34926,"[-0.3033673079675967, 0.126306495949047, -0.70...","[-0.11685454545454546, 0.15534545454545454, -0...","[-0.15900079397992653, 0.0887342444197698, -0....","[-0.35558021068573, 0.039914488792419434, -0.6...","[-0.1849, 0.2184, -0.8421, 0.2184, -0.1849, 0....",1.491349,1.537531,1.622049,2.268048,1.488856,1.721564,1.839313,5.170227,1.885332,1.597153,[C:1]1([H:12])([H:13])[C:2]([H:14])([H:15])[N:...,"[0.49443089547526725, 1.0991997480138507, 0.30...","[[-4.341168544169625, -1.0123746616058171, 1.1...","[0.021969246435640244, 0.02297250447568544, 0....",1.395748


In [5]:
[sec.tolist() for sec in df.iloc[1]['grid'].tolist()]

[[-3.135893501910544, -4.689287925870755, 1.7470477911189017],
 [-3.135893501910544, -4.689287925870755, 2.454154572305449],
 [-3.135893501910544, -4.335734535277481, 2.1006011817121752],
 [-3.135893501910544, -4.335734535277481, 2.807707962898723],
 [-3.135893501910544, -3.9821811446842084, 1.7470477911189017],
 [-3.135893501910544, -3.9821811446842084, 2.454154572305449],
 [-3.135893501910544, -3.9821811446842084, 3.161261353491996],
 [-3.135893501910544, -3.6286277540909344, 2.1006011817121752],
 [-3.135893501910544, -3.6286277540909344, 2.807707962898723],
 [-3.135893501910544, -3.2750743634976605, 2.454154572305449],
 [-3.135893501910544, -2.921520972904387, -0.020719161847466916],
 [-3.135893501910544, -2.921520972904387, 0.6863876193390805],
 [-3.135893501910544, -2.5679675823111134, 0.3328342287458068],
 [-3.135893501910544, -2.2144141917178395, -0.020719161847466916],
 [-3.135893501910544, -2.2144141917178395, 0.6863876193390805],
 [-2.7823401113172705, -5.749948097650577, 2.4

In [7]:
df.iloc[0]['qm_esp']

array([0.00380753, 0.00553621, 0.00572025, ..., 0.00240039, 0.00112108,
       0.00205172])

In [10]:
from openff.nagl import GNNModel
nagl_model = GNNModel.load("/Users/k2584788/.local/share/mamba/envs/charge_model_env/lib/python3.11/site-packages/openff/nagl_models/models/am1bcc/openff-gnn-am1bcc-0.1.0-rc.3.pt")


In [None]:
nagl_model.compute_properties(Molecule.from_smiles('CCO'))

{'am1bcc_charges': array([-0.09629,  0.13245, -0.60293,  0.04465,  0.04465,  0.04465,
         0.01728,  0.01728,  0.39826], dtype=float32)}

In [18]:
dir_path = openff.nagl_models.get_nagl_model_dirs_paths()[0]
nagl_model = GNNModel.load(str(dir_path) + "/openff-gnn-am1bcc-0.1.0-rc.3.pt")

In [None]:
dir_path = openff.nagl_models.get_nagl_model_dirs_paths()[0]


'/Users/k2584788/.local/share/mamba/envs/charge_model_env/lib/python3.11/site-packages/openff/nagl_models/models/am1bcc'