In [1]:
import numpy as np
import pandas as pd
from pymatgen.core.structure import Structure, Molecule
from pymatgen.io.cif import CifParser
from collections import Counter

In [2]:
num_metal_dict = {1: 4, 2: 2, 3: 2, 4: 3, 5: 4, 6: 4, 7: 6, 8: 3, 9: 2, 10: 2, 11: 1, 12: 4}
specie_metal_dict = {1: 'Zn', 2: 'Cu', 3: 'Zn', 4: 'Cr', 5: 'Cd', 6: 'Mn', 7: 'Zr', 8: 'Al', 9: 'V', 10: 'Ba', 11: 'In', 12: 'Ni'}
links_ratio = {'acs': 3, 'bcu': 4, 'etb': 2, 'fof': 1, 'nbo': 2, 'pcu': 3, 'pts': 1, 'rht': 0.666667, 'sra': 2, 'tbo': 1.333333, 'the': 2.666667}
metal_links = {'acs': 6, 'bcu': 8, 'etb': 8, 'fof': 4, 'nbo': 4, 'pcu': 6, 'pts': 4, 'rht': 4, 'sra': 4, 'tbo': 4, 'the': 8}
organic_links = {'acs': 2, 'bcu': 2, 'etb': 4, 'fof': 4, 'nbo': 2, 'pcu': 2, 'pts': 4, 'rht': 6, 'sra': 2, 'tbo': 3, 'the': 3}


class DeriveLinkers:
    def __init__(self, MOFname, file, cif_path):
        self.MOFname = MOFname
        df = pd.read_csv(file)
        self.cif_path = cif_path
        self.mof_info = df[df['MOFname'] == MOFname]
        self.num_m = self.mof_info['metal_linker'].item()
        self.num_l1 = self.mof_info['organic_linker1'].item()
        self.num_l2 = self.mof_info['organic_linker2'].item()    
        self.fg_info = pd.read_csv('Functional_Groups.csv')
            
    
    def _get_extra_elem(self, Z_FG):
        extra_elem = []
        for fg in Z_FG:
            for key, val in fg.items():
                if val != 0 and key not in ['Name', 'Total', 'H', 'C', 'O']:
                    extra_elem.append(key)
        return extra_elem

    def get_fg_info(self, fg_list=None):
        FG = []
        if fg_list == None:
            fg_list = self.mof_info['functional_groups'].item()
        if isinstance(fg_list, str):
            fg_list = fg_list.split('-')
            for item in fg_list:
                FG.append(list(self.fg_info[self.fg_info['Name']==item].to_dict('index').values())[0]) 
        return FG
    
    def structure_info(self):
        try:
            struct = Structure.from_file(self.cif_path + self.MOFname + '.cif')
        except:
            struct = CifParser(self.cif_path + self.MOFname + '.cif', occupancy_tolerance=2.).get_structures()
            struct = struct[0]
        metal_linker = Molecule.from_file('linkers/metal_linker_' + str(self.num_m) + '.xyz')
        organic_linker1 = Molecule.from_file('linkers/organic_linker_' + str(self.num_l1) + '.xyz')
        organic_linker2 = Molecule.from_file('linkers/organic_linker_' + str(self.num_l2) + '.xyz')
        return struct, metal_linker, organic_linker1, organic_linker2
    
    def counter(self):
        struct, metal_linker, organic_linker1, organic_linker2 = self.structure_info()
        Z_struct = dict(Counter([site.specie.symbol for site in struct.sites]))
        Z_metal_linker = dict(Counter([site.specie.symbol for site in metal_linker.sites]))
        Z_organic_linker1 = dict(Counter([site.specie.symbol for site in organic_linker1.sites]))
        Z_organic_linker2 = dict(Counter([site.specie.symbol for site in organic_linker2.sites]))
        
        Z_FG = self.get_fg_info()
        if self.mof_info['topology'].item() == 'etb':
            Z_FG += self.get_fg_info('OEt')
        Z_FG = {fg['Name']: {key: val for key, val in fg.items() if key not in ['Name', 'Total'] and val != 0} for fg in Z_FG}
        
        Z_all = {'N_a': Z_struct, 'N_M': Z_metal_linker, 'N_L1': Z_organic_linker1, 'N_L2': Z_organic_linker2} 
        Z_all.update(Z_FG)
        return Z_all

    def number_of_linkers(self):
        struct, metal_linker, organic_linker1, organic_linker2 = self.structure_info()
        num_metal = sum([site.specie.is_metal for site in struct.sites])
        M = num_metal / num_metal_dict[self.mof_info['metal_linker'].item()]

        Z_struct = Counter([site.specie.symbol for site in struct.sites])
        elems_for_solve = list(Z_struct.keys())
        Z_FG = self.get_fg_info()
        
        if self.mof_info['topology'].item() == 'etb':
            Z_struct['O'] -= 6 * Z_struct['P'] / 8
            Z_etb = self.get_fg_info('OEt')[0]
            for elem in Z_etb.keys():
                if elem in ['Name', 'Total']:
                    continue
                Z_struct[elem] -= Z_etb[elem] * M * 8
        
        if len(Z_FG) == 0:
            Z_FG = [{key: 0 for key in Z_struct.keys()}]
            Z_FG[0].update({'Name': 'No FG'})
        else:
            for i in range(len(Z_FG)):                    
                Z_FG[i]['H'] -= 1
        
        for fg in Z_FG:
            if self.num_l1 == 9 and fg['Name'] == 'Ph':
                self.num_l1 = 1
                _, _, organic_linker1, _ = self.structure_info()
            elif self.num_l2 == 9 and fg['Name'] == 'Ph':
                self.num_l2 = 1
                _, _, _, organic_linker2 = self.structure_info()
        
        Z_metal_linker = Counter([site.specie.symbol for site in metal_linker.sites])
        Z_organic_linker1 = Counter([site.specie.symbol for site in organic_linker1.sites])
        Z_organic_linker2 = Counter([site.specie.symbol for site in organic_linker2.sites])  
        
        if self.num_l1 == 50:
            Z_organic_linker1['H'] -= 2
        if self.num_l2 == 50:
            Z_organic_linker2['H'] -= 2
            
        if self.mof_info['topology'].item() == 'rht':
            Z_organic_linker1['C'] += 18
            Z_organic_linker1['H'] += 9
            Z_organic_linker2['C'] += 18
            Z_organic_linker2['H'] += 9
        
        N_a = np.array([Z_struct[elem] for elem in elems_for_solve])
        const = links_ratio[self.mof_info['topology'].item()]
        RHS = np.array([[Z_metal_linker[elem] + const * Z_organic_linker2[elem],\
                         Z_organic_linker1[elem] - Z_organic_linker2[elem], \
                         *[fg[elem] for fg in Z_FG]] for elem in elems_for_solve])
        
        if self.mof_info['topology'].item() == 'pcu' and self.mof_info['metal_linker'].item() != 1:
            N_a[elems_for_solve.index('C')] += 2 * M
            N_a[elems_for_solve.index('N')] -= 2 * M
        
        solution = np.linalg.lstsq(RHS, N_a, rcond=None)[0]        
        L2 = const * solution[0] - solution[1]
        solution_dict = {'M': solution[0], 'L1': solution[1], 'L2': L2}
        
        for i in range(len(Z_FG)):
            #solution_dict[Z_FG[i]['Name']] = solution[i+2]
            solution_dict['FG'+str(i+1)] = solution[i+2]
        
        if len(Z_FG) == 1:
            solution_dict['FG2'] = 0
        
        #print(f'Elems: \t{elems_for_solve}')
        #print(f'N_a: \t{N_a}')
        #print(f'RHS: \n{RHS}')
        #print(f'\nsolution: {solution}')
        return solution_dict

    
class BagofLinkers(DeriveLinkers):
    def __init__(self, MOFname, file, cif_path):
        super().__init__(MOFname, file, cif_path)    
        self.fg_list = list(self.fg_info['Name'])
    
    def _one_hot(self, size, idx, num):
        one_hot = np.zeros(size)
        one_hot[idx] = num
        return one_hot
    
    def vectorize(self):
        linkers = self.number_of_linkers()
        m_vec = self._one_hot(12, self.num_m - 1, linkers['M'])
        l_vec = self._one_hot(59, self.num_l1 - 1, linkers['L1'])
        l_vec += self._one_hot(59, self.num_l2 - 1, linkers['L2'])
        
        fg_vec = np.zeros(len(self.fg_list))
        fg = self.mof_info['functional_groups'].item()  
        key = list(linkers.keys())
        if isinstance(fg, str):
            fg = fg.split('-')
            for name in fg:
                fg_vec += self._one_hot(len(self.fg_list), self.fg_list.index(name) - 1, linkers[name])
        return m_vec, l_vec, fg_vec
    
    def embedding(self):
        return np.concatenate(self.vectorize())
        

In [6]:
num = 10815
MOFname = 'mof_unit_' + str(num)
file = 'train.csv'
linkers = DeriveLinkers(MOFname, file, 'mof_cif_train/')
linkers.counter()

{'N_a': {'Cu': 24, 'H': 171, 'C': 408, 'N': 21, 'O': 159},
 'N_M': {'O': 8, 'C': 4, 'Cu': 2},
 'N_L1': {'C': 24, 'H': 15},
 'N_L2': {'C': 21, 'N': 3, 'H': 12},
 'NO2': {'O': 2, 'N': 1},
 'HCO': {'C': 1, 'O': 1, 'H': 1}}

In [7]:
linkers.number_of_linkers()

{'M': 11.999994181820911,
 'L1': 7.000008606056134,
 'L2': 0.9999915151558678,
 'FG1': 18.000026909076396,
 'FG2': 26.999992727279903}