# Creating a dataset

In [1]:
import sys
import numpy as np
from importlib import reload
import networkx as nx

import torch
from torch.utils.data import DataLoader

sys.path.append('../')

In [2]:
import nff.data as d

### Example data: ethanol (from Wujie)

In [3]:
ethanol_data = np.load('ethanol_ccsd_t-train.npz')
nxyz_data = np.dstack((np.array([ethanol_data.f.z]*1000).reshape(1000, -1, 1), np.array(ethanol_data.f.R)))
force_data = ethanol_data.f.F
energy_data = ethanol_data.f.E.squeeze() - ethanol_data.f.E.mean()
smiles_data = ["COC"] * 1000

In [4]:
props = {
    'nxyz': nxyz_data,
    'energy': energy_data,
    'energy_grad': [-x for x in force_data],
    'smiles': smiles_data
}

### data.dataset

In [5]:
dataset = d.Dataset(props.copy(), units='atomic')

In [7]:
dataset[0]

{'nxyz': tensor([[ 6.0000e+00,  5.5206e-03,  5.9149e-01, -8.1382e-04],
         [ 6.0000e+00, -1.2536e+00, -2.5536e-01, -2.9801e-02],
         [ 8.0000e+00,  1.0878e+00, -3.0755e-01,  4.8230e-02],
         [ 1.0000e+00,  6.2821e-02,  1.2838e+00, -8.4279e-01],
         [ 1.0000e+00,  6.0567e-03,  1.2303e+00,  8.8535e-01],
         [ 1.0000e+00, -2.2182e+00,  1.8981e-01, -5.8160e-02],
         [ 1.0000e+00, -9.1097e-01, -1.0539e+00, -7.8160e-01],
         [ 1.0000e+00, -1.1920e+00, -7.4248e-01,  9.2197e-01],
         [ 1.0000e+00,  1.8488e+00, -2.8632e-02, -5.2569e-01]]),
 'energy': tensor(-2742.2935),
 'energy_grad': tensor([[  19.2809,  -63.6013,    3.6154],
         [-112.6047,   80.1408,   19.0383],
         [ -69.8792,    3.7907,   35.5264],
         [  -3.4364,   10.6967,    8.8279],
         [  -3.5309,    1.8254,   -3.7656],
         [  20.8410,  -22.0867,   34.9517],
         [  71.8365,  -46.4502,  -31.4148],
         [  11.0559,   29.6765,  -24.5629],
         [  66.4369,    6

In [None]:
len(dataset)

In [None]:
dataset.generate_neighbor_list(cutoff=5)

In [None]:
dataset[0]

### Plotting an example graph from the neighbor list

In [None]:
%matplotlib inline
nbr_list = dataset[0]['nbr_list'].numpy()
G = nx.from_edgelist(nbr_list)
nx.draw_kamada_kawai(G)

## data.loader

In [None]:
loader = DataLoader(dataset, batch_size=5, collate_fn=d.collate_dicts)

In [None]:
len(loader)

In [None]:
next(iter(loader))

# Loading/saving dataset from file

In [None]:
dataset = d.Dataset.from_file('dataset.pth.tar')

In [None]:
dataset.save('dataset.pth.tar')