## Software dependencies
to make sure the following code works properly, please use the software versions specified below:

In [None]:
# python==3.10.14
# jupyterlab==4.1.5
# ipykernel==6.29.3
# dill==0.3.8
# matminer==0.9.0
# scikit-learn==1.4.1.post1
# pymatgen==2023.9.25
# numpy==1.26.4
# pandas==1.5.3

In [None]:
from voltage_mining_model import VoltageMiningModel
from xgboost import XGBRegressor
import dill

### Model Demo

In [None]:
# load the voltage mining model python instance with dill
with open('data/vmm_demo.pkl', 'rb') as file:
    vmm = dill.load(file)

The following cell loads a csv file and returns a pandas dataframe with the chemical formulae and predicted voltage. \
Make sure that your csv file is formatted the same as the `demo.csv` file provided. It should have a column of indices (with no ID), a column of charged phase formula and a column of discharge phase formula.

In [None]:
# this cell loads a csv file (a demo file in this cell) and predict voltage from the provided charged & discharged phases formulae
# you may change the file path to your own csv file with chemical formulae for voltage prediction
# you may change the "output_csv" option to True to save a copy of the predictions
pred_df = vmm.pred_from_file(file_path="data/demo.csv", output_csv=False)
pred_df.head()

## Reproducing results from publication

In [None]:
# load the train and test sets from csv files
# predict voltage based on formula
# this cell could take up to minutes to run
train_pred_df = vmm.pred_from_file("data/train.csv", output_csv=False)
test_pred_df = vmm.pred_from_file("data/test.csv", output_csv=False)

In [None]:
# reproduce the performance stats from the manuscript
from sklearn.metrics import r2_score, mean_absolute_error
train_true_v, train_pred_v = train_pred_df["voltage"], train_pred_df["predicted_voltage"]
test_true_v, test_pred_v = test_pred_df["voltage"], test_pred_df["predicted_voltage"]
print(f"training set performance: r^2={r2_score(train_true_v, train_pred_v)}, mae={mean_absolute_error(train_true_v, train_pred_v)}")
print(f"test set performance: r^2={r2_score(test_true_v, test_pred_v)}, mae={mean_absolute_error(test_true_v, test_pred_v)}")