Skip to content

Commit

Permalink
Add README
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586754774
  • Loading branch information
JesseFarebro committed Feb 20, 2024
1 parent 088022c commit 3d4dacd
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 11 deletions.
Binary file added .github/assets/method-overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 93 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,97 @@
[![Unittests](https://github.com/google/putting-dune/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google/putting-dune/actions/workflows/pytest_and_autopublish.yml)
[![PyPI version](https://badge.fury.io/py/putting_dune.svg)](https://badge.fury.io/py/putting_dune)

This repository contains the code of the paper ["Learning and Controlling Silicon Dopant Transitions in Graphene using Scanning Transmission Electron Microscopy"](https://arxiv.org/abs/2311.17894).
The so-called Putting Dune, provides the simulator and methods to learn transition rates of 3-fold silicon-doped graphene.

<img alt="Method Overview" src="/.github/assets/method-overview.png" />

## Data Representation

The protocol buffer data representation used in Putting Dune can be found in
`putting_dune/putting_dune.proto`. You can find a one-to-one correspondence
of the protocol buffer messages and a Python dataclass in
`putting_dune/microscope_utils.py`.

When performing alignment and rate learning we expect a sequence of serialized
`Trajectory` objects. As an example, we can serialize trajectories as follows:

```py
from putting_dune import io as pdio
from putting_dune import microscope_utils

trajectories = [
microscope_utils.Trajectory(
observations=[
microscope_utils.Observation(...),
...,
],
),
...,
]

pdio.write_records("my-recorded-trajectories.tfrecords", trajectories)
```

## Image Aligner

The first step in our pipeline is to perform image alignment.
To train the image alignment model you can follow the steps in
`putting_dune/image_alignment/train.py`.

Once the image aligner is trained you can perform image alignment on the
recorded trajectories via the script `putting_dune/pipeline/align_trajectories.py`.
For example,

```sh
python -m putting_dune.pipeline.align_trajectories \
--source_path my-recorded-trajectories.tfrecords \
--target_path my-aligned-recorded-trajectories.tfrecords \
--aligner_path my_trained_aligner \
--alignment_iterations 5
```

## Rate Learner

Once the trajectories have been aligned you can now train the rate model.
This can be done with `putting_dune/pipeline/train_rate_learner.py`.
For example,

```sh
python -m putting_dune.pipeline.train_rate_learner \
--source_path my-aligned-recorded-trajectories.tfrecords \
--workdir my-rate-model
```

Once training is complete there will be various plots and checkpoints
that are saved to the working directory. This model can then be used
to derive a greedy controller or predict learned rates.

## Citation

```bib
@article{schwarzer23stem,
author = {
Max Schwarzer and
Jesse Farebrother and
Joshua Greaves and
Ekin Dogus Cubuk and
Rishabh Agarwal and
Aaron Courville and
Marc G. Bellemare and
Sergei Kalinin and
Igor Mordatch and
Pablo Samuel Castro and
Kevin M. Roccapriore
},
title = {Learning and Controlling Silicon Dopant Transitions in Graphene
using Scanning Transmission Electron Microscopy},
journal = {CoRR},
volume = {abs/2311.17894},
year = {2023},
}
```

## Note

*This is not an officially supported Google product.*
14 changes: 6 additions & 8 deletions putting_dune/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import collections
import copy
import functools
import io
import typing
from typing import Any, Deque, Optional, Sequence, Tuple
import urllib
import zipfile

import cv2
Expand Down Expand Up @@ -460,19 +458,19 @@ def process_detection_predictions(
return microscope_utils.AtomicGridMicroscopeFrame(grid)

@classmethod
def from_url(
def from_path(
cls,
url: str = 'https://storage.googleapis.com/spr_data_bucket_public/alignment/20230403-image-aligner.zip',
path: epath.Path,
workdir: Optional[str] = None,
reload: bool = False,
**kwargs,
) -> 'ImageAligner':
"""Construct model from URL.
Args:
url: Model URL, expected to be a zip file.
path: Model path, expected to be a zip file.
workdir: Optional, locatioon (e.g., temp dir) to extract weights to.
reload: Optional, whether to force-redownload aligner.
reload: Optional, whether to force-reload the aligner.
**kwargs: Optional arguments for the ImageAligner.
Returns:
Expand All @@ -484,8 +482,8 @@ def from_url(
model_path = epath.Path(workdir) / 'model_weights' / 'image-alignment-model'
if not model_path.exists() or reload:
model_path.mkdir(parents=True, exist_ok=True)
with urllib.request.urlopen(url) as request:
with zipfile.ZipFile(io.BytesIO(request.read())) as model_zip:
with path.open('rb') as model_zip_fp:
with zipfile.ZipFile(model_zip_fp) as model_zip:
model_zip.extractall(model_path.parent)

return ImageAligner(model_path=model_path, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions putting_dune/pipeline/align_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Args:

source_path: epath.Path
target_path: epath.Path
aligner_url: str
aligner_path: epath.Path
history_length: int = 5
alignment_iterations: int = 1
base_step_size: float = 1
Expand Down Expand Up @@ -101,8 +101,8 @@ def main(args: Args) -> None:
trajectories.extend(pdio.read_records(file, microscope_utils.Trajectory))

with tempfile.TemporaryDirectory() as tmpdir:
aligner = alignment.ImageAligner.from_url(
args.aligner_url, workdir=tmpdir, hybrid=args.hybrid
aligner = alignment.ImageAligner.from_path(
args.aligner_path, workdir=tmpdir, hybrid=args.hybrid
)

aligned_trajectories = []
Expand Down

0 comments on commit 3d4dacd

Please sign in to comment.