In [2]:
import json

In [3]:
#function to load json
def load_json(file):
    with open(file) as f:
        data = json.load(f)
    return data

In [4]:
data = load_json('/home/so87pot/n0w0f/structllm_data/qmof_filtered.json')

In [24]:
from pymatgen.core.structure import Structure

def dict_to_structure(structure_dict):
    return Structure.from_dict(structure_dict).to(fmt='cif')

def filter_data(data,number_of_atoms):
    filtered_data = []
    for entry in data:
        if all(key in entry['data'] for key in ['natoms', 'pld', 'lcd', 'density', 'EgPBE', 'volume']):
            if entry['data']['natoms']['value'] < number_of_atoms:
                filtered_entry = {
                    'id': entry['id'],
                    'structure': dict_to_structure(entry['structure']),
                    'natoms': entry['data']['natoms']['value'],
                    'pld': entry['data']['pld']['value'],
                    'lcd': entry['data']['lcd']['value'],
                    'density': entry['data']['density']['value'],
                    'EgPBE': entry['data']['EgPBE']['value'],
                    'volume': entry['data']['volume']['value'],
                }
                filtered_data.append(filtered_entry)
    return filtered_data

# code  to call filter data and save json
data = load_json('/home/so87pot/n0w0f/structllm/src/structllm/dataprep/qmof_dataset/screened_mofs.json')

filtered_data = filter_data(data,100)
with open('filtered_data.json', 'w') as f:
    json.dump(filtered_data, f, indent=4)

In [25]:
filtered_data[0]

{'id': '6196c6e7a6be6ad338993951',
 'structure': "# generated using pymatgen\ndata_Ba2CuH14(C3O8)2\n_symmetry_space_group_name_H-M   'P 1'\n_cell_length_a   6.94195231\n_cell_length_b   7.17878940\n_cell_length_c   8.79165165\n_cell_angle_alpha   82.25599710\n_cell_angle_beta   71.35701702\n_cell_angle_gamma   81.58285545\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   Ba2CuH14(C3O8)2\n_chemical_formula_sum   'Ba2 Cu1 H14 C6 O16'\n_cell_volume   408.85747126\n_cell_formula_units_Z   1\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  Ba  Ba0  1  0.09830491  0.73435257  0.35340987  1\n  Ba  Ba1  1  0.90169454  0.26564683  0.64659097  1\n  Cu  Cu2  1  0.00000517  0.99999776  0.99999986  1\n  H  H3  1  0.59629623  0.90436682  0.07226429  1\n  H  H4  1  0.40370491  0.09

In [7]:
from pymatgen.core import Structure
from_file  = "/home/so87pot/n0w0f/xtal2txt/tests/data/InCuS2_p1.cif"
structure = Structure.from_file(str(from_file), "cif")
structure

Structure Summary
Lattice
    abc : 5.52040491 5.52040491 6.7959803201326165
 angles : 113.96335098610967 113.96335098610967 90.0
 volume : 169.53429042489148
      A : 5.52040491 -0.0 -0.0
      B : -3.380273101514416e-16 5.52040491 3.380273101514416e-16
      C : -2.7602024549999995 -2.760202455 5.563084875
    pbc : True True True
PeriodicSite: Cu4 (Cu+) (-6.882e-17, 2.76, 2.782) [0.25, 0.75, 0.5]
PeriodicSite: Cu5 (Cu+) (0.0, 0.0, 0.0) [0.0, 0.0, 0.0]
PeriodicSite: In0 (In3+) (2.76, 2.76, 1.69e-16) [0.5, 0.5, 0.0]
PeriodicSite: In1 (In3+) (2.76, -7.194e-16, 2.782) [0.75, 0.25, 0.5]
PeriodicSite: S8 (S2-) (4.299, 4.14, 1.391) [0.9038, 0.875, 0.25]
PeriodicSite: S9 (S2-) (1.38, -1.539, 4.172) [0.625, 0.0962, 0.75]
PeriodicSite: S10 (S2-) (1.221, 1.38, 1.391) [0.3462, 0.375, 0.25]
PeriodicSite: S11 (S2-) (-1.38, 1.539, 4.172) [0.125, 0.6538, 0.75]

In [8]:
from xtal2txt.core import TextRep
transformations = [
    #("permute_structure", {"seed": 42}),
     ("translate_single_atom", {"seed": 42}),
     ("perturb_structure", {"seed": 42, "max_distance":0.1}),
     ("translate_structure", {"seed": 42, "vector": [0.1,0.1,0.1], "frac_coords": True}),
    ]

text_rep = TextRep.from_input(structure, transformations)

In [None]:
structure.perturb(distance=0.1,)

In [8]:
structure.translate_sites(indices=[0], vector=[0.25, 0.25, 0.25],frac_coords=True)

In [99]:
from pymatgen.core.structure import Molecule

species = [s.element for s in structure.species]
coords = [c for c in structure.cart_coords]
molecule_ = Molecule(species, coords, )
zmatrix = molecule_.get_zmatrix()

In [90]:
def updated_zmatrix_rep(zmatrix, decimal_places=1):
    lines = zmatrix.split('\n')
    main_part = []
    variables_part = []

    # Determine the main part and the variables part of the Z-matrix
    for line in lines:
        if '=' in line:
            variables_part.append(line)
        else:
            if line.strip():  # Skip empty lines
                main_part.append(line)

    # Extract variables from the variables part
    variable_dict = {}
    for var_line in variables_part:
        var, value = var_line.split('=')
        if var.startswith('B'):
            rounded_value = round(float(value.strip()), decimal_places)
        else:
            rounded_value = int(round(float(value.strip())))
        variable_dict[var] = f"{rounded_value}" if var.startswith(('A', 'D')) else f"{rounded_value:.{decimal_places}f}"

    # Replace variables in the main part
    replaced_lines = []
    for line in main_part:
        parts = line.split()
        atom = parts[0]
        replaced_line = line
        for i in range(1, len(parts)):
            var = parts[i]
            if var in variable_dict:
                replaced_line = replaced_line.replace(var, variable_dict[var])
        replaced_lines.append(replaced_line)

    return '\n'.join(replaced_lines)



In [23]:
struct = Structure.from_str("data_ZnMoCl2\n_symmetry_space_group_name_H-M   'P 1'\n_cell_length_a   3.65\n_cell_length_b   3.23\n_cell_length_c   7.1\n_cell_angle_alpha   90.0\n_cell_angle_beta   101.56\n_cell_angle_gamma   90.0\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   ZnMoCl2\n_chemical_formula_sum   'Zn1 Mo1 Cl2'\n_cell_volume   82.06\n_cell_formula_units_Z   1\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  Zn  Zn3  1  0.25  0.0  0.02  1.0\n  Mo  Mo2  1  0.43  0.62  0.04  1.0\n  Cl  Cl0  1  0.15  0.62  0.56  1.0\n  Cl  Cl1  1  0.51  0.0  0.82  1.0\n", fmt="cif")

In [29]:
struct

Structure Summary
Lattice
    abc : 3.65 3.23 7.1
 angles : 90.0 101.56000000000002 90.0
 volume : 82.0075175703201
      A : 3.5759611725600706 0.0 -0.7314380987772003
      B : -1.9778045806229754e-16 3.23 1.9778045806229754e-16
      C : 0.0 0.0 7.1
    pbc : True True True
PeriodicSite: Zn3 (Zn) (0.894, 0.0, -0.04086) [0.25, 0.0, 0.02]
PeriodicSite: Mo2 (Mo) (1.538, 2.003, -0.03052) [0.43, 0.62, 0.04]
PeriodicSite: Cl0 (Cl) (0.5364, 2.003, 3.866) [0.15, 0.62, 0.56]
PeriodicSite: Cl1 (Cl) (1.824, 0.0, 5.449) [0.51, 0.0, 0.82]

In [27]:
species_ =  [s for s in structure.species]

In [28]:
species_

[Species Cu+,
 Species Cu+,
 Species In3+,
 Species In3+,
 Species S2-,
 Species S2-,
 Species S2-,
 Species S2-]

In [None]:
from scipy.spatial.distance import squareform

def get_distance(self, i: int, j: int) -> float:
        """Get distance between site i and j.

        Args:
            i (int): 1st site index
            j (int): 2nd site index

        Returns:
            Distance between the two sites.
        """
        return self[i].distance(self[j])

def _find_nn_pos_before_site(self, site_idx):
        """Returns index of nearest neighbor atoms."""
        all_dist = [(self.get_distance(site_idx, idx), idx) for idx in range(site_idx)]
        all_dist = sorted(all_dist, key=lambda x: x[0])
        return [d[1] for d in all_dist]

def get_zmatrix(self, molecule, decimal_places=6):
    """Returns a z-matrix representation of the molecule."""
    output = []
    output_var = []
    for idx, site in enumerate(molecule):
        if idx == 0:
            output.append(f"{site.specie}")
        else:
            nn = self._find_nn_pos_before_site(molecule, idx)
            bond_length = molecule.get_distance(idx, nn[0])
            bond_length_str = f"{bond_length:.{decimal_places}f}"
            bond_rep = f"{nn[0] + 1} {bond_length_str}"
            bond_var = f"B{idx}={bond_length_str}"
            if idx == 1:
                output.append(f"{molecule[idx].specie} {bond_rep}")
            elif idx == 2:
                angle = molecule.get_angle(idx, nn[0], nn[1])
                angle_str = f"{angle:.{decimal_places}f}"
                output.append(f"{molecule[idx].specie} {bond_rep} {nn[1] + 1} {angle_str}")
                output_var.append(f"A{idx}={angle_str}")
            else:
                angle = molecule.get_angle(idx, nn[0], nn[1])
                angle_str = f"{angle:.{decimal_places}f}"
                dihedral = molecule.get_dihedral(idx, nn[0], nn[1], nn[2])
                dihedral_str = f"{dihedral:.{decimal_places}f}"
                output.append(f"{molecule[idx].specie} {bond_rep} {nn[1] + 1} {angle_str} {nn[2] + 1} {dihedral_str}")
                output_var.extend([f"A{idx}={angle_str}", f"D{idx}={dihedral_str}"])
            output_var.append(bond_var)
    return "\n".join(output) + "\n\n" + "\n".join(output_var)


In [4]:
from xtal2txt.local_env import LocalEnvAnalyzer

In [5]:
lenv = LocalEnvAnalyzer()

In [26]:
local_envs_0 = lenv.structure_to_local_env_string(structure)

In [25]:
local_envs_0[0][1]

{'Site': 'In3+',
 'Wyckoff Label': '2b',
 'Environment': Coordination geometry type : Tetrahedron (IUPAC: T-4 || IUCr: [4t])
 
   - coordination number : 4
 ------------------------------------------------------------,
 'Molecule': Molecule Graph
 Molecule: 
 Molecule Summary
 Site: In0 (In3+) (0.0000, 0.0000, -0.0000)
 Site: S10 (S2-) (-1.5391, -1.3801, 1.3908)
 Site: S8 (S2-) (1.5391, 1.3801, 1.3908)
 Site: S9 (S2-) (1.3801, -1.5391, -1.3908)
 Site: S11 (S2-) (-1.3801, 1.5391, -1.3908)
 Graph: bonds
 from    to  to_image    
 ----  ----  ------------
    0     1  (0, 0, 0)   
    0     2  (0, 0, 0)   
    0     3  (0, 0, 0)   
    0     4  (0, 0, 0)   ,
 'SMILES': '[S][In]([S])[S].[S]'}

In [216]:

from typing import List
import re
class NumTokenizer:
    """Tokenize numbers as implemented in Regression Transformer.
        https://www.nature.com/articles/s42256-023-00639-z"""
        

    def __init__(self) -> None:
        """Tokenizer for numbers."""
        self.regex = re.compile(r"(\+|-)?(\d+)(\.)?(\d+)?\s*")

    import re

    def num_matcher(self, text: str) -> str:
        """Extract numbers from a sentence and replace them with tokens."""
        matches = re.findall(r'\b\d+(?:\.\d+)?\b', text)  # This regex captures both whole numbers and decimal numbers
        for match in matches:
            tokens = self.tokenize(match)
            replacement = ' '.join(tokens)
            text = re.sub(r'\b' + re.escape(match) + r'\b', replacement, text, count=1)  # replace only the first occurrence
        return text

    def tokenize(self, text: str) -> List[str]:
        """Tokenization of a property.

        Args:
            text: number as string to be tokenized.

        Returns:
            extracted tokens.
        """
        tokens = []
        matched = self.regex.match(text)
        if matched:
            sign, units, dot, decimals = matched.groups()
            tokens = []
            if sign:
                tokens += [f"_{sign}_"]
            tokens += [
                f"_{number}_{position}_" for position, number in enumerate(units[::-1])
            ][::-1]
            if dot:
                tokens += [f"_{dot}_"]
            if decimals:
                tokens += [
                    f"_{number}_-{position}_"
                    for position, number in enumerate(decimals, 1)
                ]
        return tokens
    
    @staticmethod
    def floating_tokens_to_float(token_ids: List[str]) -> float:
        """Converts tokens representing a float value into a float.
        NOTE: Expects that non-floating tokens are strippped off

        Args:
            token_ids: List of tokens, each representing a float.
                E.g.: ['_0_0_', '_._', '_9_-1_', '_3_-2_', '_1_-3_']

        Returns:
            float: Float representation for the list of tokens.
        """
        try:
            float_string = "".join([token.split("_")[1] for token in token_ids])
            float_value = float(float_string)
        except ValueError:
            float_value = -1
        return float_value

    
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Converts tokens to string.

        Args:
            tokens: List of tokens.

        Returns:
            str: String representation of the tokens.
        """
        return "".join([token.split("_")[1] for token in tokens])
        


In [9]:
cif

"data_InCuS2\n_symmetry_space_group_name_H-M   P1\n_cell_length_a   5.52\n_cell_length_b   5.52\n_cell_length_c   6.796\n_cell_angle_alpha   113.963\n_cell_angle_beta   113.963\n_cell_angle_gamma   90.0\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   InCuS2\n_chemical_formula_sum   'In2 Cu2 S4'\n_cell_volume   169.534\n_cell_formula_units_Z   2\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_type_symbol\n _atom_type_oxidation_number\n  In3+  3.0\n  Cu+  1.0\n  S2-  -2.0\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  In3+  In0  1  0.609  0.603  0.097  1.0\n  In3+  In1  1  0.854  0.341  0.594  1.0\n  Cu+  Cu2  1  0.36  0.347  0.348  1.0\n  Cu+  Cu3  1  0.361  0.853  0.609  1.0\n  S2-  S4  1  0.001  0.963  0.342  1.0\n  S2-  S5  1  0.218  0.749  0.839  1.0\n  S2-  S6  1  0.445  0.471  0.358  1.0\n 

In [38]:
import re
def num_matcher(text: str) -> str:
    """Extract numbers from a sentence and replace them with tokens."""
    matches = re.findall(r'(\d+\.\d+|\d+)', text)  # This regex captures both whole numbers and decimal numbers
    print(matches)

In [None]:
slice = c

In [39]:
num_matcher("data_InCuS2\n_symmetry_space_group_name_H-M   P1\n_cell_length_a   5.52\n_cell_length_b   5.52\n_cell_length_c  ")

['2', '1', '5.52', '5.52']


In [220]:
num_tokenizer = NumTokenizer()

In [221]:
mm = num_tokenizer.tokenize("5.67")

In [222]:
mm

['_5_0_', '_._', '_6_-1_', '_7_-2_']

In [223]:
c1

"data_InCuS2\n_symmetry_space_group_name_H-M   P1\n_cell_length_a   5.52\n_cell_length_b   5.52\n_cell_length_c   6.796\n_cell_angle_alpha   113.963\n_cell_angle_beta   113.963\n_cell_angle_gamma   90.0\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   InCuS2\n_chemical_formula_sum   'In2 Cu2 S4'\n_cell_volume   169.534\n_cell_formula_units_Z   2\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_type_symbol\n _atom_type_oxidation_number\n  In3+  3.0\n  Cu+  1.0\n  S2-  -2.0\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  In3+  In0  1  0.609  0.603  0.097  1.0\n  In3+  In1  1  0.854  0.341  0.594  1.0\n  Cu+  Cu2  1  0.36  0.347  0.348  1.0\n  Cu+  Cu3  1  0.361  0.853  0.609  1.0\n  S2-  S4  1  0.001  0.963  0.342  1.0\n  S2-  S5  1  0.218  0.749  0.839  1.0\n  S2-  S6  1  0.445  0.471  0.358  1.0\n 

In [214]:
new_c1 = num_tokenizer.num_matcher("+0.23")

In [224]:
new_c1 = num_tokenizer.num_matcher(c1)
new_c1

"data_InCuS2\n_symmetry_space_group_name_H-M   P1\n_cell_length_a   _5_0_ _._ _5_-1_ _2_-2_\n_cell_length_b   _5_0_ _._ _5_-1_ _2_-2_\n_cell_length_c   _6_0_ _._ _7_-1_ _9_-2_ _6_-3_\n_cell_angle_alpha   _1_2_ _1_1_ _3_0_ _._ _9_-1_ _6_-2_ _3_-3_\n_cell_angle_beta   _1_2_ _1_1_ _3_0_ _._ _9_-1_ _6_-2_ _3_-3_\n_cell_angle_gamma   _9_1_ _0_0_ _._ _0_-1_\n_symmetry_Int_Tables_number   _1_0_\n_chemical_formula_structural   InCuS2\n_chemical_formula_sum   'In2 Cu2 S4'\n_cell_volume   _1_2_ _6_1_ _9_0_ _._ _5_-1_ _3_-2_ _4_-3_\n_cell_formula_units_Z   _2_0_\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  _1_0_  'x, y, z'\nloop_\n _atom_type_symbol\n _atom_type_oxidation_number\n  In3+  _3_0_ _._ _0_-1_\n  Cu+  _1_0_ _._ _0_-1_\n  S2-  -_2_0_ _._ _0_-1_\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  In3+  In0  _1_0_  _0_0_ _._ _6_-1_ _0_-2_ _9_-3

In [170]:
new_c1 = num_tokenizer.num_matcher(c1)

In [129]:
tokens = num_tokenizer.tokenize("0.3")


In [6]:
import os
import json
import re

CIF_VOCAB = os.path.join("/home/so87pot/n0w0f/xtal2txt/src/xtal2txt/vocabs/cif_vocab_rt.json")
def load_vocab(vocab):
    with open(vocab, "r", encoding="utf-8") as file:
                    return json.load(file)
vocab = load_vocab(CIF_VOCAB)

In [90]:
tokens = list(vocab.keys())
string_tokens = [token for token in tokens if isinstance(token, str)]
string_tokens.sort(key=len, reverse=True)


In [93]:
transformations = [
    #("permute_structure", {"seed": 42}),
     ("translate_single_atom", {"seed": 42}),
     ("perturb_structure", {"seed": 42, "max_distance":0.1}),
     ("translate_structure", {"seed": 42, "vector": [0.1,0.1,0.1], "frac_coords": True}),
    ]

text_rep = TextRep.from_input(structure, transformations)
c1 = text_rep.get_cif_string()

In [None]:

import json

# Open the text file and read the lines
with open('/home/so87pot/n0w0f/xtal2txt/src/xtal2txt/vocabs/cif_vocab_rt.txt', 'r') as f:
    lines = [line.strip() for line in f]

# Create a dictionary where the keys are the lines and the values are sequential numbers
data = {line: i for i, line in enumerate(lines, start=164)}

# Write the dictionary to a JSON file
with open('/home/so87pot/n0w0f/xtal2txt/src/xtal2txt/vocabs/cif_vocab_rt.json', 'w') as f:
    json.dump(data, f)


In [1]:
from xtal2txt.tokenizer import CifTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [9]:

transformations = None
cif = TextRep.from_input(structure, transformations).get_cif_string()
slice = TextRep.from_input(structure, transformations).get_slice()
crystal = TextRep.from_input(structure, transformations).get_crystal_llm_rep()

In [10]:
cif

"data_InCuS2\n_symmetry_space_group_name_H-M   P1\n_cell_length_a   5.52\n_cell_length_b   5.52\n_cell_length_c   6.796\n_cell_angle_alpha   113.963\n_cell_angle_beta   113.963\n_cell_angle_gamma   90.0\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   InCuS2\n_chemical_formula_sum   'In2 Cu2 S4'\n_cell_volume   169.534\n_cell_formula_units_Z   2\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_type_symbol\n _atom_type_oxidation_number\n  In3+  3.0\n  Cu+  1.0\n  S2-  -2.0\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  In3+  In0  1  0.609  0.603  0.097  1.0\n  In3+  In1  1  0.854  0.341  0.594  1.0\n  Cu+  Cu2  1  0.36  0.347  0.348  1.0\n  Cu+  Cu3  1  0.361  0.853  0.609  1.0\n  S2-  S4  1  0.001  0.963  0.342  1.0\n  S2-  S5  1  0.218  0.749  0.839  1.0\n  S2-  S6  1  0.445  0.471  0.358  1.0\n 

In [11]:
tokenizer = CifTokenizer(special_num_token=True,model_max_length=512, truncation=False, padding=False)

In [None]:
ciftokenizer = CifTokenizer(special_num_token=True,model_max_length=512, truncation=False, padding=False)

slicetokenizer = SliceTokenizer(special_num_token=True,model_max_length=512, truncation=False, padding=False)

crystaltokenizer = CrysllmTokenizer(special_num_token=True,model_max_length=512, truncation=False, padding=False)

In [13]:
cif

"data_InCuS2\n_symmetry_space_group_name_H-M   P1\n_cell_length_a   5.52\n_cell_length_b   5.52\n_cell_length_c   6.796\n_cell_angle_alpha   113.963\n_cell_angle_beta   113.963\n_cell_angle_gamma   90.0\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   InCuS2\n_chemical_formula_sum   'In2 Cu2 S4'\n_cell_volume   169.534\n_cell_formula_units_Z   2\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_type_symbol\n _atom_type_oxidation_number\n  In3+  3.0\n  Cu+  1.0\n  S2-  -2.0\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  In3+  In0  1  0.609  0.603  0.097  1.0\n  In3+  In1  1  0.854  0.341  0.594  1.0\n  Cu+  Cu2  1  0.36  0.347  0.348  1.0\n  Cu+  Cu3  1  0.361  0.853  0.609  1.0\n  S2-  S4  1  0.001  0.963  0.342  1.0\n  S2-  S5  1  0.218  0.749  0.839  1.0\n  S2-  S6  1  0.445  0.471  0.358  1.0\n 

In [14]:
tokenizer.tokenize(cif)

['data_',
 'In',
 'Cu',
 'S',
 '_symmetry_space_group_name_H-M',
 '   ',
 'P',
 '_cell_length_a',
 '   ',
 '_5_0_',
 '_._',
 '_5_-1_',
 '_2_-2_',
 '_cell_length_b',
 '   ',
 '_5_0_',
 '_._',
 '_5_-1_',
 '_2_-2_',
 '_cell_length_c',
 '   ',
 '_6_0_',
 '_._',
 '_7_-1_',
 '_9_-2_',
 '_6_-3_',
 '_cell_angle_alpha',
 '   ',
 '_1_2_',
 '_1_1_',
 '_3_0_',
 '_._',
 '_9_-1_',
 '_6_-2_',
 '_3_-3_',
 '_cell_angle_beta',
 '   ',
 '_1_2_',
 '_1_1_',
 '_3_0_',
 '_._',
 '_9_-1_',
 '_6_-2_',
 '_3_-3_',
 '_cell_angle_gamma',
 '   ',
 '_9_1_',
 '_0_0_',
 '_._',
 '_0_-1_',
 '_symmetry_Int_Tables_number',
 '   ',
 '_1_0_',
 '_chemical_formula_structural',
 '   ',
 'In',
 'Cu',
 'S',
 '_chemical_formula_sum',
 '   ',
 "'",
 'In',
 ' ',
 'Cu',
 ' ',
 'S',
 "'",
 '_cell_volume',
 '   ',
 '_1_2_',
 '_6_1_',
 '_9_0_',
 '_._',
 '_5_-1_',
 '_3_-2_',
 '_4_-3_',
 '_cell_formula_units_Z',
 '   ',
 '_2_0_',
 'loop_',
 ' ',
 '_symmetry_equiv_pos_site_id',
 ' ',
 '_symmetry_equiv_pos_as_xyz',
 '  ',
 '_1_0_',
 '  ',
 

In [16]:
from xtal2txt.tokenizer import NumTokenizer

In [17]:
num_tokenizer = NumTokenizer()


In [18]:
num_tokenizer.num_matcher("P1\n_cell_length_a   5.52\n_cell_length_b   5.52\n_cell_length_c   6.796\n_cell_angle_alpha")

'P1\n_cell_length_a   _5_0__.__5_-1__2_-2_\n_cell_length_b   _5_0__.__5_-1__2_-2_\n_cell_length_c   _6_0__.__7_-1__9_-2__6_-3_\n_cell_angle_alpha'

In [37]:
num_tokenizer.tokenize("P1\n_cell_length_a   _5_0_ _._ _5_-1_ _2_-2_\n_cell_length_b   _5_0_ _._ _5_-1_ _2_-2_\n_cell_length_c   _6_0_ _._ _7_-1_ _9_-2_ _6_-3_\n_cell_angle_alpha")

[]

In [19]:
num_tokenizer.num_matcher("InCuS2")

'InCuS2'

In [22]:
import re
def num_matcher_( text: str) -> str:
        """Extract numbers from a sentence and replace them with tokens."""
        matches = re.findall(r'(\d+\.\d+|\d+)', text)  # This regex captures both whole numbers and decimal numbers
        print(matches)
        

In [23]:
num_matcher_("InCuS2")

['2']


In [198]:
import json
import os
import re

from tokenizers import Tokenizer
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from typing import List
import re


THIS_DIR = os.path.dirname("/home/so87pot/n0w0f/xtal2txt/src/xtal2txt/")

SLICE_VOCAB = os.path.join(THIS_DIR, "vocabs", "slice_vocab.txt")
SLICE_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "slice_vocab_rt.txt")

COMPOSITION_VOCAB = os.path.join(THIS_DIR, "vocabs", "composition_vocab.txt")

CIF_VOCAB = os.path.join(THIS_DIR, "vocabs", "cif_vocab.json")
CIF_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "cif_vocab_rt.json")

CRYSTAL_LLM_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab.json")
CRYSTAL_LLM_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab_rt.json")


ROBOCRYS_VOCAB = os.path.join(THIS_DIR, "vocabs", "robocrys_vocab.json")

from xtal2txt.analysis import (
    ANALYSIS_MASK_TOKENS,
    CIF_ANALYSIS_DICT,
    COMPOSITION_ANALYSIS_DICT,
    CRYSTAL_LLM_ANALYSIS_DICT,
    SLICE_ANALYSIS_DICT,
)


class NumTokenizer_():
    """Tokenize numbers as implemented in Regression Transformer.
        https://www.nature.com/articles/s42256-023-00639-z"""
        

    def __init__(self) -> None:
        """Tokenizer for numbers."""
        self.regex = re.compile(r"(\+|-)?(\d+)(\.)?(\d+)?\s*")

    def num_matcher(self, text: str) -> str:
        """Extract numbers from a sentence and replace them with tokens."""
        pattern = r"\d+(?:\.\d+)?"  # Match any number, whether it is part of a string or not
        matches = re.finditer(pattern, text)
        for match in matches:
            start, end = match.start(), match.end()
            tokens = self.tokenize(match.group())
            replacement = ''.join(tokens)
            text = text[:start] + replacement + text[end:]
        return text

    def tokenize(self, text: str) -> List[str]:
        """Tokenization of numbers as in RT.
         '0.9' -> '_0_0_', '_._', '_9_-1_'

        Args:
            text: number as string to be tokenized.

        Returns:
            extracted tokens.
        """
        tokens = []
        matched = self.regex.match(text)
        if matched:
            sign, units, dot, decimals = matched.groups()
            tokens = []
            if sign:
                tokens += [f"_{sign}_"]
            tokens += [
                f"_{number}_{position}_" for position, number in enumerate(units[::-1])
            ][::-1]
            if dot:
                tokens += [f"_{dot}_"]
            if decimals:
                tokens += [
                    f"_{number}_-{position}_"
                    for position, number in enumerate(decimals, 1)
                ]
        return tokens
    
    @staticmethod
    def convert_tokens_to_float(tokens: List[str]) -> float:
        """Converts tokens representing a float value into a float.
        NOTE: Expects that non-floating tokens are strippped off

        Args:
            tokens: List of tokens, each representing a float.
                E.g.: ['_0_0_', '_._', '_9_-1_', '_3_-2_', '_1_-3_']

        Returns:
            float: Float representation for the list of tokens.
        """
        try:
            float_string = "".join([token.split("_")[1] for token in tokens])
            float_value = float(float_string)
        except ValueError:
            float_value = -1
        return float_value

    
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Converts tokens to string.

        Args:
            tokens: List of tokens.

        Returns:
            str: String representation of the tokens.
        """
        return "".join([token.split("_")[1] for token in tokens])
        


class Xtal2txtTokenizer(PreTrainedTokenizer):
    def __init__(
        self, special_num_token:bool=False, vocab_file=None, model_max_length=None, padding_length=None, **kwargs
    ):
        super(Xtal2txtTokenizer, self).__init__(
         model_max_length=model_max_length, **kwargs
        )
        self.truncation = False
        self.padding = False
        self.padding_length = padding_length
        self.special_num_tokens = special_num_token
        self.vocab = self.load_vocab(vocab_file)
        self.vocab_file = vocab_file

    def load_vocab(self, vocab_file):
        _, file_extension = os.path.splitext(vocab_file)
        if file_extension == ".txt":
            with open(vocab_file, "r", encoding="utf-8") as file:
                vocab = file.read().splitlines()
            return {token: idx for idx, token in enumerate(vocab)}
        elif file_extension == ".json":
            with open(vocab_file, "r", encoding="utf-8") as file:
                return json.load(file)
        else:
            raise ValueError(f"Unsupported file type: {file_extension}")

    def get_vocab(self):
        return self.vocab
    
    def get_special_num_tokens(self,text):
        num_tokenizer = NumTokenizer()
        return num_tokenizer.num_matcher(text)


    def tokenize(self, text):
        if self.special_num_tokens:
            text = self.get_special_num_tokens(text)

        tokens = list(self.vocab.keys())
        string_tokens = [token for token in tokens if isinstance(token, str)]
        string_tokens.sort(key=len, reverse=True)
        escaped_tokens = [re.escape(token) for token in string_tokens]
        pattern_str = "|".join(escaped_tokens)
        pattern = re.compile(pattern_str)
        matches = pattern.findall(text)

        if self.truncation and len(matches) > self.model_max_length:
            matches = matches[: self.model_max_length]

        if self.padding and len(matches) < self.padding_length:
            matches += [self.pad_token] * (self.padding_length - len(matches))

        return matches

    def convert_tokens_to_string(self, tokens):
        return " ".join(tokens)

    def _add_tokens(self, new_tokens, **kwargs):
        for token in new_tokens:
            if token not in self.added_tokens_encoder:
                self.vocab[token] = len(self.vocab)

    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index):
        return list(self.vocab.keys())[index]

    def enable_truncation(self, max_length):
        self.model_max_length = max_length
        self.truncation = True

    def disable_truncation(self):
        self.truncation = False

    def enable_padding(self, length=None):
        self.padding = True
        self.padding_length = length

    def disable_padding(self):
        self.padding = False

    def add_special_tokens(self, special_tokens):
        for token, value in special_tokens.items():
            if value not in self.vocab:
                setattr(self, token, value)
                self.vocab[value] = len(self.vocab)
        self.save_vocabulary(os.path.dirname(self.vocab_file))

    def token_analysis(self, tokens):
        """This method should be implemented by the Downstream tokenizers."""
        raise NotImplementedError

    def save_vocabulary(self, save_directory, filename_prefix=None):
        """Save the vocabulary, ensures vocabularies are not overwritten. Filename follow the convention {index}-{filename_prefix}.json. Index keeps track of the latest vocabulary saved."""
        index = 0
        if os.path.isdir(save_directory):
            vocab_files = list(
                filter(lambda x: x.endswith(".json"), os.listdir(save_directory))
            )
            for vocab_file in vocab_files:
                try:
                    index = max(index, int(vocab_file.split("-")[0]))
                except ValueError:
                    pass  # Ignore files that do not start with an integer

        vocab_file = os.path.join(
            save_directory,
            f"{index + 1}-{filename_prefix}.json"
            if filename_prefix
            else f"{index + 1}.json",
        )

        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self.vocab, f, ensure_ascii=False)

        return (vocab_file,)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                vocab_files = list(
                    filter(
                        lambda x: x.endswith(".json"),
                        os.listdir(pretrained_model_name_or_path),
                    )
                )
                vocab_files.sort(key=lambda x: int(x.split("-")[0]))
                vocab_file = os.path.join(
                    pretrained_model_name_or_path, vocab_files[-1]
                )

        if vocab_file is None:
            raise ValueError("You should specify a path to a vocab file")

        with open(vocab_file, "r", encoding="utf-8") as f:
            vocab = json.load(f)

        tokenizer = cls(vocab_file, *inputs, **kwargs)
        tokenizer.vocab = vocab

        return tokenizer


class SliceTokenizer(Xtal2txtTokenizer):
    def __init__(
        self,
        special_num_token:bool=False,
        vocab_file=None,
        model_max_length=None,
        padding_length=None,
        **kwargs,
    ):  
        if special_num_token:
            vocab_file = SLICE_RT_VOCAB if vocab_file is None else vocab_file
        else:
            vocab_file = SLICE_VOCAB if vocab_file is None else vocab_file
        super(SliceTokenizer, self).__init__(
            special_num_token=special_num_token,
            vocab_file=vocab_file,
            model_max_length=model_max_length,
            padding_length=padding_length,
            **kwargs,
        )

    

class CifTokenizer(Xtal2txtTokenizer):
    def __init__(
        self, special_num_token:bool = False, vocab_file=None, model_max_length=None, padding_length=None, **kwargs
    ):
        if special_num_token:
            vocab_file = CIF_RT_VOCAB 
        else:
            vocab_file = CIF_VOCAB 
        super(CifTokenizer, self).__init__(
            special_num_token=special_num_token,
            vocab_file=vocab_file,
            model_max_length=model_max_length,
            padding_length=padding_length,
            **kwargs,
        )

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def token_analysis(self, list_of_tokens):
        """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
        token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
        analysis_masks = ANALYSIS_MASK_TOKENS
        token_type = CIF_ANALYSIS_DICT
        return [
            analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
            for token in list_of_tokens
        ]


In [199]:
ciftokenizer = CifTokenizer(special_num_token=True,model_max_length=512, truncation=False, padding=False)

slicetokenizer = SliceTokenizer(special_num_token=True,model_max_length=512, truncation=False, padding=False)

In [204]:
cif

"data_InCuS2\n_symmetry_space_group_name_H-M   P1\n_cell_length_a   5.52\n_cell_length_b   5.52\n_cell_length_c   6.796\n_cell_angle_alpha   113.963\n_cell_angle_beta   113.963\n_cell_angle_gamma   90.0\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   InCuS2\n_chemical_formula_sum   'In2 Cu2 S4'\n_cell_volume   169.534\n_cell_formula_units_Z   2\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_type_symbol\n _atom_type_oxidation_number\n  In3+  3.0\n  Cu+  1.0\n  S2-  -2.0\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  In3+  In0  1  0.609  0.603  0.097  1.0\n  In3+  In1  1  0.854  0.341  0.594  1.0\n  Cu+  Cu2  1  0.36  0.347  0.348  1.0\n  Cu+  Cu3  1  0.361  0.853  0.609  1.0\n  S2-  S4  1  0.001  0.963  0.342  1.0\n  S2-  S5  1  0.218  0.749  0.839  1.0\n  S2-  S6  1  0.445  0.471  0.358  1.0\n 

In [205]:
num_tokenizer.num_matcher(cif)

"data_InCuS_2_0_\n_symmetry_space_group_name_H-M_1_0_  P1\n_cell_len_5_0__.__5_-1__2_-2_a _5_0__.__5_-1__2_-2_52_6_0__.__7_-1__9_-2__6_-3__1_2__1_1__3_0__.__9_-1__6__1_2__1_1__3_0__.__9_-1__6_-_9_1__0_0__.__0_-1__-3_3_h_b   5.52_1_0__cell_length_c   6.796\n_cell_angl_2_0__alpha   113.963\n_cell_a_2_0_2__4_0__gle_beta   1_1_2__6_1__9_0__.__5_-1__3_-2__4_2_0_-3__cell_angle_gamma   90.0\n_symmetry_Int_Tables_number   1\n_c_1_0_emical_formula_structural   InCuS2\n_chemical_formula_sum   'In2 Cu_3_0_3_0__.__0__1_0__._2_0__2_0__.__0_-1_-1_4'\n_cell_volume   169.534\n_cell_formula_units_Z   2\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_type_symbol\n_3_0___0__1__0_0__._0_0__._0_0__._1_0__.__3_0_0_1__1__0_0__._0_0__._0_0__._1_0__.__0_-1_2__1__0_0___0_0__._0_0__._1_0__.__0_-1_3__1__0_0__._0_0__._0_0__._1_0__._2_0__4__1__0_0__._0_0__._0_0__._1_0__._2_0__5__1__0_0__._0_0__._0_0__._1_0__._2_0__6__1__0_0__._0_0__._0_0__._1_0__._2_0__7__1__0_0__._0

In [202]:
num_tokenizer = NumTokenizer_()
num_tokenizer.num_matcher("InCuS2h Hi Cu2 2 0.2 ")


'InCuS_2_0_h H_2_2_0_0__.__2_-1_0_ Cu2 2 0.2 '

In [203]:
NumTokenizer().num_matcher("InCuS2h Hi Cu2 2 0.2 ")

'InCuS2h Hi Cu2 _2_0_ 0._2_0_ '

In [171]:
def hi(text: str) -> str:
        """Extract numbers from a sentence and replace them with tokens."""
        matches = re.findall(r'\d+(?:\.\d+)?', text)
        print(matches)

num_tokenizer = NumTokenizer_()
print(hi("InCuS2h Hi Cu2 2 0.2 "))
print(num_tokenizer.num_matcher("InCuS2h Hi Cu2 2 0.2 "))
#


['2', '2', '2', '0.2']
None
InCuS2h Hi Cu2 _2_0_ 0._2_0_ 


In [91]:
ciftokenizer.tokenize("InCuS2h Hi 2.0 5.2  h")

['In',
 'Cu',
 'S',
 ' ',
 'H',
 ' ',
 '_2_0_',
 '.',
 ' ',
 '_5_0_',
 '_._',
 '_2_-1_',
 '  ']

In [None]:
\b(\d+(?:\.\d*)?|\d+\w+)\b


In [100]:
import re

def find_numbers(text):
  """
  This function finds standalone numbers and numbers within a single letter in a text string.

  Args:
      text: The text string to search.

  Returns:
      A list of matched numbers.
  """
  pattern = r"[a-zA-Z]?(\d+(?:\.\d*)?|\d[a-zA-Z]?)\b"  # Match digit followed by single letter
  
  matches = re.findall(pattern, text)
  return matches

# Example usage
text = "In2S, Cu2, He3, 5.0, 123, FeO2"
numbers = find_numbers(text)
print(numbers)  # Output: ['2', '2', '5.0', '123']


['2S', '2', '3', '5.0', '123', '2']


In [115]:
import re

def find_numbers(text):
    """
    This function finds standalone numbers and numbers within a single letter in a text string.

    Args:
        text: The text string to search.

    Returns:
        A list of matched numbers.
    """
    pattern = r"(?<!\w)(?:(?<=\D)\d+(?:\.\d+)?|\d+(?=\D))(?!\w)"
    matches = re.findall(pattern, text)
    return matches

# Example usage
text = "In2S, Cu2, He3, 5.0, 123, FeO2, InCuS2h"
numbers = find_numbers(text)
print(numbers)  # Output: ['2', '2', '3', '5.0', '123', '2']


['5.0', '123']


In [149]:
NumTokenizer_

__main__.NumTokenizer_

In [163]:
import re

def find_numbers_in_strings(text):
    """
    This function finds numbers that are part of a string and standalone numbers.

    Args:
        text: The text string to search.

    Returns:
        A list of matched numbers.
    """
    pattern = r"\d+(?:\.\d+)?"  # Match any number, whether it is part of a string or not
    matches = re.findall(pattern, text)
    return matches

# Example usage
text = "In2S, Cu2, He3, 5.0, 123, FeO2, InCuS2h, Fe2So3, 0.2, 5.8"
numbers = find_numbers_in_strings(text)
print(numbers)  # Output: ['2', '2', '3', '5', '0', '123', '2', '2']



['2', '2', '3', '5.0', '123', '2', '2', '2', '3', '0.2', '5.8']
