Skip to content

Commit

Permalink
CHGNet-matgl implementation (#242)
Browse files Browse the repository at this point in the history
* ENH: fixing chgnet dset

* MAINT: create tensors in lg device

* MAINT: use register buffer in Potential and LightningPotential

* MAIN: rename chgnet graph feats

* FIX: clamp cos values to -1, 1 with eps

* ENH: start implementing chgnetdset

* Fix loading graphs

* use dgl path attrs in chgnet dataset

* TST: add chgnetdataset test and fix errors

* TST assert that unnormalized predictions are not the same

* TST: clamp cos values to -1, 1 with eps in tests

* ENH: use torch.nan for None magmoms

* BUG: fix setting lg node data

* use no_grad in directed line graph

* FIX: set lg data using num nodes

* TST: test up to 4 decimals

* MAINT: update to renamed DEFAULT_ELEMENTS

* FIX: directed lg compatibility

* maint: update to new dataset interface

* MAINT: update to new dataset interface

* TST: fix graph test

* MAINT: minor edit in directed line graph

* update to use dtype interface

* add tol to threebody cutoff

* add tol to threebody cutoff

* FiX: remove tol and set pbc_offshift to float64

* ENH: chunked chgnet dataset

* remove state attr in has_cache

* fix chunk_sizes

* trange when loading indices

* singular keys in collate

* hard code label keys

* run pre-commit

* change chgnet default elements

* FIX: create nan tensor for missing magmoms

* add tol to threebody cutoff

* add tol to threebody cutoff

* FiX: remove tol and set pbc_offshift to float64

* ENH: chunked chgnet dataset

* remove state attr in has_cache

* fix chunk_sizes

* trange when loading indices

* singular keys in collate

* hard code label keys

* run pre-commit

* change chgnet default elements

* FIX: nan tensor shape

* FIX: allow skipping nan tensors

* add xavier normal and update chunked dataset

* fix getitem

* fix getitem

* fix getitem

* fix getitem

* fix getitem

* fix getitem

* huber loss

* MAINT: use torch instead of numpy

* MAINT: keep onehot matrix as attribute

* MAINT: remove unnecessary statements

* MAINT: remove unnecessary statements

* MAINT: onehot as buffer

* MAINT: property offset as buffer

* MAINT: onehot as buffer

* MAINT: property offset as buffer

* change order in init

* TST update tests

* ENH use lstsq to avoid constructing full normal eqs

* change order in init

* TST update tests

* ENH use lstsq to avoid constructing full normal eqs

* remove numpy import

* remove print

* STY: fix lint

* FIX: backwards compat with pre-trained models

* ENH: raise load_model error from baseexception

* TST: fix atomref tests

* STY: ruff

* FIX: use tuple in isinstance for 3.9 compat

* remove numpy import

* STY: ruff

* remove numpy import

* STY: ruff

* remove assert in compat (fails for some batched graphs)

* ENH: messy graphnorm mess

* FIX: fix allow missing labels

* use lg num_nodes() directly

* use lg num_nodes() directly

* do not assert

* FIX: fix ensuring line graph for bonds right at cutoff

* remove numpy import

* STY: ruff

* Remove wheel and release.

* Bump pymatgen from 2023.9.2 to 2023.9.10 (#162)

Bumps [pymatgen](https://github.com/materialsproject/pymatgen) from 2023.9.2 to 2023.9.10.
- [Release notes](https://github.com/materialsproject/pymatgen/releases)
- [Changelog](https://github.com/materialsproject/pymatgen/blob/master/CHANGES.md)
- [Commits](materialsproject/pymatgen@v2023.9.2...v2023.9.10)

---
updated-dependencies:
- dependency-name: pymatgen
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Add united test for trainer.test and description in the example (#165)

* ENH: allow skipping label keys

* use tuple

* ENH: allow skipping label keys

* use tuple

* use skip labels in chunked dataset

* add empty axis to magmoms

* add empty axis to magmoms

* ENH: graph norm implementation

* TST: add graph_norm test

* remove adding extra axis to magmoms

* remove adding extra axis to magmoms

* add skip label keys to chunked dataset

* fix chunked dset

* add OOM dataset

* len w state_attr

* int idx

* increase compatibility tol

* lintings

* STY: fix some linting errors

* STY: fix mypy errors

* remove numpy import

* STY: ruff

* remove numpy import

* STY: ruff

* TYP: use Sequence instead of list

* lint

* MAINT: use sequential in MLP

* ENH: norm gated MLP

* MAINT: use sequential in MLP

* store linear layers and activation separately in MLP

* use MLP in gated MLP

* remove unnecessary Sequential

* correct magmom training index!

* revert magmom index bc it was correct!

* ENH: graphnorm in mlp and gmlp

* remove numpy import

* STY: ruff

* remove numpy import

* STY: ruff

* FIX: remove repeated bond expansion

* hack to load new state dicts in PL checkpoints

* allow site_wise loss options

* only set grad enabled in forward

* adapt core to allow normalization of different layers

* remove some TODOS

* allow normalization in chgnet

* always normalize last

* always normalize last

* fix normalization inputs

* fix mlp forward

* fix mlp forward

* messy norm

* allow norm kwargs and allow batching by edges or nodes in graphnorm

* test graphnorm

* graph norm in chgnet

* allow layernorm in chgnet

* allow layernorm in chgnet

* rename args

* rename args

* fix mypy errors

* add tolerance in lg compatibility

* add tolerance in lg compatibility

* raise runtime error for incompatible graph

* raise runtime error for incompatible graph

* create tensors on same device in norm

* create tensors on same device in norm

* update chgnet to use new line graph interface

* update chgnet paper link

* update line graph in dataset

* no bias in output of conv layers

* some docstrings

* moved mlp_out from InteractionBlock to ConvFunctions and added non-linearity

* fix typo

* moved out_layer to linear

* solved bug

* solved bug

* removed normalization from bondgraph layer

* uploaded pretrained model and modified ASE interface

* fix linting

* fixed chgnet dataset by adding lattice

* hot fix

* add frac_coords to pre-processed graphs

* hot fix

* solved bug

* remove ignore model

* add 11M model weights

* renamed pretrained weights

* Adding CHGNet-matgl implementation

* corrected texts and comments

* fix more texts

* more texts fixes

* refactor CHGNet path in test

* fixed linting

* fixed texts

* remove unused CHGNetDataset

* restructure matgl modules for CHGNet implementations

* fix ruff

* update model versioning for Potential class

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: lbluque <lbluque@berkeley.edu>
Co-authored-by: Shyue Ping Ong <shyuep@users.noreply.github.com>
Co-authored-by: Shyue Ping Ong <sp@ong.ai>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com>
Co-authored-by: lbluque <lbluque@meta.com>
Co-authored-by: kenko911 <kenko911@gmail.com>
  • Loading branch information
8 people committed May 6, 2024
1 parent 8ed58f9 commit 3d94dd4
Show file tree
Hide file tree
Showing 38 changed files with 2,334 additions and 1,357 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
"predicted = []\n",
"mp = []\n",
"os.environ[\"MPRESTER_MUTE_PROGRESS_BARS\"] = \"true\"\n",
"mpr = MPRester(\"FwTXcju8unkI2VbInEgZDTN8coDB6S6U\")\n",
"mpr = MPRester(\"YOUR_API_KEY\")\n",
"\n",
"# Load the pre-trained M3GNet Potential\n",
"pot = matgl.load_model(\"M3GNet-MP-2021.2.8-PES\")\n",
Expand Down Expand Up @@ -265,7 +265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.14"
},
"vscode": {
"interpreter": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
"\n",
"import matgl\n",
"from matgl.ext.pymatgen import Structure2Graph, get_element_list\n",
"from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_efs\n",
"from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes\n",
"from matgl.models import M3GNet\n",
"from matgl.utils.training import PotentialLightningModule\n",
"from matgl.config import DEFAULT_ELEMENTS\n",
"\n",
"# To suppress warnings for clearer output\n",
"warnings.simplefilter(\"ignore\")"
Expand Down Expand Up @@ -123,7 +124,7 @@
},
"outputs": [],
"source": [
"element_types = get_element_list(structures)\n",
"element_types = DEFAULT_ELEMENTS\n",
"converter = Structure2Graph(element_types=element_types, cutoff=5.0)\n",
"dataset = MGLDataset(\n",
" threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True\n",
Expand All @@ -134,7 +135,7 @@
" shuffle=True,\n",
" random_state=42,\n",
")\n",
"my_collate_fn = partial(collate_fn_efs, include_line_graph=True)\n",
"my_collate_fn = partial(collate_fn_pes, include_line_graph=True)\n",
"train_loader, val_loader, test_loader = MGLDataLoader(\n",
" train_data=train_data,\n",
" val_data=val_data,\n",
Expand Down Expand Up @@ -239,7 +240,7 @@
"source": [
"# save trained model\n",
"model_export_path = \"./trained_model/\"\n",
"model.save(model_export_path)\n",
"lit_module.model.save(model_export_path)\n",
"\n",
"# load trained model\n",
"model = matgl.load_model(path=model_export_path)"
Expand Down Expand Up @@ -335,7 +336,7 @@
"source": [
"# save trained model\n",
"model_save_path = \"./finetuned_model/\"\n",
"model_pretrained.save(model_save_path)\n",
"lit_module_finetune.model.save(model_save_path)\n",
"# load trained model\n",
"trained_model = matgl.load_model(path=model_save_path)"
]
Expand Down Expand Up @@ -382,7 +383,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 3d94dd4

Please sign in to comment.