In [None]:
import matplotlib.pyplot as plt
import numpy as np
import yaml
from cryojax.data import RelionParticleParameterDataset, RelionParticleStackDataset
from cryojax.image import irfftn, rfftn
from cryojax.image.operators import LowpassFilter
from cryojax_ensemble_refinement.internal import DatasetGeneratorConfig

# Let's load the config file

In [None]:
with open("./config_data_geneation.yaml", "r") as f:
    config_json = yaml.safe_load(f)
    config = dict(DatasetGeneratorConfig(**config_json).model_dump())

In [None]:
config

The image generation can be run from the command line as

`cryojax_er generate_data --config config_data_generation.yaml`

# Visualize the images!

In [None]:
stack_dataset = RelionParticleStackDataset(
    RelionParticleParameterDataset(
        path_to_starfile=config["path_to_starfile"],
        path_to_relion_project=config["path_to_relion_project"],
        loads_envelope=False,
    )
)

In [None]:
lowpass_filter = LowpassFilter(
    stack_dataset[0].parameters.instrument_config.frequency_grid_in_pixels,
    frequency_cutoff_fraction=0.7,
)

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10, 10))

images = stack_dataset[0:4].images
images = irfftn(lowpass_filter(rfftn(images)))

for i in range(4):
    ax.flatten()[i].imshow(images[i], cmap="gray")

## Metadata

Information about the ensemble and other parameters is saved to a metadata file

In [None]:
metadata = np.load("tutorial_data/metadata.npz")

metadata.files

In [None]:
metadata["ensemble_indices_per_image"]

In [None]:
weight_0 = np.isclose(metadata["ensemble_indices_per_image"], 0).mean()
weight_1 = np.isclose(metadata["ensemble_indices_per_image"], 1).mean()

weight_0, weight_1