In [None]:
import pytorch_lightning as L
import torch_geometric
import torch
import pandas as pd
import pickle
import numpy as np
import os
import matplotlib.pyplot as plt
import random
import re

import src.utils as utils
import src.models as models

seed = 0


def get_sc(X, Y):
    X = np.stack(X.values)
    Y = np.stack(Y.values)
    return 1 - np.trapz(np.abs(X - Y), axis=1) / np.trapz(abs(Y), axis=1)

## Load data and set parameters

In [None]:
with open("database/graphs_300_rpa.pckl", "rb") as graph_file:
    graphs = pickle.load(graph_file)
data_df = pd.read_pickle("database/data_300_rpa.pckl")

In [None]:
train, val, test = utils.train_val_test(graphs)
random.seed(seed)

# select here on how many samples you want to train
curr_train = random.sample(train, 300)

# optimal parameters
params_100 = [[96, 96, 96], [192, 2], [512, 512, 512]]
params_300 = [[96, 96, 96], [384, 4], [512, 512]]
params_1000 = [[96, 96, 96], [192, 8], [96, 8], [96, 2], [2048, 2048, 2048]]
params_3000 = [[48, 48, 48], [384, 8], [2048, 2048, 2048]]
params_4610 = [[96, 96], [192, 8], [192, 2], [2048, 2048, 2048]]
params_ipa = [[48, 48], [48, 4], [96, 4], [1024, 1024, 1024]]

print("Length of full training set: " + str(len(train)))
print("Length of full validation set: " + str(len(val)))
print("Length of full test set: " + str(len(test)))

## Train various models

These cells do not need to be run as trained models are supplied already.         
The line which saves the state_dict is also commented out for that reason.            
If you want to train your own models, check for optimal hyperparameters in the Supplemental Information of the paper.            
If you use these, you should get similar results.

In [None]:
# Train a TL model
checkpoint_callback = L.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="valid",
    mode="min",
    dirpath="ckpts",
    filename="{epoch:02d}-{valid:.2f}",
)

TrainLoader = torch_geometric.loader.DataLoader(
    curr_train, batch_size=128, shuffle=True
)
ValLoader = torch_geometric.loader.DataLoader(val, batch_size=128, shuffle=False)
TestLoader = torch_geometric.loader.DataLoader(test, batch_size=1, shuffle=False)

model_tl = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
trainer = L.Trainer(
    max_epochs=500,
    check_val_every_n_epoch=20,
    precision="bf16-mixed",
    callbacks=[checkpoint_callback],
)
trainer.fit(model=model_tl, train_dataloaders=TrainLoader, val_dataloaders=ValLoader)
checkpoint = torch.load(checkpoint_callback.best_model_path)
model_tl.load_state_dict(checkpoint["state_dict"])
results = utils.model_eval(model_tl, TestLoader, data_df)
results.describe()
# torch.save(model_tl.state_dict(), "trained_models/tl4610.pt")

In [None]:
# Train a DL model

TrainLoader = torch_geometric.loader.DataLoader(
    curr_train, batch_size=20, shuffle=True, drop_last=True
)
ValLoader = torch_geometric.loader.DataLoader(val, batch_size=32, shuffle=False)
TestLoader = torch_geometric.loader.DataLoader(test, batch_size=1, shuffle=False)

checkpoint_callback = L.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="valid",
    mode="min",
    dirpath="ckpts",
    filename="{epoch:02d}-{valid:.2f}",
)

model_dl = models.LitGatNN(lr=1e-5, decay=1e-4, params=params_300)
trainer = L.Trainer(
    max_epochs=500,
    check_val_every_n_epoch=20,
    precision="bf16-mixed",
    callbacks=[checkpoint_callback],
)
trainer.fit(model=model_dl, train_dataloaders=TrainLoader, val_dataloaders=ValLoader)
checkpoint = torch.load(checkpoint_callback.best_model_path)
model_dl.load_state_dict(checkpoint["state_dict"])
results = utils.model_eval(model_dl, TestLoader, data_df)
results.describe()
# torch.save(model_dl.state_dict(), "trained_models/dl300.pt")

In [None]:
# TL on small cells

for max_size in range(2, 7):
    # select only small cells
    small_train = []
    max_sites = max_size
    for graph in train:
        if graph.x.shape[0] <= max_sites:
            small_train.append(graph)

    # Train a TL model
    checkpoint_callback = L.callbacks.ModelCheckpoint(
        save_top_k=1,
        monitor="valid",
        mode="min",
        dirpath="ckpts",
        filename="{epoch:02d}-{valid:.2f}",
    )

    TrainLoader = torch_geometric.loader.DataLoader(
        small_train, batch_size=128, shuffle=True
    )
    ValLoader = torch_geometric.loader.DataLoader(val, batch_size=128, shuffle=False)
    TestLoader = torch_geometric.loader.DataLoader(test, batch_size=1, shuffle=False)

    model_tl = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
    trainer = L.Trainer(
        max_epochs=500,
        check_val_every_n_epoch=20,
        precision="bf16-mixed",
        callbacks=[checkpoint_callback],
    )
    trainer.fit(
        model=model_tl, train_dataloaders=TrainLoader, val_dataloaders=ValLoader
    )
    checkpoint = torch.load(checkpoint_callback.best_model_path)
    model_tl.load_state_dict(checkpoint["state_dict"])
    results = utils.model_eval(model_tl, TestLoader, data_df)
    results.describe()
    # torch.save(model_tl.state_dict(), f"trained_models/tl{max_size}.pt")

## Load trained models

In [None]:
# Load an IPA model
model_ipa = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=2e-3, decay=0)
model_ipa.eval()
model_ipa.cuda()

In [None]:
# Load DL models
model_dl_100 = models.LitGatNN(lr=1e-3, decay=0, params=params_100)
state_dict = torch.load("trained_models/dl100.pt")
model_dl_100.load_state_dict(state_dict)
model_dl_300 = models.LitGatNN(lr=1e-3, decay=0, params=params_300)
state_dict = torch.load("trained_models/dl300.pt")
model_dl_300.load_state_dict(state_dict)
model_dl_1000 = models.LitGatNN(lr=1e-3, decay=0, params=params_1000)
state_dict = torch.load("trained_models/dl1000.pt")
model_dl_1000.load_state_dict(state_dict)
model_dl_3000 = models.LitGatNN(lr=1e-3, decay=0, params=params_3000)
state_dict = torch.load("trained_models/dl3000.pt")
model_dl_3000.load_state_dict(state_dict)
model_dl_4610 = models.LitGatNN(lr=1e-3, decay=0, params=params_4610)
state_dict = torch.load("trained_models/dl4610.pt")
model_dl_4610.load_state_dict(state_dict)

model_dl_100.cuda()
model_dl_300.cuda()
model_dl_1000.cuda()
model_dl_3000.cuda()
model_dl_4610.cuda()

model_dl_100.eval()
model_dl_300.eval()
model_dl_1000.eval()
model_dl_3000.eval()
model_dl_4610.eval()

In [None]:
# Load TL models
model_tl_100 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl100.pt")
model_tl_100.load_state_dict(state_dict)
model_tl_300 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl300.pt")
model_tl_300.load_state_dict(state_dict)
model_tl_1000 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl1000.pt")
model_tl_1000.load_state_dict(state_dict)
model_tl_3000 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl3000.pt")
model_tl_3000.load_state_dict(state_dict)
model_tl_4610 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl4610.pt")
model_tl_4610.load_state_dict(state_dict)

model_tl_100.cuda()
model_tl_300.cuda()
model_tl_1000.cuda()
model_tl_3000.cuda()
model_tl_4610.cuda()

model_tl_100.eval()
model_tl_300.eval()
model_tl_1000.eval()
model_tl_3000.eval()
model_tl_4610.eval()

In [None]:
# Load small models
model_tl_2 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl2.pt")
model_tl_2.load_state_dict(state_dict)
model_tl_3 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl3.pt")
model_tl_3.load_state_dict(state_dict)
model_tl_4 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl4.pt")
model_tl_4.load_state_dict(state_dict)
model_tl_5 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl5.pt")
model_tl_5.load_state_dict(state_dict)
model_tl_6 = models.LitGatNN_pre("IPA_results/eps_300.pt", lr=1e-5, decay=1e-4)
state_dict = torch.load("trained_models/tl6.pt")
model_tl_6.load_state_dict(state_dict)


model_tl_2.cuda()
model_tl_3.cuda()
model_tl_4.cuda()
model_tl_5.cuda()
model_tl_6.cuda()

## Evaluate models

In [None]:
for graph in test:
    graph.cuda()
TestLoader = torch_geometric.loader.DataLoader(test, batch_size=1, shuffle=False)
results_dl_100 = utils.model_eval(model_dl_100, TestLoader, data_df)
results_dl_300 = utils.model_eval(model_dl_300, TestLoader, data_df)
results_dl_1000 = utils.model_eval(model_dl_1000, TestLoader, data_df)
results_dl_3000 = utils.model_eval(model_dl_3000, TestLoader, data_df)
results_dl_4610 = utils.model_eval(model_dl_4610, TestLoader, data_df)

results_tl_100 = utils.model_eval(model_tl_100, TestLoader, data_df)
results_tl_300 = utils.model_eval(model_tl_300, TestLoader, data_df)
results_tl_1000 = utils.model_eval(model_tl_1000, TestLoader, data_df)
results_tl_3000 = utils.model_eval(model_tl_3000, TestLoader, data_df)
results_tl_4610 = utils.model_eval(model_tl_4610, TestLoader, data_df)

results_tl_2 = utils.model_eval(model_tl_2, TestLoader, data_df)
results_tl_3 = utils.model_eval(model_tl_3, TestLoader, data_df)
results_tl_4 = utils.model_eval(model_tl_4, TestLoader, data_df)
results_tl_5 = utils.model_eval(model_tl_5, TestLoader, data_df)
results_tl_6 = utils.model_eval(model_tl_6, TestLoader, data_df)

In [None]:
# One can get the error measures on the test set like so:
results_dl_100.describe()

In [None]:
# Evaluate model trained on small cells for separate cell sizes
errors_and_size = []
for _, row in results_tl_4.iterrows():
    n_sites = data_df[data_df["mat_id"] == row["name"][0]].nsites.values[0]
    errors_and_size.append([row.mse, row.sc, n_sites])

list_of_lists = [[], [], [], [], [], [], [], []]
for row in errors_and_size:
    list_of_lists[row[2] - 1].append([row[0], row[1]])
for idx, row in enumerate(list_of_lists):
    list_of_lists[idx] = np.array(row)
errors_and_size = np.array(errors_and_size)

In [None]:
# IPA as baseline
test_ipa = []
test_rpa = []
for graph in test:
    test_ipa.append(graph.ipa.cpu().detach().numpy())
    test_rpa.append(graph.y.cpu().detach().numpy())
test_ipa = np.stack(test_ipa)
test_rpa = np.stack(test_rpa)
print(
    "Median MSE[IPA_DFT,RPA_DFT]: "
    + str(np.median(((test_ipa - test_rpa) ** 2).mean(axis=1)))
)
print(
    "Median SC[IPA_DFT,RPA_DFT]: "
    + str(
        np.median(
            1
            - np.trapz(np.abs(test_ipa - test_rpa), axis=1)
            / np.trapz(abs(test_rpa), axis=1)
        )
    )
)

In [None]:
# Evaluate train errors
errors_dl_100 = []
errors_dl_300 = []
errors_dl_1000 = []
errors_dl_3000 = []
errors_dl_4610 = []

errors_tl_100 = []
errors_tl_300 = []
errors_tl_1000 = []
errors_tl_3000 = []
errors_tl_4610 = []

random.seed(seed)
train_100 = random.sample(train, 100)
random.seed(seed)
train_300 = random.sample(train, 300)
random.seed(seed)
train_1000 = random.sample(train, 1000)
random.seed(seed)
train_3000 = random.sample(train, 3000)

for graph in train_100:
    graph.cuda()
    true = graph.y.cpu().detach().numpy()
    errors_dl_100.append(model_dl_100(graph).cpu().detach().numpy() - true)
    errors_tl_100.append(model_tl_100(graph).cpu().detach().numpy() - true)

for graph in train_300:
    graph.cuda()
    true = graph.y.cpu().detach().numpy()
    errors_dl_300.append(model_dl_300(graph).cpu().detach().numpy() - true)
    errors_tl_300.append(model_tl_300(graph).cpu().detach().numpy() - true)

for graph in train_1000:
    graph.cuda()
    true = graph.y.cpu().detach().numpy()
    errors_dl_1000.append(model_dl_1000(graph).cpu().detach().numpy() - true)
    errors_tl_1000.append(model_tl_1000(graph).cpu().detach().numpy() - true)

for graph in train_3000:
    graph.cuda()
    true = graph.y.cpu().detach().numpy()
    errors_dl_3000.append(model_dl_3000(graph).cpu().detach().numpy() - true)
    errors_tl_3000.append(model_tl_3000(graph).cpu().detach().numpy() - true)

for graph in train:
    graph.cuda()
    true = graph.y.cpu().detach().numpy()
    errors_dl_4610.append(model_dl_4610(graph).cpu().detach().numpy() - true)
    errors_tl_4610.append(model_tl_4610(graph).cpu().detach().numpy() - true)

errors_dl_100 = np.array(errors_dl_100)
errors_dl_300 = np.array(errors_dl_300)
errors_dl_1000 = np.array(errors_dl_1000)
errors_dl_3000 = np.array(errors_dl_3000)
errors_dl_4610 = np.array(errors_dl_4610)
errors_tl_100 = np.array(errors_tl_100)
errors_tl_300 = np.array(errors_tl_300)
errors_tl_1000 = np.array(errors_tl_1000)
errors_tl_3000 = np.array(errors_tl_3000)
errors_tl_4610 = np.array(errors_tl_4610)

print(
    "Median train MAE for DL 100: "
    + str(np.median(np.mean(np.abs(errors_dl_100), axis=1)))
)
print(
    "Median train MSE for DL 100: "
    + str(np.median(np.mean((errors_dl_100) ** 2, axis=1)))
)
print(
    "Median train MAE for DL 300: "
    + str(np.median(np.mean(np.abs(errors_dl_300), axis=1)))
)
print(
    "Median train MSE for DL 300: "
    + str(np.median(np.mean((errors_dl_300) ** 2, axis=1)))
)
print(
    "Median train MAE for DL 1000: "
    + str(np.median(np.mean(np.abs(errors_dl_1000), axis=1)))
)
print(
    "Median train MSE for DL 1000: "
    + str(np.median(np.mean((errors_dl_1000) ** 2, axis=1)))
)
print(
    "Median train MAE for DL 3000: "
    + str(np.median(np.mean(np.abs(errors_dl_3000), axis=1)))
)
print(
    "Median train MSE for DL 3000: "
    + str(np.median(np.mean((errors_dl_3000) ** 2, axis=1)))
)
print(
    "Median train MAE for DL 4610: "
    + str(np.median(np.mean(np.abs(errors_dl_4610), axis=1)))
)
print(
    "Median train MSE for DL 4610: "
    + str(np.median(np.mean((errors_dl_4610) ** 2, axis=1)))
)

print("\n")

print(
    "Median train MAE for TL 100: "
    + str(np.median(np.mean(np.abs(errors_tl_100), axis=1)))
)
print(
    "Median train MSE for TL 100: "
    + str(np.median(np.mean((errors_tl_100) ** 2, axis=1)))
)
print(
    "Median train MAE for TL 300: "
    + str(np.median(np.mean(np.abs(errors_tl_300), axis=1)))
)
print(
    "Median train MSE for TL 300: "
    + str(np.median(np.mean((errors_tl_300) ** 2, axis=1)))
)
print(
    "Median train MAE for TL 1000: "
    + str(np.median(np.mean(np.abs(errors_tl_1000), axis=1)))
)
print(
    "Median train MSE for TL 1000: "
    + str(np.median(np.mean((errors_tl_1000) ** 2, axis=1)))
)
print(
    "Median train MAE for TL 3000: "
    + str(np.median(np.mean(np.abs(errors_tl_3000), axis=1)))
)
print(
    "Median train MSE for TL 3000: "
    + str(np.median(np.mean((errors_tl_3000) ** 2, axis=1)))
)
print(
    "Median train MAE for TL 4610: "
    + str(np.median(np.mean(np.abs(errors_tl_4610), axis=1)))
)
print(
    "Median train MSE for TL 4610: "
    + str(np.median(np.mean((errors_tl_4610) ** 2, axis=1)))
)

## New plots

In [None]:
plt.style.use("publication.mplstyle")

In [None]:
model_dl = model_dl_300
model_tl = model_tl_300

In [None]:
from matplotlib.lines import Line2D

lw = 1
fig, axes = plt.subplots(3, 3, figsize=[(3 + 3 / 8) * 2, 3])
model_tl.cuda()
model_dl.cuda()
model_dl.eval()
model_tl.eval()
ipa_rpa_vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
titles = []
for idx, ax in enumerate(axes.ravel()):
    graph_name = results_dl_100.sort_values("ipa_rpa").iloc[
        int(ipa_rpa_vals[idx] * len(results_dl_100) + 1)
    ]["name"][0]
    for graph in test:
        if graph.mat_id == graph_name:
            ax.plot(graph.ipa.cpu(), "k", linewidth=lw * 0.5)
            ax.plot(graph.y.cpu(), "k", linewidth=lw)
            ax.plot(
                model_dl_300(graph.cuda()).cpu().detach().numpy().flatten(),
                color="limegreen",
                linewidth=lw * 0.5,
            )
            ax.plot(
                model_dl_4610(graph.cuda()).cpu().detach().numpy().flatten(),
                color="limegreen",
                linestyle="--",
                linewidth=lw,
            )
            ax.plot(
                model_tl_300(graph.cuda()).cpu().detach().numpy().flatten(),
                color="tab:orange",
                linewidth=lw * 0.5,
            )
            ax.plot(
                model_tl_4610(graph.cuda()).cpu().detach().numpy().flatten(),
                color="tab:orange",
                linestyle="--",
                linewidth=lw,
            )

            titles.append(graph["mat_id"][5:])
            ax.set_ylim(bottom=0, top=ax.get_ylim()[1] * 1.2)
            ax.set_xlim([0, 1000])
            ax.set_xticks([0, 200, 400, 600, 800, 1000])
            ax.set_xticklabels([0, 2, 4, 6, 8, 10])
            ax.annotate(
                re.sub(
                    r"(\d+)",
                    r"$_{\1}$",
                    data_df[data_df["mat_id"] == graph_name].formula.values[0],
                ),
                xy=(1, 1),
                xycoords="axes fraction",
                xytext=(-4, -4),
                textcoords="offset points",
                ha="right",
                va="top",
                fontsize=8,
                clip_on=False,
            )


axes.ravel()[0].annotate("$Q_{10\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")
axes.ravel()[1].annotate("$Q_{20\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")
axes.ravel()[2].annotate("$Q_{30\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")
axes.ravel()[3].annotate("$Q_{40\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")
axes.ravel()[4].annotate("$Q_{50\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")
axes.ravel()[5].annotate("$Q_{60\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")
axes.ravel()[6].annotate("$Q_{70\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")
axes.ravel()[7].annotate("$Q_{80\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")
axes.ravel()[8].annotate("$Q_{90\%}$ ", xy=[0.05, 0.85], xycoords="axes fraction")


legend_handles = [
    Line2D([0], [0], color="k", lw=lw * 1.5, label="RPA (target)"),
    Line2D([0], [0], color="k", lw=lw * 0.5, label="IPA"),
    Line2D([0], [0], color="limegreen", lw=lw * 0.5, label="DL ($N=300$)"),
    Line2D([0], [0], color="limegreen", lw=lw, linestyle="--", label="DL ($N=4610$)"),
    Line2D([0], [0], color="tab:orange", lw=lw * 0.5, label="TL ($N=300$)"),
    Line2D([0], [0], color="tab:orange", lw=lw, linestyle="--", label="TL ($N=4610$)"),
]


fig.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(0.85, 0.51),
    frameon=False,
)


# set y-ticks manually
axes[0, 0].set_yticks([0, 5, 10])
axes[1, 0].set_yticks([0, 5, 10])
axes[2, 0].set_yticks([0, 10, 20])
axes[0, 1].set_yticks([0, 3, 6])
axes[1, 1].set_yticks([0, 5, 10, 15])
axes[2, 1].set_yticks([0, 5, 10])
axes[0, 2].set_yticks([0, 2, 4])
axes[1, 2].set_yticks([0, 5, 10])
axes[2, 2].set_yticks([0, 3, 6])

fig.supxlabel(r"Energy (eV)", x=0.45, y=-0.01)
fig.supylabel(r"$\mathrm{Im}(\overline{\varepsilon})$", x=-0.005, y=0.53)
fig.tight_layout(pad=0.0, w_pad=0.2, h_pad=0.3, rect=[0, 0, 0.85, 1])
fig.show()
fig.savefig("plots/Fig1_revised.pdf")

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

lw = 1
fig, axes = plt.subplots(3, 1, figsize=[(3 + 3 / 8), 7])


# Top plot
x_vals = np.array([100, 300, 1000, 3000, 4610])
x_vals_small = np.array([74, 343, 1470, 2191, 2699])
y_vals_small = np.array([0.791, 0.848, 0.870, 0.879, 0.884])
y_vals_dl = np.array([0.716, 0.754, 0.827, 0.874, 0.900])
y_vals_tl = np.array([0.834, 0.863, 0.881, 0.893, 0.893])
axes[0].set_xscale("log")

axes[0].plot(x_vals, y_vals_dl, marker="*", color="limegreen", label="Direct Learning")
axes[0].plot(
    x_vals, y_vals_tl, marker="*", color="tab:orange", label="Transfer Learning"
)
axes[0].plot(
    x_vals_small, y_vals_small, marker="*", color="tab:blue", label="TL on small cells"
)
axes[0].plot([0, 10000], [0.6896362, 0.6896362], "k--", label="IPA baseline")

axes[0].set_xlabel("Training Set Size")
axes[0].set_ylabel("Median SC[RPA$_\mathrm{ML}$;RPA$_\mathrm{DFT}]$")
axes[0].set_yticks([0.7, 0.75, 0.8, 0.85, 0.9])
axes[0].legend()

axes[0].annotate(2, xy=[x_vals_small[0], y_vals_small[0] - 0.01])
axes[0].annotate(3, xy=[x_vals_small[1], y_vals_small[1] - 0.01])
axes[0].annotate(4, xy=[x_vals_small[2], y_vals_small[2] - 0.01])
axes[0].annotate(5, xy=[x_vals_small[3], y_vals_small[3] - 0.01])
axes[0].annotate(6, xy=[x_vals_small[4] + 200, y_vals_small[4] - 0.005])

axes[0].set_xlim([60, 6000])

# Middle plot
y_vals_small = np.array([0.225, 0.117, 0.086, 0.079, 0.075])
y_vals_dl = np.array([0.394, 0.262, 0.160, 0.075, 0.052])
y_vals_tl = np.array([0.133, 0.101, 0.077, 0.059, 0.061])
axes[1].set_xscale("log")
axes[1].set_yscale("log")

axes[1].plot(x_vals, y_vals_dl, marker="*", color="limegreen", label="Direct Learning")
axes[1].plot(
    x_vals, y_vals_tl, marker="*", color="tab:orange", label="Transfer Learning"
)
axes[1].plot(
    x_vals_small, y_vals_small, marker="*", color="tab:blue", label="TL on small cells"
)
axes[1].plot([0, 10000], [0.46935984, 0.46935984], "k--", label="IPA baseline")

axes[1].set_xlabel("Training Set Size")
axes[1].set_ylabel("Median MSE[RPA$_\mathrm{ML}$;RPA$_\mathrm{DFT}$]", labelpad=-3)
axes[1].set_yticks([5e-2, 6e-2, 0.1, 0.2, 0.3, 0.4])
axes[1].set_yticklabels(["0.05", "", "0.1", "0.2", "0.3", "0.4"])

# axes[1].legend()

axes[1].annotate(2, xy=[x_vals_small[0] + 1, y_vals_small[0] + 0.005])
axes[1].annotate(3, xy=[x_vals_small[1], y_vals_small[1] + 0.005])
axes[1].annotate(4, xy=[x_vals_small[2], y_vals_small[2] + 0.005])
axes[1].annotate(5, xy=[x_vals_small[3], y_vals_small[3] + 0.005])
axes[1].annotate(6, xy=[x_vals_small[4], y_vals_small[4] + 0.005])

axes[1].set_xlim([60, 6000])

# Bottom plot
# Version without inset

axes2 = axes[2].twinx()
axes2.set_zorder(0)
axes[2].set_zorder(1)
axes[2].patch.set_alpha(0)
axes2.patch.set_visible(False)

sc_list = [
    list_of_lists[2][:, 1],
    list_of_lists[3][:, 1],
    list_of_lists[4][:, 1],
    list_of_lists[5][:, 1],
    list_of_lists[6][:, 1],
    list_of_lists[7][:, 1],
]
mse_list = [
    list_of_lists[2][:, 0],
    list_of_lists[3][:, 0],
    list_of_lists[4][:, 0],
    list_of_lists[5][:, 0],
    list_of_lists[6][:, 0],
    list_of_lists[7][:, 0],
]
axes[2].boxplot(sc_list, showfliers=False)
axes[2].set_xlabel("Atoms per unit cell")
axes[2].set_ylabel("SC[RPA$_\mathrm{ML}$;RPA$_\mathrm{DFT}$]")
axes[2].set_xticks([1, 2, 3, 4, 5, 6])
axes[2].set_xticklabels([3, 4, 5, 6, 7, 8])
axes[2].set_ylim([0.3, 1])
axes2.set_ylim([0, 300])
axes2.set_xlim([0, 7])

axes2.hist(
    errors_and_size[:, 2] - 2,
    bins=[0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
    color="tab:blue",
)
axes2.set_ylabel("Occurence in test set")


fig.tight_layout(w_pad=1, h_pad=0.5)
plt.show()
fig.savefig("plots/Fig2_revised.pdf")

In [None]:
model_ipa.eval()
model_ipa.cuda()
model_dl_300.eval()
model_dl_300.cuda()
model_tl_300.eval()
model_tl_300.cuda()
rows = []
for graph in test:
    cdict = {}
    graph.cuda()
    cdict["ipa_true"] = graph.ipa.cpu().numpy()
    cdict["rpa_true"] = graph.y.cpu().numpy()
    cdict["ipa_pred"] = model_ipa(graph).cpu().detach().numpy()
    cdict["dl_pred"] = model_dl_300(graph).cpu().detach().numpy()
    cdict["tl_pred"] = model_tl_300(graph).cpu().detach().numpy()
    rows.append(cdict)
comp_df_300 = pd.DataFrame(rows)
comp_df_300["iptrue_rptrue"] = get_sc(comp_df_300["ipa_true"], comp_df_300["rpa_true"])
comp_df_300["ippred_iptrue"] = get_sc(comp_df_300["ipa_pred"], comp_df_300["ipa_true"])
comp_df_300["dlpred_rptrue"] = get_sc(comp_df_300["dl_pred"], comp_df_300["rpa_true"])
comp_df_300["tlpred_rptrue"] = get_sc(comp_df_300["tl_pred"], comp_df_300["rpa_true"])

model_dl_4610.eval()
model_dl_4610.cuda()
model_tl_4610.eval()
model_tl_4610.cuda()
rows = []
for graph in test:
    cdict = {}
    graph.cuda()
    cdict["ipa_true"] = graph.ipa.cpu().numpy()
    cdict["rpa_true"] = graph.y.cpu().numpy()
    cdict["ipa_pred"] = model_ipa(graph).cpu().detach().numpy()
    cdict["dl_pred"] = model_dl_4610(graph).cpu().detach().numpy()
    cdict["tl_pred"] = model_tl_4610(graph).cpu().detach().numpy()
    rows.append(cdict)
comp_df_4610 = pd.DataFrame(rows)
comp_df_4610["iptrue_rptrue"] = get_sc(
    comp_df_4610["ipa_true"], comp_df_4610["rpa_true"]
)
comp_df_4610["ippred_iptrue"] = get_sc(
    comp_df_4610["ipa_pred"], comp_df_4610["ipa_true"]
)
comp_df_4610["dlpred_rptrue"] = get_sc(
    comp_df_4610["dl_pred"], comp_df_4610["rpa_true"]
)
comp_df_4610["tlpred_rptrue"] = get_sc(
    comp_df_4610["tl_pred"], comp_df_4610["rpa_true"]
)

In [None]:
fig, axes = plt.subplots(
    3,
    3,
    figsize=[2 * (3 + 3 / 8), 6],
    gridspec_kw={"width_ratios": [2, 2, 0.4], "height_ratios": [0.4, 2, 2]},
)
axes = axes.ravel()
bins = np.linspace(0, 1, 51)
cmp = "Blues"
comp_df = comp_df_300

axes[0].hist(comp_df["dlpred_rptrue"], bins=bins, color="tab:blue")
axes[1].hist(comp_df["tlpred_rptrue"], bins=bins, color="tab:blue")
axes[5].hist(
    comp_df["iptrue_rptrue"], bins=bins, orientation="horizontal", color="tab:blue"
)
axes[8].hist(
    comp_df["ippred_iptrue"], bins=bins, orientation="horizontal", color="tab:blue"
)


X = "dlpred_rptrue"
Y = "iptrue_rptrue"
axes[3].hist2d(comp_df[X], comp_df[Y], bins=bins, cmap=cmp)
axes[3].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)
axes[3].set_ylabel("SC[IPA$_\mathrm{DFT}$;RPA$_\mathrm{DFT}$]")

X = "tlpred_rptrue"
Y = "iptrue_rptrue"
axes[4].hist2d(comp_df[X], comp_df[Y], bins=bins, cmap=cmp)
axes[4].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)

X = "dlpred_rptrue"
Y = "ippred_iptrue"
axes[6].hist2d(comp_df[X], comp_df[Y], bins=bins, cmap=cmp)
axes[6].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)
axes[6].set_xlabel("SC[RPA$_\mathrm{DL}$;RPA$_\mathrm{DFT}$]")
axes[6].set_ylabel("SC[IPA$_\mathrm{ML}$;IPA$_\mathrm{DFT}$]")

X = "tlpred_rptrue"
Y = "ippred_iptrue"
axes[7].hist2d(comp_df[X], comp_df[Y], bins=bins, cmap=cmp)
axes[7].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)
axes[7].set_xlabel("SC[RPA$_\mathrm{TL}$;RPA$_\mathrm{DFT}$]")


axes[0].grid(False)
axes[1].grid(False)
axes[5].grid(False)
axes[8].grid(False)


for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

axes[2].axis("off")

axes[0].set_xlim([0, 1])
axes[1].set_xlim([0, 1])
axes[5].set_ylim([0, 1])
axes[8].set_ylim([0, 1])

axes[3].set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
axes[6].set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
axes[6].set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
axes[7].set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])

labelfs = 10
axes[0].annotate(
    r"$\mathbf{a}$", [0.02, 0.7], xycoords="axes fraction", fontsize=labelfs
)
axes[1].annotate(
    r"$\mathbf{b}$", [0.02, 0.7], xycoords="axes fraction", fontsize=labelfs
)
axes[5].annotate(
    r"$\mathbf{c}$", [0.72, 0.95], xycoords="axes fraction", fontsize=labelfs
)
axes[8].annotate(
    r"$\mathbf{d}$", [0.72, 0.95], xycoords="axes fraction", fontsize=labelfs
)

axes[3].annotate(
    r"$\mathbf{ac}$", [0.02, 0.93], xycoords="axes fraction", fontsize=labelfs
)
axes[4].annotate(
    r"$\mathbf{bc}$", [0.02, 0.93], xycoords="axes fraction", fontsize=labelfs
)
axes[6].annotate(
    r"$\mathbf{ad}$", [0.02, 0.93], xycoords="axes fraction", fontsize=labelfs
)
axes[7].annotate(
    r"$\mathbf{bd}$", [0.02, 0.93], xycoords="axes fraction", fontsize=labelfs
)

fig.tight_layout()
fig.subplots_adjust(wspace=0.07, hspace=0.07)

fig.savefig("plots/Fig3.pdf", transparent=True)

In [None]:
fig, axes = plt.subplots(
    3,
    3,
    figsize=[2 * (3 + 3 / 8), 6],
    gridspec_kw={"width_ratios": [2, 2, 0.4], "height_ratios": [0.4, 2, 2]},
)
axes = axes.ravel()
bins = np.linspace(0, 1, 51)
cmp = "Blues"
comp_df = comp_df_4610

axes[0].hist(comp_df["dlpred_rptrue"], bins=bins, color="tab:blue")
axes[1].hist(comp_df["tlpred_rptrue"], bins=bins, color="tab:blue")
axes[5].hist(
    comp_df["iptrue_rptrue"], bins=bins, orientation="horizontal", color="tab:blue"
)
axes[8].hist(
    comp_df["ippred_iptrue"], bins=bins, orientation="horizontal", color="tab:blue"
)


X = "dlpred_rptrue"
Y = "iptrue_rptrue"
axes[3].hist2d(comp_df[X], comp_df[Y], bins=bins, cmap=cmp)
axes[3].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)
axes[3].set_ylabel("SC[IPA$_\mathrm{DFT}$;RPA$_\mathrm{DFT}$]")

X = "tlpred_rptrue"
Y = "iptrue_rptrue"
axes[4].hist2d(comp_df[X], comp_df[Y], bins=bins, cmap=cmp)
axes[4].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)

X = "dlpred_rptrue"
Y = "ippred_iptrue"
axes[6].hist2d(comp_df[X], comp_df[Y], bins=bins, cmap=cmp)
axes[6].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)
axes[6].set_xlabel("SC[RPA$_\mathrm{DL}$;RPA$_\mathrm{DFT}$]")
axes[6].set_ylabel("SC[IPA$_\mathrm{ML}$;IPA$_\mathrm{DFT}$]")

X = "tlpred_rptrue"
Y = "ippred_iptrue"
axes[7].hist2d(comp_df[X], comp_df[Y], bins=bins, cmap=cmp)
axes[7].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)
axes[7].set_xlabel("SC[RPA$_\mathrm{TL}$;RPA$_\mathrm{DFT}$]")


axes[0].grid(False)
axes[1].grid(False)
axes[5].grid(False)
axes[8].grid(False)


for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

axes[2].axis("off")

axes[0].set_xlim([0, 1])
axes[1].set_xlim([0, 1])
axes[5].set_ylim([0, 1])
axes[8].set_ylim([0, 1])

axes[3].set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
axes[6].set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
axes[6].set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
axes[7].set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])

labelfs = 10
axes[0].annotate(
    r"$\mathbf{a}$", [0.02, 0.7], xycoords="axes fraction", fontsize=labelfs
)
axes[1].annotate(
    r"$\mathbf{b}$", [0.02, 0.7], xycoords="axes fraction", fontsize=labelfs
)
axes[5].annotate(
    r"$\mathbf{c}$", [0.72, 0.95], xycoords="axes fraction", fontsize=labelfs
)
axes[8].annotate(
    r"$\mathbf{d}$", [0.72, 0.95], xycoords="axes fraction", fontsize=labelfs
)

axes[3].annotate(
    r"$\mathbf{ac}$", [0.02, 0.93], xycoords="axes fraction", fontsize=labelfs
)
axes[4].annotate(
    r"$\mathbf{bc}$", [0.02, 0.93], xycoords="axes fraction", fontsize=labelfs
)
axes[6].annotate(
    r"$\mathbf{ad}$", [0.02, 0.93], xycoords="axes fraction", fontsize=labelfs
)
axes[7].annotate(
    r"$\mathbf{bd}$", [0.02, 0.93], xycoords="axes fraction", fontsize=labelfs
)

fig.tight_layout()
fig.subplots_adjust(wspace=0.07, hspace=0.07)

fig.savefig("plots/SFig1.pdf", transparent=True)

In [None]:
X = "dlpred_rptrue"
Y = "tlpred_rptrue"

cmp = "Blues"
bins = np.linspace(0, 1, 51)
fig, axes = plt.subplots(1, 2, figsize=[2 * (3 + 3 / 8), 3])
axes = axes.ravel()

axes[0].hist2d(comp_df_300[X], comp_df_300[Y], bins=bins, cmap=cmp)
axes[0].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)
axes[0].set_xlabel("SC[RPA$_\mathrm{DL}$;RPA$_\mathrm{DFT}$]")
axes[0].set_ylabel("SC[RPA$_\mathrm{TL}$;RPA$_\mathrm{DFT}$]")

axes[1].hist2d(comp_df_4610[X], comp_df_4610[Y], bins=bins, cmap=cmp)
axes[1].plot([0, 1], [0, 1], color="tab:orange", ls="--", alpha=0.7)
axes[1].set_xlabel("SC[RPA$_\mathrm{DL}$;RPA$_\mathrm{DFT}$]")
axes[1].set_ylabel("SC[RPA$_\mathrm{TL}$;RPA$_\mathrm{DFT}$]")


plt.tight_layout()
fig.savefig("plots/Fig4.pdf")

In [None]:
true = []
pred_tl_300 = []
pred_tl_4610 = []
pred_dl_300 = []
pred_dl_4610 = []

for graph in test:
    true.append(graph.y.cpu().detach().numpy())
    pred_tl_300.append(model_tl_300(graph.cuda()).cpu().detach().numpy())
    pred_tl_4610.append(model_tl_4610(graph.cuda()).cpu().detach().numpy())
    pred_dl_300.append(model_dl_300(graph.cuda()).cpu().detach().numpy())
    pred_dl_4610.append(model_dl_4610(graph.cuda()).cpu().detach().numpy())

In [None]:
fig, axes = plt.subplots(2, 2, figsize=[2 * (3 + 3 / 8), 3 * 2])
axes[0, 0].scatter(
    np.trapz(true, dx=0.01, axis=1),
    np.trapz(pred_tl_300, dx=0.01, axis=1),
    s=0.5,
    color="tab:blue",
)
axes[0, 0].set_title("TL: 300")
axes[0, 1].scatter(
    np.trapz(true, dx=0.01, axis=1),
    np.trapz(pred_dl_300, dx=0.01, axis=1),
    s=0.5,
    color="tab:blue",
)
axes[0, 1].set_title("DL: 300")
axes[1, 0].scatter(
    np.trapz(true, dx=0.01, axis=1),
    np.trapz(pred_tl_4610, dx=0.01, axis=1),
    s=0.5,
    color="tab:blue",
)
axes[1, 0].set_title("TL: 4610")
axes[1, 1].scatter(
    np.trapz(true, dx=0.01, axis=1),
    np.trapz(pred_dl_4610, dx=0.01, axis=1),
    s=0.5,
    color="tab:blue",
)
axes[1, 1].set_title("DL: 4610")

for ax in axes.ravel():
    ax.plot([0, 100], [0, 100], color="tab:orange")
    ax.set_xlim([0, 160])
    ax.set_ylim([0, 100])
    ax.set_xlabel(r"QW$_\mathrm{DFT}$")
    ax.set_ylabel(r"QW$_\mathrm{ML}$")
fig.tight_layout()
fig.savefig("plots/SFig4.png")

## Check the runtime of a model on your machine

In [None]:
from pymatgen.core import Structure
from torch_geometric.data import Data

model_dl_4610.eval()

In [None]:
%%timeit
# Load the cif
structure = Structure.from_file("otherstructs/agm004850436.cif")

In [None]:
structure = Structure.from_file("otherstructs/agm004850436.cif")

In [None]:
%%timeit
# Create the graph
nbr_fea_idx = []
nbr_fea = []

self_fea_idx = []

all_nbrs = structure.get_all_neighbors(5)
for site, nbr in enumerate(all_nbrs):
    nbr_fea_idx_sub, nbr_fea_sub, self_fea_idx_sub = [], [], []

    for n in range(len(nbr)):
        self_fea_idx_sub.append(site)

    for j in range(len(nbr)):
        nbr_fea_idx_sub.append(nbr[j][2])

    for j in range(len(nbr)):
        nbr_fea_sub.append(nbr[j][1])

    nbr_fea_idx.append(nbr_fea_idx_sub)
    nbr_fea.append(nbr_fea_sub)

    self_fea_idx.append(self_fea_idx_sub)

edges = torch.stack(
    (
        torch.tensor(
            [item for items in self_fea_idx for item in items], dtype=torch.long
        ),
        torch.tensor(
            [item for items in nbr_fea_idx for item in items], dtype=torch.long
        ),
    )
)
nbr_fea = [item for items in nbr_fea for item in items]
x_vals = np.linspace(0, 5, 51)
edge_attr = np.sqrt(10 / np.pi) * np.array(
    [np.exp(-10 * (nbr_fea - val) ** 2) for val in x_vals]
)
edge_attr = torch.tensor(edge_attr, dtype=torch.float)
edge_attr = torch.transpose(edge_attr, 0, 1)

atoms = np.array(range(len(all_nbrs)))
self_fea = []

for atom_id in atoms:
    # encode atom information by group and row
    group = torch.tensor(
        structure.species[atom_id].group - 1, dtype=torch.int64)
    if group > 1:
        group -= 10
    row = torch.tensor(
        structure.species[atom_id].row - 1, dtype=torch.int64)
    group = torch.nn.functional.one_hot(group, num_classes=8)
    row = torch.nn.functional.one_hot(row, num_classes=5)
    self_fea.append(torch.hstack([group, row]))

self_fea = torch.vstack(self_fea)
nbr_fea = torch.tensor(nbr_fea, dtype=torch.float)
edge_attr = edge_attr

graph = Data(
    x=self_fea.to(torch.float32),
    edge_index=edges,
    edge_attr=edge_attr,
    )

In [None]:
# Create the graph
nbr_fea_idx = []
nbr_fea = []

self_fea_idx = []

all_nbrs = structure.get_all_neighbors(5)
for site, nbr in enumerate(all_nbrs):
    nbr_fea_idx_sub, nbr_fea_sub, self_fea_idx_sub = [], [], []

    for n in range(len(nbr)):
        self_fea_idx_sub.append(site)

    for j in range(len(nbr)):
        nbr_fea_idx_sub.append(nbr[j][2])

    for j in range(len(nbr)):
        nbr_fea_sub.append(nbr[j][1])

    nbr_fea_idx.append(nbr_fea_idx_sub)
    nbr_fea.append(nbr_fea_sub)

    self_fea_idx.append(self_fea_idx_sub)

edges = torch.stack(
    (
        torch.tensor(
            [item for items in self_fea_idx for item in items], dtype=torch.long
        ),
        torch.tensor(
            [item for items in nbr_fea_idx for item in items], dtype=torch.long
        ),
    )
)
nbr_fea = [item for items in nbr_fea for item in items]
x_vals = np.linspace(0, 5, 51)
edge_attr = np.sqrt(10 / np.pi) * np.array(
    [np.exp(-10 * (nbr_fea - val) ** 2) for val in x_vals]
)
edge_attr = torch.tensor(edge_attr, dtype=torch.float)
edge_attr = torch.transpose(edge_attr, 0, 1)

atoms = np.array(range(len(all_nbrs)))
self_fea = []

for atom_id in atoms:
    # encode atom information by group and row
    group = torch.tensor(structure.species[atom_id].group - 1, dtype=torch.int64)
    if group > 1:
        group -= 10
    row = torch.tensor(structure.species[atom_id].row - 1, dtype=torch.int64)
    group = torch.nn.functional.one_hot(group, num_classes=8)
    row = torch.nn.functional.one_hot(row, num_classes=5)
    self_fea.append(torch.hstack([group, row]))

self_fea = torch.vstack(self_fea)
nbr_fea = torch.tensor(nbr_fea, dtype=torch.float)
edge_attr = edge_attr

graph = Data(
    x=self_fea.to(torch.float32),
    edge_index=edges,
    edge_attr=edge_attr,
)

In [None]:
%%timeit
model_dl_4610(graph.cuda())

## Learn the SC between IPA and RPA

In [None]:
# Modify the graphs to have SC[RPA;IPA] as target
with open("database/graphs_300_rpa.pckl", "rb") as graph_file:
    graphs = pickle.load(graph_file)
data_df = pd.read_pickle("database/data_300_rpa.pckl")

for graph in graphs:
    graph.y = torch.tensor(
        1 - np.trapz(np.abs(graph.y - graph.ipa)) / np.trapz(abs(graph.ipa))
    )

train, val, test = utils.train_val_test(graphs)
random.seed(seed)

In [None]:
# create the SC model
model_sc = models.LitGatNN(lr=1e-3, decay=0, params=params_ipa)

out = torch_geometric.nn.Sequential(
    "x",
    [
        (
            torch_geometric.nn.MLP(
                [params_ipa[-2][0] * params_ipa[-2][1]] + params_ipa[-1] + [1],
                act="relu",
            ),
            "x -> x",
        )
    ],
)
model_sc.gatnn.mlp1 = out

In [None]:
# train a SC model
# as before, you don't need to run this cell
# as a trained state_dict is already provided



checkpoint_callback = L.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="valid",
    mode="min",
    dirpath="ckpts",

    filename="{epoch:02d}-{valid:.2f}",
)



TrainLoader = torch_geometric.loader.DataLoader(
    train, batch_size=64, shuffle=True, drop_last=True
)



ValLoader = torch_geometric.loader.DataLoader(val, batch_size=64, shuffle=False)


TestLoader = torch_geometric.loader.DataLoader(val, batch_size=1, shuffle=False)



trainer = L.Trainer(
    max_epochs=200,
    check_val_every_n_epoch=20,
    precision="bf16-mixed",
    callbacks=[checkpoint_callback],
)



trainer.fit(model=model_sc, train_dataloaders=TrainLoader, val_dataloaders=ValLoader)


checkpoint = torch.load(checkpoint_callback.best_model_path)


model_sc.load_state_dict(checkpoint["state_dict"])


# torch.save(model_sc.state_dict(), f"trained_models/dl_sc.pt")

In [None]:
# load the sc model
state_dict = torch.load("trained_models/dl_sc.pt")
model_sc.load_state_dict(state_dict)

In [None]:
# evaluate the model
model_sc.cuda()
model_sc.eval()
real = []
preds = []
for graph in test:
    graph.cuda()
    real.append(graph.y.cpu().detach().numpy())
    preds.append(model_sc(graph).cpu().detach().numpy()[0])
real = np.array(real)
preds = np.array(preds)

In [None]:
# Figure 5
fig = plt.Figure(figsize=[3 + 3 / 8, 3 + 3 / 8])
plt.hist2d(
    np.array(real),
    np.array(preds).flatten(),
    bins=np.linspace(0.18, 1, 50),
    cmap="Blues",
)
plt.plot([0.0, 1], [0.0, 1], c="tab:orange", linestyle="--", linewidth=1)
plt.xlim([0.18, 1])
plt.ylim([0.18, 1])
plt.xlabel("SC[RPA$_\mathrm{DFT}$;IPA$_\mathrm{DFT}$]")
plt.ylabel("Model Prediction")
plt.savefig("plots/Fig5_revised.pdf")

In [None]:
print("The mean absolute error is: " + str(np.mean(np.abs(real - preds))))
print("The median absolute error is: " + str(np.median(np.abs(real - preds))))

## Check for any correlations between SC[IPA,RPA] and materials

In [None]:
# train a sc-model on all graphs
# this helps with properly placing the materials in the validation/test set
# as before, you don't need to run this cell
# as a trained state_dict is already provided

checkpoint_callback = L.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="valid",
    mode="min",
    dirpath="ckpts",
    filename="{epoch:02d}-{valid:.2f}",
)

TrainLoader = torch_geometric.loader.DataLoader(
    graphs, batch_size=64, shuffle=True, drop_last=True
)


trainer = L.Trainer(
    max_epochs=200,
    check_val_every_n_epoch=20,
    precision="bf16-mixed",
    callbacks=[checkpoint_callback],
)
trainer.fit(model=model_sc, train_dataloaders=TrainLoader, val_dataloaders=ValLoader)
checkpoint = torch.load(checkpoint_callback.best_model_path)
model_sc.load_state_dict(checkpoint["state_dict"])
# torch.save(model_sc.state_dict(), f"trained_models/dl_sc_all.pt")

In [None]:
# load the sc model
state_dict = torch.load("trained_models/dl_sc_all.pt")
model_sc.load_state_dict(state_dict)

In [None]:
# Generate latent embeddings and other stuff

# As the model is very slightly non-deterministic,
# the latents have been pre-generated and the save command is commented out
lats = []
true_sc = []
formula = []
model_sc.eval()
for graph in graphs:
    model_sc.cuda()
    graph.cuda()
    lat = model_sc.gatnn.latents(graph)
    lats.append(lat.cpu().detach().numpy())
    true_sc.append(graph.y.cpu().detach().numpy())
    formula.append(data_df[data_df["mat_id"] == graph.mat_id]["formula"].values[0])
lats = np.stack(lats).squeeze(1)
true_sc = np.stack(true_sc)
# np.save("sc_latents.npy",lats)

In [None]:
# generate the UMAP
import umap

lats = np.load("sc_latents.npy")

reducer = umap.UMAP(n_neighbors=30, min_dist=0.1, random_state=1)
reducer.fit(lats)
embeds = reducer.embedding_

In [None]:
fig = plt.figure(figsize=[2 * (3 + 3 / 8), 3 + 3 / 8])
plt.scatter(
    embeds[:, 0], embeds[:, 1], c=true_sc, s=np.exp(5 * (1 - true_sc)), cmap="jet"
)
plt.xlabel("UMAP component 1")
plt.ylabel("UMAP component 2")
cbar = plt.colorbar()
cbar.set_label("SC[RPA,IPA]", rotation=270, labelpad=15)
plt.savefig("plots/Fig6.png")

### Interactive UMAP plot

In [None]:
from bokeh.plotting import figure, show, output_notebook
from bokeh.palettes import Spectral10
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from io import BytesIO
from PIL import Image
import base64

try:
    from itertools import izip
except ImportError:  # Python 3
    izip = zip
from PIL import Image, ImageDraw


def embeddable_image2(y):
    scale = 128
    x = np.linspace(0, scale, 2001)
    y = 0.9 * scale * (1 - y / np.max(y))
    im = Image.new("RGB", (scale, scale), (255, 255, 255))
    draw = ImageDraw.Draw(im)
    draw.line(list(izip(x, y[0])), fill=(31, 119, 180), width=2)
    draw.line(list(izip(x, y[1])), fill=(255, 127, 14), width=2)
    buffer = BytesIO()
    im.save(buffer, format="png")
    for_encoding = buffer.getvalue()
    return "data:image/png;base64," + base64.b64encode(for_encoding).decode()


def load_image(mat_id):
    image = Image.open("images/" + mat_id + ".png")
    buffer = BytesIO()
    image.thumbnail(size=(128, 128))
    image.save(buffer, format="png")
    for_encoding = buffer.getvalue()
    return "data:image/png;base64," + base64.b64encode(for_encoding).decode()


seed = 0


# Collect everything into a dataframe for plotting
plot_df = pd.DataFrame(embeds, columns=("x", "y"))
plot_df["formula"] = formula
plot_df["sc"] = true_sc

sust = []
sust_eu = []
marker = []
size = []
spgs = []
mat_ids = []
gaps = []
peak_pos = []
peak_hi = []
elem_in = []
eps_0 = []
qw = []
for graph in graphs:
    entry = data_df[data_df["mat_id"] == graph.mat_id]
    marker.append("circle")
    size.append(6)
    mat_ids.append(graph.mat_id)


plot_df["markers"] = marker
plot_df["sizes"] = size
plot_df["ids"] = mat_ids
plot_df["crystal"] = list(map(load_image, mat_ids))

In [None]:
from bokeh.plotting import output_file, save
import bokeh
from bokeh.palettes import Cividis3, Viridis256
from bokeh.transform import linear_cmap

output_notebook()

datasource = ColumnDataSource(plot_df)

colormapper_sc = linear_cmap(
    field_name="sc",
    palette=bokeh.palettes.Turbo256,
    low=min(true_sc),
    high=max(true_sc),
)

# set width and height to fixed values to visualize the map in the notebook
# use stretch_both to export the map for viewing in e.g. a browser
plot_figure = figure(
    title="UMAP projection",
    # width=2000,
    # height=1000,
    sizing_mode="stretch_both",
    tools=("pan, wheel_zoom, reset, box_zoom"),
)

plot_figure.add_tools(
    HoverTool(
        tooltips="""
    <figure>
        <figcaption style='font-size: 18px; text-align: center;'> 
            Formula: @formula <br> 
            ID: @ids 
        </figcaption>
        <img src='@crystal' style='display: block; margin: 5px auto'/>
    </figure>
"""
    )
)


# Place the name of each material
plot_figure.text(
    "x",
    "y",
    text="formula",
    text_font_size="20pt",
    text_align="center",
    text_baseline="middle",
    source=datasource,
    x_offset=0,
    y_offset=0,
    color=colormapper_sc,
)

show(plot_figure)

# Save the image as HTML
output_file("interactive_plot_sc.html")
save(plot_figure)