From 29f8e07b6164dc392b38a26af1c465b4ee04d640 Mon Sep 17 00:00:00 2001 From: Guillaume Gaullier Date: Tue, 3 Oct 2023 08:40:31 +0200 Subject: [PATCH 1/3] Make cryodrgn analyze produce a plot of the learning curve, fix #304 --- cryodrgn/commands/analyze.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/cryodrgn/commands/analyze.py b/cryodrgn/commands/analyze.py index e0fb9e56..a87bac00 100644 --- a/cryodrgn/commands/analyze.py +++ b/cryodrgn/commands/analyze.py @@ -100,7 +100,7 @@ def analyze_z1(z, outdir, vg): vg.gen_volumes(outdir, ztraj) -def analyze_zN(z, outdir, vg, skip_umap=False, num_pcs=2, num_ksamples=20): +def analyze_zN(z, outdir, vg, workdir, epoch, skip_umap=False, num_pcs=2, num_ksamples=20): zdim = z.shape[1] # Principal component analysis @@ -135,6 +135,17 @@ def analyze_zN(z, outdir, vg, skip_umap=False, num_pcs=2, num_ksamples=20): # Make some plots logger.info("Generating plots...") + # Plot learning curve + loss = analysis.parse_loss(f"{workdir}/run.log") + plt.figure(figsize=(4, 4)) + plt.plot(loss) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.axvline(x=epoch, linestyle="--", color="black", label=f"Epoch {epoch}") + plt.legend() + plt.tight_layout() + plt.savefig(f"{outdir}/learning_curve_epoch{epoch}.png") + def plt_pc_labels(x=0, y=1): plt.xlabel(f"PC{x+1} ({pca.explained_variance_ratio_[x]:.2f})") plt.ylabel(f"PC{y+1} ({pca.explained_variance_ratio_[y]:.2f})") @@ -352,6 +363,7 @@ def main(args): t1 = dt.now() E = args.epoch workdir = args.workdir + epoch = args.epoch zfile = f"{workdir}/z.{E}.pkl" weights = f"{workdir}/weights.{E}.pkl" cfg = ( @@ -391,6 +403,8 @@ def main(args): z, outdir, vg, + workdir, + epoch, skip_umap=args.skip_umap, num_pcs=args.pc, num_ksamples=args.ksample, From ff503278850a88ec7a8c1da0cb16eab220f7548d Mon Sep 17 00:00:00 2001 From: Guillaume Gaullier Date: Wed, 11 Oct 2023 09:17:09 +0200 Subject: [PATCH 2/3] Update learning curve plot in Jupyter notebook templates --- .../templates/cryoDRGN_ET_viz_template.ipynb | 6 +++-- .../templates/cryoDRGN_figures_template.ipynb | 26 +++++++++++++++++++ .../cryoDRGN_filtering_template.ipynb | 6 +++-- .../templates/cryoDRGN_viz_template.ipynb | 6 +++-- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb b/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb index eab3df74..471ab338 100644 --- a/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb @@ -254,8 +254,10 @@ "source": [ "loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n", "plt.plot(loss)\n", - "plt.xlabel('epoch')\n", - "plt.ylabel('loss')" + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.axvline(x=epoch, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.legend()" ] }, { diff --git a/cryodrgn/templates/cryoDRGN_figures_template.ipynb b/cryodrgn/templates/cryoDRGN_figures_template.ipynb index eefb2a49..cc16dc09 100644 --- a/cryodrgn/templates/cryoDRGN_figures_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_figures_template.ipynb @@ -57,6 +57,32 @@ "umap = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/umap.pkl')" ] }, + { + "cell_type": "markdown", + "id": "83324d44-767e-47e2-a3c7-cab9b430fab5", + "metadata": {}, + "source": [ + "# Plot learning curve" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6ba1e22-696c-4c46-a2de-48c386ef8526", + "metadata": {}, + "outputs": [], + "source": [ + "loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n", + "plt.figure(figsize=(4, 4))\n", + "plt.plot(loss)\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.axvline(x=epoch, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.legend()\n", + "plt.tight_layout()\n", + "#plt.savefig(f\"{WORKDIR}/analyze.{EPOCH}/learning_curve_epoch{EPOCH}.png\")" + ] + }, { "cell_type": "markdown", "id": "9cce7848", diff --git a/cryodrgn/templates/cryoDRGN_filtering_template.ipynb b/cryodrgn/templates/cryoDRGN_filtering_template.ipynb index 775355cc..0465dfc5 100644 --- a/cryodrgn/templates/cryoDRGN_filtering_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_filtering_template.ipynb @@ -317,8 +317,10 @@ "source": [ "loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n", "plt.plot(loss)\n", - "plt.xlabel('epoch')\n", - "plt.ylabel('loss')" + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.axvline(x=epoch, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.legend()" ] }, { diff --git a/cryodrgn/templates/cryoDRGN_viz_template.ipynb b/cryodrgn/templates/cryoDRGN_viz_template.ipynb index d0a7b814..80d0e277 100644 --- a/cryodrgn/templates/cryoDRGN_viz_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_viz_template.ipynb @@ -236,8 +236,10 @@ "source": [ "loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n", "plt.plot(loss)\n", - "plt.xlabel('epoch')\n", - "plt.ylabel('loss')" + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.axvline(x=epoch, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.legend()" ] }, { From 9d2e65c1c353cd006b0612ae342a4da4949083dd Mon Sep 17 00:00:00 2001 From: Guillaume Gaullier Date: Wed, 18 Oct 2023 09:12:12 +0200 Subject: [PATCH 3/3] Fix a variable name in learning curve plots in notebooks --- cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb | 2 +- cryodrgn/templates/cryoDRGN_figures_template.ipynb | 2 +- cryodrgn/templates/cryoDRGN_filtering_template.ipynb | 2 +- cryodrgn/templates/cryoDRGN_viz_template.ipynb | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb b/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb index 471ab338..eedf8a5d 100644 --- a/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb @@ -256,7 +256,7 @@ "plt.plot(loss)\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", - "plt.axvline(x=epoch, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", "plt.legend()" ] }, diff --git a/cryodrgn/templates/cryoDRGN_figures_template.ipynb b/cryodrgn/templates/cryoDRGN_figures_template.ipynb index cc16dc09..a8d35ce6 100644 --- a/cryodrgn/templates/cryoDRGN_figures_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_figures_template.ipynb @@ -77,7 +77,7 @@ "plt.plot(loss)\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", - "plt.axvline(x=epoch, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", "plt.legend()\n", "plt.tight_layout()\n", "#plt.savefig(f\"{WORKDIR}/analyze.{EPOCH}/learning_curve_epoch{EPOCH}.png\")" diff --git a/cryodrgn/templates/cryoDRGN_filtering_template.ipynb b/cryodrgn/templates/cryoDRGN_filtering_template.ipynb index 0465dfc5..d67f0a1f 100644 --- a/cryodrgn/templates/cryoDRGN_filtering_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_filtering_template.ipynb @@ -319,7 +319,7 @@ "plt.plot(loss)\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", - "plt.axvline(x=epoch, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", "plt.legend()" ] }, diff --git a/cryodrgn/templates/cryoDRGN_viz_template.ipynb b/cryodrgn/templates/cryoDRGN_viz_template.ipynb index 80d0e277..b3aa861a 100644 --- a/cryodrgn/templates/cryoDRGN_viz_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_viz_template.ipynb @@ -238,7 +238,7 @@ "plt.plot(loss)\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", - "plt.axvline(x=epoch, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", "plt.legend()" ] },