In [None]:
import datetime

import numpy as np
import pandas as pd
import hydra
import matplotlib.pyplot as plt
from matplotlib import ticker

import e3psi
import hubbardml
from hubbardml import datasets
from hubbardml import graphs
from hubbardml import keys
from hubbardml import plots

In [None]:
graph_data = None
with hydra.initialize(version_base="1.3", config_path="."):
    cfg = hydra.compose(config_name="config.yaml", overrides=["model=u"])
    graph = hydra.utils.instantiate(cfg["graph"])
    graph_data = hubbardml.GraphData(graph, cfg["dataset"])

In [None]:
df = graph_data.dataset 

In [None]:
fig = plots.plot_param_histogram(
    df,
    bins=30,
    density=False
)
fig.set_size_inches(5, 2)
# path = EXPERIMENT_DIR / "plots" / "param_histogram.pdf"
# fig.savefig(path, bbox_inches='tight')
# path.absolute()

In [None]:

df[df['dir'].str.contains('MnO2') & (df[keys.ATOM_1_ELEMENT] == "Fe")]['dir'].unique()

In [None]:
graph_data.dataset['dir'].unique()

In [None]:
print(graph_data.dataset[keys.SPECIES].unique())
df = graph_data.identify_duplicates(
    graph_data.dataset, group_by=[keys.SPECIES]
)
# df[df[keys.TRAINING_LABEL] != keys.DUPLICATE][keys.SPECIES].value_counts()

In [None]:
species_counts = df.loc[df[keys.TRAINING_LABEL] != keys.DUPLICATE, keys.SPECIES].value_counts()
species_counts

In [None]:
df[keys.DIST_IN].hist()

In [None]:
similarities = pd.concat([entry[2] for entry in graph_data.get_similarity_frames(group_by=['species'])])

In [None]:
ax = similarities["occs_sum"].hist(bins=40000, log=False)
ax.get_figure().set_size_inches(12, 1.5)
# ax.set_xlim(left=0.)
ax.set_xlim([0, 5e-3])
ax.axvline(hubbardml.graphs.DEFAULT_OCCS_TOL)
ax.set_xlabel("Rotationally invariant distance")
ax.set_ylabel("Occurrences")
plt.savefig('img/occs_sum_distances.pdf', bbox_inches='tight')

In [None]:
ax.axvline(hubbardml.graphs.DEFAULT_OCCS_TOL)
ax

In [None]:
ax = similarities["occs_prod"].hist(bins=40000, log=False)
ax.get_figure().set_size_inches(12, 1.5)
ax.set_xlim([0, 5e-3])
ax.axvline(hubbardml.graphs.DEFAULT_OCCS_TOL)
ax.set_xlabel("Rotationally invariant distance")
ax.set_ylabel("Occurrences")
plt.savefig('img/occs_prod_distances.pdf', bbox_inches='tight')

In [None]:

def plot_param_histogram(
    df: pd.DataFrame,
    x_label: str = "Hubbard param. (eV)",
    y_label: str = "Frequency",
    title: str = None,
    param_col: str = keys.PARAM_OUT,
    bins=20,
) -> plt.Figure:
    # Plot the histogram
    fig, ax = plt.subplots()
    fig.suptitle(title)

    kwargs = dict(
        histtype='stepfilled',
        alpha=0.8,
        density=True,
        bins=bins,
        ec="k",
        stacked=True,
        # log=True,
    )
    for species, frame in df.groupby(keys.SPECIES):
        ax.hist(
            frame[param_col],
            color=frame.iloc[0][keys.COLOUR],
            label = "-".join(species),
            **kwargs,
        )

    if x_label:
        ax.set_xlabel(x_label)
    if y_label:
        ax.set_ylabel(y_label)
    fig.legend()
    return fig


In [None]:
last_iter_subframes = []
for path, sc_rows in datasets.iter_self_consistent_paths(df):
    # Get the maximum iteration reached
    max_iter = sc_rows[keys.UV_ITER].max()
    # Get the rows containing the last iteration
    max_iter_rows = sc_rows[sc_rows[keys.UV_ITER] == max_iter]
    last_iter_subframes.append(max_iter_rows)

last_iter_frame = pd.concat(last_iter_subframes)
last_iter_frame = last_iter_frame.drop_duplicates(hubbardml.similarities.CLUSTER_ID)

# Plot the histogram
plot_param_histogram(last_iter_frame, bins=20);

In [None]:
mn_olivines = df[df[keys.DIR].str.contains("Li1.00MnPO4")]
mn_olivines = mn_olivines[mn_olivines[keys.ATOM_1_IDX] == 1]

fig, ax = plt.subplots(figsize=(4, 3))
# ax = plt.gca()
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
ax.set_ylabel("Hubbard $U$ (eV)")
ax.set_xlabel("Iteration")
plots.plot_series(
    ax,
    mn_olivines[keys.UV_ITER], 
    mn_olivines[keys.PARAM_OUT],
    plots.plot_colours[1],
    label="test",
)
fig = ax.get_figure()
fig.savefig("self_consistent.pdf", bbox_inches="tight")


In [None]:
plot_param_histogram(df)

In [None]:
import torch
from e3nn import o3
import matplotlib.pyplot as plt

change_of_coord = torch.tensor([
    # this specifies the change of basis zxy -> xyz
    [0., -1., 0.],
    [0., 0., -1.],
    [1., 0., 0.]
])

D = o3.Irrep(2, 1).D_from_matrix(change_of_coord)

plt.imshow(D, cmap="RdBu", vmin=-1, vmax=1)
plt.colorbar();

In [None]:
occu1 = torch.tensor(
    [
       [0.018,  0.002,  0.000,  0.005,  0.000],
       [0.002,  0.053, -0.000, -0.026,  0.000],
       [0.000, -0.000,  0.048,  0.000,  0.024],
       [0.005, -0.026,  0.000,  0.050, -0.000],
       [0.000,  0.000,  0.024, -0.000,  0.040],
])
plt.imshow(occu1, cmap="RdBu", vmin=-0.05, vmax=.05)
plt.colorbar();

In [None]:
plt.imshow(D.T @ occu1 @ D, cmap="RdBu", vmin=-0.05, vmax=.05)
plt.colorbar();