In [None]:
import time
import matplotlib.pyplot as plt
from ase.io import read
from sklearn.manifold import TSNE
import chemiscope
import numpy as np
import os
from dscribe.descriptors import SOAP
from mace.calculators.foundations_models import mace_mp
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

# Start timing
start_time = time.time()

# Load dataset
a = read('li_train_MatGen_gen_data.db', index=':')

# Add ID property to Atoms objects (if not already present)
for i, atoms in enumerate(a):
    atoms.info['id'] = i

def soap_tnse_with_environments(frames, environments):
    if environments is None:
        raise ValueError("'environments' must be provided")

    grouped_envs = {}
    unique_structures = set()

    # Get atom-centered indices from environments
    for [env_index, atom_index, _cutoff] in environments:
        if env_index not in grouped_envs:
            grouped_envs[env_index] = []
        grouped_envs[env_index].append(atom_index)
        unique_structures.add(env_index)
    centers = list(grouped_envs.values())

    # only include frames that are present in the environments
    if len(unique_structures) != len(frames):
        frames = [frames[index] for index in sorted(unique_structures)]

    # Get global species
    species = set()
    for frame in frames:
        species.update(frame.get_chemical_symbols())
    species = list(species)

    # Initialize calculator
    soap = SOAP(
        species=species,
        r_cut=4.5,
        n_max=8,
        l_max=6,
        sigma=0.2,
        rbf="gto",
        average="outer",
        periodic=True,
        weighting={"function": "pow", "c": 1, "m": 5, "d": 1, "r0": 3.5},
        compression={"mode": "mu1nu1"},
    )

    # Calculate descriptors
    feats = soap.create(frames, centers=centers)
    # Compute tsne
    perplexity = min(50, feats.shape[0] - 1)
    reducer = TSNE(n_components=2, perplexity=perplexity)
    return reducer.fit_transform(feats)

# Get t-SNE coordinates
tsne_coords = soap_tnse_with_environments(a, chemiscope.all_atomic_environments(a))


# Separate coordinates based on ID
x_11359_or_less = []
y_11359_or_less = []
x_greater_11359 = []
y_greater_11359 = []

for i, coords in enumerate(tsne_coords):
    if a[i].info['id'] <= 52873:
        x_11359_or_less.append(coords[0])
        y_11359_or_less.append(coords[1])
    else:
        x_greater_11359.append(coords[0])
        y_greater_11359.append(coords[1])

# Create scatter plot with conditional coloring
plt.figure(figsize=(10, 8))
plt.scatter(x_11359_or_less, y_11359_or_less, c='blue', label='Li_MP_data')
plt.scatter(x_greater_11359, y_greater_11359, c='red', label='DFT_data')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.title('t-SNE Visualization Colored by ID')
plt.legend()

# Save as TIFF
plt.savefig("SOAP_li_train_MatGen_gen_data.tiff", dpi=300, format='tiff', pil_kwargs={'compression': 'tiff_lzw'})


# End timing
end_time = time.time()
elapsed = end_time - start_time
print(f"\n Done! Total time: {elapsed:.2f} seconds.")