# 01: QM9 Loader Tutorial
In this notebook, we demonstrate how to utilize this repository to load and analyze data in the `QM9` dataset, which can be downloaded [here](http://quantum-machine.org/datasets/). 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
complete_path = os.getcwd()
if 'tutorials' in complete_path:
    os.chdir("..")

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import networkx as nx
import dgl
import os

In [None]:
from crescendo.datasets.qm9 import QMXDataset
# from crescendo.models.mpnn import MPNN  # MATT TODO

## Load the data
Here, we load the internally-stored testing data from the local `data` directory, but the user can always modify this path to point to their own copy of the `QM9` database. Note that the `debug=9` flag below contstrains `load` to loading a maximum of 10 data points, thus speeding up the runtime. For a real analysis of the entire database, one should leave `debug` to it's default value of `-1`.

### Default path
Note that by default, the program will look for a `$QM9_DATA_PATH` environment variable, and if that is not set, it will require a `path` parameter to be passed to the `load()` method. This environment variable can be easily set permanently by:
1. Go to your `.bash_profile`
2. Add a line `export QM9_DATA_PATH='/my/path/to/qm9`

In [None]:
# Just in case the user accidentally load from their environment
# variable path, let's just load only some structures max by default
LOAD_MAXIMUM = 20

In [None]:
test_path = 'data/qm9_test_data'
qm9_dat = QMXDataset(debug=LOAD_MAXIMUM)

The `load` method contains a lot of useful flags for loading subsets of `QM9`.
* `max_heavy_atoms`: the maximum number of C, N, O and F allowed. Default is 9, which corresponds to the entire `QM9` dataset.
* `keep_zwitter`: if `True`, will keep [Zwitterionic](https://en.wikipedia.org/wiki/Zwitterion) compounds. Default is `False`.
* `canonical`: if `True`, will load the canonical `SMILES` string as opposed to the normal one. Default is `True`. 

In [None]:
qm9_dat.load(path=test_path)

## Examining the raw data
The raw data loaded into the `qm9_dat` object is contained in `qm9_dat.raw` dictionary and consists of `QM9SmilesDatum` objects indexed by keys corresponding to the QM9 ID's. We can do quite a few things with the data, including analysis methods:
* `has_n_membered_ring`
* `is_aromatic`
* `has_double_bond`
* etc...

and a `to_graph` method

In [None]:
qm9_dat.analyze()

## Featurize the data
The `QMXDataset` class offers functionality for loading in auxiliary data from other datasets and intersecting it with the QM9 dataset. This is a critical feature to have for machine learning because we will want to pair features and targets together. For example, the QM8 dataset is a subset of QM9 in terms of the overlap of QM9 ids. The `qm8_EP` featurizer method handles this.

### Load in the QM8 data
We first need to load in the QM8 data before we execute the featurizer.

In [None]:
qm9_dat.load_qm8_electronic_properties(path='data/qm8_test_data.txt')

## Make ML-ready
The `ml_ready` method will pair the QM9 data with the QM8 electronic information. Note that we need to tell the `ml_ready` method which node and edge features to use when featurizing the structures.

In [None]:
qm9_dat.ml_ready(
    'qm8_EP',
    atom_feature_list=['type', 'hybridization'],
    bond_feature_list=['type']
)

### Get the `DataLoaders`
`PyTorch` has an incredibly useful piece of functionality called the `DataLoader`, which basically handles all nuances of batch training for the user.

In [None]:
loaders = qm9_dat.get_data_loaders(
    p_tvt=(0.2, 0.2, 0.6),
    seed=123,
    method='random'
)

In [None]:
train_loader = loaders['train']

The first entry of every batch is the graph, the second entry is the targets and the third entry is the QM9 IDs. One can always "draw" a batch of graphs via `networkx` using something like: `nx.draw(batch.to_networkx(), with_labels=True)`.

In [None]:
for batch in train_loader:
    print(batch[0], '\n')
    print(batch[1], '\n')
    print(batch[2].tolist(), '\n')

## Training
**Matt TODO** - it will look something like this.

In [None]:
mod = MPNN(
    n_node_features=n_class_per_feature[0],
    n_edge_features=n_class_per_feature[1],
    output_size=4
)

In [None]:
res = mod.forward(g, n, e)