In [None]:
import random

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch

import hubbardml
from hubbardml import models
from hubbardml import datasets
from hubbardml import keys

random.seed(0xDEADBEEF)
torch.manual_seed(0xDEADBEEF)


dtype = torch.float32
torch.set_default_dtype(dtype)
device = "cuda" if torch.cuda.is_available() else "cpu"
device, torch.get_default_dtype()

In [None]:
TEST_PERCENTAGE = 0.2

P_ELEMENT = 'p_element'
D_ELEMENT = 'd_element'
SAVEFIGS = False

TARGET_PARAM_TYPE = 'V'
# DATASET = 'data_uv_unique_inout_2022_10_13.json'

# DATASET = 'data_uv_unique_inout_2022_10_13.json'
DATASET = 'data_uv_unique_inout_2023_2_8.json'

def plotfile(label: str):
    return f'plots/{DATASET}_{label}.pdf'

# Inputs

Load the dataset to be used

In [None]:
df = hubbardml.datasets.load(f'../data/{DATASET}')

## Filtering

Filter the DF to keep things we want

In [None]:
df = hubbardml.VGraph.prepare_dataset(df)
df = df[df[keys.PARAM_IN] > 0.5]
print(len(df))

In [None]:
df[df[keys.PARAM_IN] == df.param_out]

## Model creation

Find out which species we have to support

In [None]:
species = list(pd.concat((df[keys.ATOM_1_ELEMENT], df[keys.ATOM_2_ELEMENT])).unique())
print(f'Found species {species}')

In [None]:
df[D_ELEMENT].value_counts()

## Model

In [None]:
graph = hubbardml.VGraph(species)
model = hubbardml.VModel(
    graph,
    feature_irreps="4x0e + 4x1e + 4x2e",
    rescaler = hubbardml.models.Rescaler.from_data(df[keys.PARAM_OUT], method="mean"),
    hidden_layers=2,
)
model.to(dtype=dtype, device=device)

## Split test/train

In [None]:
hubbardml.datasets.split(df, method='category', frac=0.2, category=['species'])

In [None]:
# Get the indices of the training and test data
train_idx = df[df[keys.TRAINING_LABEL] == keys.TRAIN].index
validate_idx = df[df[keys.TRAINING_LABEL] == keys.VALIDATE].index

print(df.loc[train_idx]['species'].value_counts())
print(df.loc[validate_idx]['species'].value_counts())

In [None]:
trainer = hubbardml.training.Trainer.from_frame(
    model = model,
    opt = torch.optim.Adam(model.parameters(), lr=0.001),
    loss_fn = torch.nn.MSELoss(),
    frame = df
)

In [None]:
trainer._opt = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
trainer.overfitting_window = 600

def progress(trainer):
    print(
        f"{trainer.epoch} {trainer.training.metrics['mse']:.5f} {trainer.validation.metrics['mse']:.5f} {trainer.validation.metrics['rmse']:.4f}")

trainer.train(
    callback=progress,
    callback_period=50,
    max_epochs=1_000
)

In [None]:
trainer.plot_training_curves();

In [None]:
predicted = model(trainer.validation_data.all_inputs()).detach().cpu().numpy().reshape(-1)
input_train = model(trainer.training_data.all_inputs()).detach().cpu().numpy().reshape(-1)

df.loc[validate_idx, keys.PARAM_OUT_PREDICTED] = predicted
df.loc[train_idx, keys.PARAM_OUT_PREDICTED] = input_train

In [None]:
df_test = df.loc[validate_idx]
test_rmse = hubbardml.datasets.rmse(df)

In [None]:
hubbardml.plots.create_parity_plot(
    df,
    axis_label='Hubbard V (eV)',
    title=f'RMSE = {test_rmse:.3f} ({TEST_PERCENTAGE} holdout)'
);

In [None]:
fig = hubbardml.plots.split_plot(
    df_test, D_ELEMENT,
    axis_label='Hubbard V (eV)',
    title=f'Test data ({TEST_PERCENTAGE * 100:.0f}%), RMSE = {test_rmse:.3f} eV'
);
if SAVEFIGS:
    fig.savefig(plotfile('+V_parity_species'), bbox_inches='tight')

In [None]:
df_ref = df_test.copy()
df_ref[keys.PARAM_OUT_PREDICTED] = df_ref[keys.PARAM_IN]
df_ref = df_ref[~(df_ref[keys.UV_ITER] == 1)]
fig = hubbardml.plots.split_plot(
    df_ref,
    D_ELEMENT,
    axis_label='Hubbard V (eV)',
    title=f'Baseline model, RMSE {hubbardml.datasets.rmse(df_ref):.3f} eV'
);
if SAVEFIGS:
    fig.savefig(plotfile('+V_parity_species_ref'), bbox_inches='tight')

In [None]:
dirs = set()
for directory in df['dir'].unique():
    dirs.add('/'.join(directory.split('/')[:-1]))

res = []
for directory in dirs:
    res.append(hubbardml.plots.create_progression_plots(df[df['dir'].str.startswith(directory)], yrange=0.4))

In [None]:
if SAVEFIGS:
    idx = 0
    for figgroup in res:
        for fig in figgroup.values():
            fig.savefig(f'plots/hubbard_v/steps_{idx}_+V.pdf', bbox_inches='tight')
            idx += 1
