Skip to content

Commit

Permalink
Merge pull request #309 from Guillawme/main
Browse files Browse the repository at this point in the history
Make cryodrgn analyze produce a plot of the learning curve, fix #304
  • Loading branch information
michal-g authored Nov 13, 2023
2 parents 9e77638 + 9d2e65c commit b5d9e4b
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 7 deletions.
16 changes: 15 additions & 1 deletion cryodrgn/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def analyze_z1(z, outdir, vg):
vg.gen_volumes(outdir, ztraj)


def analyze_zN(z, outdir, vg, skip_umap=False, num_pcs=2, num_ksamples=20):
def analyze_zN(z, outdir, vg, workdir, epoch, skip_umap=False, num_pcs=2, num_ksamples=20):
zdim = z.shape[1]

# Principal component analysis
Expand Down Expand Up @@ -135,6 +135,17 @@ def analyze_zN(z, outdir, vg, skip_umap=False, num_pcs=2, num_ksamples=20):
# Make some plots
logger.info("Generating plots...")

# Plot learning curve
loss = analysis.parse_loss(f"{workdir}/run.log")
plt.figure(figsize=(4, 4))
plt.plot(loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.axvline(x=epoch, linestyle="--", color="black", label=f"Epoch {epoch}")
plt.legend()
plt.tight_layout()
plt.savefig(f"{outdir}/learning_curve_epoch{epoch}.png")

def plt_pc_labels(x=0, y=1):
plt.xlabel(f"PC{x+1} ({pca.explained_variance_ratio_[x]:.2f})")
plt.ylabel(f"PC{y+1} ({pca.explained_variance_ratio_[y]:.2f})")
Expand Down Expand Up @@ -352,6 +363,7 @@ def main(args):
t1 = dt.now()
E = args.epoch
workdir = args.workdir
epoch = args.epoch
zfile = f"{workdir}/z.{E}.pkl"
weights = f"{workdir}/weights.{E}.pkl"
cfg = (
Expand Down Expand Up @@ -391,6 +403,8 @@ def main(args):
z,
outdir,
vg,
workdir,
epoch,
skip_umap=args.skip_umap,
num_pcs=args.pc,
num_ksamples=args.ksample,
Expand Down
6 changes: 4 additions & 2 deletions cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,10 @@
"source": [
"loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n",
"plt.plot(loss)\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('loss')"
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n",
"plt.legend()"
]
},
{
Expand Down
26 changes: 26 additions & 0 deletions cryodrgn/templates/cryoDRGN_figures_template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,32 @@
"umap = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/umap.pkl')"
]
},
{
"cell_type": "markdown",
"id": "83324d44-767e-47e2-a3c7-cab9b430fab5",
"metadata": {},
"source": [
"# Plot learning curve"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6ba1e22-696c-4c46-a2de-48c386ef8526",
"metadata": {},
"outputs": [],
"source": [
"loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n",
"plt.figure(figsize=(4, 4))\n",
"plt.plot(loss)\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n",
"plt.legend()\n",
"plt.tight_layout()\n",
"#plt.savefig(f\"{WORKDIR}/analyze.{EPOCH}/learning_curve_epoch{EPOCH}.png\")"
]
},
{
"cell_type": "markdown",
"id": "9cce7848",
Expand Down
6 changes: 4 additions & 2 deletions cryodrgn/templates/cryoDRGN_filtering_template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,10 @@
"source": [
"loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n",
"plt.plot(loss)\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('loss')"
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n",
"plt.legend()"
]
},
{
Expand Down
6 changes: 4 additions & 2 deletions cryodrgn/templates/cryoDRGN_viz_template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,10 @@
"source": [
"loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n",
"plt.plot(loss)\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('loss')"
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n",
"plt.legend()"
]
},
{
Expand Down

0 comments on commit b5d9e4b

Please sign in to comment.