In [14]:
import sys
sys.path.append('/home/dboiko/simg-2')

import os
os.chdir('/home/dboiko/simg-2')

In [6]:
import logging
logging.basicConfig(level=logging.ERROR)

from tqdm.auto import tqdm
import numpy as np
import pandas as pd

import torch

from simg.model_utils import pipeline
from simg.data import get_connectivity_info
from simg.graph_construction import convert_NBO_graph_to_downstream

QM7_GRAPH_PATH = "/home/dboiko/SPAHM/SPAHM/1_QM7/xyz"

def xyz_to_graph(xyz_path, target, type_):
    with open(xyz_path, "r") as f:
        xyz = f.read()

    xyz_data = [l + '\n' for l in xyz.split('\n')[2:-1]]
    symbols = [l.split()[0] for l in xyz_data]
    coordinates = np.array([[float(num) for num in l.strip().split()[1:]] for l in xyz_data])
    connectivity = get_connectivity_info(xyz_data)

    graph, _, _, (a2b_preds, node_preds, int_preds) = pipeline(symbols, coordinates, connectivity, use_threshold=False)

    graph.y = torch.FloatTensor(node_preds)
    graph.a2b_targets = torch.FloatTensor(a2b_preds)
    graph.interaction_targets = torch.FloatTensor(int_preds)
    graph.qm9_id = xyz_path.split('/')[-1].split('.')[0]

    graph.type = type_
    graph.normalized_targets = torch.FloatTensor([target])
    graph = convert_NBO_graph_to_downstream(graph)

    return graph

In [7]:
all_graphs = []
for file in tqdm(os.listdir(QM7_GRAPH_PATH)):
    if not file.endswith('.xyz'):
        continue
    all_graphs.append(
        xyz_to_graph(os.path.join(QM7_GRAPH_PATH, file), 0, 'train')
    )

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 7165/7165 [05:17<00:00, 22.55it/s]


In [8]:
torch.save(
    all_graphs,
    '/home/dboiko/reproducing_experiments/qm7_graphs_from_repo.pt'
)

In [13]:
list(np.loadtxt('/home/dboiko/SPAHM/test-indices/QM7_test_indices.dat').astype(int))

[1892,
 3754,
 834,
 4871,
 257,
 2396,
 4339,
 635,
 451,
 882,
 5052,
 864,
 1644,
 5803,
 2913,
 2310,
 311,
 3039,
 462,
 2004,
 5063,
 6510,
 5676,
 4818,
 1278,
 6350,
 521,
 5625,
 6759,
 4766,
 916,
 4757,
 6106,
 2591,
 1294,
 4616,
 4252,
 3194,
 202,
 2175,
 2735,
 4639,
 6887,
 3616,
 7099,
 2232,
 394,
 2241,
 5959,
 1788,
 2853,
 2738,
 4793,
 3509,
 1736,
 3060,
 4486,
 3579,
 963,
 3648,
 2806,
 5422,
 6015,
 3453,
 6724,
 6523,
 134,
 3266,
 217,
 976,
 2447,
 5789,
 6157,
 1861,
 1379,
 3803,
 200,
 1702,
 4076,
 5506,
 6283,
 1718,
 541,
 295,
 2784,
 378,
 6463,
 5455,
 6873,
 2939,
 144,
 4780,
 6577,
 5717,
 5106,
 7146,
 2137,
 1457,
 1281,
 1578,
 2141,
 5333,
 5355,
 578,
 1741,
 3654,
 3294,
 3303,
 938,
 533,
 4498,
 5335,
 2584,
 7040,
 2346,
 4631,
 2731,
 4712,
 4200,
 6880,
 5969,
 5075,
 6983,
 6805,
 39,
 6896,
 5165,
 464,
 1656,
 1259,
 4752,
 4800,
 3688,
 4668,
 5978,
 1216,
 613,
 3663,
 4314,
 1828,
 1893,
 779,
 1696,
 7016,
 4869,
 4002,
 7087,


In [9]:
all_graphs[0]

Data(qm9_id='0001', x=[21, 58], edge_index=[2, 64], edge_attr=[64, 36], is_atom=[9], is_lp=[9], is_bond=[9], type='train', normalized_targets=[1], symbol=[9])

In [10]:
files

Unnamed: 0,0,1,2,type
0,0001.xyz,-417.031,-431.787272,train
1,0002.xyz,-711.117,-730.835895,test
2,0003.xyz,-563.084,-573.104249,train
3,0004.xyz,-403.695,-412.902289,train
4,0005.xyz,-858.499,-869.887692,train
...,...,...,...,...
7096,7168.xyz,-1224.850,-1339.264266,train
7097,7169.xyz,-1065.110,-1180.593348,train
7098,7170.xyz,-1309.130,-1446.772902,train
7099,7171.xyz,-1296.090,-1430.446486,train


In [12]:
files = pd.read_csv('/home/dboiko/SPAHM/qmllib/tests/assets/hof_qm7.txt', header=None, sep=' ')
files[0] = files[0].map(lambda x: x.split('/')[-1].split('.')[0])
file2y = dict(zip(files[0], files[1]))

In [13]:
y_mean = np.mean([torch.FloatTensor([file2y[graph.qm9_id]]).numpy() for graph in all_graphs], axis=0)
y_std = np.std([torch.FloatTensor([file2y[graph.qm9_id]]).numpy() for graph in all_graphs], axis=0)

In [15]:
y_mean.shape

(1,)

In [17]:
torch.FloatTensor([file2y[all_graphs[0].qm9_id]])/y_std

tensor([-1.8687])

In [18]:
torch.save(
    all_graphs[:500],
    '/home/dboiko/reproducing_experiments/qm7_graphs_500.pt'
)