In [1]:
import numpy as np
import pickle

from deepchem.molnet import load_tox21

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
_, tox21_datasets, _ = load_tox21(featurizer='GraphConv')
train_dataset, val_dataset, test_dataset = tox21_datasets

In [3]:
X = np.concatenate((train_dataset.X, val_dataset.X, test_dataset.X))
y = np.concatenate((train_dataset.y, val_dataset.y, test_dataset.y))
adj_list = [x.get_adjacency_list() for x in X]
print(X.shape, y.shape)

(7831,) (7831, 12)


In [4]:
node_attributes = np.concatenate([x.get_atom_features() for x in X])
print(node_attributes.shape)
with open('./datas/Tox21/Tox21_node_attributes.pkl', 'wb') as f:
    pickle.dump(node_attributes, f)
node_attributes

(145459, 75)


array([[1., 0., 0., ..., 0., 1., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [5]:
graph_labels = y.argmax(-1)
print(graph_labels.shape)

with open('./datas/Tox21/Tox21_graph_labels.txt', 'w') as f:
    for label in graph_labels:
        f.write(f'{label}\n')

graph_labels

(7831,)


array([0, 0, 0, ..., 2, 7, 0])

In [6]:
graph_indicator = np.array([i + 1 for i in range(len(adj_list)) for j in range(len(adj_list[i]))])
print(graph_indicator.shape)

with open('./datas/Tox21/Tox21_graph_indicator.txt', 'w') as f:
    for graph_idx in graph_indicator:
        f.write(f'{graph_idx}\n')
        
graph_indicator

(145459,)


array([   1,    1,    1, ..., 7831, 7831, 7831])

In [7]:
def adj_list_edge_list(adj_list):
    n = len(adj_list)
    A = []
    for u, adj_u in enumerate(adj_list):
        for v in adj_u:
            A.append([u + 1, v + 1])
    return np.array(A)

A = [adj_list_edge_list(x.get_adjacency_list()) for x in X]
A = [el for el in A if el.size > 0]
A = np.concatenate(A)

with open('./datas/Tox21/Tox21_A.txt', 'w') as f:
    for u, v in A:
        f.write(f'{u}, {v}\n')

A

array([[ 1, 10],
       [ 2, 11],
       [ 3, 11],
       ...,
       [44,  4],
       [44, 10],
       [44, 19]])

In [11]:
y.shape

(7831, 12)

In [13]:
X[0].get_atom_features().shape

(11, 75)