In [47]:
import os, sys
import numpy as np
import networkx as nx
import itertools as it
import pickle
import argparse
from multiprocessing import Process

from latticeproteins.thermodynamics import LatticeThermodynamics
from latticeproteins.interactions import miyazawa_jernigan
from latticeproteins.conformations import ConformationList, Conformations
from latticeproteins.sequences import find_differences, _residues
from latticeproteins.evolve import monte_carlo_fixation_walk, fixation
from latticeproteins.sequences import random_sequence, hamming_distance

In [39]:
def fixation(fitness1, fitness2,*args, **kwargs):
    """ Simple Gillespie fixation probability between two organism with fitnesses 1 and 2.
    (With infinite population size!)

    .. math::
        p_{\\text{fixation}} = \\frac{1 - e^{-N \\frac{f_2-f_1}{f1}}}{1 - e^{-\\frac{f_2-f_1}{f1}}}
    """
    sij = (fitness2 - fitness1)/abs(fitness1)
    # Check if any nans exist if an array of fitnesses is given.
    fixation = 1 - np.exp(-sij)
    if type(fixation) == np.ndarray:
        fixation = np.nan_to_num(fixation)
        fixation[sij < 0] = 0
    return  fixation

class PredictedLattice(object):
    """Predict the stability and fraction folded of any sequence with respect
    to some wildtype lattice model. Calculates the independent effect of all mutations
    (and pairwise effects if `double` is True) and sums those effects to predict
    other sequences.
    """
    def __init__(self, wildtype, temp, confs, double=False, target=None):
        self.wildtype = wildtype
        self.temp = temp
        self.conformations = confs
        self.target = target
        self._lattice = LatticeThermodynamics(self.temp, self.conformations)
        self.double = double

        combos = []
        sites = list(range(self.conformations.length()))
        self.dG0 = self._lattice.stability(self.wildtype, target=self.target)


        #####  Build a dictionary of additive and pairwise mutational effects ####
        # Calculate first order coefs
        self.dGs = {}
        for i in sites:
            other_sites = sites[:]
            other_sites.remove(i)
            for aa in _residues:
                combos.append((i, aa))

        for c in combos:
            seq = list(self.wildtype[:])
            seq[c[0]] = c[1]
            # Calculate dG as dG_wt -
            self.dGs[c] = self._lattice.stability(seq, target=self.target) - self.dG0

        if self.double:
            # Calculate second order coefs
            combos = []
            sites = list(range(self.conformations.length()))
            for i in sites:
                other_sites = sites[:]
                other_sites.remove(i)
                for aa in _residues:
                    for j in other_sites:
                        for aa2 in _residues:
                            combos.append((i,aa,j,aa2))

            for c in combos:
                seq = list(self.wildtype[:])
                seq[c[0]] = c[1]
                seq[c[2]] = c[3]
                # Calculate dG2
                self.dGs[c] = self._lattice.stability(seq, target=self.target) - (self.dG0 + self.dGs[(c[0],c[1])]+ self.dGs[(c[2],c[3])])

    def stability(self, seq, target=None):
        """Calculate the stability of a given sequence using the Lattice predictor"""
        # Get additive coefs to build predictions
        if target != self.target:
            raise Exception("Target does not match wildtype target.")
        loci = find_differences(self.wildtype, seq)
        # Get all additive combinations for the sequence given
        add = [(pair[0], seq[pair[0]]) for pair in it.combinations(loci, 1)]
        if self.double:
            # Get all pairwise effects for the sequence given
            pairs = [(pair[0], seq[pair[0]], pair[1], seq[pair[1]]) for pair in it.combinations(loci, 2)]
            dgs = add + pairs
        else:
            dgs = add
        # Get the wildtype stability
        stability = float(self.dG0)
        # Sum the mutational effects
        for coef in dgs:
            stability += self.dGs[coef]
        return stability

    def fracfolded(self, seq, target=None):
        """Calculate the fraction folded for a given sequence"""
        return 1.0 / (1.0 + np.exp(self.stability(seq, target=target) / self.temp))


def enumerate_walks(seq, lattice, selected_trait="fracfolded", max_mutations=5, target=None):
    """Use Monte Carlo method to walk

    Parameters
    ----------
    seq : str
        seq
    lattice : LatticeThermodynamics object
        Lattice protein calculator
    selected_trait : str
        The trait to select.
    max_mutations : int (default = 15)
        Max number of mutations to make in the walk.
    target : str
        selected lattice target conformation. If None, the lattice will
        fold to the natural native conformation.

    Returns
    -------
    edges : list
        A list of all edges by the random walks out. Each element in the lists
        is a tuple. Tuple[0] is (seq_i, seq_j), Tuple[1] is {"weight" : <fixation probability>}.
    """
    length = len(seq)
    fitness_method = getattr(lattice, selected_trait)
    fitness0 = fitness_method(seq, target=target)
    finished = False

    moves = [seq]
    fitnesses = [fitness0]
    edges = []
    ends = []
    z = 0
    while len(moves) != 0 and z < max_mutations:

        new_moves = []
        new_fitnesses = []

        for i, m in enumerate(moves):
            sequence = list(m[:])
            fitness0 = fitnesses[i]
            # Construct grid of all stabilities of all amino acids at all sites
            AA_grid = np.array([_residues]*length)
            fits = np.zeros(AA_grid.shape, dtype=float)
            for (i,j), AA in np.ndenumerate(AA_grid):
                seq1 = sequence[:]
                seq1[i] = AA_grid[i,j]
                fits[i,j] = fitness_method(seq1, target=target)

            # Calculate fitness for all neighbors in sequence space
            fix = fixation(fitness0, fits)*(1./fits.size) # multplied by flat prior for all mutations
            site, aa_index = np.where(fix > 1.0e-20) # Select fixations that are 1 percent of max
            #print(aa_index)
            AA = AA_grid[site, aa_index]
            FF = fits[site, aa_index]
            prob = fix[site, aa_index]
            #print(fix)
            for i in range(len(site)):
                move = sequence[:]
                move[site[i]] = AA[i]
                if move != sequence and move not in ends:
                    ends.append("".join(move))
                    new_moves.append("".join(move))
                    new_fitnesses.append(FF[i])
                    edges.append((("".join(sequence[:]),"".join(move)), {"weight" : prob[i]}))
                    

        moves, indices = np.unique(new_moves, return_index=True)
        print(len(moves))
        fitnesses = np.array(new_fitnesses)[indices]
        z += 1
        if len(moves) > 1000:
            break

    return edges


# ----------------------------------------------------------
# Functions used in this code.
# ----------------------------------------------------------

def flux_out_of_node(G, node_i):
    """Determine """
    # Get flux coming from source
    total_flux_avail = G.node[node_i]["flux"]
    edges = {}
    # Normalize the transition probability from source
    norm = sum([G.edge[node_i][node_j]["weight"] for node_j in G.neighbors(node_i)])
    # Iterate over neighbors divvy up flux across neighbors
    for node_j in G.neighbors(node_i):
        if norm != 0:
            fixation = G.edge[node_i][node_j]["weight"]
            dflux = (fixation/norm) * total_flux_avail
            G.edge[node_i][node_j]["delta_flux"] = dflux
            G.node[node_j]["flux"] += dflux
        else:
            G.edge[node_i][node_j]["delta_flux"] = 0
    return edges

def flux_from_source(G, source):
    # Reset the flux of each node
    init_flux = dict([(node, 0) for node in G.nodes()])
    nx.set_node_attributes(G, "flux", init_flux)
    G.node[source]["flux"] = 1
    # Add flux to each node.
    levels = ring_levels(G, source)
    for l in levels:
        for node_i in levels[l]:
            edges = flux_out_of_node(G, node_i)
            for key, flux_to_add in edges.items():
                node_i, node_j = key
                G.node[node_j]["flux"] += flux_to_add
    return G


def ring_levels(G, root):
    levels = dict([(i,[]) for i in range(20)])
    levels[0].append(root)
    for node in G.nodes():
        neighbors = G.neighbors(node)
        for neigh in neighbors:
            key = hamming_distance(root, neigh)
            levels[key].append(neigh)
    for key, val in levels.items():
        z = sorted(list(set(val)))
        levels[key] = z
    return levels


def build_graphs(edges1, edges2, source):
    """Construct two different networks from a set of edges.
    """
    # -----------------------------------------------
    # build initial graphs
    # -----------------------------------------------
    edges0 = edges1
    # Build Graph
    G0 = nx.DiGraph()
    for key, weight in edges0:
        i,j = key[0], key[1]
        G0.add_edge(i,j, weight=weight["weight"])

    # Build Graph
    G2 = nx.DiGraph()
    for key, weight in edges2:
        i,j = key[0], key[1]
        G2.add_edge(i,j, weight=weight["weight"])

    seq = source
    # -----------------------------------------------
    # Calculate the flux at each node and edge
    # -----------------------------------------------
    G0 = flux_from_source(G0, seq)
    G2 = flux_from_source(G2, seq)

    # Get a dictionary of change in fluxes along each edge.
    edges_0 = {(i,j):  G0.edge[i][j]["delta_flux"] for i,j in G0.edges()}
    edges_2 = {(i,j):  G2.edge[i][j]["delta_flux"] for i,j in G2.edges()}

    # -----------------------------------------------
    # Calculate the change in delta_flux on each edge
    # -----------------------------------------------
    edges_diff = {}
    # See what edges we lost
    for key, val in edges_0.items():
        if key in edges_2:
            weight = edges_2[key] - edges_0[key]
            if weight < 0:
                # This edge gained flux
                color = "r"
            else:
                # This edge lost flux
                color = "b"
            edges_diff[key] = dict(color=color, weight=abs(weight))
        else:
            # This edge was lost in our predictions
            edges_diff[key] = dict(weight=val, color="r")

    # See what edges we gained.
    for key, val in edges_2.items():
        if key in edges_0:
            pass
        else:
            # This edge was gained in our predictions
            edges_diff[key] = dict(weight=val, color="b")

    # -----------------------------------------------
    # Calculate the change in flux at each node
    # -----------------------------------------------
    nodes_0 = {i: G0.node[i]["flux"] for i in G0.nodes()}
    nodes_2 = {i: G2.node[i]["flux"] for i in G2.nodes()}

    node_diff = {}
    for key, val in nodes_0.items():
        if key in nodes_2:
            diff = nodes_2[key] - val
            if diff > 0:
                color = "b"
            else:
                color = "r"
            node_diff[key] = dict(color=color, outer=nodes_2[key], inner=val)
        else:
            node_diff[key] = dict(color="r", outer=nodes_0[key], inner=0)

    for key, val in nodes_2.items():
        if key in nodes_0:
            pass
        else:
            node_diff[key] = dict(color="b", outer=val, inner=0)

    # -----------------------------------------------
    # Construct a network of differences
    # -----------------------------------------------
    Gdiff = nx.DiGraph()
    for key, val in edges_diff.items():
        Gdiff.add_edge(key[0],key[1],**val)

    for key, val in node_diff.items():
        Gdiff.node[key].update(**val)

    return G0, G2, Gdiff

In [43]:
seq = "WIKSKCMFCSWH"
temp = 1
length = len(seq)

confs = Conformations(length, "database")
cs = confs.k_lowest_confs(seq, temp, 3)
db1 = cs[0]
db2 = cs[1]
db3 = cs[2]

target = cs[0]
db = [db1, db2, db3]

#c = [db[0], "U"*length]
#confs = ConformationList(length, c)

In [44]:
lattice = LatticeThermodynamics(temp, confs)
edges0 = enumerate_walks(seq, lattice, target=target, max_mutations=7)

plattice1 = PredictedLattice(seq, temp, confs, double=False, target=target)
edges1 = enumerate_walks(seq, plattice1, target=target, max_mutations=7)

11
48
118
184
236
261
270
11
51
123
173
162
120
78


In [45]:
plattice1 = PredictedLattice(seq, temp, confs, double=False, target=target)
edges1 = enumerate_walks(seq, plattice1, target=target, max_mutations=7)

11
51
123
173
162
120
78


In [48]:
Gactual, Gpredict, Gdiff = build_graphs(edges0, edges1, seq)

In [63]:


def plot_networks(G1, G2, Gdiff, source):
    """"""
    # options
    node_scale = 600
    edge_scale = 40
    node_color = "k"
    

    # Remove nodes that have small flux
    nodes_to_remove = []
    for node in G1.nodes():
        if G1.node[node]["flux"] < 0.01:
            nodes_to_remove.append(node)

    nodes_to_remove2 = []
    for node in G2.nodes():
        if G2.node[node]["flux"] < 0.01:
            nodes_to_remove2.append(node)
        
    G1.remove_nodes_from(nodes_to_remove)
    G2.remove_nodes_from(nodes_to_remove2)
    Gdiff = modified_Gdiff(G1, G2)

    from matplotlib.gridspec import GridSpec
    from matplotlib.patches import Circle
    
    def draw_circles(ax):
        """Draw circles add increasing hamming distances for each network."""
        for i in range(0,8):
            circle = Circle((0, 0), i, facecolor='none',
                    edgecolor="k", linewidth=.5, alpha=0.5, linestyle="--")
            ax.add_patch(circle)

    
    # Initialize a figure
    fig = plt.figure(figsize=(20,8))
    
    # Initialize a gridspec
    gs = GridSpec(1, 3)
       
    seq = source
    # Calculate the positions for all nodes on rings
    pos = ring_position(Gdiff, seq)

    # -------------------------------------------------
    # Draw the first network
    # -------------------------------------------------
    
    ax1 = plt.subplot(gs[0, 0])
    
    # Set the widths of the edges to the delta flux attribute of each edge.
    edge_widths = np.array([G1.edge[i][j]["delta_flux"] for i,j in G1.edges()])
    edge_widths = edge_widths * edge_scale
    #edge_widths = np.ma.log10(edge_widths).filled(0) * edge_scale
    nx.draw_networkx_edges(G1, pos=pos, ax=ax1,
        width=edge_widths,                
        arrows=False,
        edge_color="gray",
        alpha=0.5
    )
    
    # Set the node sizes to the amount of flux passing through each node.
    node_size = [G1.node[i]["flux"] * node_scale for i in G1.nodes()]
    #node_size = np.ma.log10(node_size).filled(0)

    nx.draw_networkx_nodes(G1, pos=pos, ax=ax1,
        node_size=node_size,                
        linewidths=None,
        node_color=node_color
    )
    
    bad_nodes1 = [node for node in Gdiff.nodes() if node not in G1.nodes()]
    bad_nodes1_size = [G2.node[node]["flux"] * node_scale for node in Gdiff.nodes() if node not in G1.nodes()]

    nx.draw_networkx_nodes(Gdiff, pos=pos, ax=ax1,
        nodelist = bad_nodes1,
        node_shape = "x",
        node_size = bad_nodes1_size,
        linewidths = None,
        node_color = "m"
    )
    
    # Draw circles
    draw_circles(ax1)
    ax1.axis("equal")
    ax1.axis("off")
    
    # -------------------------------------------------
    # Draw the second network
    # -------------------------------------------------
    
    ax2 = plt.subplot(gs[0, 1])

    
    # Set the widths of the edges to the delta flux attribute of each edge.
    edge_widths = np.array([G2.edge[i][j]["delta_flux"] for i,j in G2.edges()])
    edge_widths = edge_widths * edge_scale
    #edge_widths = np.ma.log10(edge_widths).filled(0) * edge_scale
    nx.draw_networkx_edges(G2, pos=pos, ax=ax2,
        width=edge_widths,                
        arrows=False,
        edge_color="gray",
        alpha=0.5
    )
    
    # Set the node sizes to the amount of flux passing through each node.
    node_size = [G2.node[i]["flux"] * node_scale for i in G2.nodes()]
    #node_size = np.ma.log10(node_size).filled(0)

    nx.draw_networkx_nodes(G2, pos=pos, ax=ax2,
        node_size=node_size,                
        linewidths=None,
        node_color=node_color
    )

    bad_nodes2 = [node for node in Gdiff.nodes() if node not in G2.nodes()]
    bad_nodes2_size = [G1.node[node]["flux"] * node_scale for node in Gdiff.nodes() if node not in G2.nodes()]

    nx.draw_networkx_nodes(Gdiff, pos=pos, ax=ax2,
        nodelist = bad_nodes2,
        node_shape = "x",
        node_size = bad_nodes2_size,
        linewidths = None,
        node_color = "m"
    )
        
    # Draw circles
    draw_circles(ax2) 
    ax2.axis("equal")
    ax2.axis("off")
    
    # -------------------------------------------------
    # Draw difference network
    # -------------------------------------------------
        
    ax3 = plt.subplot(gs[0, 2])

    
    # Set the widths of the edges to the delta flux attribute of each edge.
    edge_widths = np.array([Gdiff.edge[i][j]["weight"] for i,j in Gdiff.edges()])
    edge_widths = edge_widths * edge_scale
    #edge_widths = np.ma.log10(edge_widths).filled(0) * edge_scale 
    edge_color = [Gdiff.edge[i][j]["color"] for i,j in Gdiff.edges()]

    nx.draw_networkx_edges(Gdiff, pos=pos, ax=ax3,
        width=edge_widths,                
        arrows=False,
        edge_color=edge_color,
        alpha=0.5
    )
    
    # Set the node sizes to the amount of flux passing through each node.
    node_size = [Gdiff.node[i]["outer"] * node_scale for i in Gdiff.nodes()]
    node_color = [Gdiff.node[i]["color"]  for i in Gdiff.nodes()]
    nx.draw_networkx_nodes(Gdiff, pos=pos, ax=ax3,
        node_size=node_size,                
        linewidths=None,
        node_color=node_color
    )

    # Set the node sizes to the amount of flux passing through each node.
    node_size = [Gdiff.node[i]["inner"] * node_scale for i in Gdiff.nodes()]
    nx.draw_networkx_nodes(Gdiff, pos=pos, ax=ax3,
        node_size=node_size,                
        linewidths=None,
        node_color="w"
    )    
    
    # Draw circles
    draw_circles(ax3) 

    ax3.axis("equal")
    ax3.axis("off")
    return fig

In [64]:
plot_networks(Gactual, Gpredict, seq)

TypeError: plot_networks() missing 1 required positional argument: 'source'