# Objective

To create a pipeline for extrapolated structure factor refinement which modifies phases using solvent flattening, with the aim of better approximating the excited state structure factors.




# Workflow

1) ${F}_{\Delta} = {F}_{timepoint} - {F}_{ground}$

2) $\Delta$ Electron Density Map = Fourier transform using ${F}_{\Delta}$ and ${\phi}_{ground}$

3) Solvent flattening of difference map
  
      difference map values > defined cutoff distance from excited model = new solvent value (0?)
  
4) FT flattened difference map to get $\phi_\Delta$

5) $\overrightarrow{F}_{excited} = \overrightarrow{F}_{ground} + n*\overrightarrow{F}_{\Delta}$
  
      n is extrapolation value and will be sampled across range
  
6) phenix.refine ground state model (round 0) or previous round's model (round 1..n) against $\overrightarrow{F}_{excited}$

7) new model from step 6 is used to generate $\phi_{excited}$

8) $\phi_{new} = (1-n)*\phi_{ground} + n*\phi_{excited}$

9) Return to step 2, replacing $\phi_{ground}$ with $\phi_{new}$ and iterate until convergence



# Required Functions

1. **subtract_structure_factor_magnitudes()**
    1. input: two sets of fourier magnitudes, weighting value
    2. output: one set of fourier magnitudes
    3. notes: HKLs must match up...any other scaling considerations?
    
2. **structure_factors_to_electron_density()**
    1. input: one set of fourier magnitudes, one set of phase angles, unit cell dimensions
    2. output: array containing electron density values for one unit cell
    3. notes: review Reciprocalspaceship to confirm details
    
3. **electron_density_solvent_flattening()**
    1. input: one array containing density values, one PDB mapped to same unit cell as array
    2. output: one array containing adjusted density values
    3. notes: review details for mapping of model and density to ensure orientation is correct etc
    
4. **electron_density_to_structure_factor_vectors()**
    1. input: one MTZ containing electron density coefficients, one PDB
    2. output: complex structure factors, separated into magnitudes and phases
    3. notes: ...details to be determined
    
5. **structure_factor_vectors_extrapolation()**
    1. input: two sets of complex structure factors, extrapolation value, weighting value
    2. output: one set of complex structure factors
    3. notes: ...
    
6. **add_structure_factor_phases()**
    1. input: two sets of phase angles, two weighting values corresponding to occupancy for both excited and ground states
    2. output: one set of phase angles
    3. notes: double check trig and details for combining phases
    
    

# Import Required Modules and Document Versions

In [3]:
import reciprocalspaceship as rs
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
import matplotlib ### for grabbing version; comment out for efficiency
import numpy as np
from scipy.stats import norm, shapiro
import scipy ### for grabbing version; comment out for efficiency
import scipy.spatial.distance as spsd
import pandas as pd
import gemmi
import glob
import subprocess

%matplotlib inline


In [4]:
print("reciprocalspaceship version: {}".format(rs.__version__))
print("seaborn version: {}".format(sns.__version__))
print("matplotlib version: {}".format(matplotlib.__version__))
print("numpy version: {}".format(np.__version__))
print("scipy version: {}".format(scipy.__version__))
print("pandas version: {}".format(pd.__version__))
print("gemmi version: {}".format(gemmi.__version__))

reciprocalspaceship version: 0.9.5
seaborn version: 0.11.1
matplotlib version: 3.3.2
numpy version: 1.19.2
scipy version: 1.5.2
pandas version: 1.1.5
gemmi version: 0.4.5


# Library of OLD Functions (Reference Only)

In [None]:
'''
def compute_weights(df, sigdf, alpha=0):
    """
    Compute weights for each structure factor based on deltaF and its uncertainty
    """
    w = (1 + (sigdf**2 / (sigdf**2).mean()) + alpha*(df**2 / (df**2).mean()))
    return w**-1

def difference_map(ground_state_mtz, excited_state_mtz, alpha, phase_pdb_mtz, output_string_mtz):
    ground = rs.read_mtz(ground_state_mtz)
    excited = rs.read_mtz(excited_state_mtz)
    diff = ground.merge(excited, left_index=True, right_index=True, suffixes=("_ground", "_excited"))
    diff["DF"] = (diff["FOBS_excited"] - diff["FOBS_ground"]).astype("SFAmplitude")
    diff["SigDF"] = np.sqrt(diff["SIGFOBS_excited"]**2 + diff["SIGFOBS_ground"]**2).astype("Stddev")
    diff["W"] = compute_weights(diff["DF"], diff["SigDF"], alpha)
    diff["WDF"] = (diff["W"]*diff["DF"]).astype("F")
    ref = rs.read_mtz(phase_pdb_mtz)
    diff["PHIFMODEL"] = ref.loc[diff.index, "PHIFMODEL"]
    diff.write_mtz(output_string_mtz)
    return

def extrapolated_map(ground_state_mtz, excited_state_mtz, r_free_mtz, alpha, N, phase_pdb_mtz, output_string_mtz):
    ground = rs.read_mtz(ground_state_mtz)
    excited = rs.read_mtz(excited_state_mtz)
    r_free_mtz = rs.read_mtz(r_free_mtz)
    extrap = ground.merge(excited, left_index=True, right_index=True, suffixes=("_ground", "_excited"))
    extrap["DF"] = (extrap["FOBS_excited"] - extrap["FOBS_ground"]).astype("SFAmplitude")
    extrap["SigDF"] = np.sqrt(extrap["SIGFOBS_excited"]**2 + extrap["SIGFOBS_ground"]**2).astype("Stddev")
    extrap["W"] = compute_weights(extrap["DF"], extrap["SigDF"], alpha)
    extrap["WDF"] = (extrap["W"]*extrap["DF"]).astype("F")
    extrap["WSigDF"] = np.sqrt(extrap["W"]**2 * extrap["SigDF"]**2).astype("Stddev")
    extrap["ExWDF"] = (extrap["FOBS_ground"] + N*extrap["WDF"]).astype("F")
    extrap["ExWSigDF"] = np.sqrt(extrap["SIGFOBS_ground"]**2 + N**2 * extrap["WSigDF"]**2).astype("Stddev")
    ref = rs.read_mtz(phase_pdb_mtz)
    extrap["PHIFMODEL"] = ref.loc[extrap.index, "PHIFMODEL"]
    extrap["FreeR_flag"] = r_free_mtz["FreeR_flag"].reindex(extrap.index, fill_value=0)
    extrap.write_mtz(output_string_mtz)
    return


def find_sites(realmap, threshold, cell):
    """
    Find local peaks in map.

    Parameters
    ----------
    realmap : np.ndarray
        3D array with voxelized electron density
    threshold : float
        Minimum voxelized density to consider for peaks
    cell : gemmi.UnitCell
        Cell parameters for crystal

    Returns
    -------
    pd.DataFrame
        DataFrame with coordinates and peak height for each site
    """
    from skimage import feature
    peaks = feature.peak_local_max(realmap, threshold_abs=threshold, exclude_border=False)
    data = []
    for p in peaks:
        pf = p/np.array(realmap.shape)
        pf_2 = pf + np.array([-1, 0, 0])
        pf_3 = pf + np.array([0, 0, 1])
        pf_4 = pf + np.array([-1, 0, -1])
        pos = cell.orthogonalize(gemmi.Fractional(*pf))
        pos_2 = cell.orthogonalize(gemmi.Fractional(*pf_2))
        pos_3 = cell.orthogonalize(gemmi.Fractional(*pf_3))
        pos_4 = cell.orthogonalize(gemmi.Fractional(*pf_4))
        d  = {"x": pos.x, "y": pos.y, "z": pos.z}
        d_2  = {"x": pos_2.x, "y": pos_2.y, "z": pos_2.z}
        d_3  = {"x": pos_3.x, "y": pos_3.y, "z": pos_3.z}
        d_4  = {"x": pos_4.x, "y": pos_4.y, "z": pos_4.z}
        d["height"] = realmap[p[0], p[1], p[2]]
        d_2["height"] = realmap[p[0], p[1], p[2]]
        d_3["height"] = realmap[p[0], p[1], p[2]]
        d_4["height"] = realmap[p[0], p[1], p[2]]
        data.append(d)
        data.append(d_2)
        data.append(d_3)
        data.append(d_4)
    return pd.DataFrame(data)


def write_pdb(sites, cell):
    s = gemmi.Structure()
    s.cell = cell
    m = gemmi.Model("substructure")
    c = gemmi.Chain("A")
    for i, site in sites.iterrows():
        a = gemmi.Atom()
        a.element = gemmi.Element("S")
        a.name = "S"
        a.pos = gemmi.Position(*site[["x", "y", "z"]].values)
        r = gemmi.Residue()
        r.name = "S"
        r.seqid = gemmi.SeqId(str(i+1))
        r.add_atom(a)
        c.append_residues([r])
    m.add_chain(c)
    s.add_model(m)
    s.assign_label_seq_id(force=True)
    s.write_minimal_pdb("sites_quad_pos.pdb")
    return

def IADDAT(input_PDB_filename, input_MTZ_filename, threshold_dict, distance_cutoff, average_out=False):
    name_init = input_MTZ_filename.replace("./diff_maps_refinedPHI/F_internal_","")
    name = name_init.replace("_extrapmap_W0-05_N1.mtz","")
    elements = [gemmi.Element('C'), gemmi.Element('N'), gemmi.Element('O'), gemmi.Element('S')]
    
    input_PDB = gemmi.read_structure(input_PDB_filename)
    input_MTZ = rs.read_mtz(input_MTZ_filename)
    input_MTZ.compute_dHKL(inplace=True)
    grid_sampling = 0.25
    a_sampling = int(input_MTZ.cell.a/(input_MTZ.dHKL.min()*grid_sampling))
    b_sampling = int(input_MTZ.cell.b/(input_MTZ.dHKL.min()*grid_sampling))
    c_sampling = int(input_MTZ.cell.c/(input_MTZ.dHKL.min()*grid_sampling))
    input_MTZ["sf"] = input_MTZ.to_structurefactor("WDF", "PHIFMODEL")
    reciprocalgrid = input_MTZ.to_reciprocalgrid("sf", gridsize=(a_sampling, b_sampling, c_sampling))
    realmap = np.real(np.fft.fftn(reciprocalgrid))
    sites = find_sites(realmap, realmap.std()*threshold_dict[name], input_MTZ.cell)
    sites_neg = find_sites(realmap*-1, realmap.std()*threshold_dict[name], input_MTZ.cell)
    sites_coords = np.array([sites.x, sites.y, sites.z])
    sites_neg_coords = np.array([sites_neg.x, sites_neg.y, sites_neg.z])
    
    IADDAT = []
    for residue in input_PDB[0]["A"]:
        atom_coords = []
        for atom in residue:
            if atom.element in elements:
                pos  = {"x": atom.pos.x, "y": atom.pos.y, "z": atom.pos.z}
                atom_coords.append(pos)
            else:
                pass
        atom_coords_df = pd.DataFrame(atom_coords)
        
        if atom_coords_df.empty:
            print("non AA processed")
            pass
        else:
            distances_pos = spsd.cdist(atom_coords_df, sites_coords.transpose())
            filter_pos = []
            for row in distances_pos.transpose():
                filt_pos = np.any(row <= distance_cutoff)
                filter_pos.append(filt_pos)
            filt_pos_df = pd.DataFrame(filter_pos)

            sites["filtered"] = filt_pos_df
            integrate_pos = sites.loc[sites['filtered'] == True]
            int_pos_value = integrate_pos.height.sum()

            distances_neg = spsd.cdist(atom_coords_df, sites_neg_coords.transpose())
            filter_neg = []
            for row in distances_neg.transpose():
                filt_neg = np.any(row <= distance_cutoff)
                filter_neg.append(filt_neg)
            filt_neg_df = pd.DataFrame(filter_neg)

            sites_neg["filtered"] = filt_neg_df
            integrate_neg = sites_neg.loc[sites_neg['filtered'] == True]
            int_neg_value = integrate_neg.height.sum()

            int_value = int_pos_value + int_neg_value
            if average_out:
                new_int_value = int_value / len(atom_coords)
                IADDAT.append(new_int_value)
            else:
                IADDAT.append(int_value)
    
    return IADDAT
'''

# Library of NEW Functions

In [None]:
def compute_weights(df, sigdf, alpha=0):
    """
    Compute weights for each structure factor based on DeltaF and its uncertainty
    """
    w = (1 + (sigdf**2 / (sigdf**2).mean()) + alpha*(df**2 / (df**2).mean()))
    return w**-1

def subtract_structure_factor_magnitudes(ground_state_mtz, excited_state_mtz, alpha, phase_pdb_mtz, output_string_mtz):
    """
    Computes DeltaF, sigDeltaF, and weighted DeltaF. Adds PHI values from molecular model for conversion to real space.
    """
    ground = rs.read_mtz(ground_state_mtz)
    excited = rs.read_mtz(excited_state_mtz)
    diff = ground.merge(excited, left_index=True, right_index=True, suffixes=("_ground", "_excited"))
    diff["DF"] = (diff["FOBS_excited"] - diff["FOBS_ground"]).astype("SFAmplitude")
    diff["SigDF"] = np.sqrt(diff["SIGFOBS_excited"]**2 + diff["SIGFOBS_ground"]**2).astype("Stddev")
    diff["W"] = compute_weights(diff["DF"], diff["SigDF"], alpha)
    diff["WDF"] = (diff["W"]*diff["DF"]).astype("F")
    ref = rs.read_mtz(phase_pdb_mtz)
    diff["PHIFMODEL"] = ref.loc[diff.index, "PHIFMODEL"]
    diff.write_mtz(output_string_mtz)
    return

    
def structure_factors_to_electron_density()
    ref = rs.read_mtz(phase_pdb_mtz)
    diff["PHIFMODEL"] = ref.loc[diff.index, "PHIFMODEL"]
    diff.write_mtz(output_string_mtz)
    return

def extrapolated_map(ground_state_mtz, excited_state_mtz, r_free_mtz, alpha, N, phase_pdb_mtz, output_string_mtz):
    ground = rs.read_mtz(ground_state_mtz)
    excited = rs.read_mtz(excited_state_mtz)
    r_free_mtz = rs.read_mtz(r_free_mtz)
    extrap = ground.merge(excited, left_index=True, right_index=True, suffixes=("_ground", "_excited"))
    extrap["DF"] = (extrap["FOBS_excited"] - extrap["FOBS_ground"]).astype("SFAmplitude")
    extrap["SigDF"] = np.sqrt(extrap["SIGFOBS_excited"]**2 + extrap["SIGFOBS_ground"]**2).astype("Stddev")
    extrap["W"] = compute_weights(extrap["DF"], extrap["SigDF"], alpha)
    extrap["WDF"] = (extrap["W"]*extrap["DF"]).astype("F")
    extrap["WSigDF"] = np.sqrt(extrap["W"]**2 * extrap["SigDF"]**2).astype("Stddev")
    extrap["ExWDF"] = (extrap["FOBS_ground"] + N*extrap["WDF"]).astype("F")
    extrap["ExWSigDF"] = np.sqrt(extrap["SIGFOBS_ground"]**2 + N**2 * extrap["WSigDF"]**2).astype("Stddev")
    ref = rs.read_mtz(phase_pdb_mtz)
    extrap["PHIFMODEL"] = ref.loc[extrap.index, "PHIFMODEL"]
    extrap["FreeR_flag"] = r_free_mtz["FreeR_flag"].reindex(extrap.index, fill_value=0)
    extrap.write_mtz(output_string_mtz)
    return


def find_sites(realmap, threshold, cell):
    """
    Find local peaks in map.

    Parameters
    ----------
    realmap : np.ndarray
        3D array with voxelized electron density
    threshold : float
        Minimum voxelized density to consider for peaks
    cell : gemmi.UnitCell
        Cell parameters for crystal

    Returns
    -------
    pd.DataFrame
        DataFrame with coordinates and peak height for each site
    """
    from skimage import feature
    peaks = feature.peak_local_max(realmap, threshold_abs=threshold, exclude_border=False)
    data = []
    for p in peaks:
        pf = p/np.array(realmap.shape)
        pf_2 = pf + np.array([-1, 0, 0])
        pf_3 = pf + np.array([0, 0, 1])
        pf_4 = pf + np.array([-1, 0, -1])
        pos = cell.orthogonalize(gemmi.Fractional(*pf))
        pos_2 = cell.orthogonalize(gemmi.Fractional(*pf_2))
        pos_3 = cell.orthogonalize(gemmi.Fractional(*pf_3))
        pos_4 = cell.orthogonalize(gemmi.Fractional(*pf_4))
        d  = {"x": pos.x, "y": pos.y, "z": pos.z}
        d_2  = {"x": pos_2.x, "y": pos_2.y, "z": pos_2.z}
        d_3  = {"x": pos_3.x, "y": pos_3.y, "z": pos_3.z}
        d_4  = {"x": pos_4.x, "y": pos_4.y, "z": pos_4.z}
        d["height"] = realmap[p[0], p[1], p[2]]
        d_2["height"] = realmap[p[0], p[1], p[2]]
        d_3["height"] = realmap[p[0], p[1], p[2]]
        d_4["height"] = realmap[p[0], p[1], p[2]]
        data.append(d)
        data.append(d_2)
        data.append(d_3)
        data.append(d_4)
    return pd.DataFrame(data)


def write_pdb(sites, cell):
    s = gemmi.Structure()
    s.cell = cell
    m = gemmi.Model("substructure")
    c = gemmi.Chain("A")
    for i, site in sites.iterrows():
        a = gemmi.Atom()
        a.element = gemmi.Element("S")
        a.name = "S"
        a.pos = gemmi.Position(*site[["x", "y", "z"]].values)
        r = gemmi.Residue()
        r.name = "S"
        r.seqid = gemmi.SeqId(str(i+1))
        r.add_atom(a)
        c.append_residues([r])
    m.add_chain(c)
    s.add_model(m)
    s.assign_label_seq_id(force=True)
    s.write_minimal_pdb("sites_quad_pos.pdb")
    return

def IADDAT(input_PDB_filename, input_MTZ_filename, threshold_dict, distance_cutoff, average_out=False):
    name_init = input_MTZ_filename.replace("./diff_maps_refinedPHI/F_internal_","")
    name = name_init.replace("_extrapmap_W0-05_N1.mtz","")
    elements = [gemmi.Element('C'), gemmi.Element('N'), gemmi.Element('O'), gemmi.Element('S')]
    
    input_PDB = gemmi.read_structure(input_PDB_filename)
    input_MTZ = rs.read_mtz(input_MTZ_filename)
    input_MTZ.compute_dHKL(inplace=True)
    grid_sampling = 0.25
    a_sampling = int(input_MTZ.cell.a/(input_MTZ.dHKL.min()*grid_sampling))
    b_sampling = int(input_MTZ.cell.b/(input_MTZ.dHKL.min()*grid_sampling))
    c_sampling = int(input_MTZ.cell.c/(input_MTZ.dHKL.min()*grid_sampling))
    input_MTZ["sf"] = input_MTZ.to_structurefactor("WDF", "PHIFMODEL")
    reciprocalgrid = input_MTZ.to_reciprocalgrid("sf", gridsize=(a_sampling, b_sampling, c_sampling))
    realmap = np.real(np.fft.fftn(reciprocalgrid))
    sites = find_sites(realmap, realmap.std()*threshold_dict[name], input_MTZ.cell)
    sites_neg = find_sites(realmap*-1, realmap.std()*threshold_dict[name], input_MTZ.cell)
    sites_coords = np.array([sites.x, sites.y, sites.z])
    sites_neg_coords = np.array([sites_neg.x, sites_neg.y, sites_neg.z])
    
    IADDAT = []
    for residue in input_PDB[0]["A"]:
        atom_coords = []
        for atom in residue:
            if atom.element in elements:
                pos  = {"x": atom.pos.x, "y": atom.pos.y, "z": atom.pos.z}
                atom_coords.append(pos)
            else:
                pass
        atom_coords_df = pd.DataFrame(atom_coords)
        
        if atom_coords_df.empty:
            print("non AA processed")
            pass
        else:
            distances_pos = spsd.cdist(atom_coords_df, sites_coords.transpose())
            filter_pos = []
            for row in distances_pos.transpose():
                filt_pos = np.any(row <= distance_cutoff)
                filter_pos.append(filt_pos)
            filt_pos_df = pd.DataFrame(filter_pos)

            sites["filtered"] = filt_pos_df
            integrate_pos = sites.loc[sites['filtered'] == True]
            int_pos_value = integrate_pos.height.sum()

            distances_neg = spsd.cdist(atom_coords_df, sites_neg_coords.transpose())
            filter_neg = []
            for row in distances_neg.transpose():
                filt_neg = np.any(row <= distance_cutoff)
                filter_neg.append(filt_neg)
            filt_neg_df = pd.DataFrame(filter_neg)

            sites_neg["filtered"] = filt_neg_df
            integrate_neg = sites_neg.loc[sites_neg['filtered'] == True]
            int_neg_value = integrate_neg.height.sum()

            int_value = int_pos_value + int_neg_value
            if average_out:
                new_int_value = int_value / len(atom_coords)
                IADDAT.append(new_int_value)
            else:
                IADDAT.append(int_value)
    
    return IADDAT