# Wyckoff based self-augmentation 

In [17]:
from pymatgen.core import Composition
import pandas as pd
import numpy as np
from sklearn import metrics
from sklearn import decomposition
from tqdm import tqdm
import time
import joblib
import re
import json 

In [10]:
# utils 

def extract_letters(input_dict):
    alphabet_pattern = re.compile(r'[a-zA-Z]+')
    extracted_letters = []
    
    for key, value_list in input_dict.items():
        for value in value_list:
            matches = re.findall(alphabet_pattern, value)
            extracted_letters.extend(matches)
    
    return extracted_letters


def create_compatibility_dict(initial_list, transformed_list):
    compatibility_dict = {}
    for initial_elem, transformed_elem in zip(initial_list, transformed_list):
        compatibility_dict[initial_elem] = transformed_elem
    return compatibility_dict


def enumerate_transformations(input_list, compatibility_dict):
    transformations = []
    for item in input_list:
        if item in compatibility_dict:
            transformed_item = compatibility_dict[item]
            transformations.append(transformed_item)
        else:
            transformations.append(item)
    return transformations


def generate_all_possible_transformations(input_list, compatibility_dicts):
    all_possible_outputs = []
    for compatibility_dict in compatibility_dicts:
        output_list = enumerate_transformations(input_list, compatibility_dict)
        all_possible_outputs.append(output_list)
    return all_possible_outputs


def process_input_string(input_string):
    input_lists = input_string.split(', ')
    initial_list = input_lists[0].strip('[]').split()
    compatibility_lists = [lst.strip('[]').split() for lst in input_lists[1:]]
    compatibility_dicts = [
        create_compatibility_dict(initial_list, transformed_list)
        for transformed_list in compatibility_lists
    ]
    return initial_list, compatibility_dicts


def process_input(input_dict, list1, list2):
    output_dict = {}
    for element, values in input_dict.items():
        output_values = []
        for value in values:
            if value[-1] in list1:
                index = list1.index(value[-1])
                output_values.append(value[:-1] + list2[index])
        output_dict[element] = output_values
    return output_dict


def wyckoff_augment(df):
    
    wyckoff_sets_file = '/your_path_to_file/wyckoff_sets.json'
    with open(wyckoff_sets_file) as f:
        wyckoff_sets = json.load(f)
    wyckoff_sets = {int(key): value for key, value
                      in wyckoff_sets.items()}
    
    for n in tqdm(range(len(df))):
        x = df['wyckoff_dic'].values[n]
        wyckoff_let = extract_letters(x)
        
        sgn = df['spacegroup_number'].values[n]
        input_string = wyckoff_sets[sgn]
        initial_list, compatibility_dicts = process_input_string(input_string)
        all_possible_outputs = generate_all_possible_transformations(wyckoff_let, compatibility_dicts)

        for i, output in enumerate(all_possible_outputs):
            output_dict = process_input(x, wyckoff_let, output)

            index_to_duplicate = n
            row_to_duplicate = df.iloc[index_to_duplicate].copy()
            row_to_duplicate['wyckoff_dic'] = output_dict

            df = df.append(row_to_duplicate, ignore_index=True)
    # drop duplicates         
    df['wyckoff_str'] = df['wyckoff_dic'].apply(str)
    df = df.drop_duplicates(subset=['wyckoff_str', 'ind'])
    df = df.drop(columns = ['wyckoff_str'])
    
    return df


def train_test_split_bysg(df, train_ratio=0.8):
    group_counts = df['spacegroup_number'].value_counts()
    train_size = int(train_ratio * len(df))
    
    group_ratios = group_counts / len(df_clean)
    train_group_sizes = (group_ratios * train_size).astype(int)
    test_group_sizes = group_counts - train_group_sizes

    train_dfs = []
    test_dfs = []

    for group, group_df in df.groupby('spacegroup_number'):
        train_group_df = group_df.sample(train_group_sizes[group])
        test_group_df = group_df.drop(train_group_df.index)
    
        train_dfs.append(train_group_df)
        test_dfs.append(test_group_df)

    train_df = pd.concat(train_dfs)
    test_df = pd.concat(test_dfs)
    
    return train_df, test_df


def kfold_split_bysg(df, n_splits=5):
    group_counts = df['spacegroup_number'].value_counts()
    group_ratios = group_counts / len(df)
    train_group_sizes = (group_ratios * len(df)).astype(int)
    
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_train_indices = []
    fold_test_indices = []

    for train_indices, test_indices in kf.split(df):
        train_df_fold = df.iloc[train_indices]
        test_df_fold = df.iloc[test_indices]

        train_indices_list = []
        test_indices_list = []

        for group, group_df in train_df_fold.groupby('spacegroup_number'):
            train_group_size = min(train_group_sizes[group], len(group_df))
            train_group_indices = group_df.sample(train_group_size).index
            test_group_indices = test_df_fold[test_df_fold['spacegroup_number'] == group].index
            train_indices_list.extend(train_group_indices)
            test_indices_list.extend(test_group_indices)

        fold_train_indices.append(train_indices_list)
        fold_test_indices.append(test_indices_list)

    return fold_train_indices, fold_test_indices

In [13]:
wyckoff_sets_file = '/your_path_to_file/wyckoff_sets.json'
with open(wyckoff_sets_file) as f:
    wyckoff_sets = json.load(f)
wyckoff_sets

{'1': '[a]',
 '2': '[a b c d e f g h i, d f e a c b h g i, c g a e d h b f i, b a g f h d c e i, e h d c a g f b i, f d h b g a e c i, g c b h f e a d i, h e f g b c d a i]',
 '3': '[a b c d e, c d a b e, b a d c e, d c b a e]',
 '4': '[a]',
 '5': '[a b c, b a c]',
 '6': '[a b c, b a c]',
 '7': '[a]',
 '8': '[a b]',
 '9': '[a]',
 '10': '[a b c d e f g h i j k l m n o, d e g a b h c f j i l k m n o, b a f e d c h g i j k l n m o, c f a g h b d e k l i j m n o, e d h b a g f c j i l k n m o, g h d c f e a b l k j i m n o, f c b h g a e d k l i j n m o, h g e f c d b a l k j i n m o]',
 '11': '[a b c d e f, b a d c e f, c d a b e f, d c b a e f]',
 '12': '[a b c d e f g h i j, b a d c e f g h i j, c d a b f e h g i j, d c b a f e h g i j]',
 '13': '[a b c d e f g, d c b a f e g, c d a b e f g, b a d c f e g]',
 '14': '[a b c d e, b a d c e, c d a b e, d c b a e]',
 '15': '[a b c d e f, b a d c e f, a b d c e f, b a c d e f]',
 '16': '[a b c d e f g h i j k l m n o p q r s t u, b a e f c d

In [20]:
# test dataframe 

data = {'pretty_formula': ['SrTiO3', 'Mg(BiO3)2', 'Co(AsO3)2'],
        'spacegroup_number': [221, 2, 162],
        'wyckoff_dic': [{'Sr': ['1b'], 'Ti': ['1a'], 'O': ['3d']}, {'Mg': ['1d'], 'Bi': ['2i', '1g', '1e'], 'O': ['2i', '2i', '2i', '2i']}, {'Co': ['1a'], 'As': ['2d'], 'O': ['6k']}],
       'ind': [0, 1, 2]}
df = pd.DataFrame(data)
df

Unnamed: 0,pretty_formula,spacegroup_number,wyckoff_dic,ind
0,SrTiO3,221,"{'Sr': ['1b'], 'Ti': ['1a'], 'O': ['3d']}",0
1,Mg(BiO3)2,2,"{'Mg': ['1d'], 'Bi': ['2i', '1g', '1e'], 'O': ...",1
2,Co(AsO3)2,162,"{'Co': ['1a'], 'As': ['2d'], 'O': ['6k']}",2


In [21]:
df_augment = wyckoff_augment(df)
df_augment 

  df = df.append(row_to_duplicate, ignore_index=True)
100%|████████████████████████████████████████████| 3/3 [00:00<00:00, 184.84it/s]


Unnamed: 0,pretty_formula,spacegroup_number,wyckoff_dic,ind
0,SrTiO3,221,"{'Sr': ['1b'], 'Ti': ['1a'], 'O': ['3d']}",0
1,Mg(BiO3)2,2,"{'Mg': ['1d'], 'Bi': ['2i', '1g', '1e'], 'O': ...",1
2,Co(AsO3)2,162,"{'Co': ['1a'], 'As': ['2d'], 'O': ['6k']}",2
3,SrTiO3,221,"{'Sr': ['1a'], 'Ti': ['1b'], 'O': ['3c']}",0
4,Mg(BiO3)2,2,"{'Mg': ['1a'], 'Bi': ['2i', '1h', '1c'], 'O': ...",1
5,Mg(BiO3)2,2,"{'Mg': ['1e'], 'Bi': ['2i', '1b', '1d'], 'O': ...",1
6,Mg(BiO3)2,2,"{'Mg': ['1f'], 'Bi': ['2i', '1c', '1h'], 'O': ...",1
7,Mg(BiO3)2,2,"{'Mg': ['1c'], 'Bi': ['2i', '1f', '1a'], 'O': ...",1
8,Mg(BiO3)2,2,"{'Mg': ['1b'], 'Bi': ['2i', '1e', '1g'], 'O': ...",1
9,Mg(BiO3)2,2,"{'Mg': ['1h'], 'Bi': ['2i', '1a', '1f'], 'O': ...",1


## test

In [2]:
import re

#SG2

def extract_alphabets(input_dict):
    alphabet_pattern = re.compile(r'[a-zA-Z]+')
    extracted_alphabets = []
    
    for key, value_list in input_dict.items():
        for value in value_list:
            matches = re.findall(alphabet_pattern, value)
            extracted_alphabets.extend(matches)
    
    return extracted_alphabets

input_dict = {'Mg': ['1d'], 'Bi': ['2i', '1g', '1e'], 'O': ['2i', '2i', '2i', '2i']}

wyc_let = extract_alphabets(input_dict)
print(wyc_let)

['d', 'i', 'g', 'e', 'i', 'i', 'i', 'i']


In [5]:
#SG2

def create_compatibility_dict(initial_list, transformed_list):
    compatibility_dict = {}
    for initial_elem, transformed_elem in zip(initial_list, transformed_list):
        compatibility_dict[initial_elem] = transformed_elem
    return compatibility_dict

def enumerate_transformations(input_list, compatibility_dict):
    transformations = []
    for item in input_list:
        if item in compatibility_dict:
            transformed_item = compatibility_dict[item]
            transformations.append(transformed_item)
        else:
            transformations.append(item)
    return transformations

def generate_all_possible_transformations(input_list, compatibility_dicts):
    all_possible_outputs = []
    for compatibility_dict in compatibility_dicts:
        output_list = enumerate_transformations(input_list, compatibility_dict)
        all_possible_outputs.append(output_list)
    return all_possible_outputs

def process_input_string(input_string):
    input_lists = input_string.split(', ')
    initial_list = input_lists[0].strip('[]').split()
    compatibility_lists = [lst.strip('[]').split() for lst in input_lists[1:]]
    compatibility_dicts = [
        create_compatibility_dict(initial_list, transformed_list)
        for transformed_list in compatibility_lists
    ]
    return initial_list, compatibility_dicts

# wyckoff sets of SG 2
wyckoff_string = "[a b c d e f g h i, d f e a c b h g i, c g a e d h b f i, b a g f h d c e i, e h d c a g f b i, f d h b g a e c i, g c b h f e a d i, h e f g b c d a i]"
initial_list, compatibility_dicts = process_input_string(wyckoff_string)


# Given input
input_1 = wyc_let

# Generate all possible transformations
all_possible_outputs = generate_all_possible_transformations(input_1, compatibility_dicts)

for i, output in enumerate(all_possible_outputs):
    print("Input:", input_1, "→ Possible Output", i+1, ":", output)

Input: ['d', 'i', 'g', 'e', 'i', 'i', 'i', 'i'] → Possible Output 1 : ['a', 'i', 'h', 'c', 'i', 'i', 'i', 'i']
Input: ['d', 'i', 'g', 'e', 'i', 'i', 'i', 'i'] → Possible Output 2 : ['e', 'i', 'b', 'd', 'i', 'i', 'i', 'i']
Input: ['d', 'i', 'g', 'e', 'i', 'i', 'i', 'i'] → Possible Output 3 : ['f', 'i', 'c', 'h', 'i', 'i', 'i', 'i']
Input: ['d', 'i', 'g', 'e', 'i', 'i', 'i', 'i'] → Possible Output 4 : ['c', 'i', 'f', 'a', 'i', 'i', 'i', 'i']
Input: ['d', 'i', 'g', 'e', 'i', 'i', 'i', 'i'] → Possible Output 5 : ['b', 'i', 'e', 'g', 'i', 'i', 'i', 'i']
Input: ['d', 'i', 'g', 'e', 'i', 'i', 'i', 'i'] → Possible Output 6 : ['h', 'i', 'a', 'f', 'i', 'i', 'i', 'i']
Input: ['d', 'i', 'g', 'e', 'i', 'i', 'i', 'i'] → Possible Output 7 : ['g', 'i', 'd', 'b', 'i', 'i', 'i', 'i']


In [4]:
def process_input(input_dict, list1, list2):
    output_dict = {}
    for element, values in input_dict.items():
        output_values = []
        for value in values:
            if value[-1] in list1:
                index = list1.index(value[-1])
                output_values.append(value[:-1] + list2[index])
        output_dict[element] = output_values
    return output_dict

# input data 
input_dict = {'Mg': ['1d'], 'Bi': ['2i', '1g', '1e'], 'O': ['2i', '2i', '2i', '2i']}
list1 = wyc_let

# output 
for i, output in enumerate(all_possible_outputs):
    output_dict = process_input(input_dict, list1, output)
    print(output_dict)

{'Mg': ['1a'], 'Bi': ['2i', '1h', '1c'], 'O': ['2i', '2i', '2i', '2i']}
{'Mg': ['1e'], 'Bi': ['2i', '1b', '1d'], 'O': ['2i', '2i', '2i', '2i']}
{'Mg': ['1f'], 'Bi': ['2i', '1c', '1h'], 'O': ['2i', '2i', '2i', '2i']}
{'Mg': ['1c'], 'Bi': ['2i', '1f', '1a'], 'O': ['2i', '2i', '2i', '2i']}
{'Mg': ['1b'], 'Bi': ['2i', '1e', '1g'], 'O': ['2i', '2i', '2i', '2i']}
{'Mg': ['1h'], 'Bi': ['2i', '1a', '1f'], 'O': ['2i', '2i', '2i', '2i']}
{'Mg': ['1g'], 'Bi': ['2i', '1d', '1b'], 'O': ['2i', '2i', '2i', '2i']}
