In [1]:
import pandas as pd
import numpy as np
import gemmi

In [2]:
def sg_symbol_to_number(x):
    sg = gemmi.find_spacegroup_by_name(x)
    try: sg = sg.number
    except:
        # format extension to the H-M notation
        sg = gemmi.find_spacegroup_by_name(x[:-1] + ':' + x[-1])
        try: sg = sg.number
        except:
            # remove S, Z, and HR terminations
            if x[-1] in ['S', 'Z']:
                sg = gemmi.find_spacegroup_by_name(x[:-1]).number
            elif x[-2:] == 'HR':
                sg = gemmi.find_spacegroup_by_name(x[:-2]).number
    return sg

In [3]:
icsd_file = '../data/icsd_manual.csv'
topo_file = '../data/topo_manual.csv'
outfile = '../data/icsd_manual_process.csv'

# read data files
icsd_data = pd.read_csv(icsd_file)
print('number of icsd samples:', len(icsd_data))

topo_data = pd.read_csv(topo_file)
print('number of topological samples:', len(topo_data))

# convert spacegroup symbols to numbers
icsd_data['spacegroup'] = icsd_data['sg_symbol'].map(lambda x: sg_symbol_to_number(x))

# drop sg_symbol column
icsd_data = icsd_data.drop(columns=['sg_symbol'])

# drop duplicates
icsd_data = icsd_data.drop_duplicates().reset_index(drop=True)
print('number of icsd samples (no duplicates):', len(icsd_data))

# merge icsd and topo data
data = icsd_data.merge(topo_data, how='outer', on=['icsd', 'spacegroup', 'formula'])
print('number of samples:', len(data))

# set unclassified bands to 0 (trivial)
data.loc[data['band'].isna(), 'band'] = 0
data['band'] = data['band'].astype(int)

number of icsd samples: 68352
number of topological samples: 7365
number of icsd samples (no duplicates): 68344
number of samples: 71548


In [4]:
data.to_csv(outfile, index=False)