In [1]:
import os
from IPython.display import Image, display, Code
os.chdir('..')

import hydra
from omegaconf import DictConfig, OmegaConf

from chemtorch.utils import DataSplit, load_model, set_seed
from chemtorch.utils.hydra import safe_instantiate

OmegaConf.register_new_resolver("eval", eval)

# ChemTorch

## How to add a new dataset

#### 1. Drag and drop you dataset into the `data` folder, e.g.: "data/e2/barriers/forward/data.csv"
#### 2. Create a config file in the `conf/data_pipeline` folder
```yaml
# conf/data_pipeline/e2.yaml
_target_: chemtorch.data_pipeline.SimpleDataPipeline

defaults:
  - data_source: single_csv_source                          # how to read the data
  - column_mapper: column_filter_and_rename                 # how to map the columns    
  - data_splitter: ratio_splitter                           # how to split the data 

data_source:
  data_path: "data/e2/barriers/forward/data.csv"

column_mapper:
  column_mapping:
    smiles: "smiles"
    label: "ea"

data_splitter:
  train_ratio: 0.9
  val_ratio: 0.05
  test_ratio: 0.05
```
#### 3. Run in command line (or change in config file or make a new config file):
```bash
python chemtorch_cli.py +experiment=graph dataset.subsample=0.05 data_pipeline=e2
```


## How to add your own model

#### 1. Drag and drop you model.py into the `src/chemtorch/model` folder, e.g.: "src/chemtorch/model/gat.py"
#### 2. Create a config file in the `conf/model/` folder
```yaml
# conf/model/gat.yaml
_target_: chemtorch.model.gnn.GNN

defaults:
  - encoder: linear_enc
  - layer_stack: layer_stack
  - layer_stack/gnn_block/gat_block@layer_stack.layer
  - pool: global_pool
  - head: mlp

hidden_channels: 64
```
#### 3. Run in command line (or change in config file or make a new config file):
```bash
python chemtorch_cli.py +experiment=graph dataset.subsample=0.05 model=gat
```

## Simple sweeps

```bash
python chemtorch_cli.py --multirun +experiment=graph dataset.subsample=0.05 model=gat,dmpnn data_pipeline=e2,sn2
```

## Override parameters

```bash
python chemtorch_cli.py --multirun +experiment=graph dataset.subsample=0.05 model=gat data_pipeline=e2 dataset/transform=randomwalk model/encoder=rwpe_and_linear_enc model.hidden_channels=128
```