In [1]:
import os
os.chdir("../../crystal/")

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio

from tqdm import tqdm
from math import lcm
from typing import List, Dict
from ripser import Rips
from utils import wasserstein_distance_matrix, plot_distance_matrix

from Symmetry import Symmetry
from UnitCell import UnitCell
from RandomCrystal import RandomCrystal
from Fractional import FractionalCoordinate, FractionalCoordinateList
from Positional import PositionalCoordinate, PositionalCoordinateList

from sklearn.manifold import TSNE, Isomap, MDS, SpectralEmbedding
from sklearn.multiclass import OneVsOneClassifier
from sklearn.svm import SVC

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

plt.rcParams['text.usetex'] = False

In [2]:
df = pd.read_csv("output/space_groups.csv", sep=" ")
df

Unnamed: 0,Space Group,Crystal System,Asymmetric Unit,Symmetries,Group Order,Unit Cell
0,2,Triclinic,30,"x,y,z;-x,-y,-z",2,"13.027,11.1822,9.2189,92.35,107.3,104.5"
1,9,Monoclinic,71,"x,y,z;x,-y,z+1/2;x+1/2,y+1/2,z;x+1/2,-y+1/2,z+1/2",4,"20.408,13.304,20.166,90.0,102.33,90.0"
2,10,Monoclinic,50,"x,y,z;-x,y,-z;-x,-y,-z;x,-y,z",4,"12.7411,12.6989,20.9991,90.0,96.29,90.0"
3,11,Monoclinic,25,"x,y,z;-x,y+1/2,-z;-x,-y,-z;x,-y+1/2,z",4,"11.454,21.695,7.227,90.0,93.15,90.0"
4,12,Monoclinic,24,"x,y,z;-x,y,-z;-x,-y,-z;x,-y,z;x+1/2,y+1/2,z;-x...",8,"22.684,13.373,12.553,90.0,69.48,90.0"
...,...,...,...,...,...,...
68,223,Cubic,7,"x,y,z;-x,-y,z;x,-y,-z;-x,y,-z;z,x,y;y,z,x;-z,-...",48,"13.705,13.705,13.705,90.0,90.0,90.0"
69,225,Cubic,4,"x,y,z;-x,-y,z;x,-y,-z;-x,y,-z;z,x,y;y,z,x;-z,-...",192,"13.624,13.624,13.624,90.0,90.0,90.0"
70,227,Cubic,5,"x,y,z;-x+1/4,-y+1/4,z;x,-y+1/4,-z+1/4;-x+1/4,y...",192,"24.345,24.345,24.345,90.0,90.0,90.0"
71,229,Cubic,5,"x,y,z;-x,-y,z;x,-y,-z;-x,y,-z;z,x,y;y,z,x;-z,-...",96,"18.578,18.578,18.578,90.0,90.0,90.0"


# Trigonal vs Hexagonal

In [28]:
crystal_system_1: str = 'Trigonal'
crystal_system_2: str = 'Hexagonal'

space_groups_1: np.ndarray = df.loc[df['Crystal System'] == crystal_system_1]['Space Group'].unique()
space_groups_2: np.ndarray = df.loc[df['Crystal System'] == crystal_system_2]['Space Group'].unique()

random_space_group_index: np.ndarray = np.random.randint(low = 0, high = len(space_groups_1), size = 1, dtype = int)[0]
space_group_1: int = space_groups_1[random_space_group_index]

random_space_group_index: np.ndarray = np.random.randint(low = 0, high = len(space_groups_2), size = 1, dtype = int)[0]
space_group_2: int = space_groups_2[random_space_group_index]

row_1 = df.loc[df['Space Group'] == space_group_1]
row_2 = df.loc[df['Space Group'] == space_group_2]

group_order_1: int = row_1['Group Order'].values[0]
group_order_2: int = row_2['Group Order'].values[0]

least_common_multiple: int = lcm(group_order_1, group_order_2)
print(f'Least common multiple: {least_common_multiple}')

crystal_dict = {}
n = 96

Least common multiple: 36


In [29]:
symmetries: list = row_1['Symmetries'].values[0].split(sep=";")
symmetries: List[Symmetry] = [Symmetry(sym) for sym in symmetries]

unit_cell: List[str] = row_1['Unit Cell'].values[0].split(sep=",")
unit_cell: UnitCell = UnitCell(*[float(x) for x in unit_cell])
normalised_unit_cell: UnitCell = unit_cell.normalise()

normalising_constant: float = unit_cell.normalising_constant
k = least_common_multiple // group_order_1
if least_common_multiple < n:
    k *= (n // least_common_multiple)

rips = Rips(maxdim=2, verbose=False)

for _ in tqdm(range(50)):
    random_crystal: RandomCrystal = RandomCrystal(symmetries, k)
    positional_coordinates: PositionalCoordinateList = random_crystal.fractional_coordinates.orthogonalise(unit_cell)
    normalised_positional_coordinates: PositionalCoordinateList = positional_coordinates.normalise(normalising_constant)

    distance_matrix: np.ndarray = normalised_positional_coordinates.calculate_distance_matrix(normalised_unit_cell, boundary_conditions=True)
    persistence = rips.fit_transform(X = distance_matrix, distance_matrix=True)

    for dim, intervals in enumerate(persistence):
        persistence[dim] = np.array(list(filter(lambda i: i[1] < float('inf'), intervals)))

    crystal_dict[f'{space_group_1}_{_+1}'] = {
        'system': crystal_system_1,
        'persistence': persistence
    }

100%|██████████| 50/50 [00:02<00:00, 19.55it/s]


In [30]:
symmetries: list = row_2['Symmetries'].values[0].split(sep=";")
symmetries: List[Symmetry] = [Symmetry(sym) for sym in symmetries]

unit_cell: List[str] = row_2['Unit Cell'].values[0].split(sep=",")
unit_cell: UnitCell = UnitCell(*[float(x) for x in unit_cell])
normalised_unit_cell: UnitCell = unit_cell.normalise()

normalising_constant: float = unit_cell.normalising_constant
k = least_common_multiple // group_order_2
if least_common_multiple < n:
    k *= (n // least_common_multiple)

rips = Rips(maxdim=2, verbose=False)

for _ in tqdm(range(50)):
    random_crystal: RandomCrystal = RandomCrystal(symmetries, k)
    positional_coordinates: PositionalCoordinateList = random_crystal.fractional_coordinates.orthogonalise(unit_cell)
    normalised_positional_coordinates: PositionalCoordinateList = positional_coordinates.normalise(normalising_constant)

    distance_matrix: np.ndarray = normalised_positional_coordinates.calculate_distance_matrix(normalised_unit_cell, boundary_conditions=True)
    persistence = rips.fit_transform(X = distance_matrix, distance_matrix=True)

    for dim, intervals in enumerate(persistence):
        persistence[dim] = np.array(list(filter(lambda i: i[1] < float('inf'), intervals)))

    crystal_dict[f'{space_group_2}_{_+1}'] = {
        'system': crystal_system_2,
        'persistence': persistence
    }

100%|██████████| 50/50 [00:01<00:00, 25.03it/s]


In [31]:
distance_matrix_0 = wasserstein_distance_matrix(crystal_dict, 0)
distance_matrix_1 = wasserstein_distance_matrix(crystal_dict, 1)
distance_matrix_2 = wasserstein_distance_matrix(crystal_dict, 2)

distance_matrix: np.ndarray = np.maximum.reduce([distance_matrix_0, distance_matrix_1, distance_matrix_2])

100%|██████████| 100/100 [00:02<00:00, 40.88it/s]
100%|██████████| 100/100 [00:01<00:00, 57.02it/s]
100%|██████████| 100/100 [00:01<00:00, 62.49it/s]


In [32]:
system = [v['system'] for v in crystal_dict.values()]
crystals = list(crystal_dict.keys())

n = 5
mds = MDS(n_components=n, dissimilarity='precomputed', metric=True)
embedding = mds.fit_transform(distance_matrix)

embedding_df = pd.DataFrame(embedding, columns=[str(_+1) for _ in range(n)])
embedding_df['Crystal System'] = system
embedding_df['Name'] = crystals

X = embedding
y = embedding_df['Crystal System'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

embedding_df.head()

Unnamed: 0,1,2,3,4,5,Crystal System,Name
0,-1.47984,-0.126551,-0.305966,1.596748,-1.206594,Trigonal,166_1
1,2.135977,-0.906353,-0.541101,-0.147913,-0.748144,Trigonal,166_2
2,-0.141921,1.128429,-0.494201,1.470201,-0.001851,Trigonal,166_3
3,-1.438334,1.972812,1.608184,0.956288,-0.32925,Trigonal,166_4
4,0.020802,0.401047,-0.651734,-0.432748,0.971603,Trigonal,166_5


In [33]:
clf = OneVsOneClassifier(SVC())
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')

class_report = classification_report(y_test, y_pred, target_names = embedding_df['Crystal System'].unique())
print('Classification Report:')
print(class_report)

Accuracy: 0.90
Classification Report:
              precision    recall  f1-score   support

    Trigonal       0.86      0.92      0.89        13
   Hexagonal       0.94      0.88      0.91        17

    accuracy                           0.90        30
   macro avg       0.90      0.90      0.90        30
weighted avg       0.90      0.90      0.90        30



# Trigonal vs Cubic

In [34]:
crystal_system_1: str = 'Trigonal'
crystal_system_2: str = 'Cubic'

space_groups_1: np.ndarray = df.loc[df['Crystal System'] == crystal_system_1]['Space Group'].unique()
space_groups_2: np.ndarray = df.loc[df['Crystal System'] == crystal_system_2]['Space Group'].unique()

random_space_group_index: np.ndarray = np.random.randint(low = 0, high = len(space_groups_1), size = 1, dtype = int)[0]
space_group_1: int = space_groups_1[random_space_group_index]

random_space_group_index: np.ndarray = np.random.randint(low = 0, high = len(space_groups_2), size = 1, dtype = int)[0]
space_group_2: int = space_groups_2[random_space_group_index]

row_1 = df.loc[df['Space Group'] == space_group_1]
row_2 = df.loc[df['Space Group'] == space_group_2]

group_order_1: int = row_1['Group Order'].values[0]
group_order_2: int = row_2['Group Order'].values[0]

least_common_multiple: int = lcm(group_order_1, group_order_2)
print(f'Least common multiple: {least_common_multiple}')

crystal_dict = {}
n = 96

Least common multiple: 192


In [35]:
symmetries: list = row_1['Symmetries'].values[0].split(sep=";")
symmetries: List[Symmetry] = [Symmetry(sym) for sym in symmetries]

unit_cell: List[str] = row_1['Unit Cell'].values[0].split(sep=",")
unit_cell: UnitCell = UnitCell(*[float(x) for x in unit_cell])
normalised_unit_cell: UnitCell = unit_cell.normalise()

normalising_constant: float = unit_cell.normalising_constant
k = least_common_multiple // group_order_1
if least_common_multiple < n:
    k *= (n // least_common_multiple)

rips = Rips(maxdim=2, verbose=False)

for _ in tqdm(range(50)):
    random_crystal: RandomCrystal = RandomCrystal(symmetries, k)
    positional_coordinates: PositionalCoordinateList = random_crystal.fractional_coordinates.orthogonalise(unit_cell)
    normalised_positional_coordinates: PositionalCoordinateList = positional_coordinates.normalise(normalising_constant)

    distance_matrix: np.ndarray = normalised_positional_coordinates.calculate_distance_matrix(normalised_unit_cell, boundary_conditions=True)
    persistence = rips.fit_transform(X = distance_matrix, distance_matrix=True)

    for dim, intervals in enumerate(persistence):
        persistence[dim] = np.array(list(filter(lambda i: i[1] < float('inf'), intervals)))

    crystal_dict[f'{space_group_1}_{_+1}'] = {
        'system': crystal_system_1,
        'persistence': persistence
    }

100%|██████████| 50/50 [00:36<00:00,  1.37it/s]


In [36]:
symmetries: list = row_2['Symmetries'].values[0].split(sep=";")
symmetries: List[Symmetry] = [Symmetry(sym) for sym in symmetries]

unit_cell: List[str] = row_2['Unit Cell'].values[0].split(sep=",")
unit_cell: UnitCell = UnitCell(*[float(x) for x in unit_cell])
normalised_unit_cell: UnitCell = unit_cell.normalise()

normalising_constant: float = unit_cell.normalising_constant
k = least_common_multiple // group_order_2
if least_common_multiple < n:
    k *= (n // least_common_multiple)

rips = Rips(maxdim=2, verbose=False)

for _ in tqdm(range(50)):
    random_crystal: RandomCrystal = RandomCrystal(symmetries, k)
    positional_coordinates: PositionalCoordinateList = random_crystal.fractional_coordinates.orthogonalise(unit_cell)
    normalised_positional_coordinates: PositionalCoordinateList = positional_coordinates.normalise(normalising_constant)

    distance_matrix: np.ndarray = normalised_positional_coordinates.calculate_distance_matrix(normalised_unit_cell, boundary_conditions=True)
    persistence = rips.fit_transform(X = distance_matrix, distance_matrix=True)

    for dim, intervals in enumerate(persistence):
        persistence[dim] = np.array(list(filter(lambda i: i[1] < float('inf'), intervals)))

    crystal_dict[f'{space_group_2}_{_+1}'] = {
        'system': crystal_system_2,
        'persistence': persistence
    }

100%|██████████| 50/50 [00:59<00:00,  1.19s/it]


In [37]:
distance_matrix_0 = wasserstein_distance_matrix(crystal_dict, 0)
distance_matrix_1 = wasserstein_distance_matrix(crystal_dict, 1)
distance_matrix_2 = wasserstein_distance_matrix(crystal_dict, 2)

distance_matrix: np.ndarray = np.maximum.reduce([distance_matrix_0, distance_matrix_1, distance_matrix_2])

100%|██████████| 100/100 [00:14<00:00,  6.95it/s]
100%|██████████| 100/100 [00:13<00:00,  7.29it/s]
100%|██████████| 100/100 [00:05<00:00, 17.85it/s]


In [38]:
system = [v['system'] for v in crystal_dict.values()]
crystals = list(crystal_dict.keys())

n = 5
mds = MDS(n_components=n, dissimilarity='precomputed', metric=True)
embedding = mds.fit_transform(distance_matrix)

embedding_df = pd.DataFrame(embedding, columns=[str(_+1) for _ in range(n)])
embedding_df['Crystal System'] = system
embedding_df['Name'] = crystals

X = embedding
y = embedding_df['Crystal System'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

embedding_df.head()

Unnamed: 0,1,2,3,4,5,Crystal System,Name
0,-2.766093,2.075187,0.657874,1.418536,-1.168205,Trigonal,164_1
1,-2.774214,2.037055,1.558659,1.671768,0.558134,Trigonal,164_2
2,-1.022356,-0.09201,0.878556,0.193382,-0.331694,Trigonal,164_3
3,-1.942407,0.964149,-0.041594,-0.326254,-0.549604,Trigonal,164_4
4,-4.242157,0.861057,-0.81071,1.511052,-0.91877,Trigonal,164_5


In [39]:
clf = OneVsOneClassifier(SVC())
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')

class_report = classification_report(y_test, y_pred, target_names = embedding_df['Crystal System'].unique())
print('Classification Report:')
print(class_report)

Accuracy: 1.00
Classification Report:
              precision    recall  f1-score   support

    Trigonal       1.00      1.00      1.00        13
       Cubic       1.00      1.00      1.00        17

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

