# Basic example of latticeproteins

In [None]:
import latticeproteins as lp # envt is latticeproteins
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import binom
from helpers import DotAccessibleDict
import multiprocessing as mp

!pwd

In [None]:
seq_length = 12
lattice = lp.thermodynamics.LatticeThermodynamics.from_length(seq_length, 1.0)
conf = None
while conf is None:
    seq = lp.random_sequence(seq_length)
    conf = lattice.native_conf(seq)
print(seq, conf)

print("Energy of native conformation: %f" % lattice.nativeE(seq, target=conf))
print("stability of native conformation: %f" % lattice.stability(seq, target=conf))
print("fraction folded: %f" % lattice.fracfolded(seq, target=conf))

lp.draw.in_notebook(seq, conf)

In [None]:
lig_length = 6
lig_lattice = lp.thermodynamics.LatticeThermodynamics.from_length(lig_length, 1.0)

lig_conf = None
while lig_conf is None:
    lig_seq = lp.random_sequence(lig_length)
    lig_conf = lig_lattice.native_conf(lig_seq)
print(lig_seq, lig_conf)

print("Energy of native conformation: %f" % lig_lattice.nativeE(lig_seq, target=lig_conf))
print("stability of native conformation: %f" % lig_lattice.stability(lig_seq, target=lig_conf))
print("fraction folded: %f" % lig_lattice.fracfolded(lig_seq, target=lig_conf))

lp.draw.in_notebook([x.lower() for x in lig_seq], lig_conf, color_sequence="r"*lig_length)

In [None]:
# bind the ligand to the protein

(be, xshift, yshift, lig_conf) = lp.conformations.BindLigand(seq, conf, lig_seq, lig_conf)
ligand_tup = (lig_seq, None, lig_conf, xshift, yshift) # the None is for ligand_color_seq
print(f"be: {be:0.3f}, xshift: {xshift}, yshift: {yshift}")


# save to .txt file
with open("bound_ligand.txt", "w", encoding="utf-8") as f:
    lp.conformations.PrintConformation(seq, conf, file = f, latex_format = False, ligand_tup = ligand_tup)

# save to .svg file
lp.draw.to_file(seq, conf, "bound_ligand.svg", ligand_tup=ligand_tup)

# display in notebook
lp.draw.in_notebook(seq, conf, ligand_tup=ligand_tup)

---
---

In [None]:
# evolution!
 
def plot_prots(args, prots):
    count_no_native = 0
    Elist, Slist, FoldFracList, Blist, FitList = [], [], [], [], []
    for prot in prots:
        seq = prot.seq
        conf = lattice.native_conf(seq)
    
        if conf is not None:
            (E, S, FF, be, fit) = get_fitness_etc(lattice, seq, conf, lig_seq, lig_conf)
            Elist.append(E)
            Slist.append(S)
            FoldFracList.append(FF)
            Blist.append(be)
            FitList.append(fit)
        else:
            print("No unique native conformation for seq:", seq)
            count_no_native += 1
            FitList.append(0.0)

    print(f"Number of sequences with no unique native conformation: {count_no_native}")

    # plot histogram of each
    # plt.hist(Elist, bins=20)
    # plt.xlabel("Native Energy")
    # plt.ylabel("Count")
    # plt.show()

    # plt.hist(Slist, bins=20)
    # plt.xlabel("Stability")
    # plt.ylabel("Count")
    # plt.show()

    # plt.hist(FoldFracList, bins=20)
    # plt.xlabel("Fraction Folded")
    # plt.ylabel("Count")
    # plt.show()

    # plt.hist(Blist, bins=20)
    # plt.xlabel("Binding Energy")
    # plt.ylabel("Count")
    # plt.show()

    plt.hist(FitList, bins=20)
    plt.xlabel("Fitness (BE * FF)")
    plt.ylabel("Count")
    plt.show()
    
    # cladeList = [prot.clade_idx for prot in prots]
    # plt.hist(cladeList, bins=range(args.clades+1))
    # plt.xlabel("Clade Index")
    # plt.ylabel("Count")
    # plt.show()


def get_fitness_etc(lattice, seq, conf, lig_seq, lig_conf):
    assert conf is not None
    
    E = lattice.nativeE(seq, target=conf)
    S = lattice.stability(seq, target=conf)
    FF = lattice.fracfolded(seq, target=conf)
    (be, xshift, yshift, newconf) = lp.conformations.BindLigand(seq, conf, lig_seq, lig_conf)
    fit = - be * FF

    return (E, S, FF, be, fit)

class Prot():
    global_idx = 0

    def __init__(self, args, clade_idx, parent_idx, seq=None):
        if seq is None: # pick a random sequence
            conf = None
            while conf is None:
                # if args.minFoldFrac is not None: # require a minimum fraction folded
                #     f = 0.0
                #     while f <= args.minFoldFrac:
                #         seq = lp.random_sequence(args.seq_length)
                #         conf = lattice.native_conf(seq)
                #         f = lattice.fracfolded(seq, target=conf)
                # else: # just pick a random sequence, with a native fold
                #     seq = lp.random_sequence(args.seq_length)
                #     conf = lattice.native_conf(seq)

                # folds = False
                # while not folds:
                #     seq = lp.random_sequence(args.seq_length)
                #     (minE, conf, partitionsum, folds) = lattice._nativeE(seq)
                
                minE = 999
                while minE >= -1:
                    seq = lp.random_sequence(args.seq_length)
                    (minE, conf, partitionsum, folds) = lattice._nativeE(seq)
                
        else: # use the provided sequence
            (minE, conf, partitionsum, folds) = lattice._nativeE(seq)
            # if conf is None:
            #     print("Warning: created Prot with no unique native conformation for seq:", seq)
        self.args = args
        self.seq = seq
        self.conf = conf # might be None if no unique native conf
        self.minE = minE
        self.fit = None
        self.clade_idx = clade_idx # index of the original founding ancestor Prot
        self.parent_idx = parent_idx # index of the immediate parent Prot. Founders have parent_idx = None
        
        # give each created Prot a unique index
        self.idx = Prot.global_idx
        Prot.global_idx += 1
        
    def __repr__(self):
        return f"Prot(idx={self.idx}, clade={self.clade_idx}, parent={self.parent_idx}, seq={''.join(self.seq)}, conf={self.conf}, minE={self.minE}, fit={self.fit})"        
        
    def compute_fitness(self, lattice, lig_seq, lig_conf):
        if self.conf is None:
            fit = 0.0
        else:
            FF = lattice.fracfolded(self.seq, target=self.conf)
            if FF == 0.0:
                fit = 0.0
            else:
                (be, xshift, yshift, newconf) = lp.conformations.BindLigand(self.seq, self.conf, lig_seq, lig_conf)
                #fit = max(0.0, - be * FF)  # as in Bloom 2004
                fit = np.exp(-be) # as in Palmer 2013
                #fit = np.exp(max(0.0, - be * FF)) # other
        self.fit = fit
        return fit
    
    def mutate(self):
        newseq = lp.sequences.mutate_sequence(self.seq, self.args.mu) #/self.args.seq_length)
        return Prot(self.args, clade_idx=self.clade_idx, parent_idx=self.idx, seq=newseq)
    
###

args = DotAccessibleDict({'seq_length': 12, 'minFoldFrac': 0.9, 'N': 16384, 'clades': 4, 'mu': 0.0005, 'softFit': 0.1})
print(args)
lattice = lp.thermodynamics.LatticeThermodynamics.from_length(args.seq_length, 1.0)

###

# HACK - use ligand and founding proteins from Palmer paper
if True:
    lig_length = 6
    lig_seq = 'LIVKRS'
    lig_conf = lig_lattice.native_conf(lig_seq)

    seq_length = 12
    founders = []
    founders.append(Prot(args, clade_idx=0, parent_idx=None, seq='FCTFKIINCEWV'))
    founders.append(Prot(args, clade_idx=1, parent_idx=None, seq='MVNLTLFSVTLM'))
    founders.append(Prot(args, clade_idx=2, parent_idx=None, seq='FLELTCLNNPCF'))
    founders.append(Prot(args, clade_idx=3, parent_idx=None, seq='IWPKAHMLSHNY'))

else:
    # random ligans
    lig_length = 6
    lig_lattice = lp.thermodynamics.LatticeThermodynamics.from_length(lig_length, 1.0)
    lig_conf = None
    while lig_conf is None:
        lig_seq = lp.random_sequence(lig_length)
        lig_conf = lig_lattice.native_conf(lig_seq)

    # random founders
    seq_length = 12
    founders = [Prot(args, clade_idx=clade_idx, parent_idx=None) for clade_idx in range(args.clades)]

###

print("ligand:", lig_seq, lig_conf)
print("ligand energy of native conformation: %f" % lig_lattice.nativeE(lig_seq, target=lig_conf))
print("ligand stability of native conformation: %f" % lig_lattice.stability(lig_seq, target=lig_conf))
print("ligand fraction folded: %f" % lig_lattice.fracfolded(lig_seq, target=lig_conf))
#lp.draw.in_notebook([x.lower() for x in lig_seq], lig_conf, color_sequence="r"*lig_length)

print("Founders:")
[prot.compute_fitness(lattice, lig_seq, lig_conf) for prot in founders]
[print(prot) for prot in founders]

prots = []
copies = args.N // args.clades
print(f"Making {copies} copies of each of {args.clades} founders")
for i, founder in enumerate(founders):
    for _ in range(copies):
        prot = Prot(args, clade_idx=founder.clade_idx, parent_idx=None, seq=founder.seq)
        prot.compute_fitness(lattice, lig_seq, lig_conf)
        prots.append(prot)

# def _compute_fit(args):
#     prot, lattice, lig_seq, lig_conf = args
#     return prot.compute_fitness(lattice, lig_seq, lig_conf)

def one_gen(prots, lattice, lig_seq, lig_conf, args):
    # selection
    fits = np.array([prot.compute_fitness(lattice, lig_seq, lig_conf) for prot in prots])
    #print("rawfits:", rawfits)

    # normalize
    min_fit = np.min(fits)
    relfits = fits - min_fit # shift so min is zero
    relfits += args.softFit # add softfit to avoid zero fitnesses
    sum_fit = np.sum(relfits)
    relfits = relfits / sum_fit
    #print("relfits:", relfits)
    
    # reproduction
    offspring = binom.rvs(n=round(args.N), p=relfits)
    #print("offspring:", offspring)
    new_prots = []
    for i, prot in enumerate(prots):
        for _ in range(offspring[i]):
            new_prots.append(prot.mutate())

    return new_prots, fits

avgFitList = []
cladeCounts = {}
maxGen = 50
for i in range(maxGen):
    # cladeCounts will hold a dict of {clade_idx: [count]} where count is a list of counts over generations
    # record average fitness for this generation
    # update cladeCounts: dict of {clade_idx: [counts over generations]}
    cladeList = [prot.clade_idx for prot in prots]
    #print("cladeList:", cladeList)
    for cl in range(args.clades):
        cnt = cladeList.count(cl)
        cladeCounts.setdefault(cl, []).append(cnt)
        
    # if i%1==0:
    #     print(f"\n=== Generation {i} ===")
    #     plot_prots(args, prots)
    prots, fits = one_gen(prots, lattice, lig_seq, lig_conf, args)
    avgFit = np.mean(fits)
    avgFitList.append(avgFit)
    
print(f"\n=== Generation {i} ===")
#plot_prots(args, prots)

args.logscale = True

# plot counts of each clade over generations
#print(cladeCounts)
plt.figure(figsize=(8,4))
gens = list(range(len(next(iter(cladeCounts.values())))))

if args.logscale:
    plt.xscale('log')
    #plt.yscale('log')
    gens = gens[1:] # skip generation 0 to avoid log(0)

colors = "rgbm"
for cl, counts in sorted(cladeCounts.items()):
    if args.logscale:
        counts = counts[1:] # skip generation 0 to avoid log(0)
    plt.plot(gens, counts, marker='o', ms=1, label=f"clade {cl}", color=colors[cl])
    
plt.xlabel("Generation")
plt.ylabel("Count")
plt.title("Clade counts over generations")
# show only 10 xticks evenly spaced across the whole x axis
if not args.logscale:
    ticks = np.linspace(gens[0], gens[-1]+1, 11, dtype=int)
    plt.xticks(ticks)
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left')
plt.tight_layout()
plt.show()


# plot avgFitList
plt.figure(figsize=(6,4))
plt.plot(range(maxGen), avgFitList, marker='o', ms=3, lw=1)
if args.logscale:
    plt.xscale('log')
plt.xlabel("Generation")
plt.ylabel("Average fitness")
plt.title("Average fitness over generations")
plt.grid(alpha=0.3)
plt.show()

# print("Final:")
# [prot.compute_fitness(lattice, lig_seq, lig_conf) for prot in prots]
# [print(prot) for prot in prots]
None