# Импорты

In [None]:
import pandas as pd
import pymatgen
from pathlib import Path
from sklearn.model_selection import train_test_split
import json
from pymatgen.core import Structure
import tensorflow as tf
import numpy as np

# Определение функций

In [None]:
def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)

In [None]:
def prepare_dataset(dataset_path):
    dataset_path = Path(dataset_path)
    targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
    struct = {
        item.name.strip(".json"): np.concatenate(read_pymatgen_dict(item).frac_coords)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    data = data.assign(structures=struct.values(), targets=targets)

    return train_test_split(data, test_size=0.2, random_state=42)

In [None]:
from tensorflow.keras.preprocessing.sequence import pad_sequences
def pad_func(dataset):
    max_len = 0
    mean_list = []
    for i in dataset:
        mean_list.append(np.mean(i))
        if len(i)>max_len:
            max_len = len(i)
    return pad_sequences(dataset, value=np.mean(mean_list, dtype=np.float64), dtype=np.float64, padding='post', truncating='post', maxlen=max_len)

In [None]:
def energy_within_threshold(prediction, target):
    # compute absolute error on energy per system.
    # then count the no. of systems where max energy error is < 0.02.
    e_thresh = 0.02
    error_energy = tf.math.abs(target - prediction)

    success = tf.math.count_nonzero(error_energy < e_thresh)
    total = tf.size(target)
    return float(success / tf.cast(total, tf.int64))

# Подготовка датасета

In [None]:
train, test = prepare_dataset('dichalcogenides_public')
train_y = train.targets
train_x = pad_func(train.structures)
test_y = test.targets
test_x = pad_func(test.structures)