### Train

In [None]:
from gnn_lib.models import CGCNN
from gnn_lib.data_utils import get_atomic_types_mapper, build_dataloader
from gnn_lib.utils import load_config
from gnn_lib.training import Trainer

config = load_config('config.yaml') # see configs/config_cgcnn.yaml
mapper = get_atomic_types_mapper(config)
model = CGCNN.from_config(mapper, config)

print(model.size())

train_loader, val_loader = build_dataloader(config, 'train'), build_dataloader(config, 'val')
trainer = Trainer(model, config, verbose=False)
trainer.train(train_loader, val_loader)

### Eval

In [None]:
device = 'cuda'

In [None]:
import torch
from gnn_lib.models import CGCNN
from gnn_lib.utils import load_config
from gnn_lib.data_utils import get_atomic_types_mapper, build_dataloader

config = load_config('config.yaml')
mapper = get_atomic_types_mapper(config)
model = CGCNN.from_config(mapper, config)
model.from_checkpoint('checkpoints/best_checkpoint.pt')
test_loader = build_dataloader(config, 'test')

preds = []
labels = []
model.to(device)
model.eval()
with torch.no_grad():
    for batch in test_loader:
        out = model(batch.to(device))
        preds.extend(out.detach().cpu().flatten().numpy())
        labels.extend(batch.energy.detach().cpu().numpy())

In [None]:
from gnn_lib.metrics import get_metrics

get_metrics(labels, preds)