In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
import shutil
import path

In [2]:
import numpy as np
import networkx as nx
import pickle
from latticeproteins.sequences import hamming_distance

# ----------------------------------------------------------
# Main Code
# ----------------------------------------------------------

def testing(dataset, d):
    # Collect bad datasets

    # Iterate through many dataset
        # Read in actual dataset
        with open("example-prediction/walks-actual-"+str(d)+".pickle", "rb") as f:
            data = pickle.load(f)
            edges0 = data["edges"]
            seq = data["seq"]
            target = data["target"]
            temp = data["temp"]
            db = data["db"]

        # Read in predicted dataset
        with open("example-prediction/walks-predicted-"+str(d)+".pickle", "rb") as f:
            data = pickle.load(f)
            edges1 = data["edges"]
            
        # Read in predicted dataset
        with open("example-prediction/walks-predicted2-"+str(d)+".pickle", "rb") as f:
            data = pickle.load(f)
            edges2 = data["edges"]

        # Construct as networks
        G0, G1, G2 = build_graphs(edges0, edges1, edges2, seq)
        return G0, G1, G2, seq

    
# ----------------------------------------------------------
# Functions used in this code.
# ----------------------------------------------------------

def retrieve(dataset, d):
    import os
    remote_path = "epistasis-ensembles-scripts/{:}/results".format(dataset)
    local_path = os.getcwd() + "/example-prediction/".format(dataset)
    part0 = "/walks-actual-{:}.pickle".format(d)
    part1 = "/walks-predicted-{:}.pickle".format(d)
    part2 = "/walks-predicted2-{:}.pickle".format(d)   

    if False in [os.path.isfile(local_path + part1), os.path.isfile(local_path + part2)]:        
        import paramiko

        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect('aciss.uoregon.edu', username='zsailer', password='Za12ch#$')

        sftp = ssh.open_sftp()
        sftp.get(remote_path + part0, local_path + part0) 
        sftp.get(remote_path + part1, local_path + part1) 
        sftp.get(remote_path + part2, local_path + part2) 
        sftp.close() 
        print("downloaded!")
    
# ----------------------------------------------------------
# Functions used in this code.
# ----------------------------------------------------------

def flux_out_of_node(G, node_i):
    """Determine """
    # Get flux coming from source
    total_flux_avail = G.node[node_i]["flux"]
    
    edges = {}
    # Normalize the transition probability from source
    norm = sum([G.edge[node_i][node_j]["weight"] for node_j in G.neighbors(node_i)])
    #print(norm)
    # Iterate over neighbors divvy up flux across neighbors
    for node_j in G.neighbors(node_i):
        if norm > 0:
            fixation = G.edge[node_i][node_j]["weight"]
            dflux = (fixation/norm) * total_flux_avail
            G.edge[node_i][node_j]["delta_flux"] = dflux
            G.node[node_j]["flux"] += dflux
        else:
            G.edge[node_i][node_j]["delta_flux"] = 0
    return edges

def flux_from_source(G, source):
    # Reset the flux of each node
    init_flux = dict([(node, 0) for node in G.nodes()])
    nx.set_node_attributes(G, "flux", init_flux)
    G.node[source]["flux"] = 1
    # Add flux to each node.
    levels = ring_levels(G, source)
    for l in levels:
        for node_i in levels[l]:
            edges = flux_out_of_node(G, node_i)
            for key, flux_to_add in edges.items():
                node_i, node_j = key
                G.node[node_j]["flux"] += flux_to_add
    return G

def ring_levels(G, root):
    levels = dict([(i,[]) for i in range(20)])
    levels[0].append(root)
    for node in G.nodes():
        neighbors = G.neighbors(node)
        for neigh in neighbors:
            key = hamming_distance(root, neigh)
            levels[key].append(neigh)
    for key, val in levels.items():
        z = sorted(list(set(val)))
        levels[key] = z
    return levels

def build_graphs(edges0, edges1, edges2, source):
    """Construct two different networks from a set of edges.
    """
    # -----------------------------------------------
    # build initial graphs
    # -----------------------------------------------

    # Build Graph
    G0 = nx.DiGraph()
    for key, weight in edges0:
        i,j = key[0], key[1]
        G0.add_edge(i,j, weight=weight["weight"])
        
    G1 = nx.DiGraph()
    for key, weight in edges1:
        i,j = key[0], key[1]
        G1.add_edge(i,j, weight=weight["weight"])
        
    # Build Graph
    G2 = nx.DiGraph()
    for key, weight in edges2:
        i,j = key[0], key[1]
        G2.add_edge(i,j, weight=weight["weight"])

    seq = source
    
    # -----------------------------------------------
    # Calculate the flux at each node and edge
    # -----------------------------------------------
    
    G0 = flux_from_source(G0, seq)
    G1 = flux_from_source(G1, seq)
    G2 = flux_from_source(G2, seq)

    return G0, G1, G2

In [200]:
def radial(r, theta):
    return (r*np.cos(theta), r*np.sin(theta))

def ring_position(G, root):
    import random
    levels = ring_levels(G, root)
    pos = {}
    for i in range(len(levels)):
        nodes = levels[i]
        nodelist = list(nodes)
        random.shuffle(nodelist)
        rotate_translate = random.random()
        for j, node in enumerate(nodelist):
            angle = 2*np.pi / len(nodes)
            pos[node] = radial(i, j*angle + + rotate_translate)
    return pos


def modified_Gdiff(G0, G2):
    # Get a dictionary of change in fluxes along each edge.
    edges_0 = dict([((i, j), G0.edge[i][j]["delta_flux"]) for i,j in G0.edges()])
    edges_2 = dict([((i, j), G2.edge[i][j]["delta_flux"]) for i,j in G2.edges()])

    # -----------------------------------------------
    # Calculate the change in delta_flux on each edge
    # -----------------------------------------------
    edges_diff = {}
    # See what edges we lost
    for key, val in edges_0.items():
        if key in edges_2:
            weight = edges_2[key] - edges_0[key]
            if weight < 0:
                # This edge gained flux
                color = "r"
            else:
                # This edge lost flux
                color = "b"
            edges_diff[key] = dict(color=color, weight=abs(weight))
        else:
            # This edge was lost in our predictions
            edges_diff[key] = dict(weight=val, color="r")

    # See what edges we gained.
    for key, val in edges_2.items():
        if key in edges_0:
            pass
        else:
            # This edge was gained in our predictions
            edges_diff[key] = dict(weight=val, color="b")

    # -----------------------------------------------
    # Calculate the change in flux at each node
    # -----------------------------------------------
    nodes_0 = dict([(i, G0.node[i]["flux"]) for i in G0.nodes()])
    nodes_2 = dict([(i, G2.node[i]["flux"]) for i in G2.nodes()])

    node_diff = {}
    for key, val in nodes_0.items():
        if key in nodes_2:
            diff = nodes_2[key] - val
            if diff > 0:
                color = "b"
            else:
                color = "r"
            node_diff[key] = dict(color=color, outer=nodes_2[key], inner=val)
        else:
            node_diff[key] = dict(color="r", outer=nodes_0[key], inner=0)

    for key, val in nodes_2.items():
        if key in nodes_0:
            pass
        else:
            node_diff[key] = dict(color="b", outer=val, inner=0)

    # -----------------------------------------------
    # Construct a network of differences
    # -----------------------------------------------
    Gdiff = nx.DiGraph()
    for key, val in edges_diff.items():
        Gdiff.add_edge(key[0],key[1],**val)

    for key, val in node_diff.items():
        Gdiff.node[key].update(**val)

    return Gdiff

def top_flux_path(G, seq):
    current = seq
    path = [seq]
    while True:
        neighbors = list(G.neighbors(current))
        if len(neighbors) == 0:
            return path
        dflux = [G.edge[current][node]["delta_flux"] for node in neighbors]
        index = np.argmax(dflux)
        if dflux[index] == 0:
            return path
        current = neighbors[index]
        path.append(current)

def plot_networks(G0, G1, G2, source, pos=None):
    """"""
    # options
    node_scale = 300
    edge_scale = 20
    node_color = "k"
    threshold = 0.01
    
    # Remove nodes that have small flux
    nodes_to_remove0 = []
    for node in G0.nodes():
        if G0.node[node]["flux"] < threshold:
            nodes_to_remove0.append(node)    
    
    # Remove nodes that have small flux
    nodes_to_remove1 = []
    for node in G1.nodes():
        if G1.node[node]["flux"] < threshold:
            nodes_to_remove1.append(node)

    nodes_to_remove2 = []
    for node in G2.nodes():
        if G2.node[node]["flux"] < threshold:
            nodes_to_remove2.append(node)
      
    G0.remove_nodes_from(nodes_to_remove0)
    G1.remove_nodes_from(nodes_to_remove1)
    G2.remove_nodes_from(nodes_to_remove2)
   
    path0 = top_flux_path(G0, source)
    path2 = top_flux_path(G2, source)
    
    Gdiff = modified_Gdiff(G0, G2)
    Gdiff_ = modified_Gdiff(G0, G1)
    Gdiff__ = nx.compose(Gdiff, Gdiff_)
    
    from matplotlib.gridspec import GridSpec
    from matplotlib.patches import Circle
    
    def draw_circles(ax):
        """Draw circles add increasing hamming distances for each network."""
        for i in range(0,8):
            circle = Circle((0, 0), i, facecolor='none',
                    edgecolor="k", linewidth=.5, alpha=0.5, linestyle="--")
            ax.add_patch(circle)

    
    # Initialize a figure
    fig = plt.figure(figsize=(20,8))
    
    # Initialize a gridspec
    gs = GridSpec(1, 3)
       
    seq = source
    # Calculate the positions for all nodes on rings
    if pos is None:
        pos = ring_position(Gdiff__, seq)

    # -------------------------------------------------
    # Draw the first network
    # -------------------------------------------------
    
    ax1 = plt.subplot(gs[0, 0])
    
    # Draw path
    elist = []
    if path0 is not None:
        ewidths = [G0.edge[path0[i-1]][path0[i]]["delta_flux"] for i in range(1,len(path0))]
        elist = [(path0[i-1],path0[i]) for i in range(1,len(path0))]
        ewidths = np.array(ewidths) * edge_scale
        nx.draw_networkx_edges(G0, pos=pos, ax=ax1,
           edgelist=elist,
           width=ewidths,
           arrows=False,
           edge_color="orange",
           alpha=.6
        )    
    
    # Set the widths of the edges to the delta flux attribute of each edge.
    
    edgelist = list(G0.edges())
    for edge in elist:
        edgelist.remove(edge)
        
    edge_widths = np.array([G0.edge[i][j]["delta_flux"] for i,j in edgelist])
    edge_widths = edge_widths * edge_scale
    nx.draw_networkx_edges(G0, pos=pos, ax=ax1,
        edgelist=edgelist,
        width=edge_widths,                
        arrows=False,
        edge_color="gray",
        alpha=0.6
    )
    
    # Set the node sizes to the amount of flux passing through each node.
    node_size = [G0.node[i]["flux"] * node_scale for i in G0.nodes()]
    nx.draw_networkx_nodes(G0, pos=pos, ax=ax1,
        node_size=node_size,                
        linewidths=None,
        node_color=node_color
    )
    
    bad_nodes1 = [node for node in Gdiff.nodes() if node not in G0.nodes()]
    bad_nodes1_size = [G2.node[node]["flux"] * node_scale for node in Gdiff.nodes() if node not in G0.nodes()]

    nx.draw_networkx_nodes(Gdiff, pos=pos, ax=ax1,
        nodelist = bad_nodes1,
        node_shape = "x",
        node_size = bad_nodes1_size,
        linewidths = None,
        node_color = "m"
    )

    
    # Draw circles
    draw_circles(ax1)
    ax1.axis("equal")
    ax1.axis("off")
    
    # -------------------------------------------------
    # Draw the second network
    # -------------------------------------------------
    
    ax2 = plt.subplot(gs[0, 1])
    
    # Draw path
    elist = []
    if path2 is not None:
        ewidths = [G2.edge[path2[i-1]][path2[i]]["delta_flux"] for i in range(1,len(path2))]
        elist = [(path2[i-1],path2[i]) for i in range(1,len(path2))]
        ewidths = np.array(ewidths) * edge_scale
        nx.draw_networkx_edges(G2, pos=pos, ax=ax2,
           edgelist=elist,
           width=ewidths,
           arrows=False,
           edge_color="orange",
           alpha=0.6
        )    
    
    # Set the widths of the edges to the delta flux attribute of each edge.
    edgelist = list(G2.edges())
    for edge in elist:
        edgelist.remove(edge)
    edge_widths = np.array([G2.edge[i][j]["delta_flux"] for i,j in edgelist])
    edge_widths = edge_widths * edge_scale
    nx.draw_networkx_edges(G2, pos=pos, ax=ax2,
        edgelist=edgelist,
        width=edge_widths,                
        arrows=False,
        edge_color="gray",
        alpha=0.6
    )    
    
    # Set the node sizes to the amount of flux passing through each node.
    node_size = [G2.node[i]["flux"] * node_scale for i in G2.nodes()]

    nx.draw_networkx_nodes(G2, pos=pos, ax=ax2,
        node_size=node_size,                
        linewidths=None,
        node_color=node_color
    )

    bad_nodes2 = [node for node in Gdiff.nodes() if node not in G2.nodes()]
    bad_nodes2_size = [G0.node[node]["flux"] * node_scale for node in Gdiff.nodes() if node not in G2.nodes()]

    nx.draw_networkx_nodes(Gdiff, pos=pos, ax=ax2,
        nodelist = bad_nodes2,
        node_shape = "x",
        node_size = bad_nodes2_size,
        linewidths = None,
        node_color = "m"
    )
        
    # Draw circles
    draw_circles(ax2) 
    ax2.axis("equal")
    ax2.axis("off")
    
    # -------------------------------------------------
    # Draw difference network
    # -------------------------------------------------
        
    ax3 = plt.subplot(gs[0, 2])

    
    # Set the widths of the edges to the delta flux attribute of each edge.
    edge_widths = np.array([Gdiff.edge[i][j]["weight"] for i,j in Gdiff.edges()])
    edge_widths = edge_widths * edge_scale
    edge_color = [Gdiff.edge[i][j]["color"] for i,j in Gdiff.edges()]

    nx.draw_networkx_edges(Gdiff, pos=pos, ax=ax3,
        width=edge_widths,                
        arrows=False,
        edge_color=edge_color,
        alpha=0.5
    )
    
    # Set the node sizes to the amount of flux passing through each node.
    node_size = [Gdiff.node[i]["outer"] * node_scale for i in Gdiff.nodes()]
    node_color = [Gdiff.node[i]["color"]  for i in Gdiff.nodes()]
    nx.draw_networkx_nodes(Gdiff, pos=pos, ax=ax3,
        node_size=node_size,                
        linewidths=None,
        node_color=node_color
    )

    # Set the node sizes to the amount of flux passing through each node.
    node_size = [Gdiff.node[i]["inner"] * node_scale for i in Gdiff.nodes()]
    nx.draw_networkx_nodes(Gdiff, pos=pos, ax=ax3,
        node_size=node_size,                
        linewidths=None,
        node_color="w"
    )    
    
    # Draw circles
    draw_circles(ax3) 

    ax3.axis("equal")
    ax3.axis("off")
    return fig, pos

def fixation(fitness1, fitness2,*args, **kwargs):
    """ Simple Gillespie fixation probability between two organism with fitnesses 1 and 2.
    (With infinite population size!)
    
    .. math::
        p_{\\text{fixation}} = \\frac{1 - e^{-N \\frac{f_2-f_1}{f1}}}{1 - e^{-\\frac{f_2-f_1}{f1}}}
    """
    sij = (fitness2 - fitness1)/abs(fitness1)
    # Check if any nans exist if an array of fitnesses is given.
    fixation = 1 - np.exp(-sij)
    if type(fixation) == np.ndarray:
        fixation = np.nan_to_num(fixation)
        fixation[sij < 0] = 0
    return  fixation

* 41
* 87
* 180
* 236
* 237
* 284!
* 296!
* 417
* 433
* 486
* 487
* 497


In [224]:
z=0

In [225]:
import pickle

In [229]:
dataset = "full-state-predictions"
d = 417
z = 0
for j in range(30):
    z += 1
    retrieve(dataset, d)
    G0, G1, G2, seq = testing(dataset, d)
    fig, pos = plot_networks(G0, G1, G2, seq) #, pos=pos)
    with open("../figures/positions-" + seq +"-"+str(d)+"-"+str(z)+".pickle", "wb") as f:
        pickle.dump(pos, f)
    plt.close()
    fig.savefig("../figures/network-" + seq +"-"+str(d)+"-"+str(z)+"-2.pdf", format="pdf")
    fig, pos = plot_networks(G0, G2, G1, seq, pos=pos)
    fig.savefig("../figures/network-" + seq +"-"+str(d)+"-"+str(z)+"-1.pdf", format="pdf")
    plt.close()
    #print(d)

# Extract top trajectory

In [618]:
import itertools as it
from latticeproteins.thermodynamics import LatticeThermodynamics
from latticeproteins.interactions import miyazawa_jernigan
from latticeproteins.conformations import ConformationList, Conformations
from latticeproteins.sequences import find_differences, _residues
from latticeproteins.evolve import monte_carlo_fixation_walk, fixation
from latticeproteins.sequences import random_sequence, hamming_distance

class PredictedLattice(object):
    """Predict the stability and fraction folded of any sequence with respect
    to some wildtype lattice model. Calculates the independent effect of all mutations
    (and pairwise effects if `double` is True) and sums those effects to predict
    other sequences.
    """
    def __init__(self, wildtype, temp, confs, double=False, target=None):
        self.wildtype = wildtype
        self.temp = temp
        self.conformations = confs
        self.target = target
        self._lattice = LatticeThermodynamics(self.temp, self.conformations)
        self.double = double

        combos = []
        sites = list(range(self.conformations.length()))
        self.dG0 = self._lattice.stability(self.wildtype, target=self.target)


        #####  Build a dictionary of additive and pairwise mutational effects ####
        # Calculate first order coefs
        self.dGs = {}
        for i in sites:
            other_sites = sites[:]
            other_sites.remove(i)
            for aa in _residues:
                combos.append((i, aa))

        for c in combos:
            seq = list(self.wildtype[:])
            seq[c[0]] = c[1]
            # Calculate dG as dG_wt -
            self.dGs[c] = self._lattice.stability(seq, target=self.target) - self.dG0

        if self.double:
            # Calculate second order coefs
            combos = []
            sites = list(range(self.conformations.length()))
            for i in sites:
                other_sites = sites[:]
                other_sites.remove(i)
                for aa in _residues:
                    for j in other_sites:
                        for aa2 in _residues:
                            combos.append((i,aa,j,aa2))

            for c in combos:
                seq = list(self.wildtype[:])
                seq[c[0]] = c[1]
                seq[c[2]] = c[3]
                # Calculate dG2
                self.dGs[c] = self._lattice.stability(seq, target=self.target) - (self.dG0 + self.dGs[(c[0],c[1])]+ self.dGs[(c[2],c[3])])

    def stability(self, seq, target=None):
        """Calculate the stability of a given sequence using the Lattice predictor"""
        # Get additive coefs to build predictions
        if target != self.target:
            raise Exception("Target does not match wildtype target.")
        loci = find_differences(self.wildtype, seq)
        # Get all additive combinations for the sequence given
        add = [(pair[0], seq[pair[0]]) for pair in it.combinations(loci, 1)]
        if self.double:
            # Get all pairwise effects for the sequence given
            pairs = [(pair[0], seq[pair[0]], pair[1], seq[pair[1]]) for pair in it.combinations(loci, 2)]
            dgs = add + pairs
        else:
            dgs = add
        # Get the wildtype stability
        stability = float(self.dG0)
        # Sum the mutational effects
        for coef in dgs:
            stability += self.dGs[coef]
        return stability

    def fracfolded(self, seq, target=None):
        """Calculate the fraction folded for a given sequence"""
        return 1.0 / (1.0 + np.exp(self.stability(seq, target=target) / self.temp))


In [589]:
# Build up landscape
temp = 1
length = len(seq)

confs = Conformations(length, "database")
cs = confs.k_lowest_confs(seq, temp, 3)
db1 = cs[0]
db2 = cs[1]
db3 = cs[2]

target = cs[0]
db = [db1, db2, db3]

# Construct a lattice model calculator
confs = Conformations(length, "database")
cs = confs.k_lowest_confs(seq, temp, 3)
target = cs[0]
lattice = LatticeThermodynamics(temp, confs)
plattice1 = PredictedLattice(seq, temp, confs, double=False, target=target)
plattice2 = PredictedLattice(seq, temp, confs, double=True, target=target)