<a href="https://colab.research.google.com/github/kangmg/OverlayMol/blob/main/test/with_plotly.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<font size="6" color="skyblue">Visualization with Plotly package<font>

<font size="5" color="pink">TODO<font>

- save xyz files ( json to xyz format function )
- rewrite docs strings

In [69]:
#@title test xyz files
!wget -q https://raw.githubusercontent.com/kangmg/OverlayMol/main/examples/sn2.xyz -O sn2.xyz
!wget -q https://raw.githubusercontent.com/kangmg/OverlayMol/main/examples/DA.xyz -O DA.xyz
!wget -q https://raw.githubusercontent.com/kangmg/OverlayMol/main/examples/butadiene.xyz -O butadiene.xyz

In [2]:
%%writefile sn2_1.xyz
6
 0.000000
  C -1.277168 0.545365 -0.000063
  Br 0.648058 0.543727 0.000199
  H -1.652166 0.593222 1.017641
  H -1.652215 -0.359651 -0.467952
  H -1.651698 1.403205 -0.550042
  Cl -4.402752 0.572053 0.000227

Overwriting sn2_1.xyz


In [162]:
%%writefile sn2_2.xyz
6
 0.000000
  C -1.515108 0.536986 0.005389
  Br 0.921841 0.546011 -0.004956
  H -1.598352 0.608067 1.086686
  H -1.600902 -0.415977 -0.510068
  H -1.592975 1.438888 -0.572824
  Cl -4.106575 0.572011 -0.004277

Overwriting sn2_2.xyz


In [4]:
#@title atomic informations


#covalent radii from Alvarez (2008)
#DOI: 10.1039/b801115j
covalent_radii = {
'H':  0.31, 'He': 0.28, 'Li': 1.28, 'Be': 0.96, 'B':  0.84, 'C':  0.76,
'N':  0.71, 'O':  0.66, 'F':  0.57, 'Ne': 0.58, 'Na': 1.66, 'Mg': 1.41,
'Al': 1.21, 'Si': 1.11, 'P':  1.07, 'S':  1.05, 'Cl': 1.02, 'Ar': 1.06,
'K':  2.03, 'Ca': 1.76, 'Sc': 1.70, 'Ti': 1.60, 'V':  1.53, 'Cr': 1.39,
'Mn': 1.61, 'Fe': 1.52, 'Co': 1.50, 'Ni': 1.24, 'Cu': 1.32, 'Zn': 1.22,
'Ga': 1.22, 'Ge': 1.20, 'As': 1.19, 'Se': 1.20, 'Br': 1.20, 'Kr': 1.16,
'Rb': 2.20, 'Sr': 1.95, 'Y':  1.90, 'Zr': 1.75, 'Nb': 1.64, 'Mo': 1.54,
'Tc': 1.47, 'Ru': 1.46, 'Rh': 1.42, 'Pd': 1.39, 'Ag': 1.45, 'Cd': 1.44,
'In': 1.42, 'Sn': 1.39, 'Sb': 1.39, 'Te': 1.38, 'I':  1.39, 'Xe': 1.40,
'Cs': 2.44, 'Ba': 2.15, 'La': 2.07, 'Ce': 2.04, 'Pr': 2.03, 'Nd': 2.01,
'Pm': 1.99, 'Sm': 1.98, 'Eu': 1.98, 'Gd': 1.96, 'Tb': 1.94, 'Dy': 1.92,
'Ho': 1.92, 'Er': 1.89, 'Tm': 1.90, 'Yb': 1.87, 'Lu': 1.87, 'Hf': 1.75,
'Ta': 1.70, 'W':  1.62, 'Re': 1.51, 'Os': 1.44, 'Ir': 1.41, 'Pt': 1.36,
'Au': 1.36, 'Hg': 1.32, 'Tl': 1.45, 'Pb': 1.46, 'Bi': 1.48, 'Po': 1.40,
'At': 1.50, 'Rn': 1.50, 'Fr': 2.60, 'Ra': 2.21, 'Ac': 2.15, 'Th': 2.06,
'Pa': 2.00, 'U':  1.96, 'Np': 1.90, 'Pu': 1.87, 'Am': 1.80, 'Cm': 1.69
}


# Ref. https://github.com/dralgroup/mlatom/blob/main/mlatom/data.py
periodic_table = """X
    H                                                                                                                           He
    Li  Be                                                                                                  B   C   N   O   F   Ne
    Na  Mg                                                                                                  Al  Si  P   S   Cl  Ar
    K   Ca  Sc                                                          Ti  V   Cr  Mn  Fe  Co  Ni  Cu  Zn  Ga  Ge  As  Se  Br  Kr
    Rb  Sr  Y                                                           Zr  Nb  Mo  Tc  Ru  Rh  Pd  Ag  Cd  In  Sn  Sb  Te  I   Xe
    Cs  Ba  La  Ce  Pr  Nd  Pm  Sm  Eu  Gd  Tb  Dy  Ho  Er  Tm  Yb  Lu  Hf  Ta  W   Re  Os  Ir  Pt  Au  Hg  Tl  Pb  Bi  Po  At  Rn
    Fr  Ra  Ac  Th  Pa  U   Np  Pu  Am  Cm
""".strip().split()

atomic_number2element_symbol = {k: v for k, v in enumerate(periodic_table)}
element_symbol2atomic_number = {v: k for k, v in atomic_number2element_symbol.items()}


# Jmol CPK coloring
# Ref. https://jmol.sourceforge.net/jscolors/
atomic_number2hex = {
'H': '#FFFFFF',   'He': '#D9FFFF',    'Li': '#CC80FF',    'Be': '#C2FF00',
'B': '#FFB5B5',   'C': '#909090',     'N': '#3050F8',   'O': '#FF0D0D',
'F': '#90E050',   'Ne': '#B3E3F5',    'Na': '#AB5CF2',  'Mg': '#8AFF00',
'Al': '#BFA6A6',  'Si': '#F0C8A0',    'P': '#FF8000',   'S': '#FFFF30',
'Cl': '#1FF01F',  'Ar': '#80D1E3',    'K': '#8F40D4',   'Ca': '#3DFF00',
'Sc': '#E6E6E6',  'Ti': '#BFC2C7',    'V': '#A6A6AB',   'Cr': '#8A99C7',
'Mn': '#9C7AC7',  'Fe': '#E06633',    'Co': '#F090A0',  'Ni': '#50D050',
'Cu': '#C88033',  'Zn': '#7D80B0',    'Ga': '#C28F8F',  'Ge': '#668F8F',
'As': '#BD80E3',  'Se': '#FFA100',    'Br': '#A62929',  'Kr': '#5CB8D1',
'Rb': '#702EB0',  'Sr': '#00FF00',    'Y': '#94FFFF',   'Zr': '#94E0E0',
'Nb': '#73C2C9',  'Mo': '#54B5B5',    'Tc': '#3B9E9E',  'Ru': '#248F8F',
'Rh': '#0A7D8C',  'Pd': '#006985',    'Ag': '#C0C0C0',  'Cd': '#FFD98F',
'In': '#A67573',  'Sn': '#668080',    'Sb': '#9E63B5',  'Te': '#D47A00',
'I': '#940094',   'Xe': '#429EB0',    'Cs': '#57178F',  'Ba': '#00C900',
'La': '#70D4FF',  'Ce': '#FFFFC7',    'Pr': '#D9FFC7',  'Nd': '#C7FFC7',
'Pm': '#A3FFC7',  'Sm': '#8FFFC7',    'Eu': '#61FFC7',  'Gd': '#45FFC7',
'Tb': '#30FFC7',  'Dy': '#1FFFC7',    'Ho': '#00FF9C',  'Er': '#00E675',
'Tm': '#00D452',  'Yb': '#00BF38',    'Lu': '#00AB24',  'Hf': '#4DC2FF',
'Ta': '#4DA6FF',  'W': '#2194D6',     'Re': '#267DAB',  'Os': '#266696',
'Ir': '#175487',  'Pt': '#D0D0E0',    'Au': '#FFD123',  'Hg': '#B8B8D0',
'Tl': '#A6544D',  'Pb': '#575961',    'Bi': '#9E4FB5',  'Po': '#AB5C00',
'At': '#754F45',  'Rn': '#428296',    'Fr': '#420066',  'Ra': '#007D00',
'Ac': '#70ABFA',  'Th': '#00BAFF',    'Pa': '#00A1FF',  'U': '#008FFF',
'Np': '#0080FF',  'Pu': '#006BFF',    'Am': '#545CF2',  'Cm': '#785CE3'
    }

atomic_number2hex = {element_symbol2atomic_number[symbol]: rgb for symbol, rgb in atomic_number2hex.items()}

"""
atomic_number2rgb = {
    'H': [255,255,255],
    'He': [217,255,255],
    'Li': [204,128,255],
    'Be': [194,255,0],
    'B': [255,181,181],
    'C': [144,144,144],
    'N': [48,80,248],
    'O': [255,13,13],
    'F': [144,224,80],
    'Ne': [179,227,245],
    'Na': [171,92,242],
    'Mg': [138,255,0],
    'Al': [191,166,166],
    'Si': [240,200,160],
    'P': [255,128,0],
    'S': [255,255,48],
    'Cl': [31,240,31],
    'Ar': [128,209,227],
    'K': [143,64,212],
    'Ca': [61,255,0],
    'Sc': [230,230,230],
    'Ti': [191,194,199],
    'V': [166,166,171],
    'Cr': [138,153,199],
    'Mn': [156,122,199],
    'Fe': [224,102,51],
    'Co': [240,144,160],
    'Ni': [80,208,80],
    'Cu': [200,128,51],
    'Zn': [125,128,176],
    'Ga': [194,143,143],
    'Ge': [102,143,143],
    'As': [189,128,227],
    'Se': [255,161,0],
    'Br': [166,41,41],
    'Kr': [92,184,209],
    'Rb': [112,46,176],
    'Sr': [0,255,0],
    'Y': [148,255,255],
    'Zr': [148,224,224],
    'Nb': [115,194,201],
    'Mo': [84,181,181],
    'Tc': [59,158,158],
    'Ru': [36,143,143],
    'Rh': [10,125,140],
    'Pd': [0,105,133],
    'Ag': [192,192,192],
    'Cd': [255,217,143],
    'In': [166,117,115],
    'Sn': [102,128,128],
    'Sb': [158,99,181],
    'Te': [212,122,0],
    'I': [148,0,148],
    'Xe': [66,158,176],
    'Cs': [87,23,143],
    'Ba': [0,201,0],
    'La': [112,212,255],
    'Ce': [255,255,199],
    'Pr': [217,255,199],
    'Nd': [199,255,199],
    'Pm': [163,255,199],
    'Sm': [143,255,199],
    'Eu': [97,255,199],
    'Gd': [69,255,199],
    'Tb': [48,255,199],
    'Dy': [31,255,199],
    'Ho': [0,255,156],
    'Er': [0,230,117],
    'Tm': [0,212,82],
    'Yb': [0,191,56],
    'Lu': [0,171,36],
    'Hf': [77,194,255],
    'Ta': [77,166,255],
    'W': [33,148,214],
    'Re': [38,125,171],
    'Os': [38,102,150],
    'Ir': [23,84,135],
    'Pt': [208,208,224],
    'Au': [255,209,35],
    'Hg': [184,184,208],
    'Tl': [166,84,77],
    'Pb': [87,89,97],
    'Bi': [158,79,181],
    'Po': [171,92,0],
    'At': [117,79,69],
    'Rn': [66,130,150],
    'Fr': [66,0,102],
    'Ra': [0,125,0],
    'Ac': [112,171,250],
    'Th': [0,186,255],
    'Pa': [0,161,255],
    'U': [0,143,255],
    'Np': [0,128,255],
    'Pu': [0,107,255],
    'Am': [84,92,242],
    'Cm': [120,92,227]
    }

atomic_number2rgb = {element_symbol2atomic_number[symbol]: rgb for symbol, rgb in atomic_number2rgb.items()}
"""
None

In [5]:
#@title import packages

# import sys                                                      #sys.exit
# import os.path                                                  #filename split
# import numpy as np                                              #calculations
# from scipy.spatial.distance import pdist, squareform, #cosine    #for the calculations of the distance matrix and angles (cosine)
# from cycler import cycler                                       #generate color cycle
# from itertools import cycle                                     #for color cycling
# import io                                                       #IO for (easy) saving multi xyz
# import re                                                       #regex (get atom x y z from multi xyz)

from numpy import ndarray
import numpy as np
from os.path import isfile, basename
import re
from collections.abc import Iterable
from copy import deepcopy
import plotly.colors as pc
import plotly.graph_objects as go

from itertools import cycle
from scipy.spatial.distance import pdist, squareform #for the calculations of the distance matrix and angles (cosine)

In [6]:
#@title backup old

'''
def kabsch(P, Q):
    """
    Description
    -----------
    Compute the optimal rotation matrix using the Kabsch algorithm to align
    two sets of points P and Q.

    Parameters
    ----------
    P : ndarray
        An array of shape (N, 3) representing the first set of points.
    Q : ndarray
        An array of shape (N, 3) representing the second set of points.

    Returns
    -------
    ndarray
        A 3x3 rotation matrix.

    Ref. https://github.com/charnley/rmsd
    """
    # Compute the covariance matrix
    C = np.dot(np.transpose(P), Q)

    # Perform singular value decomposition
    V, S, W = np.linalg.svd(C)

    # Check if a reflection is needed
    d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0
    if d:
        # Adjust the singular values and V matrix if reflection is needed
        S[-1] = -S[-1]
        V[:, -1] = -V[:, -1]

    # Compute the rotation matrix
    U = np.dot(V, W)

    return U


def align_xyz(vec1, vec2, coord)->ndarray:
    """
    Align a set of coordinates by computing the rotation matrix using the Kabsch
    algorithm with two reference vectors and applying it to the coordinates.

    Parameters
    ----------
    vec1 : ndarray
        An array of shape (N, 3) representing the first reference vector.
    vec2 : ndarray
        An array of shape (N, 3) representing the second reference vector.
    coord : ndarray
        An array of shape (M, 3) representing the coordinates to be aligned.

    Returns
    -------
    ndarray
        The aligned coordinates.
    """
    # Compute the rotation matrix using the Kabsch algorithm
    rotmatrix = kabsch(vec1, vec2)

    # Apply the rotation matrix to the coordinates
    return np.dot(coord, rotmatrix)

'''

# def xyz_format_to_json(xyz_coord:str|dict)->dict:
#     """
#     Description
#     -----------
#     Converts xyz format coordinate(string or file) to json

#     Parameters
#     ----------
#     xyz_coord : str or dict
#         str : xyz format string or xyz format file path
#         dict : {name: xyz format}

#     Returns
#     -------
#     xyz_json : dict
#         json format xyz
#         {
#             "name": str,
#             "n_atoms": int,
#             "coordinate": ndarray
#         }

#     Usage
#     -----
#     >>> molecule = {
#     >>>   'name': 'aspirin',
#     >>>   'coordinate': '''
#     >>>   2
#     >>>
#     >>>   H 0.0 0.0 0.7
#     >>>   H 0.0 0.0 0.0
#     >>>   '''
#     >>>     }
#     >>> xyz_json = xyz_format_to_json(molecule)
#     """
#     def _read_string(xyz:str)->str:
#         """read xyz string or filepath
#         """
#         # xyz -> filepath
#         if isfile(xyz):
#             with open(xyz, "r") as file:
#                 xyz_string = file.read()
#             name = basename(xyz)
#             return name, xyz_string
#         # xyz -> xyz format string
#         else:
#             name = ""
#             return name, xyz

#     if isinstance(xyz_coord, dict):
#         name = xyz_coord["name"]
#         _, xyz_string = _read_string(xyz_coord["coordinate"])

#     if isinstance(xyz_coord, str):
#         name, xyz_string = _read_string(xyz_coord)

#     # number of atoms
#     n_atoms = re.search(r"(\d+)", xyz_string).group(0)

#     # split xyz string
#     pattern = re.compile("([a-zA-Z]{1,2}(\s+-?\d+.\d+){3,3})+")
#     xyz_lines = np.array(np.array(list(re.split(r'\s+', tup[0]) for tup in pattern.findall(xyz_string))))

#     # converts atomic_symbol to atomic_number
#     xyz_lines[:, 0] = np.array(list(element_symbol2atomic_number[symbol] for symbol in xyz_lines[:, 0]))

#     # json format xyz
#     xyz_json = {
#         "name": name,
#         "n_atoms": n_atoms,
#         "coordinate": xyz_lines.astype(float)
#     }

#     return xyz_json


'''
def open_xyz_files(xyz_coordinates:str|list)->list:
    """
    Description
    -----------
    Open and read XYZ files, extract headers and atomic coordinates.

    Parameters
    ----------
    filenames : str or list
        str : xyz format traj file path
        list : list of xyz format strings or file paths

    Returns
    -------
    xyz_format_jsons : list
        list of json format xyz

    Usage
    -----
    >>> xyz_files = [
    >>>     {'reactant': 'sn2_reac.xyz'},
    >>>     {'TS': 'sn2_TS.xyz'},
    >>>     {'prod': 'sn2_prod'}
    >>>  ]
    >>> xyz_format_jsons = open_xyz_files(xyz_files)
    >>>
    >>> xyz_format_jsons_from_traj = open_xyz_files('sn2_traj.xyz')
    """
    # traj file
    if isinstance(xyz_coordinates, str):
        with open(xyz_coordinates, "r") as file:
            traj_string = file.read()
        # find all xyz format strings
        pattern = re.compile("(\s?\d+\n.*\n(\s*[a-zA-Z]{1,2}(\s+-?\d+.\d+){3,3}\n?)+)")
        matched_xyz_formats = pattern.findall(traj_string)
        xyz_format_strings = list(tup[0] for tup in matched_xyz_formats)

        # convert xyz format stirng to json format
        xyz_format_jsons = list(map(xyz_format_to_json, xyz_format_strings))

        return xyz_format_jsons

    elif isinstance(xyz_coordinates, Iterable):
        return list(map(xyz_format_to_json, xyz_coordinates))

    else:
        raise TypeError("xyz_coordinates must be str or list")



def superimpose(xyz_format_jsons:list, option="aa", option_param:None|list=None)->dict:
    """
    Description
    -----------
    Superimpose molecules

    Parameters
    ----------
    option : str
        supported options : ["aa", "a", "sa"]
        - aa  : all atoms
        - a   : atoms
        - sa  : same atoms

    option_param : list
        list of atom index to superimpose
        index starts with 1
        - option="aa"   : None
        - option="a"    : e.g. [[1, 2, 3], [4, 5, 6]] # same order as xyz files
        - option="sa"   : e.g. [1, 2, 3]
    """
    # copy xyz_format_jsons
    _xyz_format_jsons = deepcopy(xyz_format_jsons)

    # all atoms ( -aa option )
    if option=="aa":
        if option_param: print("\033[31m[WARNING]\033[0m", "`aa` option does not require `option_param`. `option_param` is ignored.")

        # `aa` option expects that each coordinates has the same order of atoms
        atomic_indice_list = list(map(lambda xyz_json : xyz_json.get("coordinate")[:, 0], _xyz_format_jsons))
        symbols_list = list(atomic_number2element_symbol[atomic_index[0]] for atomic_index in atomic_indice_list)
        # check each coordinates has same order
        if not np.all(np.array(symbols_list) == symbols_list[0]):
            raise ValueError("The `aa` option expects that each coordinates has the same order of atoms. \n Try other options like `sa` or `a`")

        # every molecule is overlaid on the first molecule
        first_molecule = _xyz_format_jsons[0]["coordinate"][:, 1:]
        centroid = np.mean(first_molecule, axis=0)

        # center the first(reference) molecule
        first_molecule -= centroid
        _xyz_format_jsons[0]["coordinate"][:, 1:] = first_molecule

        for mol_idx in range(len(_xyz_format_jsons)):
            # center the molecule
            centroid_mol = np.mean(_xyz_format_jsons[mol_idx]["coordinate"][:, 1:], axis=0)
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] -= centroid_mol

            # overlay each molecule on the first molecule
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] = align_xyz(
                vec1=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:], # molecule to align
                vec2=first_molecule, # reference molecule
                coord=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:] # coordinates to align
              )

    # atoms ( -a option )
    elif option=="a":
        # `option_param` should be a list with the same size as the number of molecules to overlay
        if not option_param: raise ValueError("`a` option requires `option_param`")
        if not isinstance(option_param, Iterable): raise TypeError("`option_param` must be list")
        if not np.all(np.array(list(len(param) for param in option_param)) == len(option_param[0])): raise ValueError("all elements in `option_param` must have the same size")
        if not len(_xyz_format_jsons) == len(option_param): raise ValueError(f"""`a` option requires `option_param` to have the same length as `xyz_format_jsons`\n
         Number of molecules to overlay : {len(_xyz_format_jsons)}
         Length of option_param : {len(option_param)}""")
        if any(list((0 in param) for param in option_param)): raise ValueError("atomic indices start with 1, but 0 was found in `option_param`")

        # reset atomic indice
        option_param = list(
            list(param - 1 for param in mol_param) for mol_param in option_param
            )

        # every molecule is overlaid on the first molecule
        first_molecule_selected_atoms = _xyz_format_jsons[0]["coordinate"][:, 1:][[option_param[0]]][0]
        centroid = np.mean(first_molecule_selected_atoms, axis=0)

        # center the first(reference) molecule
        _xyz_format_jsons[0]["coordinate"][:, 1:] -= centroid

        for mol_idx in range(len(_xyz_format_jsons)):
            # center the molecule
            centroid_mol = np.mean(_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][option_param[mol_idx]], axis=0)
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] -= centroid_mol

            # overlay each molecule on the first molecule
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] = align_xyz(
                vec1=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][[option_param[mol_idx]]][0], # molecule to align
                vec2=first_molecule_selected_atoms,
                coord=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:] # coordinates to align
              )

    # same atom ( -sa option )
    elif option=="sa":
        # `option_param` should be a list
        if not option_param: raise ValueError("`sa` option requires `option_param`")
        if not isinstance(option_param, Iterable): raise TypeError("`option_param` must be list")
        if 0 in option_param: raise ValueError("atomic indices start with 1, but 0 was found in `option_param`")

        # reset atomic indice
        option_param = list(param - 1 for param in option_param)

        # `sa` option expects that each coordinates has the same order of selected atoms
        selected_atomic_indice_list = list(map(lambda xyz_json : xyz_json.get("coordinate")[:, 0][[option_param]], _xyz_format_jsons))
        selected_symbols_list = list(atomic_number2element_symbol[atomic_number[0][0]] for atomic_number in selected_atomic_indice_list)

        # check each coordinates has same order
        if not np.all(np.array(selected_symbols_list) == selected_symbols_list[0]):
            raise ValueError("The `aa` option expects that each coordinates has the same order of atoms. \n Try other options like `sa` or `a`")

        # every molecule is overlaid on the first molecule
        first_molecule_selected_atoms = _xyz_format_jsons[0]["coordinate"][:, 1:][[option_param]][0]
        centroid = np.mean(first_molecule_selected_atoms, axis=0)

        # center the first(reference) molecule
        _xyz_format_jsons[0]["coordinate"][:, 1:] -= centroid

        for mol_idx in range(len(_xyz_format_jsons)):
            # center the molecule
            #centroid_mol = np.mean(_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][[option_param]], axis=0)
            centroid_mol = np.mean(_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][option_param], axis=0)
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] -= centroid_mol
            # overlay each molecule on the first molecule
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] = align_xyz(
                vec1=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][[option_param]][0], # molecule to align
                vec2=first_molecule_selected_atoms,
                coord=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:] # coordinates to align
              )

    else:
        raise ValueError(f"Unsupported option : {option}")

    return _xyz_format_jsons
'''

None

In [98]:
#@title UDFs

def kabsch(P, Q):
    """
    Description
    -----------
    Compute the optimal rotation matrix using the Kabsch algorithm to align
    two sets of points P and Q.

    Parameters
    ----------
    P : ndarray
        An array of shape (N, 3) representing the first set of points.
    Q : ndarray
        An array of shape (N, 3) representing the second set of points.

    Returns
    -------
    ndarray
        A 3x3 rotation matrix.

    Ref. https://github.com/charnley/rmsd
    """
    # Compute the covariance matrix
    C = np.dot(np.transpose(P), Q)

    # Perform singular value decomposition
    V, S, W = np.linalg.svd(C)

    # Check if a reflection is needed
    d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0
    if d:
        # Adjust the singular values and V matrix if reflection is needed
        S[-1] = -S[-1]
        V[:, -1] = -V[:, -1]

    # Compute the rotation matrix
    U = np.dot(V, W)

    return U


def align_xyz(vec1, vec2, coord)->ndarray:
    """
    Align a set of coordinates by computing the rotation matrix using the Kabsch
    algorithm with two reference vectors and applying it to the coordinates.

    Parameters
    ----------
    vec1 : ndarray
        An array of shape (N, 3) representing the first reference vector.
    vec2 : ndarray
        An array of shape (N, 3) representing the second reference vector.
    coord : ndarray
        An array of shape (M, 3) representing the coordinates to be aligned.

    Returns
    -------
    ndarray
        The aligned coordinates.
    """
    # Compute the rotation matrix using the Kabsch algorithm
    rotmatrix = kabsch(vec1, vec2)

    # Apply the rotation matrix to the coordinates
    return np.dot(coord, rotmatrix)


def xyz_format_to_json(xyz_coord:str|dict)->dict:
    """
    Description
    -----------
    Converts xyz format coordinate(string or file) to json

    Parameters
    ----------
    xyz_coord : str or dict
        str : xyz format string or xyz format file path
        dict : {name: xyz format}

    Returns
    -------
    xyz_json : dict
        json format xyz
        {
            "name": str,
            "n_atoms": int,
            "coordinate": ndarray
        }

    Usage
    -----
    >>> molecule = {
    >>>   'name': 'aspirin',
    >>>   'coordinate': '''
    >>>   2
    >>>
    >>>   H 0.0 0.0 0.7
    >>>   H 0.0 0.0 0.0
    >>>   '''
    >>>     }
    >>> xyz_json = xyz_format_to_json(molecule)
    """
    def _read_string(xyz:str)->str:
        """read xyz string or filepath
        """
        # xyz -> filepath
        if isfile(xyz):
            with open(xyz, "r") as file:
                xyz_string = file.read()
            name = basename(xyz)
            return name, xyz_string
        # xyz -> xyz format string
        else:
            name = ""
            return name, xyz

    if isinstance(xyz_coord, dict):
        name = xyz_coord["name"]
        _, xyz_string = _read_string(xyz_coord["coordinate"])

    if isinstance(xyz_coord, str):
        name, xyz_string = _read_string(xyz_coord)

    # number of atoms
    n_atoms = re.search(r"(\d+)", xyz_string).group(0)

    # split xyz string
    pattern = re.compile("([a-zA-Z]{1,2}(\s+-?\d+.\d+){3,3})+")
    xyz_lines = np.array(np.array(list(re.split(r'\s+', tup[0]) for tup in pattern.findall(xyz_string))))

    # converts atomic_symbol to atomic_number
    xyz_lines[:, 0] = np.array(list(element_symbol2atomic_number[symbol] for symbol in xyz_lines[:, 0]))

    # json format xyz
    xyz_json = {
        "name": name,
        "n_atoms": n_atoms,
        "coordinate": xyz_lines.astype(float)
    }

    return xyz_json



def open_xyz_files(xyz_coordinates:str|list)->list:
    """
    Description
    -----------
    Open and read XYZ files, extract headers and atomic coordinates.

    Parameters
    ----------
    filenames : str or list
        str : xyz format traj file path
        list : list of xyz format strings or file paths

    Returns
    -------
    xyz_format_jsons : list
        list of json format xyz

    Usage
    -----
    >>> xyz_files = [
    >>>     {'name': 'reactant', 'coordinate': 'sn2_reac.xyz'},
    >>>     {'name': 'TS', 'coordinate': 'sn2_TS.xyz'},
    >>>     {'name': 'prod', 'coordinate': 'sn2_prod'}
    >>>  ]
    >>> xyz_format_jsons = open_xyz_files(xyz_files)
    >>>
    >>> xyz_format_jsons_from_traj = open_xyz_files('sn2_traj.xyz')
    """
    # traj file
    if isinstance(xyz_coordinates, str):
        with open(xyz_coordinates, "r") as file:
            traj_string = file.read()
        # find all xyz format strings
        pattern = re.compile("(\s?\d+\n.*\n(\s*[a-zA-Z]{1,2}(\s+-?\d+.\d+){3,3}\n?)+)")
        matched_xyz_formats = pattern.findall(traj_string)
        xyz_format_strings = list(tup[0] for tup in matched_xyz_formats)

        # convert xyz format stirng to json format
        traj_idx = list(idx+1 for idx in range(len(xyz_format_strings)))
        xyz_format_jsons = list(xyz_format_to_json({'name': idx, 'coordinate': xyz_string}) for idx, xyz_string in zip(traj_idx, xyz_format_strings))

        return xyz_format_jsons

    elif isinstance(xyz_coordinates, Iterable):
        return list(map(xyz_format_to_json, xyz_coordinates))

    else:
        raise TypeError("xyz_coordinates must be str or list")



def superimpose(xyz_format_jsons:list, option="aa", option_param:None|list=None)->dict:
    """
    Description
    -----------
    Superimpose molecules

    Parameters
    ----------
    option : str
        supported options : ["aa", "a", "sa"]
        - aa  : all atoms
        - a   : atoms
        - sa  : same atoms

    option_param : list
        list of atom index to superimpose
        index starts with 1
        - option="aa"   : None
        - option="a"    : e.g. [[1, 2, 3], [4, 5, 6]] # same order as xyz files
        - option="sa"   : e.g. [1, 2, 3]
    """
    # copy xyz_format_jsons
    _xyz_format_jsons = deepcopy(xyz_format_jsons)

    # all atoms ( -aa option )
    if option=="aa":
        if option_param: print("\033[31m[WARNING]\033[0m", "`aa` option does not require `option_param`. `option_param` is ignored.")

        # `aa` option expects that each coordinates has the same order of atoms
        atomic_indice_list = list(map(lambda xyz_json : xyz_json.get("coordinate")[:, 0], _xyz_format_jsons))
        symbols_list = list(atomic_number2element_symbol[atomic_index[0]] for atomic_index in atomic_indice_list)
        # check each coordinates has same order
        if not np.all(np.array(symbols_list) == symbols_list[0]):
            raise ValueError("The `aa` option expects that each coordinates has the same order of atoms. \n Try other options like `sa` or `a`")

        # every molecule is overlaid on the first molecule
        first_molecule = _xyz_format_jsons[0]["coordinate"][:, 1:]
        centroid = np.mean(first_molecule, axis=0)

        # center the first(reference) molecule
        first_molecule -= centroid
        _xyz_format_jsons[0]["coordinate"][:, 1:] = first_molecule

        for mol_idx in range(len(_xyz_format_jsons)):
            # center the molecule
            centroid_mol = np.mean(_xyz_format_jsons[mol_idx]["coordinate"][:, 1:], axis=0)
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] -= centroid_mol

            # overlay each molecule on the first molecule
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] = align_xyz(
                vec1=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:], # molecule to align
                vec2=first_molecule, # reference molecule
                coord=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:] # coordinates to align
              )

    # atoms ( -a option )
    elif option=="a":
        # `option_param` should be a list with the same size as the number of molecules to overlay
        if not option_param: raise ValueError("`a` option requires `option_param`")
        if not isinstance(option_param, Iterable): raise TypeError("`option_param` must be list")
        if not np.all(np.array(list(len(param) for param in option_param)) == len(option_param[0])): raise ValueError("all elements in `option_param` must have the same size")
        if not len(_xyz_format_jsons) == len(option_param): raise ValueError(f"""`a` option requires `option_param` to have the same length as `xyz_format_jsons`\n
         Number of molecules to overlay : {len(_xyz_format_jsons)}
         Length of option_param : {len(option_param)}""")
        if any(list((0 in param) for param in option_param)): raise ValueError("atomic indices start with 1, but 0 was found in `option_param`")

        # reset atomic indice
        option_param = list(
            list(param - 1 for param in mol_param) for mol_param in option_param
            )

        # every molecule is overlaid on the first molecule
        first_molecule_selected_atoms = _xyz_format_jsons[0]["coordinate"][:, 1:][[option_param[0]]][0]
        centroid = np.mean(first_molecule_selected_atoms, axis=0)

        # center the first(reference) molecule
        _xyz_format_jsons[0]["coordinate"][:, 1:] -= centroid

        for mol_idx in range(len(_xyz_format_jsons)):
            # center the molecule
            centroid_mol = np.mean(_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][option_param[mol_idx]], axis=0)
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] -= centroid_mol

            # overlay each molecule on the first molecule
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] = align_xyz(
                vec1=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][[option_param[mol_idx]]][0], # molecule to align
                vec2=first_molecule_selected_atoms,
                coord=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:] # coordinates to align
              )

    # same atom ( -sa option )
    elif option=="sa":
        # `option_param` should be a list
        if not option_param: raise ValueError("`sa` option requires `option_param`")
        if not isinstance(option_param, Iterable): raise TypeError("`option_param` must be list")
        if 0 in option_param: raise ValueError("atomic indices start with 1, but 0 was found in `option_param`")

        # reset atomic indice
        option_param = list(param - 1 for param in option_param)

        # `sa` option expects that each coordinates has the same order of selected atoms
        selected_atomic_indice_list = list(map(lambda xyz_json : xyz_json.get("coordinate")[:, 0][[option_param]], _xyz_format_jsons))
        selected_symbols_list = list(atomic_number2element_symbol[atomic_number[0][0]] for atomic_number in selected_atomic_indice_list)

        # check each coordinates has same order
        if not np.all(np.array(selected_symbols_list) == selected_symbols_list[0]):
            raise ValueError("The `aa` option expects that each coordinates has the same order of atoms. \n Try other options like `sa` or `a`")

        # every molecule is overlaid on the first molecule
        first_molecule_selected_atoms = _xyz_format_jsons[0]["coordinate"][:, 1:][[option_param]][0]
        centroid = np.mean(first_molecule_selected_atoms, axis=0)

        # center the first(reference) molecule
        _xyz_format_jsons[0]["coordinate"][:, 1:] -= centroid

        for mol_idx in range(len(_xyz_format_jsons)):
            # center the molecule
            #centroid_mol = np.mean(_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][[option_param]], axis=0)
            centroid_mol = np.mean(_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][option_param], axis=0)
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] -= centroid_mol
            # overlay each molecule on the first molecule
            _xyz_format_jsons[mol_idx]["coordinate"][:, 1:] = align_xyz(
                vec1=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:][[option_param]][0], # molecule to align
                vec2=first_molecule_selected_atoms,
                coord=_xyz_format_jsons[mol_idx]["coordinate"][:, 1:] # coordinates to align
              )
    elif option==None:
        return xyz_format_jsons

    else:
        raise ValueError(f"Unsupported option : {option}")

    return _xyz_format_jsons


def xyz2molecular_graph(xyz_format_jsons:list, covalent_radius_percent:float=108.):
    """
    Description
    -----------

    Parameters
    ----------
    - xyz_format_jsons : list
        list of json format xyz
    """
    def _covalent_radii(element:str, percent:float):
        """resize covalent radius
        """
        radius = covalent_radii[element]
        radius = (radius * (percent / 100))
        return radius

    # get molecular connetivity & bond length
    for mol_idx in range(len(xyz_format_jsons)):
        # get interatomic distance(L2 norm) matrix
        atomic_coordinates = xyz_format_jsons[mol_idx]["coordinate"][:, 1:] # (N, 3)
        L2_matrix = squareform(pdist(atomic_coordinates, 'euclidean')) # (N, N)

        # get sum of atomic radii matrix
        symbols_vector = np.array(list(
            atomic_number2element_symbol[atomic_number] for atomic_number in xyz_format_jsons[mol_idx]["coordinate"][:, 0]
              )) # (N, 3)
        radii_vector = np.array(list(_covalent_radii(symbol, covalent_radius_percent) for symbol in symbols_vector)) # (N, 3)
        radii_sum_matrix = np.add.outer(radii_vector, radii_vector) # (N, N)

        # get adjacency(bond) matrix
        adjacency_matrix = np.array(L2_matrix <= radii_sum_matrix) # (N, N)
        np.fill_diagonal(adjacency_matrix, 0) # diagonal means self-bonding
        xyz_format_jsons[mol_idx]["adjacency_matrix"] = adjacency_matrix

        # bond length matrix = adjacency_matrix * L2_matrix
        bond_length_matrix = adjacency_matrix * L2_matrix # (N, N)

        # get bond length table
        # remove duplicated values. Rba = Rab ( symmetrix mat. )
        bond_length_matrix[np.triu_indices_from(bond_length_matrix, k=1)] = 0
        mask = ~np.equal(bond_length_matrix, 0)
        # bond ( atom_pair ) & bond length
        atom_pairs = np.array(np.nonzero(mask)).T
        length = bond_length_matrix[mask]
        # bond length table ["atom_1_idx", "atom_2_idx", "distance"] # idx start with 1
        #atom_1 = symbols_vector[atom_pairs[:, 0]]
        #atom_2 = symbols_vector[atom_pairs[:, 1]]
        #bond_length_table = np.column_stack((atom_1, atom_2, length))
        bond_length_table = np.column_stack((atom_pairs[:, 0] + 1, atom_pairs[:, 1] + 1, length))
        xyz_format_jsons[mol_idx]["bond_length_table"] = bond_length_table




In [166]:
#@title plot function

def plot_overlay(xyz_format_jsons:list, colorby:str="molecule", exclude_elements:list=None, exclude_atomic_idx:list=None, cmap:str=None, covalent_radius_percent:float=108., **kwargs):
    """
    Description
    -----------
    Visualization of molecular structures in 3D using Plotly.

    Parameters
    ----------
    - xyz_format_jsons : list
        list of json format xyz

    - colorby : str
        supported options : ["molecule", "atom"]
        - molecule  : color by molecule
        - atom      : color by atom

    - exclude_elements : list
        list of elements to exclude from visualization. e.g. ["H"]

    - exclude_atomic_idx : list
        list of atoms to exclude from visualization. e.g. [1, 3, 4]

    - cmap : str or list
        plotly colormap to use for coloring.
        Supported options : [  ]
        Refer)
        https://plotly.com/python/builtin-colorscales/

        or

        iterable color list
        e.g. ['red', 'blue', 'green']
        Refer)
        https://community.plotly.com/t/plotly-colours-list/11730/3


    - covalent_radius_percent : float
        resize covalent radii by this percent
        default : 108%

    Returns
    -------
    """
    def _get_colors(cmap:str|list, n:int):
        """get n size color list from plotly colormap
        """
        if not cmap:
            cmap = 'Plotly3'

        try: pc.get_colorscale(cmap)
        except Exception:
            print("\033[31m[WARNING]\033[0m", f"`{cmap}` is not a valid plotly colormap. Applying default colormap instead.")
            cmap = 'Plotly3'

        if isinstance(cmap, str):
            colors = pc.get_colorscale(cmap)
            return list(pc.sample_colorscale(colors, list(ratio for ratio in np.linspace(0, 1, n+1)[1:]), colortype='rgb'))

        if isinstance(cmap, list):
            cyclic_iterator = cycle(cmap)
            return list(next(cyclic_iterator) for _ in range(n))

    # set default values
    alpha_atoms = kwargs.get("alpha_atoms", 0.55) # atoms opacity
    alpha_bonds = kwargs.get("alpha_bonds", 0.35) # bonds opacity
    atom_scaler = kwargs.get("atom_scaler", 4e1) # sphere radius for atom view, change exponent
    bond_scaler = kwargs.get("bond_scaler", 7e4) # cylinder radius for bond view, change exponent
    legend = kwargs.get("legend", False) # add legend
    bgcolor = kwargs.get("bgcolor", 'black') # background color

    # copy xyz_format_jsons
    _xyz_format_jsons = deepcopy(xyz_format_jsons)

    # plotly figure
    fig = go.Figure()

    # exclude atoms
    if exclude_atomic_idx:
        # `exclude_atomic_idx` option expects that each coordinates has the same order of atoms
        symbols_list = list(map(lambda xyz_json : xyz_json.get("coordinate")[:, 0], _xyz_format_jsons))
        if not np.all(np.array(symbols_list) == symbols_list[0]):
            print("\033[31m[WARNING]\033[0m", "`exclude_atomic_idx` option expects that each coordinates has the same order of atoms")
        # atomic indice start with 1
        if 0 in exclude_atomic_idx: raise ValueError("atomic indices start with 1, but 0 was found in `exclude_atomic_idx`")

        # reset atomic indice
        exclude_atomic_idx = list(idx - 1 for idx in exclude_atomic_idx)

        # check if atomic index is out of range
        if any(max(exclude_atomic_idx) > len(_xyz_format_jsons[mol_idx]["coordinate"]) for mol_idx in range(len(_xyz_format_jsons))):
            raise ValueError(f"Atomic index {max(exclude_atomic_idx)} provided in `exclude_atomic_idx` is out of range in your molecule.")

        for mol_idx in range(len(_xyz_format_jsons)):
            # filter the atom in `exclude_atomic_idx`
            atom_filtered_coordinate = list(
                atomic_coordinate for atomic_idx, atomic_coordinate in enumerate(_xyz_format_jsons[mol_idx]["coordinate"]) if atomic_idx not in exclude_atomic_idx
                  )
            # overwrite filtered coordinate
            _xyz_format_jsons[mol_idx]["coordinate"] = atom_filtered_coordinate
            # adjust number of atoms : n_atoms
            _xyz_format_jsons[mol_idx]["n_atoms"] = len(atom_filtered_coordinate)

    # exclude elements
    if exclude_elements:
        for mol_idx in range(len(_xyz_format_jsons)):
            # filter the element in `exclude_elements`
            element_filtered_coordinate = list(
                atomic_coordinate for atomic_coordinate in _xyz_format_jsons[mol_idx]["coordinate"] if atomic_number2element_symbol[atomic_coordinate[0]] not in exclude_elements
                  )
            # overwrite filtered coordinate
            _xyz_format_jsons[mol_idx]["coordinate"] = element_filtered_coordinate
            # adjust number of atoms : n_atoms
            _xyz_format_jsons[mol_idx]["n_atoms"] = len(element_filtered_coordinate)


    # number of molecules
    num_of_xyz = len(_xyz_format_jsons)

    # max number of atoms
    num_atom_xyz = max(len(xyz_jsons["coordinate"]) for xyz_jsons in _xyz_format_jsons)

    # analyze molecular connectivity
    xyz2molecular_graph(_xyz_format_jsons, covalent_radius_percent)

    # bond thickness and atom size
    bond_thickness = np.maximum(np.log10(bond_scaler / num_atom_xyz) * 2, 1)
    atom_size = np.maximum(np.log10(atom_scaler / num_atom_xyz) * 5, 2)

    if colorby == "molecule":
        # set color palette
        palette = _get_colors(cmap, num_of_xyz)

        # plot atoms & bonds
        for mol_idx in range(len(_xyz_format_jsons)):
            color = palette[mol_idx]

            # Add atoms to plot
            fig.add_trace(go.Scatter3d(
                x=_xyz_format_jsons[mol_idx]["coordinate"][:, 1],
                y=_xyz_format_jsons[mol_idx]["coordinate"][:, 2],
                z=_xyz_format_jsons[mol_idx]["coordinate"][:, 3],
                mode='markers',
                opacity=alpha_atoms,
                marker=dict(size=atom_size, color=color),
                #name=_xyz_format_jsons[mol_idx]["name"]
                name=f'{_xyz_format_jsons[mol_idx]["name"]} atoms'
            ))
            legend_group_namegroup = _xyz_format_jsons[mol_idx]["name"]

            # Add bonds to plot
            bonds = _xyz_format_jsons[mol_idx]["bond_length_table"][:, :2]
            first_bond = True
            for bond in bonds:
                bond = bond.astype(int) - 1
                atom_1_coord = _xyz_format_jsons[mol_idx]["coordinate"][:, 1:][bond[0]]
                atom_2_coord = _xyz_format_jsons[mol_idx]["coordinate"][:, 1:][bond[1]]
                fig.add_trace(go.Scatter3d(
                    x=[atom_1_coord[0], atom_2_coord[0]],
                    y=[atom_1_coord[1], atom_2_coord[1]],
                    z=[atom_1_coord[2], atom_2_coord[2]],
                    mode='lines',
                    opacity=alpha_bonds,
                    line=dict(width=bond_thickness, color=color),
                    legendgroup=legend_group_namegroup,
                    name=f"{legend_group_namegroup} bonds",
                    showlegend=True if first_bond else False
                ))
                first_bond = False


    elif colorby == "atom":
        # `legend` is only working when colorby='molecule'.
        if legend: print("\033[31m[WARNING]\033[0m", f"`legend`=True is not a applicable when colorby='atom'.")

        # plot atoms
        all_coordinates = np.vstack(list(json["coordinate"] for json in _xyz_format_jsons))
        elements = set(all_coordinates[:, 0].astype(int))

        # plot element-wise
        for element in elements:
            element_coordinates = all_coordinates[all_coordinates[:, 0].astype(int) == element]

            # Add atoms to plot
            fig.add_trace(go.Scatter3d(
                x=element_coordinates[:, 1],
                y=element_coordinates[:, 2],
                z=element_coordinates[:, 3],
                mode='markers',
                opacity=alpha_atoms,
                marker=dict(size=atom_size, color=atomic_number2hex[element]),
                showlegend=False
            ))

        # plot bonds
        for mol_idx in range(len(_xyz_format_jsons)):
            bonds = _xyz_format_jsons[mol_idx]["bond_length_table"][:, :2]
            for bond in bonds:
                bond = bond.astype(int) - 1 # internally idx start with 0
                atom_1_coord = _xyz_format_jsons[mol_idx]["coordinate"][:, 1:][bond[0]]
                atom_2_coord = _xyz_format_jsons[mol_idx]["coordinate"][:, 1:][bond[1]]
                fig.add_trace(go.Scatter3d(
                    x=[atom_1_coord[0], atom_2_coord[0]],
                    y=[atom_1_coord[1], atom_2_coord[1]],
                    z=[atom_1_coord[2], atom_2_coord[2]],
                    mode='lines',
                    opacity=alpha_bonds,
                    line=dict(width=bond_thickness, color='gray'),
                    showlegend=False
                ))


    else:
        raise ValueError(f"Unsupported colorby : {colorby}")

    # figure layout setting
    fig.update_layout(
        scene=dict(
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
            aspectmode='data',
            camera_projection=dict(type='orthographic'),
            #aspectratio=dict(x=1, y=1, z=1)
        ),
        showlegend=True if legend else False,
        paper_bgcolor=bgcolor
    )

    fig.show()



# 함수 나머지 부분 다시 쓰기.


# opacity 값 기본값 세팅
# scaler 기본 값 세팅
# 위에 함수 2개로 쪼개기
# xyz_filter()
# overlay_plot()

In [127]:
#@title animation plot | now testing

def plot_animation(xyz_format_jsons:list, colorby:str="molecule", exclude_elements:list=None, exclude_atomic_idx:list=None, cmap:str=None, covalent_radius_percent:float=108., **kwargs):
    """
    Description
    -----------
    Visualization of molecular structures in 3D using Plotly.

    Parameters
    ----------
    - xyz_format_jsons : list
        list of json format xyz

    - colorby : str
        supported options : ["molecule", "atom"]
        - molecule  : color by molecule
        - atom      : color by atom

    - exclude_elements : list
        list of elements to exclude from visualization. e.g. ["H"]

    - exclude_atomic_idx : list
        list of atoms to exclude from visualization. e.g. [1, 3, 4]

    - cmap : str or list
        plotly colormap to use for coloring.
        Supported options : [  ]
        Refer)
        https://plotly.com/python/builtin-colorscales/

        or

        iterable color list
        e.g. ['red', 'blue', 'green']
        Refer)
        https://community.plotly.com/t/plotly-colours-list/11730/3


    - covalent_radius_percent : float
        resize covalent radii by this percent
        default : 108%

    Returns
    -------
    """
    STABLE = False
    if not STABLE:
        print("\033[31m[WARNING]\033[0m", f"plot_animation function is now testing. It is not working perfectly.")

    def _get_colors(cmap:str|list, n:int):
        """get n size color list from plotly colormap
        """
        if not cmap:
            cmap = 'Plotly3'

        try: pc.get_colorscale(cmap)
        except Exception:
            print("\033[31m[WARNING]\033[0m", f"`{cmap}` is not a valid plotly colormap. Applying default colormap instead.")
            cmap = 'Plotly3'

        if isinstance(cmap, str):
            colors = pc.get_colorscale(cmap)
            return list(pc.sample_colorscale(colors, list(ratio for ratio in np.linspace(0, 1, n+1)[1:]), colortype='rgb'))

        if isinstance(cmap, list):
            cyclic_iterator = cycle(cmap)
            return list(next(cyclic_iterator) for _ in range(n))

    # set default values
    alpha_atoms = kwargs.get("alpha_atoms", 0.55) # atoms opacity
    alpha_bonds = kwargs.get("alpha_bonds", 0.55) # bonds opacity
    atom_scaler = kwargs.get("atom_scaler", 2e1) # sphere radius for atom view, change exponent
    bond_scaler = kwargs.get("bond_scaler", 1e4) # cylinder radius for bond view, change exponent
    legend = kwargs.get("legend", False) # add legend

    # copy xyz_format_jsons
    _xyz_format_jsons = deepcopy(xyz_format_jsons)

    # plotly figure
    fig = go.Figure()

    # exclude atoms
    if exclude_atomic_idx:
        # `exclude_atomic_idx` option expects that each coordinates has the same order of atoms
        symbols_list = list(map(lambda xyz_json : xyz_json.get("coordinate")[:, 0], _xyz_format_jsons))
        if not np.all(np.array(symbols_list) == symbols_list[0]):
            print("\033[31m[WARNING]\033[0m", "`exclude_atomic_idx` option expects that each coordinates has the same order of atoms")
        # atomic indice start with 1
        if 0 in exclude_atomic_idx: raise ValueError("atomic indices start with 1, but 0 was found in `exclude_atomic_idx`")

        # reset atomic indice
        exclude_atomic_idx = list(idx - 1 for idx in exclude_atomic_idx)

        # check if atomic index is out of range
        if any(max(exclude_atomic_idx) > len(_xyz_format_jsons[mol_idx]["coordinate"]) for mol_idx in range(len(_xyz_format_jsons))):
            raise ValueError(f"Atomic index {max(exclude_atomic_idx)} provided in `exclude_atomic_idx` is out of range in your molecule.")

        for mol_idx in range(len(_xyz_format_jsons)):
            # filter the atom in `exclude_atomic_idx`
            atom_filtered_coordinate = list(
                atomic_coordinate for atomic_idx, atomic_coordinate in enumerate(_xyz_format_jsons[mol_idx]["coordinate"]) if atomic_idx not in exclude_atomic_idx
                  )
            # overwrite filtered coordinate
            _xyz_format_jsons[mol_idx]["coordinate"] = atom_filtered_coordinate
            # adjust number of atoms : n_atoms
            _xyz_format_jsons[mol_idx]["n_atoms"] = len(atom_filtered_coordinate)

    # exclude elements
    if exclude_elements:
        for mol_idx in range(len(_xyz_format_jsons)):
            # filter the element in `exclude_elements`
            element_filtered_coordinate = list(
                atomic_coordinate for atomic_coordinate in _xyz_format_jsons[mol_idx]["coordinate"] if atomic_number2element_symbol[atomic_coordinate[0]] not in exclude_elements
                  )
            # overwrite filtered coordinate
            _xyz_format_jsons[mol_idx]["coordinate"] = element_filtered_coordinate
            # adjust number of atoms : n_atoms
            _xyz_format_jsons[mol_idx]["n_atoms"] = len(element_filtered_coordinate)


    # 모든 분자 구조에 대한 전체 범위 계산
    all_coords = np.vstack([xyz["coordinate"][:, 1:] for xyz in _xyz_format_jsons])
    x_range = [np.min(all_coords[:, 0]), np.max(all_coords[:, 0])]
    y_range = [np.min(all_coords[:, 1]), np.max(all_coords[:, 1])]
    z_range = [np.min(all_coords[:, 2]), np.max(all_coords[:, 2])]

    # 약간의 여유 추가
    padding = 0.1
    x_range = [x_range[0] - padding, x_range[1] + padding]
    y_range = [y_range[0] - padding, y_range[1] + padding]
    z_range = [z_range[0] - padding, z_range[1] + padding]

    # 프레임 리스트를 저장할 변수
    frames = []

    if colorby == "molecule":
        # number of molecules
        num_of_xyz = len(_xyz_format_jsons)
        # max number of atoms
        num_atom_xyz = max(len(xyz_jsons["coordinate"]) for xyz_jsons in _xyz_format_jsons)

        # set color palette
        palette = _get_colors(cmap, num_of_xyz)

        # analyze molecular connectivity
        xyz2molecular_graph(_xyz_format_jsons, covalent_radius_percent)


        # 각 분자에 대해 프레임 생성
        for mol_idx in range(len(_xyz_format_jsons)):
            frame_data = []
            color = palette[mol_idx]

            # 원자 추가
            atom_size = np.maximum(np.log10(atom_scaler / num_atom_xyz) * 5, 2)
            frame_data.append(go.Scatter3d(
                x=_xyz_format_jsons[mol_idx]["coordinate"][:, 1],
                y=_xyz_format_jsons[mol_idx]["coordinate"][:, 2],
                z=_xyz_format_jsons[mol_idx]["coordinate"][:, 3],
                mode='markers',
                opacity=alpha_atoms,
                marker=dict(size=atom_size, color=color),
                name=f'{_xyz_format_jsons[mol_idx]["name"]} atoms'
            ))

            # 결합 추가
            bonds = _xyz_format_jsons[mol_idx]["bond_length_table"][:, :2]
            bond_thickness = np.maximum(np.log10(bond_scaler / num_atom_xyz) * 2, 1)
            bond_x, bond_y, bond_z = [], [], []
            for bond in bonds:
                bond = bond.astype(int) - 1
                atom_1_coord = _xyz_format_jsons[mol_idx]["coordinate"][:, 1:][bond[0]]
                atom_2_coord = _xyz_format_jsons[mol_idx]["coordinate"][:, 1:][bond[1]]
                bond_x.extend([atom_1_coord[0], atom_2_coord[0], None])
                bond_y.extend([atom_1_coord[1], atom_2_coord[1], None])
                bond_z.extend([atom_1_coord[2], atom_2_coord[2], None])

            frame_data.append(go.Scatter3d(
                x=bond_x, y=bond_y, z=bond_z,
                mode='lines',
                opacity=alpha_bonds,
                line=dict(width=bond_thickness, color=color),
                name=f"{_xyz_format_jsons[mol_idx]['name']} bonds"
            ))

            # 프레임 추가
            frames.append(go.Frame(data=frame_data, name=str(mol_idx)))

        # 초기 프레임 설정
        fig.add_traces(frames[0].data)

    # 레이아웃 설정
    fig.frames = frames
    fig.update_layout(
        updatemenus=[{
            'buttons': [
                {
                    'args': [None, {'frame': {'duration': 50, 'redraw': True}, 'fromcurrent': True, 'transition': {'duration': 30, 'easing': 'quadratic-in-out'}}],
                    'label': 'Play',
                    'method': 'animate',
                },
                {
                    'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 0}}],
                    'label': 'Pause',
                    'method': 'animate',
                }
            ],
            'direction': 'left',
            'pad': {'r': 10, 't': 87},
            'showactive': False,
            'type': 'buttons',
            'x': 0.1,
            'xanchor': 'right',
            'y': 0,
            'yanchor': 'top'
        }],
        sliders=[{
            'active': 0,
            'yanchor': 'top',
            'xanchor': 'left',
            'currentvalue': {
                'font': {'size': 20},
                'prefix': 'Frame: ',
                'visible': True,
                'xanchor': 'right'
            },
            'transition': {'duration': 300, 'easing': 'cubic-in-out'},
            'pad': {'b': 10, 't': 50},
            'len': 0.9,
            'x': 0.1,
            'y': 0,
            'steps': [{'args': [[f.name], {'frame': {'duration': 300, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 300}}], 'label': str(k), 'method': 'animate'} for k, f in enumerate(fig.frames)]
        }],
        scene=dict(
            xaxis=dict(range=x_range , visible=False),
            yaxis=dict(range=y_range , visible=False),
            zaxis=dict(range=z_range , visible=False),
            aspectmode='data',
            camera_projection=dict(type='orthographic'),
            aspectratio=dict(x=1, y=1, z=1),
        ),
        showlegend=True if legend else False
    )

    fig.show()

> TEST

In [123]:
jsons = open_xyz_files("butadiene.xyz")
superimposed_jsons = superimpose(jsons, option=None)
plot_animation(superimposed_jsons, colorby="molecule", legend=True)



In [124]:
jsons = open_xyz_files("DA.xyz")
superimposed_jsons = superimpose(jsons, option="aa")
plot_animation(superimposed_jsons, colorby='molecule', cmap='Plotly3', legend=True)



In [167]:
jsons = open_xyz_files("DA.xyz")
superimposed_jsons = superimpose(jsons, option="aa")
plot_overlay(superimposed_jsons, colorby="molecule", legend=True)

In [168]:
jsons = open_xyz_files("DA.xyz")
superimposed_jsons = superimpose(jsons, option="aa")
plot_overlay(superimposed_jsons, colorby="atom", legend=True)



In [170]:
jsons = open_xyz_files("DA.xyz")
superimposed_jsons = superimpose(jsons, option="a", option_param=[[1,2,3,4,5,6,7,8], [1,2,3,4,5,6,7,8], [6,7,8,1,2,3,4,5]])
plot_overlay(superimposed_jsons, colorby="molecule", cmap=["r", "g", "b"], legend=True, bgcolor='gray')



In [171]:
jsons = open_xyz_files([
{
    "name": "molecule_1",
    "coordinate": "sn2_1.xyz"
    },
{
    "name": "molecule_2",
    "coordinate": "sn2_2.xyz"
    }
  ])
superimposed_jsons = superimpose(jsons, option="aa")
plot_overlay(superimposed_jsons, colorby="molecule", legend=True)

In [172]:
jsons = open_xyz_files("sn2.xyz")
superimposed_jsons = superimpose(jsons, option="aa")
plot_overlay(superimposed_jsons, colorby="atom", cmap="YlOrRd", legend=True,  bgcolor='darkgray')



In [173]:
jsons = open_xyz_files("butadiene.xyz")
superimposed_jsons = superimpose(jsons, option="sa", option_param=[1, 2, 3])
plot_overlay(superimposed_jsons, colorby="molecule", legend=True, bgcolor='white')

In [90]:
!wget -q https://raw.githubusercontent.com/kangmg/aimDIAS/main/examples/sn2.xyz -O tmp.xyz

In [174]:
jsons = open_xyz_files("tmp.xyz")
superimposed_jsons = superimpose(jsons, option=None)
plot_animation(superimposed_jsons, colorby="molecule", legend=True)



In [94]:
!wget -q https://raw.githubusercontent.com/kangmg/aimDIAS/main/examples/wittig.xyz -O tmp2.xyz

In [175]:
jsons = open_xyz_files("tmp2.xyz")
superimposed_jsons = superimpose(jsons, option=None)
plot_animation(superimposed_jsons, colorby="molecule", legend=True)



In [177]:
jsons = open_xyz_files("butadiene.xyz")
superimposed_jsons = superimpose(jsons, option="sa", option_param=[1, 2, 3])
plot_overlay(superimposed_jsons, colorby="molecule", cmap="Blues", legend=True, atom_scaler=2e1)

In [181]:
jsons = open_xyz_files("butadiene.xyz")
superimposed_jsons = superimpose(jsons, option="sa", option_param=[1, 2, 3])
plot_overlay(superimposed_jsons, colorby="atom", cmap="Blues", legend=True, bgcolor="black", atom_scaler=8e1, bond_scaler=1e9)

