Skip to content

Commit

Permalink
Merge pull request #9 from ml-struct-bio/develop
Browse files Browse the repository at this point in the history
CryoDRGN-ET
Imagesource refactor

---------

Co-authored-by: Adam Lerer <alerer@fb.com>
Co-authored-by: Adam Lerer <adam.lerer@gmail.com>
Co-authored-by: Ramya Rangan <ramyar@Ramyas-MacBook-Pro.local>
Co-authored-by: Ramya Rangan <ramyar@vpn12-client-172-20-208-6.princeton.edu>
Co-authored-by: Ramya Rangan <rr1992@della-gpu.princeton.edu>
Co-authored-by: Ramya Rangan <ramyar@vpn12-client-172-20-208-132.princeton.edu>
Co-authored-by: Ramya Rangan <ramyar@vpn12-client-172-20-208-26.princeton.edu>
Co-authored-by: Ramya Rangan <ramyar@vpn12-client-172-20-214-131.princeton.edu>
Co-authored-by: Ryan Feathers <rf2366@princeton.edu>
Co-authored-by: Michal R. Grzadkowski <mgrzad@gmail.com>
  • Loading branch information
11 people committed Sep 6, 2023
2 parents 845d2b5 + 2bdbf83 commit d3e2da4
Show file tree
Hide file tree
Showing 28 changed files with 2,156 additions and 309 deletions.
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[![CI](https://github.com/zhonge/cryodrgn/actions/workflows/main.yml/badge.svg)](https://github.com/zhonge/cryodrgn/actions/workflows/main.yml)

# :snowflake::dragon: cryoDRGN: Deep Reconstructing Generative Networks for cryo-EM heterogeneous reconstruction
# :snowflake::dragon: cryoDRGN: Deep Reconstructing Generative Networks for cryo-EM and cryo-ET heterogeneous reconstruction

CryoDRGN is a neural network based algorithm for heterogeneous cryo-EM reconstruction. In particular, the method models a *continuous* distribution over 3D structures by using a neural network based representation for the volume.

Expand All @@ -10,19 +10,28 @@ The latest documentation for cryoDRGN is available [on gitbook](https://ez-lab.g

For any feedback, questions, or bugs, please file a Github issue, start a Github discussion, or email the [google group](https://groups.google.com/g/cryodrgn).

## New in Version 2.x
## New in Version 3.x

The official cryoDRGN2 release. Version 2.x includes new tools for ab initio reconstruction and significant codebase improvements.
The official [cryoDRGN-ET](https://www.biorxiv.org/content/10.1101/2023.08.18.553799v1) release for heterogeneous subtomogram analysis.

### Version 2.3
* [NEW] Heterogeneous reconstruction of subtomograms. See documentation [on gitbook](https://ez-lab.gitbook.io/cryodrgn/)
* [NEW] `cryodrgn direct_traversal` for making movies
* Updated `cryodrgn backproject_voxel` for voxel-based homogeneous reconstruction
* Major refactor of dataset loading for handling large datasets

### Previous versions

<details><summary>Version 2.3</summary>

* Model configuration files are now saved as human-readable config.yaml files (https://github.com/zhonge/cryodrgn/issues/235)
* Fix machine stamp in output .mrc files for better compatibility with downstream tools (https://github.com/zhonge/cryodrgn/pull/260)
* Better documentation of help flags in ab initio reconstruction tools (https://github.com/zhonge/cryodrgn/issues/258)
* [FIX] By default, window images in `cryodrgn abinit_homo` (now consistent with other reconstruction tools) (https://github.com/zhonge/cryodrgn/issues/258)
* [FIX] Reduce memory usage when using `--preprocessed` and `--ind` (https://github.com/zhonge/cryodrgn/pull/272)

### Version 2.2
</details>

<details><summary>Version 2.2</summary>

* [NEW] Tools for ab initio homogeneous and heterogeneous reconstruction:

Expand All @@ -43,9 +52,7 @@ The official cryoDRGN2 release. Version 2.x includes new tools for ab initio rec

* Note: we are working on a major refactor of data loading for handling large datasets for the next minor version (v2.4). This will entail an API change for the mrc.py library module


### Previous versions

</details>

<details><summary>Version 1.1.x</summary>

Expand Down
11 changes: 8 additions & 3 deletions analysis_scripts/fsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ def main(args):
vol1 = ImageSource.from_file(args.vol1)
vol2 = ImageSource.from_file(args.vol2)

assert isinstance(vol1, np.ndarray)
assert isinstance(vol2, np.ndarray)
vol1 = vol1.images()
vol2 = vol2.images()

# assert isinstance(vol1, np.ndarray)
# assert isinstance(vol2, np.ndarray)

if args.mask:
mask = ImageSource.from_file(args.mask)
assert isinstance(mask, np.ndarray)
mask = mask.images()

# assert isinstance(mask, np.ndarray)
vol1 *= mask
vol2 *= mask

Expand Down
82 changes: 56 additions & 26 deletions analysis_scripts/plotfsc.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,72 @@
"""Plot FSC txtfile"""

import argparse

import matplotlib.pyplot as plt
import numpy as np
import argparse
import os


# Load data from file
def load_data(file):
data = np.loadtxt(file)
x = data[:, 0]
y = data[:, 1]
return x, y


# Plot data
def plot_data(x, y, label):
plt.plot(x, y, label=label)


def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", nargs="*", help="Input")
parser = argparse.ArgumentParser(description="Plot FSC data.")
parser.add_argument(
"-i", "--input", nargs="+", help="input cryoDRGN fsc text files", required=True
)
parser.add_argument(
"-t",
type=float,
default=0.143,
help="Cutoff for resolution estimation (default: %(default)s)",
"-a", "--angpix", type=float, default=0, help="physical pixel size in angstrom"
)
parser.add_argument("--labels", nargs="*", help="Labels for plotting")
parser.add_argument("-o")
parser.add_argument("-o", "--output", type=str, help="output file name")
return parser


def main(args):
labels = args.labels if args.labels else args.input
assert len(labels) == len(args.input)
for i, f in enumerate(args.input):
print(f)
x = np.loadtxt(f)
plt.plot(x[:, 0], x[:, 1], label=labels[i])
w = np.where(x[:, 1] < args.t)
print(w)
print(x[:, 0][w])
print(1 / x[:, 0][w])
plt.legend(loc="best")
plt.ylim((0, 1))
plt.ylabel("FSC")
plt.xlabel("frequency")
if args.o:
plt.savefig(args.o)
# Create a subplot
fig, ax = plt.subplots(figsize=(10, 5))

# Load and plot data from each file
for file in args.input:
x, y = load_data(file)
plot_data(x, y, os.path.basename(file))

ax.set_aspect(0.3) # Set the aspect ratio on the plot specifically

if args.angpix != 0:
freq = np.arange(1, 6) * 0.1
res = ["1/{:.1f}".format(val) for val in ((1 / freq) * args.angpix)]
print(res)
res_text = res
plt.xticks(np.arange(1, 6) * 0.1, res_text)
plt.xlabel("1/resolution (1/Å)")
plt.ylabel("Fourier shell correlation")
else:
plt.xlabel("Spatial Frequency")
plt.ylabel("Fourier shell correlation")

plt.ylim(0, 1.0)
plt.xlim(0, 0.5)

# Create the legend on the figure, not the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", prop={"size": 6})

plt.grid(True)

plt.tight_layout()
plt.subplots_adjust(right=0.8)

if args.output:
plt.savefig(args.output, dpi=300, bbox_inches="tight")
else:
plt.show()

Expand Down
1 change: 0 additions & 1 deletion cryodrgn/commands/abinit_het.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,6 @@ def main(args):

data = dataset.ImageDataset(
mrcfile=args.particles,
tilt_mrcfile=args.tilt,
norm=args.norm,
invert_data=args.invert_data,
ind=ind,
Expand Down
1 change: 0 additions & 1 deletion cryodrgn/commands/abinit_homo.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@ def main(args):
data = dataset.ImageDataset(
mrcfile=args.particles,
lazy=args.lazy,
tilt_mrcfile=args.tilt,
norm=args.norm,
invert_data=args.invert_data,
ind=args.ind,
Expand Down
66 changes: 41 additions & 25 deletions cryodrgn/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import seaborn as sns
import cryodrgn
from cryodrgn import analysis, utils
from cryodrgn import analysis, utils, config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,13 +161,13 @@ def plt_umap_labels_jointplot(g):
plt.savefig(f"{outdir}/z_pca.png")

# PCA -- Style 2 -- Scatter, with marginals
g = sns.jointplot(pc[:, 0], pc[:, 1], alpha=0.1, s=1, rasterized=True, height=4)
g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], alpha=0.1, s=1, rasterized=True, height=4)
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/z_pca_marginals.png")

# PCA -- Style 3 -- Hexbin
g = sns.jointplot(pc[:, 0], pc[:, 1], height=4, kind="hex")
g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], height=4, kind="hex")
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/z_pca_hexbin.png")
Expand All @@ -182,14 +182,19 @@ def plt_umap_labels_jointplot(g):

# Style 2 -- Scatter with marginal distributions
g = sns.jointplot(
umap_emb[:, 0], umap_emb[:, 1], alpha=0.1, s=1, rasterized=True, height=4
x=umap_emb[:, 0],
y=umap_emb[:, 1],
alpha=0.1,
s=1,
rasterized=True,
height=4,
)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/umap_marginals.png")

# Style 3 -- Hexbin / heatmap
g = sns.jointplot(umap_emb[:, 0], umap_emb[:, 1], kind="hex", height=4)
g = sns.jointplot(x=umap_emb[:, 0], y=umap_emb[:, 1], kind="hex", height=4)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(f"{outdir}/umap_hexbin.png")
Expand Down Expand Up @@ -316,7 +321,7 @@ def plt_umap_labels_jointplot(g):
plt_pc_labels_jointplot(g, i - 1, i)
else:
g = sns.jointplot(
pc[:, i], pc[:, i + 1], alpha=0.1, s=1, rasterized=True, height=4
x=pc[:, i], y=pc[:, i + 1], alpha=0.1, s=1, rasterized=True, height=4
)
g.ax_joint.scatter(t, np.zeros(10), c="cornflowerblue", edgecolor="white")
plt_pc_labels_jointplot(g)
Expand Down Expand Up @@ -349,7 +354,7 @@ def main(args):
workdir = args.workdir
zfile = f"{workdir}/z.{E}.pkl"
weights = f"{workdir}/weights.{E}.pkl"
config = (
cfg = (
f"{workdir}/config.yaml"
if os.path.exists(f"{workdir}/config.yaml")
else f"{workdir}/config.pkl"
Expand Down Expand Up @@ -377,7 +382,7 @@ def main(args):
invert=args.invert,
vol_start_index=args.vol_start_index,
)
vg = VolumeGenerator(weights, config, vol_args, skip_vol=args.skip_vol)
vg = VolumeGenerator(weights, cfg, vol_args, skip_vol=args.skip_vol)

if zdim == 1:
analyze_z1(z, outdir, vg)
Expand All @@ -392,24 +397,35 @@ def main(args):
)

# copy over template if file doesn't exist
out_ipynb = f"{outdir}/cryoDRGN_viz.ipynb"
if not os.path.exists(out_ipynb):
logger.info("Creating jupyter notebook...")
ipynb = f"{cryodrgn._ROOT}/templates/cryoDRGN_viz_template.ipynb"
shutil.copyfile(ipynb, out_ipynb)
else:
logger.info(f"{out_ipynb} already exists. Skipping")
logger.info(out_ipynb)

# copy over template if file doesn't exist
out_ipynb = f"{outdir}/cryoDRGN_filtering.ipynb"
if not os.path.exists(out_ipynb):
logger.info("Creating jupyter notebook...")
ipynb = f"{cryodrgn._ROOT}/templates/cryoDRGN_filtering_template.ipynb"
shutil.copyfile(ipynb, out_ipynb)
cfg = config.load(cfg)
if cfg["model_args"]["encode_mode"] == "tilt":
out_ipynb = f"{outdir}/cryoDRGN_ET_viz.ipynb"
if not os.path.exists(out_ipynb):
logger.info("Creating jupyter notebook...")
ipynb = f"{cryodrgn._ROOT}/templates/cryoDRGN_ET_viz_template.ipynb"
shutil.copyfile(ipynb, out_ipynb)
else:
logger.info(f"{out_ipynb} already exists. Skipping")
logger.info(out_ipynb)
else:
logger.info(f"{out_ipynb} already exists. Skipping")
logger.info(out_ipynb)
out_ipynb = f"{outdir}/cryoDRGN_viz.ipynb"
if not os.path.exists(out_ipynb):
logger.info("Creating jupyter notebook...")
ipynb = f"{cryodrgn._ROOT}/templates/cryoDRGN_viz_template.ipynb"
shutil.copyfile(ipynb, out_ipynb)
else:
logger.info(f"{out_ipynb} already exists. Skipping")
logger.info(out_ipynb)

# copy over template if file doesn't exist
out_ipynb = f"{outdir}/cryoDRGN_filtering.ipynb"
if not os.path.exists(out_ipynb):
logger.info("Creating jupyter notebook...")
ipynb = f"{cryodrgn._ROOT}/templates/cryoDRGN_filtering_template.ipynb"
shutil.copyfile(ipynb, out_ipynb)
else:
logger.info(f"{out_ipynb} already exists. Skipping")
logger.info(out_ipynb)

# copy over template if file doesn't exist
out_ipynb = f"{outdir}/cryoDRGN_figures.ipynb"
Expand Down

0 comments on commit d3e2da4

Please sign in to comment.