In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

In [None]:
# define colors
c = ["C0", "red", "#4c123b", "#f516e6", "#EDB32B"]

In [None]:
train_pca_coeffs = np.load("../../data/GO-MACE-23_PCA/train.npy")

In [None]:
mols = ["aspirin", "ethanol", "malonaldehyde", "naphthalene", "salicylic", "toluene"]
rmd17_pca_coeffs = {}

for mol in mols:
    rmd17_pca_coeffs[mol] = np.load(f"../../data/GO-MACE-23_PCA/{mol}.npy")

In [None]:
fullerenes_pca_coeffs = np.load("../../data/GO-MACE-23_PCA/fullerenes.npy")

In [None]:
en_mols = ["ch3cho", "co2", "h2o", "ch4", "ch2o"]

encapsulated_pca_coeffs = {}
for mol in en_mols:
    encapsulated_pca_coeffs[mol] = np.load(f"../../data/GO-MACE-23_PCA/{mol}_c60.npy")

In [None]:
spice_pca_coeffs = np.load("../../data/GO-MACE-23_PCA/spice.npy")

In [None]:
isolated_pca_coeffs = np.load("../../data/GO-MACE-23_PCA/isolated.npy")

In [None]:
fig = plt.figure(figsize=(3.5, 3.5))
ax = fig.add_subplot(111)

ax.scatter(
    train_pca_coeffs[:, 0],
    train_pca_coeffs[:, 1],
    alpha=0.25,
    c=c[0],
    s=16,
    edgecolors="none",
    label="GO dataset",
    rasterized=True,
)

ax.scatter(
    spice_pca_coeffs[:, 0],
    spice_pca_coeffs[:, 1],
    alpha=0.25,
    s=16,
    label="SPICE",
    edgecolors="none",
    color=c[1],
    marker="s",
    rasterized=True,
)

for mol in mols:
    ax.scatter(
        rmd17_pca_coeffs[mol][:, 0],
        rmd17_pca_coeffs[mol][:, 1],
        s=16,
        alpha=1,
        label=mol,
        edgecolors="none",
        marker="^",
        c=c[2],
    )

ax.scatter(
    fullerenes_pca_coeffs[:, 0],
    fullerenes_pca_coeffs[:, 1],
    alpha=1,
    s=16,
    label="fullerenes",
    edgecolors="none",
    color=c[3],
)

for a in en_mols:
    ax.scatter(
        encapsulated_pca_coeffs[a][:, 0],
        encapsulated_pca_coeffs[a][:, 1],
        s=16,
        alpha=1,
        edgecolors="none",
        color=c[4],
    )

ax.scatter(
    isolated_pca_coeffs[:, 0],
    isolated_pca_coeffs[:, 1],
    s=16,
    alpha=1,
    label="isoalted",
    marker="x",
    color="k",
)


ax.set_axis_off()

labels = [
    "GO dataset",
    "SPICE",
    "rMD17",
    "fullerenes",
    "M @ C$_{60}$",
    "isolated molecules",
]
custom_lines = [
    Line2D([0], [0], color=c[0], lw=0.75, ls="None", marker="o", markersize=4),
    Line2D([0], [0], color=c[1], lw=0.75, ls="None", marker="s", markersize=4),
    Line2D([0], [0], color=c[2], lw=0.75, ls="None", marker="^", markersize=4),
    Line2D([0], [0], color=c[3], lw=0.75, ls="None", marker="o", markersize=4),
    Line2D([0], [0], color=c[4], lw=0.75, ls="None", marker="o", markersize=4),
    Line2D([0], [0], color="k", lw=0.75, ls="None", marker="x", markersize=4),
]
fig.legend(
    custom_lines,
    labels,
    fontsize=8,
    ncols=2,
    loc="upper left",
    bbox_to_anchor=(0.15, 1.05),
    columnspacing=0.5,
    frameon=False,
)

# fig.savefig("./fig1.svg", dpi=300, bbox_inches="tight")