Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 62 additions & 77 deletions notebooks/02_bbmep/create_figures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,15 @@
"metadata": {},
"outputs": [],
"source": [
"cost_option = \"hour\" #'wallclock_hour', 'energy', 'hour'\n",
"# {\"Orbformer scratch, single point\": '#74baae', \"Orbformer scratch, all points\": '#4e938d',\n",
"# \"Orbformer finetune LAC (200k), all points\": '#366c6c', \"Orbformer finetune LAC (400k), all points\":'#27474a',\n",
"# \"Orbformer finetune LAC (1000k), all points\":'#1d2429', \"Psiformer scratch, single point\":'#8de971'}\n",
"qmc_colors = [\"#74baae\", \"#4e938d\", \"#27474a\", \"#8de971\"]\n",
"\n",
"# Single color for all non-QMC methods, different markers\n",
"non_qmc_color = \"#E05A33\" # red/orange\n",
"non_qmc_markers = ['s', '^', 'v', 'D', 'p', 'h', '*', 'X', 'P', '<', '>'] # variety of markers\n",
"non_qmc_marker_sizes = {'s': 180, '^': 240, 'v': 240, 'D': 180, 'p': 180, 'h': 180, '*': 300, 'X': 180, 'P': 180, '<': 180, '>': 180}\n",
"\n",
"fig, ax = plt.subplots(3, 2, figsize=(18.5, 19), sharey=True)\n",
"compute_sec = {\n",
Expand All @@ -180,8 +183,7 @@
" \"Ethane\": 296013.404 / 200000 / (4 * 16265.031 / 32000),\n",
" \"Formamide\": 182497.587 / 200000 / (10323.08825 / 32000),\n",
"}\n",
"rcParams[\"axes.spines.top\"] = True\n",
"ax_top = {}\n",
"rcParams[\"axes.spines.top\"] = False\n",
"print(MAE.keys())\n",
"for m, i in enumerate(MAE.keys()):\n",
" for n, j in enumerate(MAE[i].keys()):\n",
Expand All @@ -190,58 +192,44 @@
" x = x * psiformer_factor[i]\n",
" if \"all\" in j.split():\n",
" x = x / 20\n",
" # Convert to GPU hours (x is already in GPU-relative units)\n",
" x_gpu_hr = x * 4 * compute_sec[i] / 3600\n",
" y = 627.5 * MAE[i][j][:, 1]\n",
" h, v = np.divmod(m, 3)\n",
" ax[v][h].scatter(x, y, marker=\"o\", s=140, label=j, color=qmc_colors[n])\n",
" a, b = np.polyfit(np.log(x), np.log(y), 1)\n",
" ax[v][h].plot(x, np.exp(a * np.log(x) + b), color=qmc_colors[n], linewidth=4)\n",
" ax[v][h].set_xscale(\"log\", base=4)\n",
" ax[v][h].scatter(x_gpu_hr, y, marker=\"o\", s=140, label=j, color=qmc_colors[n], zorder=3)\n",
" a, b = np.polyfit(np.log(x_gpu_hr), np.log(y), 1)\n",
" ax[v][h].plot(x_gpu_hr, np.exp(a * np.log(x_gpu_hr) + b), color=qmc_colors[n], linewidth=4, zorder=2)\n",
" ax[v][h].set_xscale(\"log\", base=10)\n",
" ax[v][h].set_yscale(\"log\", base=4)\n",
" ticks = [10, 50, 250, 1000, 4000, 16000, 64000, 300000]\n",
" ticks = [0.01, 0.1, 1, 10, 100]\n",
" cost_factor = 300 / (280 / 64)\n",
" for k in MAE_es[i].keys():\n",
" for idx, k in enumerate(MAE_es[i].keys()):\n",
" # Convert CPU cost to equivalent GPU hours for plotting\n",
" cpu_hr = MAE_es[i][k][0]\n",
" gpu_hr_equiv = cpu_hr / cost_factor # convert CPU hr to GPU hr equivalent\n",
" ax[v][h].scatter(\n",
" MAE_es[i][k][0] / (4 * compute_sec[i] / 3600 * cost_factor),\n",
" gpu_hr_equiv,\n",
" 627.5 * MAE_es[i][k][1],\n",
" label=k,\n",
" s=180,\n",
" color=colormaps_es[k],\n",
" marker=marker_es[k],\n",
" s=non_qmc_marker_sizes[non_qmc_markers[idx % len(non_qmc_markers)]],\n",
" color=non_qmc_color,\n",
" marker=non_qmc_markers[idx % len(non_qmc_markers)],\n",
" zorder=4,\n",
" edgecolors='white',\n",
" linewidths=1.5,\n",
" )\n",
" ax[v][h].set_xticks(ticks, ticks, fontsize=16)\n",
" ax[v][h].axhspan(0.2, 5, color=\"grey\", alpha=0.2, lw=0)\n",
" ax[v][h].set_xlim([6, 580000])\n",
" ax_top[m] = ax[v][h].twiny()\n",
" ax_top[m].set_xscale(\"log\", base=4)\n",
" ax_top[m].set_xlim([6, 580000])\n",
" if cost_option == \"hour\":\n",
" ax[v][h].set_xticklabels(\n",
" np.round(np.array(ticks) * 4 * compute_sec[i] / 3600 * cost_factor, 1)\n",
" )\n",
" else:\n",
" ax[v][h].set_xticklabels(cost)\n",
" ax[v][h].axhline(1, color=\"k\", ls=\":\", linewidth=4)\n",
" ax[v][h].set_xticks(ticks)\n",
" ax[v][h].set_xticklabels(ticks, fontsize=16)\n",
" ax[v][h].axhspan(0.2, 5, color=\"grey\", alpha=0.2, lw=0, zorder=0)\n",
" ax[v][h].set_xlim([0.0025, 300] if h == 0 else [0.005, 1100])\n",
" ax[v][h].axhline(1, color=\"darkgray\", ls=\":\", linewidth=4, zorder=1)\n",
" ax[v][h].set_yticks(\n",
" [0.5, 1, 2, 5, 10, 20, 50], [0.5, 1, 2, 5, 10, 20, 50], fontsize=16\n",
" [0.5, 1, 2, 5, 10, 20, 50, 100, 200], [0.5, 1, 2, 5, 10, 20, 50, 100, 200], fontsize=16\n",
" )\n",
" ax[v][h].set_ylim([0.2, 99])\n",
" if v == 0:\n",
" if cost_option == \"hour\":\n",
" ax_top[m].set_xlabel(\"Orbformer A100 GPU hr/structure\", fontsize=22)\n",
" else:\n",
" ax_top[m].set_xlabel(\n",
" \"Scratch training or finetuning steps per structure\", fontsize=22\n",
" )\n",
" ax[v][h].minorticks_off()\n",
" ax[v][h].set_ylim([0.25, 850])\n",
" if h == 0:\n",
" ax[v][h].set_ylabel(\"MARE over entire MEP (kcal/mol)\", fontsize=22)\n",
" if cost_option == \"hour\":\n",
" ax_top[m].set_xticks(ticks)\n",
" ax_top[m].set_xticklabels(\n",
" np.round(np.array(ticks) * 4 * compute_sec[i] / 3600, 2), fontsize=16\n",
" )\n",
" else:\n",
" ax_top[m].set_xticks(ticks)\n",
" ax_top[m].set_xticklabels(ticks)\n",
" ax[v][h].set_ylabel(\"MARE over MEP (kcal/mol)\", fontsize=22)\n",
"ax[2][1].set_axis_off()\n",
"\n",
"ax[0][0].legend(\n",
Expand All @@ -251,89 +239,82 @@
" borderpad=0.16,\n",
" labelspacing=0.4,\n",
")\n",
"ax[0][0].text(0.03, 0.9, \"(g) Ethane\", transform=ax[0][0].transAxes, fontsize=24)\n",
"ax[0][0].text(0.03, 1.02, \"(g) Ethane\", transform=ax[0][0].transAxes, fontsize=24)\n",
"arrimg = mpimg.imread(f\"{data_dir}/molecule_images/bbmep/Ethane/Ethane_0.png\")\n",
"imagebox = OffsetImage(arrimg, zoom=0.12)\n",
"ab = AnnotationBbox(imagebox, (2500, 20), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.62, 0.88), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[0][0].add_artist(ab)\n",
"arrimg = mpimg.imread(f\"{data_dir}/molecule_images/bbmep/Ethane/Ethane_19.png\")\n",
"imagebox = OffsetImage(arrimg, zoom=0.2)\n",
"ab = AnnotationBbox(imagebox, (35000, 15), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.88, 0.82), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[0][0].add_artist(ab)\n",
"ax[0][0].annotate(\n",
" \"\", xytext=(5500, 20), xy=(12000, 20), arrowprops=dict(arrowstyle=\"->\")\n",
" \"\", xytext=(0.72, 0.88), xy=(0.78, 0.88), xycoords='axes fraction', arrowprops=dict(arrowstyle=\"->\")\n",
")\n",
"\n",
"ax[1][0].text(0.03, 0.9, \"(h) Formamide\", transform=ax[1][0].transAxes, fontsize=24)\n",
"arrimg = mpimg.imread(f\"{data_dir}/molecule_images/bbmep/Formamide/Formamide_0.png\")\n",
"imagebox = OffsetImage(arrimg, zoom=0.12)\n",
"ab = AnnotationBbox(imagebox, (10000, 30), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.68, 0.88), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[1][0].add_artist(ab)\n",
"arrimg = mpimg.imread(f\"{data_dir}/molecule_images/bbmep/Formamide/Formamide_19.png\")\n",
"imagebox = OffsetImage(arrimg, zoom=0.2)\n",
"ab = AnnotationBbox(imagebox, (100000, 25), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.92, 0.82), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[1][0].add_artist(ab)\n",
"ax[1][0].annotate(\n",
" \"\", xytext=(22000, 30), xy=(45000, 30), arrowprops=dict(arrowstyle=\"->\")\n",
" \"\", xytext=(0.78, 0.88), xy=(0.84, 0.88), xycoords='axes fraction', arrowprops=dict(arrowstyle=\"->\")\n",
")\n",
"\n",
"ax[2][0].text(0.03, 0.9, \"(i) 1-Propanol\", transform=ax[2][0].transAxes, fontsize=24)\n",
"arrimg = mpimg.imread(f\"{data_dir}/molecule_images/bbmep/1-Propanol/1-Propanol_0.png\")\n",
"imagebox = OffsetImage(arrimg, zoom=0.17)\n",
"ab = AnnotationBbox(imagebox, (12000, 38), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.68, 0.88), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[2][0].add_artist(ab)\n",
"arrimg = mpimg.imread(f\"{data_dir}/molecule_images/bbmep/1-Propanol/1-Propanol_19.png\")\n",
"arrimg = np.rot90(arrimg)\n",
"imagebox = OffsetImage(arrimg, zoom=0.25)\n",
"ab = AnnotationBbox(imagebox, (120000, 28), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.92, 0.82), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[2][0].add_artist(ab)\n",
"ax[2][0].annotate(\n",
" \"\", xytext=(30000, 38), xy=(55000, 38), arrowprops=dict(arrowstyle=\"->\")\n",
" \"\", xytext=(0.78, 0.88), xy=(0.84, 0.88), xycoords='axes fraction', arrowprops=dict(arrowstyle=\"->\")\n",
")\n",
"\n",
"ax[0][1].text(\n",
" 0.03, 0.9, \"(j) 2-Aminopropan-2-ol\", transform=ax[0][1].transAxes, fontsize=24\n",
" 0.03, 1.02, \"(j) 2-Aminopropan-2-ol\", transform=ax[0][1].transAxes, fontsize=24\n",
")\n",
"arrimg = mpimg.imread(\n",
" f\"{data_dir}/molecule_images/bbmep/2-Aminopropan-2-ol/2-Aminopropan-2-ol_0.png\"\n",
")\n",
"imagebox = OffsetImage(arrimg, zoom=0.18)\n",
"ab = AnnotationBbox(imagebox, (14000, 35), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.68, 0.88), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[0][1].add_artist(ab)\n",
"arrimg = mpimg.imread(\n",
" f\"{data_dir}/molecule_images/bbmep/2-Aminopropan-2-ol/2-Aminopropan-2-ol_19.png\"\n",
")\n",
"arrimg = remove_black_edge(rotate(arrimg, angle=60, reshape=True))\n",
"imagebox = OffsetImage(arrimg, zoom=0.23)\n",
"ab = AnnotationBbox(imagebox, (120000, 28), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.92, 0.82), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[0][1].add_artist(ab)\n",
"ax[0][1].annotate(\n",
" \"\", xytext=(28000, 30), xy=(60000, 30), arrowprops=dict(arrowstyle=\"->\")\n",
" \"\", xytext=(0.78, 0.88), xy=(0.84, 0.88), xycoords='axes fraction', arrowprops=dict(arrowstyle=\"->\")\n",
")\n",
"\n",
"ax[1][1].text(0.03, 0.9, \"(k) L-Alanine\", transform=ax[1][1].transAxes, fontsize=24)\n",
"arrimg = mpimg.imread(f\"{data_dir}/molecule_images/bbmep/L-Alanine/L-Alanine_0.png\")\n",
"imagebox = OffsetImage(arrimg, zoom=0.18)\n",
"ab = AnnotationBbox(imagebox, (28000, 40), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.7, 0.88), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[1][1].add_artist(ab)\n",
"arrimg = mpimg.imread(f\"{data_dir}/molecule_images/bbmep/L-Alanine/L-Alanine_19.png\")\n",
"arrimg = remove_black_edge(rotate(arrimg, angle=60))\n",
"imagebox = OffsetImage(arrimg, zoom=0.16)\n",
"ab = AnnotationBbox(imagebox, (140000, 25), frameon=False)\n",
"ab = AnnotationBbox(imagebox, (0.92, 0.82), xycoords='axes fraction', frameon=False, clip_on=False)\n",
"ax[1][1].add_artist(ab)\n",
"ax[1][1].annotate(\n",
" \"\", xytext=(44000, 30), xy=(85000, 30), arrowprops=dict(arrowstyle=\"->\")\n",
" \"\", xytext=(0.8, 0.88), xy=(0.86, 0.88), xycoords='axes fraction', arrowprops=dict(arrowstyle=\"->\")\n",
")\n",
"\n",
"if cost_option == \"wallclock_hour\":\n",
" ax[2][0].set_xlabel(\"GPU or CPU hr/structure\", fontsize=22)\n",
" ax[1][1].set_xlabel(\"GPU or CPU hr/structure\", fontsize=22)\n",
"elif cost_option == \"energy\":\n",
" ax[2][0].set_xlabel(\"Energy cost/structure (KWh)\", fontsize=22)\n",
" ax[1][1].set_xlabel(\"Energy cost/structure (KWh)\", fontsize=22)\n",
"else:\n",
" ax[2][0].set_xlabel(\"Other methods AMD EPYC 7763 CPU hr/structure\", fontsize=22)\n",
" ax[1][1].set_xlabel(\"Other methods AMD EPYC 7763 CPU hr/structure\", fontsize=22)\n",
"ax[2][0].set_xlabel(\"A100 GPU hr/structure\", fontsize=22)\n",
"ax[1][1].set_xlabel(\"A100 GPU hr/structure\", fontsize=22)\n",
"\n",
"ax[0][0].text(\n",
" 0.55,\n",
Expand All @@ -347,7 +328,7 @@
"# plt.draw()\n",
"if save_figures:\n",
" plt.savefig(f\"{data_dir}/BBMEP_result_fig_hours.pdf\", dpi=600, bbox_inches='tight')\n",
"# plt.show()"
"plt.show()\n"
]
},
{
Expand All @@ -359,7 +340,7 @@
"source": [
"## Get reaction profile figure\n",
"\n",
"fig, ax = plt.subplots(3, 2, figsize=(13, 17.5), sharex=True)\n",
"fig, ax = plt.subplots(3, 2, figsize=(13, 19), sharex=True)\n",
"\n",
"for m, j in enumerate([\"Ethane\", \"Formamide\", \"2-Aminopropan-2-ol\"]):\n",
" step = int(MAE[j][\"Orbformer fine-tune LAC (400k), all structures\"][-1, 0])\n",
Expand Down Expand Up @@ -464,8 +445,8 @@
" )\n",
" ax[m][0].set_ylabel(\"Relative E (kcal/mol)\", fontsize=22)\n",
" ax[m][1].set_ylabel(\"Relative E Error (kcal/mol)\", fontsize=22)\n",
"ax[0][0].text(0.03, 0.9, \"(a) Ethane\", transform=ax[0][0].transAxes, fontsize=24)\n",
"ax[0][1].text(0.03, 0.9, \"(d) Ethane\", transform=ax[0][1].transAxes, fontsize=24)\n",
"ax[0][0].text(0.03, 1.02, \"(a) Ethane\", transform=ax[0][0].transAxes, fontsize=24)\n",
"ax[0][1].text(0.03, 1.02, \"(d) Ethane\", transform=ax[0][1].transAxes, fontsize=24)\n",
"ax[1][0].text(0.03, 0.9, \"(b) Formamide\", transform=ax[1][0].transAxes, fontsize=24)\n",
"ax[1][1].text(0.03, 0.9, \"(e) Formamide\", transform=ax[1][1].transAxes, fontsize=24)\n",
"ax[2][0].text(\n",
Expand Down Expand Up @@ -532,14 +513,18 @@
"ax[0][1].text(\n",
" 0.18, 1.1, \"Relative Energy Error\", transform=ax[0][1].transAxes, fontsize=26\n",
")\n",
"ax[0][0].set_ylim([-100, 55])\n",
"ax[2][0].set_ylim([-65, 58])\n",
"plt.subplots_adjust(wspace=0.2, hspace=0.02)\n",
"fig.align_ylabels(ax[:, 0])\n",
"fig.align_ylabels(ax[:, 1])\n",
"plt.subplots_adjust(wspace=0.2, hspace=0.15)\n",
"ax[2][0].set_xticks(2 * np.arange(11))\n",
"ax[2][0].set_xlim([-0.5, 19.5])\n",
"ax[2][0].set_xticklabels(2 * np.arange(11))\n",
"ax[2][0].set_xlabel(\"Image ID on MEP\", fontsize=22)\n",
"ax[2][1].set_xlabel(\"Image ID on MEP\", fontsize=22)\n",
"plt.draw()\n",
"\n",
"if save_figures:\n",
" plt.savefig(f\"{data_dir}/BBMEP_energy_profile.pdf\", dpi=600, bbox_inches='tight')\n",
"plt.show()"
Expand All @@ -556,7 +541,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "oneqmc",
"language": "python",
"name": "python3"
},
Expand All @@ -570,7 +555,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.11.14"
}
},
"nbformat": 4,
Expand Down
Loading