In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import re
import random
from pathlib import Path
import time
import dotenv
import warnings
import datetime

from tqdm import tqdm
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from ase.visualize import view

warnings.filterwarnings("ignore")

In [3]:
dir_mp = Path(".")

# 1. Materials Project

Save your materials project API key in a .env file in the same directory as this notebook. The file should look like this:

```
MP_API_KEY=your_api_key
```

In [None]:
dotenv.load_dotenv()
MP_API_KEY = os.getenv("MP_API_KEY")

### 1.1 Retrieving created_at from the Materials Project API

`created_at` is only available in the `mpr.materials.search`, so we will use this function to retrieve the registration date of the materials.

In [None]:
with MPRester(MP_API_KEY) as mpr:
    total_docs = mpr.materials.search(
        num_sites=[0, 40],
        fields=[
            "material_id",
            "created_at",
        ],
    )

In [None]:
data = []
for doc in total_docs:
    data.append(
        {
            "material_id": doc.material_id,
            "created_at": doc.created_at,
        }
    )
df_mp_created_at = pd.DataFrame(data)
# remove duplicates
df_mp_created_at = df_mp_created_at.drop_duplicates(subset=["material_id"])
# save to csv
df_mp_created_at.to_csv(dir_mp / "mp-created-at.csv", index=False)

### 1.2. Donwload snapshot with constraints of num_sites <= 40 and energy convex hull <= 0.25 eV and experimental = True

In [None]:
with MPRester(MP_API_KEY) as mpr:
    docs = mpr.summary.search(
        num_sites=[0, 40],
        energy_above_hull=[0, 0.25],
        theoretical=False,
        fields=[
            "material_id",
            "structure",
            "energy_above_hull",
            "band_gap",
            "theoretical",
        ],
    )

In [None]:
excluded_gas_list = [
    "H",
    "He",
    "N",
    "O",
    "F",
    "Ne",
    "Cl",
    "Ar",
    "Kr",
    "Xe",
    "Rn",
    "Fr",
    "Og",
]

In [None]:
data = []
for doc in tqdm(docs):
    st = doc.structure
    elements = [elmt.symbol for elmt in st.composition.elements]

    if len(elements) == 1 and elements[0] in excluded_gas_list:
        print(elements)
        continue

    if max(st.lattice.abc) > 20:
        print(st.formula, st.lattice.abc)
        continue

    row = {
        "material_id": doc.material_id,
        "energy_above_hull": doc.energy_above_hull,
        "band_gap": doc.band_gap,
        "cif": st.to(fmt="cif"),
    }
    data.append(row)

df_mp_api = pd.DataFrame(data)
# remove duplicates
df_mp_api = df_mp_api.drop_duplicates(subset="material_id")
# shuffle
df_mp_api = df_mp_api.sample(frac=1, random_state=42).reset_index()
# save to csv
df_mp_api.to_csv(dir_mp / "mp-api.csv", index=False)

In [None]:
# calculate properties
from pandarallel import pandarallel

pandarallel.initialize(progress_bar=True)


def calculate_property(data):
    st = Structure.from_str(data.cif, fmt="cif")
    sg = SpacegroupAnalyzer(st, symprec=0.1)
    data["composition"] = st.composition.reduced_composition.alphabetical_formula
    data["volume"] = st.volume
    data["density"] = st.density
    data["atomic_density"] = st.density
    data["crystal_system"] = sg.get_crystal_system()
    data["space_group_symbol"] = sg.get_space_group_symbol()
    data["space_group_number"] = sg.get_space_group_number()
    return data


df_mp_api = pd.read_csv(dir_mp / "mp-api.csv")
df_mp_total = df_mp_api.parallel_apply(calculate_property, axis=1)
df_mp_total.to_csv(dir_mp / "mp-total.csv", index=False)

### 1.3. Make test set registered after 

In [None]:
# merge created_at
df_mp_created_at = pd.read_csv(dir_mp / "mp-created-at.csv")
df_mp_total = pd.read_csv(dir_mp / "mp-total.csv")
df_mp_total = pd.merge(df_mp_total, df_mp_created_at, on="material_id")
print(len(df_mp_total))

In [None]:
# plot according to the year
plt.rcParams["font.size"] = 25
df_mp_total["created_at"].apply(lambda x: int(x[:4])).value_counts().sort_index().plot(
    kind="bar",
    color="skyblue",
    title="Materials Project API",
    figsize=(12, 6),
    xlabel="Year",
    ylabel="Number of Entries",
)

In [None]:
def convert_to_datetime(date_str):
    date_str = date_str.split(".")[0]  # Split by "." and take the first part
    return datetime.datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S")


df_mp_total["created_at_datetime"] = df_mp_total["created_at"].apply(
    convert_to_datetime
)

cutoff_date = pd.to_datetime("2018-08-04")

In [None]:
# train val
df_train_val = df_mp_total[df_mp_total["created_at_datetime"] < cutoff_date]
num_val = int(len(df_train_val) * 0.1)
df_train = df_train_val.iloc[:-num_val]
df_val = df_train_val.iloc[-num_val:]
# test
df_test = df_mp_total[df_mp_total["created_at_datetime"] >= cutoff_date]
print(len(df_train), len(df_val), len(df_test))

In [None]:
# save
df_train.to_csv(dir_mp / "train.csv", index=False)
df_val.to_csv(dir_mp / "val.csv", index=False)
df_test.to_csv(dir_mp / "test.csv", index=False)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(24, 12))
df_train["crystal_system"].value_counts().sort_index().plot(
    kind="bar", color=sns.color_palette("pastel")[0], ax=axes[0], title="Train"
)
df_test["crystal_system"].value_counts().sort_index().plot(
    kind="bar", color=sns.color_palette("pastel")[1], ax=axes[1], title="Test"
)

# 2. Text Prompts

In [None]:
# ! generate_text_prompt.py

In [None]:
path_prompts = Path("../mp-50/prompts/")  # TODO: change the path
text_files = list(path_prompts.glob("*.txt"))

# read and make df
prompts = {}
for text_file in text_files:
    material_id = text_file.stem
    with open(text_file, "r") as f:
        text = f.read()
        revised_text = re.sub(r"\d+\.\s", "", text)
        text_prompts = revised_text.split("\n")
        prompt = random.choice(text_prompts)  # select one prompt randomly
        prompts[material_id] = prompt

df_prompts = pd.DataFrame(prompts.items(), columns=["material_id", "prompt"])

In [None]:
# update trian, test, val
for split in ["train", "val", "test"]:
    df = pd.read_csv(dir_mp / f"{split}.csv")
    df = pd.merge(df, df_prompts, on="material_id")
    df.to_csv(dir_mp / f"{split}.csv", index=False)

# Info lattice matrix

In [97]:
df_train = pd.read_csv(dir_mp / "train.csv")
st_list = [Structure.from_str(cif, fmt="cif") for cif in df_train["cif"]]

In [101]:
lattice_params = np.array([st.lattice.parameters for st in st_list])
lattice_params_mean = lattice_params.mean(axis=0)
lattice_params_std = lattice_params.std(axis=0)
print(lattice_params_mean, lattice_params_std)

In [105]:
# write
lattice_params_mean = lattice_params_mean.tolist()
lattice_params_std = lattice_params_std.tolist()
with open(dir_mp / "lattice_params.txt", "w") as f:
    f.write(f"mean: {lattice_params_mean}\n")
    f.write(f"std: {lattice_params_std}\n")

### random split dataset (not time based)

In [None]:
df_train = pd.read_csv(dir_mp / "train.csv")
df_val = pd.read_csv(dir_mp / "val.csv")
df_test = pd.read_csv(dir_mp / "test.csv")
print(len(df_train), len(df_val), len(df_test))

In [92]:
df_total = pd.concat([df_train, df_val, df_test])
# remove index column
df_total = df_total.drop(columns=["index"])
# random shuffle
df_total = df_total.sample(frac=1, random_state=42).reset_index(drop=True)
# new_split
num_train = len(df_train)
num_val = len(df_val)

new_train = df_total.iloc[:num_train]
new_val = df_total.iloc[num_train : num_train + num_val]
new_test = df_total.iloc[num_train + num_val :]

In [93]:
save_dir = dir_mp / "random_split"
save_dir.mkdir(exist_ok=True)
new_train.to_csv(save_dir / "train.csv", index=False)
new_val.to_csv(save_dir / "val.csv", index=False)
new_test.to_csv(save_dir / "test.csv", index=False)

### mineral dataset

In [None]:
df_train = pd.read_csv(dir_mp / "train.csv", index_col=0)
df_val = pd.read_csv(dir_mp / "val.csv", index_col=0)
df_test = pd.read_csv(dir_mp / "test.csv", index_col=0)
print(len(df_train), len(df_val), len(df_test))
df_total = pd.concat([df_train, df_val, df_test])
print(len(df_total))

In [None]:
df_mineral = pd.read_csv(dir_mp / "mineral/mineral.csv")[["material_id", "mineral"]]
print(len(df_mineral))
# Remove Nan values in mineral column
df_mineral = df_mineral.dropna(subset=["mineral"])
print(len(df_mineral))
# Only keep when the occurence is more than 10
df_mineral = df_mineral.groupby("mineral").filter(lambda x: len(x) > 50)
print(len(df_mineral))

In [None]:
df_total = pd.merge(df_total, df_mineral, on="material_id")
print(len(df_total))

In [None]:
df_total["mineral"].value_counts().plot(kind="bar", figsize=(12, 6))
plt.title("Mineral Distribution (occurence > 40)")

In [None]:
# Split: train: 6,000 | val = test = total - train
df_total = df_total.sample(frac=1, random_state=42).reset_index(drop=True)
num_train = 6000
df_train = df_total.iloc[:num_train]
df_val = df_total.iloc[num_train:]
df_test = df_val.copy()
print(len(df_train), len(df_val), len(df_test))

In [None]:
df_train.to_csv(dir_mp / "mineral/train.csv", index=False)
df_val.to_csv(dir_mp / "mineral/val.csv", index=False)
df_test.to_csv(dir_mp / "mineral/test.csv", index=False)