In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import torch
import torch_geometric
import re
import matplotlib

# This is the seed used for splits throughout training and hyperparameter search
# If this is changed, it will result in data leakage
seed = 0

plt.style.use("publication.mplstyle")


def pretty_formula(formula):
    # makes the formula pretty for plots
    nums = re.findall(r"\d+", formula)
    for num in set(nums):
        formula = re.sub(f"{num}", f"$_{num}$", formula)
    return formula


def train_val_test_formula(data_df, graphs):
    # perform the train_val_test split
    # THIS USES THE SEED DEFINED ABOVE
    unique_formulas = np.unique(data_df.formula.values)
    np.random.seed(seed)
    np.random.shuffle(unique_formulas)
    length = len(unique_formulas)
    train_form = unique_formulas[: int(0.8 * length)]
    val_form = unique_formulas[int(0.8 * length) : int(0.9 * length)]
    test_form = unique_formulas[int(0.9 * length) :]
    len(train_form) + len(val_form) + len(test_form)
    train_list = []
    val_list = []
    test_list = []
    for graph in graphs:
        graph_id = graph.mat_id
        graph_form = data_df[data_df["mat_id"] == graph_id].formula.values[0]
        if graph_form in train_form:
            train_list.append(graph)
        elif graph_form in val_form:
            val_list.append(graph)
        elif graph_form in test_form:
            test_list.append(graph)
        else:
            print("Something went wrong in building the train/val/test sets")
    TrainLoader = torch_geometric.loader.DataLoader(
        train_list, batch_size=512, shuffle=True
    )
    ValLoader = torch_geometric.loader.DataLoader(val_list, batch_size=10, shuffle=True)
    TestLoader = torch_geometric.loader.DataLoader(
        test_list, batch_size=10, shuffle=True
    )
    return TrainLoader, ValLoader, TestLoader, train_list, val_list, test_list


def model_eval(model, test_set, data_df):
    # evaluate a model on a given test set
    mse_names = []
    mse_vals = []
    mae_vals = []
    mape_vals = []
    simils = []
    formulas = []
    model.eval()
    for graph in test_set:
        out = model(graph).flatten()
        curr_mse = torch.nn.functional.mse_loss(out, graph.y)
        curr_mse = curr_mse.cpu().detach().numpy()
        curr_mae = torch.nn.functional.l1_loss(out, graph.y)
        curr_mae = curr_mae.cpu().detach().numpy()
        curr_mape = torch.mean(
            torch.abs((graph.y - model(graph).flatten()) / (graph.y + 1e-16))
        )
        curr_mape = curr_mape.cpu().detach().numpy()
        mse_vals.append(curr_mse)
        mae_vals.append(curr_mae)
        mape_vals.append(curr_mape)
        mse_names.append(graph.mat_id)
        sc = np.trapz(
            np.abs(out.cpu().detach().numpy() - graph.y.cpu().detach().numpy())
        ) / np.trapz(abs(graph.y.cpu().detach().numpy()))
        simils.append(sc)
        structure = data_df.loc[data_df["mat_id"] == graph.mat_id].structure.values[0]
        formulas.append(
            data_df.loc[data_df["mat_id"] == graph.mat_id].formula.values[0]
        )
        graph.cuda()
    mse_df = pd.DataFrame(
        list(zip(mse_names, mse_vals, mae_vals, mape_vals, simils, formulas)),
        index=range(len(mse_names)),
        columns=["name", "mse", "mae", "mape", "sc", "formulas"],
    )
    mse_df["mse"] = mse_df["mse"].astype(np.float32)
    mse_df["mae"] = mse_df["mae"].astype(np.float32)
    mse_df["mape"] = mse_df["mape"].astype(np.float32)
    mse_df["sc"] = mse_df["sc"].astype(np.float32)
    return mse_df

In [None]:
# definitions of the models, do not change anything here, otherwise the trained state_dicts cannot load

class GATNN_attpool_100(torch.nn.Module):
    def __init__(self, dropout_frac=0.0):
        super().__init__()
        self.mlp_init_node = torch_geometric.nn.MLP(
            [13, 48, 48, 48], dropout=dropout_frac, act="relu"
        )
        self.gat1 = torch_geometric.nn.GATv2Conv(
            48,
            48,
            heads=8,
            edge_dim=51,
            concat=True,
            dropout=dropout_frac,
            add_self_loops=False,
        )
        self.gat2 = torch_geometric.nn.GATv2Conv(
            384,
            96,
            heads=4,
            edge_dim=51,
            concat=True,
            dropout=dropout_frac,
            add_self_loops=False,
        )

        self.mlp_att = torch_geometric.nn.MLP([384, 384], act="relu")
        self.mlp1 = torch_geometric.nn.MLP([384, 2048, 4096, 2001], act="relu")

    def forward(self, graph):
        x, edge_index, edge_attr = graph.x, graph.edge_index, graph.edge_attr
        x = self.mlp_init_node(x)
        x = self.gat1(x, edge_index, edge_attr)
        x = self.gat2(x, edge_index, edge_attr)
        if graph.batch == None:
            batch = torch.tensor(np.zeros(len(graph.x)), dtype=torch.int64).cuda()
        else:
            batch = graph.batch
        att = self.mlp_att(x)
        att = torch_geometric.utils.softmax(att, index=batch)
        x = torch_geometric.nn.pool.global_add_pool(x * att, graph.batch)
        x = self.mlp1(x)
        x = torch.nn.functional.leaky_relu(x)
        return x


class GATNN_attpool_300(torch.nn.Module):
    def __init__(self, dropout_frac=0.0):
        super().__init__()
        self.mlp_init_node = torch_geometric.nn.MLP(
            [13, 48, 48], dropout=dropout_frac, act="relu"
        )
        self.gat1 = torch_geometric.nn.GATv2Conv(
            48,
            48,
            heads=4,
            edge_dim=51,
            concat=True,
            dropout=dropout_frac,
            add_self_loops=False,
        )
        self.gat2 = torch_geometric.nn.GATv2Conv(
            192,
            96,
            heads=4,
            edge_dim=51,
            concat=True,
            dropout=dropout_frac,
            add_self_loops=False,
        )

        self.mlp_att = torch_geometric.nn.MLP([384, 384], act="relu")
        self.mlp1 = torch_geometric.nn.MLP([384, 1024, 1024, 1024, 2001], act="relu")

    def forward(self, graph):
        x, edge_index, edge_attr = graph.x, graph.edge_index, graph.edge_attr
        x = self.mlp_init_node(x)
        x = self.gat1(x, edge_index, edge_attr)
        x = self.gat2(x, edge_index, edge_attr)
        if graph.batch == None:
            batch = torch.tensor(np.zeros(len(graph.x)), dtype=torch.int64).cuda()
        else:
            batch = graph.batch
        att = self.mlp_att(x)
        att = torch_geometric.utils.softmax(att, index=batch)
        x = torch_geometric.nn.pool.global_add_pool(x * att, graph.batch)
        x = self.mlp1(x)
        x = torch.nn.functional.leaky_relu(x)
        return x


class GATNN_attpool_100_n(torch.nn.Module):
    def __init__(self, dropout_frac=0.0):
        super().__init__()
        self.mlp_init_node = torch_geometric.nn.MLP(
            [13, 96, 192], dropout=dropout_frac, act="relu"
        )
        self.gat1 = torch_geometric.nn.GATv2Conv(
            192,
            48,
            heads=8,
            edge_dim=51,
            concat=True,
            dropout=dropout_frac,
            add_self_loops=False,
        )

        self.mlp_att = torch_geometric.nn.MLP([384, 384], act="relu")
        self.mlp1 = torch_geometric.nn.MLP([384, 2048, 2048, 2001], act="relu")

    def forward(self, graph):
        x, edge_index, edge_attr = graph.x, graph.edge_index, graph.edge_attr
        x = self.mlp_init_node(x)
        x = self.gat1(x, edge_index, edge_attr)
        if graph.batch == None:
            batch = torch.tensor(np.zeros(len(graph.x)), dtype=torch.int64).cuda()
        else:
            batch = graph.batch
        att = self.mlp_att(x)
        att = torch_geometric.utils.softmax(att, index=batch)
        x = torch_geometric.nn.pool.global_add_pool(x * att, graph.batch)
        x = self.mlp1(x)
        x = torch.nn.functional.leaky_relu(x)
        return x


class GATNN_attpool_300_n(torch.nn.Module):
    def __init__(self, dropout_frac=0.0):
        super().__init__()
        self.mlp_init_node = torch_geometric.nn.MLP(
            [13, 96, 96], dropout=dropout_frac, act="relu"
        )
        self.gat1 = torch_geometric.nn.GATv2Conv(
            96,
            96,
            heads=12,
            edge_dim=51,
            concat=True,
            dropout=dropout_frac,
            add_self_loops=False,
        )

        self.mlp_att = torch_geometric.nn.MLP([1152, 1152], act="relu")
        self.mlp1 = torch_geometric.nn.MLP([1152, 512, 512, 2001], act="relu")

    def forward(self, graph):
        x, edge_index, edge_attr = graph.x, graph.edge_index, graph.edge_attr
        x = self.mlp_init_node(x)
        x = self.gat1(x, edge_index, edge_attr)
        if graph.batch == None:
            batch = torch.tensor(np.zeros(len(graph.x)), dtype=torch.int64).cuda()
        else:
            batch = graph.batch
        att = self.mlp_att(x)
        att = torch_geometric.utils.softmax(att, index=batch)
        x = torch_geometric.nn.pool.global_add_pool(x * att, graph.batch)
        x = self.mlp1(x)
        x = torch.nn.functional.leaky_relu(x)
        return x

In [None]:
# load in the data and graphs
data_df_300 = pd.read_pickle("data/data_300.pckl")
data_df_100 = pd.read_pickle("data/data_100.pckl")
with open("graphs/graphs_300_eps.pckl", "rb") as graph_file:
    graphs_300 = pickle.load(graph_file)
for graph in graphs_300:
    graph.to("cuda")
with open("graphs/graphs_100_eps.pckl", "rb") as graph_file:
    graphs_100 = pickle.load(graph_file)
for graph in graphs_100:
    graph.to("cuda")
with open("graphs/graphs_300_n.pckl", "rb") as graph_file:
    graphs_300_n = pickle.load(graph_file)
for graph in graphs_300_n:
    graph.to("cuda")
with open("graphs/graphs_100_n.pckl", "rb") as graph_file:
    graphs_100_n = pickle.load(graph_file)
for graph in graphs_100_n:
    graph.to("cuda")

In [None]:
#load the models
model_300 = GATNN_attpool_300().to("cuda")
model_300.load_state_dict(torch.load("models/eps_300.pt"))
model_300.eval()

model_100 = GATNN_attpool_100().to("cuda")
model_100.load_state_dict(torch.load("models/eps_100.pt"))
model_100.eval()

model_300_n = GATNN_attpool_300_n().to("cuda")
model_300_n.load_state_dict(torch.load("models/n_300.pt"))
model_300_n.eval()

model_100_n = GATNN_attpool_100_n().to("cuda")
model_100_n.load_state_dict(torch.load("models/n_100.pt"))
model_100_n.eval()

In [None]:
# do the train_val_test splits (alternatively, use the lists supplied in the github)
# other seeds (specified above) will result in data leakage
test_300 = train_val_test_formula(data_df_300, graphs_300)[5]
test_100 = train_val_test_formula(data_df_100, graphs_100)[5]
test_300_n = train_val_test_formula(data_df_300, graphs_300_n)[5]
test_100_n = train_val_test_formula(data_df_100, graphs_100_n)[5]

train_300 = train_val_test_formula(data_df_300, graphs_300)[3]
train_100 = train_val_test_formula(data_df_100, graphs_100)[3]
train_300_n = train_val_test_formula(data_df_300, graphs_300_n)[3]
train_100_n = train_val_test_formula(data_df_100, graphs_100_n)[3]

In [None]:
# evaluate the models on the test set
mse_df_300 = model_eval(model_300, test_300, data_df_300)
mse_df_100 = model_eval(model_100, test_100, data_df_100)
mse_df_300_n = model_eval(model_300_n, test_300_n, data_df_300)
mse_df_100_n = model_eval(model_100_n, test_100_n, data_df_100)

In [None]:
# Figure 1
c_dft = "tab:blue"
c_ml = "tab:orange"
lw = 1.5

fig, axes = plt.subplots(4, 4, figsize=[(6 + 6 / 8), 0.85*5], squeeze=True)

for idx, ax in enumerate(axes[:, 0]):
    mse_df = mse_df_100
    model = model_100
    test_graphs = test_100

    quantile = mse_df.quantile(
        0.2 * (idx + 1), numeric_only=True, interpolation="nearest"
    ).sc
    name = mse_df[mse_df["sc"] == quantile]["name"].values[0]
    formula = mse_df[mse_df["sc"] == quantile]["formulas"].values[0]
    for graph in test_graphs:
        if graph.mat_id == name:
            ax.plot(graph.y.cpu().detach().numpy().flatten(), color=c_dft, linewidth=lw)
            ax.plot(
                model(graph).cpu().detach().numpy().flatten(), color=c_ml, linewidth=lw
            )

    if idx != 3:
        plt.setp(ax.get_xticklabels(), visible=False)
    else:
        ax.set_xticks([0, 500, 1000, 1500, 2000])
        ax.set_xticklabels([0, 5, 10, 15, 20])
    ax.set_xlim([0, 2000])
    ax.set_ylim(bottom=0)

    ax.legend(frameon=False, title=pretty_formula(formula))

for idx, ax in enumerate(axes[:, 1]):
    mse_df = mse_df_300
    model = model_300
    test_graphs = test_300

    quantile = mse_df.quantile(
        0.2 * (idx + 1), numeric_only=True, interpolation="nearest"
    ).sc
    name = mse_df[mse_df["sc"] == quantile]["name"].values[0]
    formula = mse_df[mse_df["sc"] == quantile]["formulas"].values[0]
    for graph in test_graphs:
        if graph.mat_id == name:
            ax.plot(graph.y.cpu().detach().numpy().flatten(), color=c_dft, linewidth=lw)
            ax.plot(
                model(graph).cpu().detach().numpy().flatten(), color=c_ml, linewidth=lw
            )

    if idx != 3:
        plt.setp(ax.get_xticklabels(), visible=False)
    else:
        ax.set_xticks([0, 500, 1000, 1500, 2000])
        ax.set_xticklabels([0, 5, 10, 15, 20])

    ax.set_xlim([0, 2000])

    ax.legend(frameon=False, title=pretty_formula(formula))

for idx, ax in enumerate(axes[:, 2]):
    mse_df = mse_df_100_n
    model = model_100_n
    test_graphs = test_100_n

    quantile = mse_df.quantile(
        0.2 * (idx + 1), numeric_only=True, interpolation="nearest"
    ).sc
    name = mse_df[mse_df["sc"] == quantile]["name"].values[0]
    formula = mse_df[mse_df["sc"] == quantile]["formulas"].values[0]
    for graph in test_graphs:
        if graph.mat_id == name:
            ax.plot(graph.y.cpu().detach().numpy().flatten(), color=c_dft, linewidth=lw)
            ax.plot(
                model(graph).cpu().detach().numpy().flatten(), color=c_ml, linewidth=lw
            )

    if idx != 3:
        plt.setp(ax.get_xticklabels(), visible=False)
    else:
        ax.set_xticks([0, 500, 1000, 1500, 2000])
        ax.set_xticklabels([0, 5, 10, 15, 20])
    ax.set_xlim([0, 2000])

    ax.legend(frameon=False, title=pretty_formula(formula))

for idx, ax in enumerate(axes[:, 3]):
    mse_df = mse_df_300_n
    model = model_300_n
    test_graphs = test_300_n

    quantile = mse_df.quantile(
        0.2 * (idx + 1), numeric_only=True, interpolation="nearest"
    ).sc
    name = mse_df[mse_df["sc"] == quantile]["name"].values[0]
    formula = mse_df[mse_df["sc"] == quantile]["formulas"].values[0]
    for graph in test_graphs:
        if graph.mat_id == name:
            ax.plot(graph.y.cpu().detach().numpy().flatten(), color=c_dft, linewidth=lw)
            ax.plot(
                model(graph).cpu().detach().numpy().flatten(), color=c_ml, linewidth=lw
            )

    if idx != 3:
        plt.setp(ax.get_xticklabels(), visible=False)
    else:
        ax.set_xticks([0, 500, 1000, 1500, 2000])
        ax.set_xticklabels([0, 5, 10, 15, 20])
    ax.set_xlim([0, 2000])

    ax.legend(frameon=False, title=pretty_formula(formula))




axes[3, 0].set_xlabel(r"Energy (eV)")
axes[3, 1].set_xlabel(r"Energy (eV)")
axes[3, 2].set_xlabel(r"Energy (eV)")
axes[3, 3].set_xlabel(r"Energy (eV)")

axes[0, 3].text(1.1, 0.5, r"$Q_{20\%}$", transform=axes[0, 3].transAxes, size=12)
axes[1, 3].text(1.1, 0.5, r"$Q_{40\%}$", transform=axes[1, 3].transAxes, size=12)
axes[2, 3].text(1.1, 0.5, r"$Q_{60\%}$", transform=axes[2, 3].transAxes, size=12)
axes[3, 3].text(1.1, 0.5, r"$Q_{80\%}$", transform=axes[3, 3].transAxes, size=12)

axes[0, 0].text(
    0.3,
    1.1,
    r"$\mathrm{Im}(\overline{\varepsilon}_{100})$",
    transform=axes[0, 0].transAxes,
    size=12,
)
axes[0, 1].text(
    0.3,
    1.1,
    r"$\mathrm{Im}(\overline{\varepsilon}_{300})$",
    transform=axes[0, 1].transAxes,
    size=12,
)
axes[0, 2].text(
    0.3,
    1.1,
    r"$\mathrm{Re}(\overline{n}_{100})$",
    transform=axes[0, 2].transAxes,
    size=12,
)
axes[0, 3].text(
    0.3,
    1.1,
    r"$\mathrm{Re}(\overline{n}_{300})$",
    transform=axes[0, 3].transAxes,
    size=12,
)

# set y-ticks manually
axes[0,0].set_yticks([2,4,6,8])
axes[1,0].set_yticks([2,4,6,8])
axes[1,0].set_ylim(top=8.5)
axes[2,0].set_yticks([3,6,9])
axes[3,0].set_yticks([3,6,9])
axes[3,0].set_ylim(top=10)

axes[0,1].set_yticks([2,4,6,8])
axes[1,1].set_yticks([1,3,5])
axes[2,1].set_yticks([3,6,9])
axes[3,1].set_yticks([1,2,3,4])

axes[0,2].set_yticks([1,2,3])
axes[1,2].set_yticks([1,1.5,2])
axes[2,2].set_yticks([1,2,3])
axes[3,2].set_yticks([1,2,3])

axes[0,3].set_yticks([1,1.5,2])
axes[1,3].set_yticks([1,2,3])
axes[2,3].set_yticks([1,2,3])
axes[3,3].set_yticks([1,1.5,2,2.5])
axes[3,3].set_ylim(top=2.8)

fig.tight_layout()
plt.subplots_adjust(hspace=0,wspace=0.22)
fig.savefig("plots/Fig1.pdf", dpi=600)

In [None]:
# Figure 2
train_vals = []
for graph in train_100:
    train_vals.append(graph.y.cpu().detach().numpy())
train_vals = np.array(train_vals)
train_mean = train_vals.mean(axis=0)

test_vs_mean_100 = []
for graph in test_100:
    test_vs_mean_100.append(
        1
        - np.trapz(np.abs(graph.y.cpu().detach().numpy() - train_mean))
        / np.trapz(graph.y.cpu().detach().numpy())
    )

train_vals = []
for graph in train_100_n:
    train_vals.append(graph.y.cpu().detach().numpy())
train_vals = np.array(train_vals)
train_mean_n = train_vals.mean(axis=0)

test_vs_mean_100_n = []
for graph in test_100_n:
    test_vs_mean_100_n.append(
        1
        - np.trapz(np.abs(graph.y.cpu().detach().numpy() - train_mean_n))
        / np.trapz(graph.y.cpu().detach().numpy())
    )
alpha = 0.8

fig, axes = plt.subplots(2, 1, figsize=[3 + 3 / 8, 0.85*2.6], sharex=True)
axes = axes.ravel()

axes[0].hist(
    test_vs_mean_100,
    bins=np.arange(0, 1.01, 0.01) - 0.005,
    alpha=alpha,
    label="Mean of training set",
    color="tab:blue",
)
axes[0].hist(
    1 - mse_df_100["sc"].values,
    bins=np.arange(0, 1.01, 0.01) - 0.005,
    alpha=alpha,
    label="ML model",
    color="tab:orange",
)
axes[1].hist(
    test_vs_mean_100_n,
    bins=np.arange(0, 1.01, 0.01) - 0.005,
    alpha=alpha,
    label="Mean of training set",
    color="tab:blue",
)
axes[1].hist(
    1 - mse_df_100_n["sc"].values,
    bins=np.arange(0, 1.01, 0.01) - 0.005,
    alpha=alpha,
    label="ML model",
    color="tab:orange",
)
plt.setp(axes[0].get_xticklabels(), visible=False)
axes[0].set_xlim([0, 1])


plt.subplots_adjust(hspace=0)
fig.supylabel("Counts", x=0.02)
fig.supxlabel("$\mathrm{SC}$", y=0.02, x=0.02 + 0.5)


# set y-ticks manually
axes[0].set_yticks([20,40,60,80])
axes[1].set_yticks([50,100,150])

axes[0].text(
    0.87,
    0.875,
    "Mean",
    transform=axes[0].transAxes,
    size=9,
    color="tab:Blue"
)
axes[0].text(
    0.92,
    0.74,
    "DL",
    transform=axes[0].transAxes,
    size=9,
    color="tab:orange"
)


### Create insets

bounds = [0.03,0.45,0.28,0.49]
axins1 = axes[0].inset_axes(bounds)
axins2 = axes[1].inset_axes(bounds)
axins1.tick_params(bottom=True, top=False, left=False, right=True, labelbottom=True, labeltop=False, labelleft=False, labelright=True,labelsize=6,pad=1)
axins2.tick_params(bottom=True, top=False, left=False, right=True, labelbottom=True, labeltop=False, labelleft=False, labelright=True,labelsize=6,pad=1)
axins1.plot(train_mean,color="tab:blue")
axins2.plot(train_mean_n,color="tab:blue")
axins1.set_xlim([0,2001])
axins1.set_xticks([501,1501])
axins1.set_xticklabels([5,15])
axins2.set_xlim([0,2001])
axins2.set_xticks([501,1501])
axins2.set_xticklabels([5,15])
axins1.set_ylim([0,4.2])
axins1.set_yticks([1,3])
axins2.set_ylim([0.7,2.5])
axins2.set_yticks([1,2])
axins1.set_xlabel(r"Energy (eV)",fontsize=7)
axins1.set_ylabel(r"$\mathrm{Im}(\overline{\varepsilon}_{100})$",fontsize=7)
axins1.yaxis.set_label_position("right")
axins1.xaxis.set_label_coords(0.5,-0.2)
axins1.yaxis.set_label_coords(1.1,0.5)
axins2.set_xlabel(r"Energy (eV)",fontsize=7)
axins2.set_ylabel(r"$\mathrm{Re}(\overline{n}_{100})$",fontsize=7)
axins2.yaxis.set_label_position("right")
axins2.xaxis.set_label_coords(0.5,-0.2)
axins2.yaxis.set_label_coords(1.1,0.5)

fig.savefig("plots/Fig2.pdf", dpi=600)

In [None]:
# Figure S3
lw=1.2
fig, axes = plt.subplots(5, 6, figsize=[6 + 6 / 8, 5])
np.random.seed(seed)
indices = np.arange(len(test_100))
for ax in axes.ravel():
    graph_id = np.random.choice(indices)
    graph = test_100[graph_id]
    ax.plot(graph.y.cpu(), color=c_dft, linewidth=lw)
    ax.plot(
        model_100(graph.cuda()).cpu().detach().numpy().flatten(),
        color=c_ml,
        linewidth=lw,
    )
    ax.set_title(graph["mat_id"])
    ax.set_xlim([0, 2000])
    ax.set_xticks([0, 500, 1000, 1500, 2000])
    ax.set_xticklabels([0, 5, 10, 15, 20])
    ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(3))
fig.supxlabel(r"Energy (eV)", x=0.52, size=12)
fig.supylabel(r"$\mathrm{Im}(\overline{\varepsilon}_{100})$", y=0.52, size=12)
fig.tight_layout()
fig.savefig("plots/S3.pdf")

In [None]:
# Figure S4
lw=1.2
fig, axes = plt.subplots(5, 6, figsize=[6 + 6 / 8, 5])
np.random.seed(seed)
indices = np.arange(len(test_100))
for ax in axes.ravel():
    graph_id = np.random.choice(indices)
    graph = test_300[graph_id]
    ax.plot(graph.y.cpu(), color=c_dft, linewidth=lw)
    ax.plot(
        model_300(graph.cuda()).cpu().detach().numpy().flatten(),
        color=c_ml,
        linewidth=lw,
    )
    ax.set_title(graph["mat_id"])
    ax.set_xlim([0, 2000])
    ax.set_xticks([0, 500, 1000, 1500, 2000])
    ax.set_xticklabels([0, 5, 10, 15, 20])
    ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(3))
fig.supxlabel(r"Energy (eV)", x=0.52, size=12)
fig.supylabel(r"$\mathrm{Im}(\overline{\varepsilon}_{300})$", y=0.52, size=12)
fig.tight_layout()
fig.savefig("plots/S4.pdf")

In [None]:
# Figure S5
lw=1.2
fig, axes = plt.subplots(5, 6, figsize=[6 + 6 / 8, 5])
np.random.seed(seed)
indices = np.arange(len(test_100))
for ax in axes.ravel():
    graph_id = np.random.choice(indices)
    graph = test_100_n[graph_id]
    ax.plot(graph.y.cpu(), color=c_dft, linewidth=lw)
    ax.plot(
        model_100_n(graph.cuda()).cpu().detach().numpy().flatten(),
        color=c_ml,
        linewidth=lw,
    )
    ax.set_title(graph["mat_id"])
    ax.set_xlim([0, 2000])
    ax.set_xticks([0, 500, 1000, 1500, 2000])
    ax.set_xticklabels([0, 5, 10, 15, 20])
    ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(3))
fig.supxlabel(r"Energy (eV)", x=0.52, size=12)
fig.supylabel(r"$\mathrm{Re}(\overline{n}_{100})$", y=0.52, size=12)
fig.tight_layout()
fig.savefig("plots/S5.pdf")

In [None]:
# Figure S6
lw=1.2
fig, axes = plt.subplots(5, 6, figsize=[6 + 6 / 8, 5])
np.random.seed(seed)
indices = np.arange(len(test_100))
for ax in axes.ravel():
    graph_id = np.random.choice(indices)
    graph = test_300_n[graph_id]
    ax.plot(graph.y.cpu(), color=c_dft, linewidth=lw)
    ax.plot(
        model_300_n(graph.cuda()).cpu().detach().numpy().flatten(),
        color=c_ml,
        linewidth=lw,
    )
    ax.set_title(graph["mat_id"])
    ax.set_xlim([0, 2000])
    ax.set_xticks([0, 500, 1000, 1500, 2000])
    ax.set_xticklabels([0, 5, 10, 15, 20])
    ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(3))
fig.supxlabel(r"Energy (eV)", x=0.52, size=12)
fig.supylabel(r"$\mathrm{Re}(\overline{n}_{300})$", y=0.52, size=12)
fig.tight_layout()
fig.savefig("plots/S6.pdf")

In [None]:
# Figure S7
train_vals = []
for graph in train_300:
    train_vals.append(graph.y.cpu().detach().numpy())
train_vals = np.array(train_vals)
train_mean = train_vals.mean(axis=0)

test_vs_mean_300 = []
for graph in test_300:
    test_vs_mean_300.append(
        1
        - np.trapz(np.abs(graph.y.cpu().detach().numpy() - train_mean))
        / np.trapz(graph.y.cpu().detach().numpy())
    )

train_vals = []
for graph in train_300_n:
    train_vals.append(graph.y.cpu().detach().numpy())
train_vals = np.array(train_vals)
train_mean_n = train_vals.mean(axis=0)

test_vs_mean_300_n = []
for graph in test_300_n:
    test_vs_mean_300_n.append(
        1
        - np.trapz(np.abs(graph.y.cpu().detach().numpy() - train_mean_n))
        / np.trapz(graph.y.cpu().detach().numpy())
    )
alpha = 0.8

fig, axes = plt.subplots(2, 1, figsize=[3 + 3 / 8, 0.85*2.6], sharex=True)
axes = axes.ravel()

axes[0].hist(
    test_vs_mean_300,
    bins=np.arange(0, 1.01, 0.01) - 0.005,
    alpha=alpha,
    label="Mean of training set",
    color="tab:blue",
)
axes[0].hist(
    1 - mse_df_300["sc"].values,
    bins=np.arange(0, 1.01, 0.01) - 0.005,
    alpha=alpha,
    label="ML model",
    color="tab:orange",
)
axes[1].hist(
    test_vs_mean_300_n,
    bins=np.arange(0, 1.01, 0.01) - 0.005,
    alpha=alpha,
    label="Mean of training set",
    color="tab:blue",
)
axes[1].hist(
    1 - mse_df_300_n["sc"].values,
    bins=np.arange(0, 1.01, 0.01) - 0.005,
    alpha=alpha,
    label="ML model",
    color="tab:orange",
)
plt.setp(axes[0].get_xticklabels(), visible=False)
axes[0].set_xlim([0, 1])


plt.subplots_adjust(hspace=0)
fig.supylabel("Counts", x=0.02)
fig.supxlabel("$\mathrm{SC}$", y=0.02, x=0.02 + 0.5)


# set y-ticks manually
axes[0].set_yticks([20,40,60,80])
axes[1].set_yticks([50,100,150,200])

axes[0].text(
    0.87-0.15,
    0.875,
    "Mean",
    transform=axes[0].transAxes,
    size=9,
    color="tab:Blue"
)
axes[0].text(
    0.92-0.15,
    0.74,
    "DL",
    transform=axes[0].transAxes,
    size=9,
    color="tab:orange"
)


### Create insets

bounds = [0.03,0.45,0.28,0.49]
axins1 = axes[0].inset_axes(bounds)
axins2 = axes[1].inset_axes(bounds)
axins1.tick_params(bottom=True, top=False, left=False, right=True, labelbottom=True, labeltop=False, labelleft=False, labelright=True,labelsize=6,pad=1)
axins2.tick_params(bottom=True, top=False, left=False, right=True, labelbottom=True, labeltop=False, labelleft=False, labelright=True,labelsize=6,pad=1)
axins1.plot(train_mean,color="tab:blue")
axins2.plot(train_mean_n,color="tab:blue")
axins1.set_xlim([0,2001])
axins1.set_xticks([501,1501])
axins1.set_xticklabels([5,15])
axins2.set_xlim([0,2001])
axins2.set_xticks([501,1501])
axins2.set_xticklabels([5,15])
axins1.set_ylim([0,4.2])
axins1.set_yticks([1,3])
axins2.set_ylim([0.7,2.5])
axins2.set_yticks([1,2])
axins1.set_xlabel(r"Energy (eV)",fontsize=7)
axins1.set_ylabel(r"$\mathrm{Im}(\overline{\varepsilon}_{300})$",fontsize=7)
axins1.yaxis.set_label_position("right")
axins1.xaxis.set_label_coords(0.5,-0.2)
axins1.yaxis.set_label_coords(1.1,0.5)
axins2.set_xlabel(r"Energy (eV)",fontsize=7)
axins2.set_ylabel(r"$\mathrm{Re}(\overline{n}_{300})$",fontsize=7)
axins2.yaxis.set_label_position("right")
axins2.xaxis.set_label_coords(0.5,-0.2)
axins2.yaxis.set_label_coords(1.1,0.5)

fig.savefig("plots/S7.pdf", dpi=600)