In [None]:
import pathlib

import hubbardml
from hubbardml import keys, plots, similarities
import numpy as np
import torch

hubbardml.utils.random_seed()

dtype = torch.float32
torch.set_default_dtype(dtype)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
device, torch.get_default_dtype()

In [None]:
VALIDATE_PERCENTAGE = 0.2
DATASET = "../data/data_uv_unique_inout_2023_2_8.json"
DATASET = "../data/data_uv_2023_8_2.json"

SAVEFIGS = False
TARGET_PARAM = 'U'


def plotfile(label: str):
    return f'plots/{DATASET}_{label}.pdf'

In [None]:
df = hubbardml.datasets.load(DATASET)

## Input creation

Filter the DF to keep things we want

In [None]:
species = list(df[keys.ATOM_1_ELEMENT].unique())
# species = ("Mn", "Fe", "Ni")
graph = hubbardml.graphs.UGraph(species)

df = graph.prepare_dataset(df)
print(len(df))

df = graph.identify_duplicates(
    df,
    # tolerances=dict(occs_tol=2e-4, param_tol=1e-3)
)
print(len(df[df[keys.TRAINING_LABEL] == keys.DUPLICATE]))
print(len(df[similarities.CLUSTER_ID].unique()))

In [None]:
print(f"Data splits set:\n{df[keys.TRAINING_LABEL].value_counts()}")

## Model creation

In [None]:
model = hubbardml.models.UModel(
    graph,
    feature_irreps="4x0e + 4x1e + 4x2e+ 4x3e",
    hidden_layers=2,
    rescaler=hubbardml.models.Rescaler.from_data(df[keys.PARAM_OUT], method="mean"),
    irrep_normalization="component",
)
model.to(dtype=dtype, device=device)

## Split test/train

In [None]:
# hubbardml.datasets.split(df, method='category', frac=0.2, category=["species"])
df = hubbardml.datasets.split_by_cluster(
    df, 
    frac=0.2, 
    category=["species", keys.SC_PATHS], 
    ignore_already_labelled=True
)

In [None]:
# Get the indices of the training and validation data
train_idx = df[df[keys.TRAINING_LABEL] == keys.TRAIN].index
validate_idx = df[df[keys.TRAINING_LABEL] == keys.VALIDATE].index

print(df.groupby([keys.TRAINING_LABEL, keys.ATOM_1_ELEMENT]).size())

In [None]:
trainer = hubbardml.Trainer.from_frame(
    model=model,
    opt=torch.optim.AdamW(model.parameters(), lr=0.01),
    loss_fn=torch.nn.MSELoss(),
    frame=df,
    target_column=keys.PARAM_OUT,
    batch_size=128,
)

In [None]:
trainer.overfitting_window = 400

trainer.train(
    callback=lambda trainer: print(trainer.status()),
    callback_period=50,
    max_epochs=10_000,
)

In [None]:
fig = trainer.plot_training_curves();
if SAVEFIGS:
    fig.savefig(plotfile('+U_training'), bbox_inches='tight')
fig.gca().set_xscale("log")

In [None]:
with torch.no_grad():
    train_predicted = hubbardml.engines.evaluate(trainer.best_model, trainer.train_loader).detach().cpu().numpy().reshape(-1)
    val_predicted = hubbardml.engines.evaluate(trainer.best_model, trainer.validate_loader).detach().cpu().numpy().reshape(-1)

df.loc[validate_idx, keys.PARAM_OUT_PREDICTED] = val_predicted
df.loc[train_idx, keys.PARAM_OUT_PREDICTED] = train_predicted

In [None]:
def rmse(y1, y2):
    return np.sqrt(((y1 - y2) ** 2).mean())


df_validate = df.loc[validate_idx]
validate_rmse = hubbardml.datasets.rmse(df_validate)
plots.create_parity_plot(df, title=f'RMSE = {validate_rmse:.3f} ({VALIDATE_PERCENTAGE} holdout)',
                         axis_label=f'${TARGET_PARAM}$ value (eV)');

In [None]:
for label in df[keys.ATOM_1_ELEMENT].unique():
    frame = df[df[keys.LABEL] == label]
    series = frame[keys.PARAM_OUT_PREDICTED]
    ax = series.plot.hist(
        alpha=0.6,
        label=label,
        color=frame.iloc[0][keys.COLOUR],
        density=True,
        # histtype='step',
        xlabel="Energy difference (eV)",
        # bins=50,
        # range=(-0.4, 0.4),

    )

# Training

In [None]:
df_train = df.loc[train_idx]

In [None]:
fig = plots.split_plot(df_train, keys.ATOM_1_ELEMENT,
                       axis_label='$U$ value (eV)',
                       title=f'Validation data ({VALIDATE_PERCENTAGE * 100:.0f}%), RMSE = {validate_rmse:.2f} eV');

if SAVEFIGS:
    fig.savefig(plotfile('+U_parity_species'), bbox_inches='tight')

# Validation

In [None]:
fig = plots.split_plot(df_validate, keys.ATOM_1_ELEMENT,
                       axis_label='$U$ value (eV)',
                       title=f'Validation data ({VALIDATE_PERCENTAGE * 100:.0f}%), RMSE = {validate_rmse:.2f} eV');

if SAVEFIGS:
    fig.savefig(plotfile('+U_parity_species'), bbox_inches='tight')

In [None]:
df_ref = df.copy()
df_ref[keys.PARAM_OUT_PREDICTED] = df_ref[keys.PARAM_IN]
df_ref = df_ref[~(df_ref[keys.UV_ITER] == 1)]
fig = plots.split_plot(df_ref, keys.ATOM_1_ELEMENT, axis_label=f'Hubbard {TARGET_PARAM} (eV)',
                       title=f'Baseline model, RMSE = {hubbardml.datasets.rmse(df_ref):.2f} eV');

if SAVEFIGS:
    fig.savefig(plotfile('U_parity_species_ref'), bbox_inches='tight')

In [None]:
df_ref = df.copy()
df_ref[keys.PARAM_OUT_PREDICTED] = df_ref[keys.PARAM_IN]
df_ref = df_ref[~(df_ref[keys.UV_ITER] == 1)]
fig = plots.split_plot(
    df_ref,
    keys.ATOM_1_ELEMENT,
    axis_label=f'Hubbard {TARGET_PARAM} (eV)',
    title=f'Baseline model, RMSE = {hubbardml.datasets.rmse(df_ref, label="both"):.2f} eV',
);

if SAVEFIGS:
    fig.savefig(plotfile('U_parity_species_ref'), bbox_inches='tight')

In [None]:
{str(pathlib.Path(directory).parent) for directory in df[keys.DIR].unique()}

In [None]:
max_range = df[keys.PARAM_OUT].max() - df[keys.PARAM_OUT].min()
max_range = max(max_range, (df[keys.PARAM_OUT_PREDICTED].max() - df[keys.PARAM_OUT_PREDICTED].min()))
print(max_range)

plots.create_progression_plots(
    df,
    '/home/azadoks/Projects/uv_ml/data/iurii/Olivines/LiMnPO4/B2_Li0.25MnPO4/DFT_plus_UV',
    yrange=1.2
);

In [None]:
if SAVEFIGS:
    idx = 0
    for figgroup in res:
        for fig in figgroup.values():
            fig.savefig(f'plots/hubbard_u/steps_{idx}_+U.pdf', bbox_inches='tight')
            idx += 1

In [None]:
import copy

fig, ax = copy.deepcopy(model.tp1).cpu().visualize()
# ax.get_lines()[0].set_color('black')
for patch in ax.patches:
    patch.set_color(plots.plot_colours[2])

if SAVEFIGS:
    fig.savefig('plots/hubbard_u_tp.pdf', bbox_inches='tight')

In [None]:
df[df[keys.PARAM_IN] == 0.][keys.TRAINING_LABEL]
