diff --git a/cryodrgn/commands/analyze.py b/cryodrgn/commands/analyze.py index e0fb9e56..a87bac00 100644 --- a/cryodrgn/commands/analyze.py +++ b/cryodrgn/commands/analyze.py @@ -100,7 +100,7 @@ def analyze_z1(z, outdir, vg): vg.gen_volumes(outdir, ztraj) -def analyze_zN(z, outdir, vg, skip_umap=False, num_pcs=2, num_ksamples=20): +def analyze_zN(z, outdir, vg, workdir, epoch, skip_umap=False, num_pcs=2, num_ksamples=20): zdim = z.shape[1] # Principal component analysis @@ -135,6 +135,17 @@ def analyze_zN(z, outdir, vg, skip_umap=False, num_pcs=2, num_ksamples=20): # Make some plots logger.info("Generating plots...") + # Plot learning curve + loss = analysis.parse_loss(f"{workdir}/run.log") + plt.figure(figsize=(4, 4)) + plt.plot(loss) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.axvline(x=epoch, linestyle="--", color="black", label=f"Epoch {epoch}") + plt.legend() + plt.tight_layout() + plt.savefig(f"{outdir}/learning_curve_epoch{epoch}.png") + def plt_pc_labels(x=0, y=1): plt.xlabel(f"PC{x+1} ({pca.explained_variance_ratio_[x]:.2f})") plt.ylabel(f"PC{y+1} ({pca.explained_variance_ratio_[y]:.2f})") @@ -352,6 +363,7 @@ def main(args): t1 = dt.now() E = args.epoch workdir = args.workdir + epoch = args.epoch zfile = f"{workdir}/z.{E}.pkl" weights = f"{workdir}/weights.{E}.pkl" cfg = ( @@ -391,6 +403,8 @@ def main(args): z, outdir, vg, + workdir, + epoch, skip_umap=args.skip_umap, num_pcs=args.pc, num_ksamples=args.ksample, diff --git a/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb b/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb index eab3df74..eedf8a5d 100644 --- a/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb @@ -254,8 +254,10 @@ "source": [ "loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n", "plt.plot(loss)\n", - "plt.xlabel('epoch')\n", - "plt.ylabel('loss')" + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.legend()" ] }, { diff --git a/cryodrgn/templates/cryoDRGN_figures_template.ipynb b/cryodrgn/templates/cryoDRGN_figures_template.ipynb index eefb2a49..a8d35ce6 100644 --- a/cryodrgn/templates/cryoDRGN_figures_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_figures_template.ipynb @@ -57,6 +57,32 @@ "umap = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/umap.pkl')" ] }, + { + "cell_type": "markdown", + "id": "83324d44-767e-47e2-a3c7-cab9b430fab5", + "metadata": {}, + "source": [ + "# Plot learning curve" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6ba1e22-696c-4c46-a2de-48c386ef8526", + "metadata": {}, + "outputs": [], + "source": [ + "loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n", + "plt.figure(figsize=(4, 4))\n", + "plt.plot(loss)\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.legend()\n", + "plt.tight_layout()\n", + "#plt.savefig(f\"{WORKDIR}/analyze.{EPOCH}/learning_curve_epoch{EPOCH}.png\")" + ] + }, { "cell_type": "markdown", "id": "9cce7848", diff --git a/cryodrgn/templates/cryoDRGN_filtering_template.ipynb b/cryodrgn/templates/cryoDRGN_filtering_template.ipynb index 19ecdb82..cd350896 100644 --- a/cryodrgn/templates/cryoDRGN_filtering_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_filtering_template.ipynb @@ -318,8 +318,10 @@ "source": [ "loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n", "plt.plot(loss)\n", - "plt.xlabel('epoch')\n", - "plt.ylabel('loss')" + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.legend()" ] }, { diff --git a/cryodrgn/templates/cryoDRGN_viz_template.ipynb b/cryodrgn/templates/cryoDRGN_viz_template.ipynb index d0a7b814..b3aa861a 100644 --- a/cryodrgn/templates/cryoDRGN_viz_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_viz_template.ipynb @@ -236,8 +236,10 @@ "source": [ "loss = analysis.parse_loss(f'{WORKDIR}/run.log')\n", "plt.plot(loss)\n", - "plt.xlabel('epoch')\n", - "plt.ylabel('loss')" + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.axvline(x=EPOCH, linestyle=\"--\", color=\"black\", label=f\"Epoch {EPOCH}\")\n", + "plt.legend()" ] }, {