# code for sampling from evcouplings model

- simulated annealing code written by Nathan Rollins.
- modified to suit toxin antitoxin by David Ding.

### Inputs:
- evCouplings model

### Outputs:
- sampled sequences at defined antitoxin positions

In [1]:
from evcouplings.couplings import CouplingsModel
from evcouplings.couplings.model import _single_mutant_hamiltonians
import numpy as np
from copy import deepcopy
import pandas as pd
import matplotlib.pyplot as plt
import viz_seqs
%matplotlib inline

dout = './samples/'
fin_evc_model = "./parED_e1_3_m80_f80.model"

In [2]:
def hamming(seq1,seq2):
    assert len(seq1) == len(seq2)
    h = 0
    for aa1,aa2 in zip(seq1,seq2):
        if aa1!=aa2:
            h+=1
    return h

class EVHannealer():
    '''optimize a focus sequence (uppercase only) according to statistical enery'''
    def __init__(self, model_file):
        self.model = CouplingsModel(model_file)
        self.L = len(self.model.index_list)
        self.AA = len(self.model.alphabet)
        
    def encode(self,seq):
        '''convert sequence to integer vector'''
        seq_vector = self.model.convert_sequences([seq])[0]
        return seq_vector
    
    def decode(self,seq_vector):
        '''convert integer vector to sequence'''
        seq = ''.join([self.model.alphabet[i] for i in seq_vector])
        return seq
    
    def random_seq_vector(self):
        seq_vector = np.array([np.random.choice(np.arange(len(self.model.alphabet)))
                               for i in range(len(self.model.target_seq))])
        return seq_vector
    
    def residuelist_to_mask(self,position_list,amino_acid_list):
        '''for listed positions, mask=1 for listed amino acid'''
        mask = np.zeros((self.L,self.AA))
        assert len(position_list) == len(amino_acid_list)
        for i,aa in zip(position_list,amino_acid_list):
            mask[self.model.index_map[i],self.model.alphabet_map[aa]] = 1.0
        return mask.astype(bool)
        
        
    def residuelist_to_negative_sitemask(self,position_list,amino_acid_list):
        '''for listed positions, mask=1 for amino acids mismatching listed amino acid'''
        mask = np.zeros((self.L,self.AA))
        assert len(position_list) == len(amino_acid_list)
        for i,aa in zip(position_list,amino_acid_list):
            mask[self.model.index_map[i],:] = 1.0
            mask[self.model.index_map[i],self.model.alphabet_map[aa]] = 0.0
        return mask.astype(bool)
    
    def get_hamming_matrix_fxn(self, target_focus_sequence, min_hamming, max_hamming, weight=10):
        '''create a fxn that favors match if hamming > max,
            penalizes match if hamming < min'''
        target_seq_vector = self.encode(target_focus_sequence)
        target_seq_mask = self.residuelist_to_mask(self.model.index_list,target_focus_sequence)
        penalty_well = lambda h: -weight*(h<=min_hamming) + weight*(h>=max_hamming)
        x = np.arange(min_hamming-5,max_hamming+5)
        plt.plot(x, np.vectorize(penalty_well)(x))
        plt.gcf().set_size_inches(3,3)
        
        mask_fxn = lambda seq_vector: (
            penalty_well(hamming(target_seq_vector, seq_vector))
            * target_seq_mask
            )
        return mask_fxn
        
        
    def mutate_seq(self, seq, residue_list, amino_acid_list):
        mseq = list(seq)
        for i,aa in zip(residue_list,amino_acid_list):
            mseq[self.model.index_map[i]] = aa
        return mseq
    
    def mutate_seq_vector(self, seq_vector, residue_list, amino_acid_list):
        mseq = deepcopy(seq_vector)
        for i,aa in zip(residue_list,amino_acid_list):
            mseq[self.model.index_map[i]] = self.model.alphabet_map[aa]
        return mseq
        
        
    def step(self,seq_vector,T,force_mutation=False,avoid_mask=None, penalty_matrix=None):
        '''introduce a mutation with boltzmann probability
        higher temperature = closer to uniform distribution'''
        EVH = _single_mutant_hamiltonians(seq_vector, self.model.J_ij, self.model.h_i)[:,:,0]
        E = deepcopy(EVH)
        if penalty_matrix is not None:
            E += penalty_matrix(seq_vector)
        
        P = np.exp(E/T)
        if avoid_mask is not None:
            P[avoid_mask(seq_vector)] = 0.0
    
        P = P/np.sum(P)
            
        P_flat = P.ravel()
        indices = np.arange(len(P_flat))
        mutation_index = np.random.choice(indices, p=P_flat)
        mutation_i, mutation_aa = np.unravel_index(mutation_index, P.shape)
        mut_vector = deepcopy(seq_vector)
        mut_vector[mutation_i] = mutation_aa
        mutation_EVH = EVH[mutation_i, mutation_aa]
        mutation_E = E[mutation_i, mutation_aa]
        previous_aa = seq_vector[mutation_i]

        return mut_vector, mutation_i, mutation_aa, previous_aa, mutation_EVH, mutation_E
        

def make_annealed_samples(t_correction, distance_penalty):
    # Start with a random sequence
    start_seq_vector = annealer.random_seq_vector()
    sample_seq_vector = deepcopy(start_seq_vector)
    # Set any fixed positions to fixed AA values
    sample_seq_vector = annealer.mutate_seq_vector(
        sample_seq_vector,fixed_pos,fixed_wt_aa
    )

    # Set temperature cycle for annealing
    L = int(len(start_seq_vector))
    T_cycle = [1.0*t_correction]*L + [0.5*t_correction]*L + [0.2*t_correction]*L

    # Record progress during annealing
    annealing_report = []
    for n,T in enumerate(T_cycle):
        sample_seq_vector, mut_i, mut_aa, prev_aa, mut_EVH, mut_E = annealer.step(
            sample_seq_vector,
            T,
            penalty_matrix = distance_penalty,
            avoid_mask = lambda x: fixed_position_mask
        )

        annealing_report.append([sample_seq_vector, mut_i, mut_aa, prev_aa, mut_EVH, mut_E])

    # Make progress report human-readable
    annealing_report = pd.DataFrame(annealing_report, columns=['seq_vector','n_i','n_aa_mut','n_aa_prev','delta_EVH','delta_E'])
    annealing_report.loc[:,'i'] = annealing_report.n_i.apply(lambda n: annealer.model.index_list[n])
    annealing_report.loc[:,'aa_prev'] = annealing_report.n_aa_prev.apply(lambda n: annealer.model.alphabet[n])
    annealing_report.loc[:,'aa_mut'] = annealing_report.n_aa_mut.apply(lambda n: annealer.model.alphabet[n])

    # Calculate EVH scores (annealer returns delta EVHs)
    target_seq = annealer.model.target_seq
    scores = annealer.model.hamiltonians(np.array(list(annealing_report.seq_vector)))
    annealing_report.loc[:,'E']  = scores[:,0]
    annealing_report.loc[:,'Eh'] = scores[:,1]
    annealing_report.loc[:,'Ej'] = scores[:,2]
    annealing_report.loc[:,'E-Ewt'] = annealing_report['E'] - annealer.model.hamiltonians([target_seq])[0,0]

    # Compare sequences to EVH model target sequence
    annealing_report.loc[:,'seq'] = annealing_report.seq_vector.apply(annealer.decode)
    annealing_report.loc[:,'Nmuts'] = annealing_report.seq.apply(lambda x: sum(a!=b for a,b in zip(target_seq,x)))


    wt_seq = ''.join(c.seq())

    def get_muts(mut_seq):
        mut_str = ':'.join([v_wt +str(pos-offset)+ v_mut for pos,v_wt, v_mut in zip(c.index_list, c.seq(), mut_seq) if v_wt != v_mut])
        return mut_str

    annealing_report['muts'] = annealing_report.apply(lambda r: get_muts(r.seq), axis=1)


    wt_muts = 'L48L:D52D:I53I:R55R:L56L:F74F:R78R:E80E:A81A:R82R'

    def make_full_mut(wt_muts, mut_muts):
        # expecting something like mut_muts = 'D52E:I53V:F74M:R78K'
        # and wt_muts = 'L48L:D52D:I53I:R55R:L56L:F74F:R78R:E80E:A81A:R82R'
        # and make the full wt muts
        dic_wt_muts = dict(
            zip([v[:-1] for v in wt_muts.split(':')],
                wt_muts.split(':')
               )
                 )
        dic_mut_muts = dict(
            zip([v[:-1] for v in mut_muts.split(':')],
                mut_muts.split(':')
               )
                 )

        full_muts = []
        for k,v in dic_wt_muts.items():
            if k in dic_mut_muts:
                full_muts.append(dic_mut_muts[k])
            else:
                full_muts.append(v)

        return ':'.join(full_muts)

    annealing_report['full_mut'] = annealing_report.apply(lambda r: make_full_mut(wt_muts, r.muts), axis=1)


    # Preview the table
    #annealing_report.iloc[[0,1,2,-3,-2,-1]]
    return annealing_report


In [3]:
# load EVH annealer and couplingsmodel
annealer = EVHannealer(fin_evc_model)

# load parameters from file to create a pairwise model
c = CouplingsModel(fin_evc_model)

In [11]:
# inspect which positions can be sampled in evcouplings

fixed_pos = list(c.index_list)
fixed_wt_aa = list(c.seq())

offset = 103

wt_muts = 'L48L:D52D:I53I:R55R:L56L:F74F:R78R:E80E:A81A:R82R'


idxs_to_remove = [] # make a list of indices to be removed from the fixed position list of uppercase indices and wt_aa to be fixed.
for m in wt_muts.split(':'):
    wt_aa = m[0]
    mut_pos = int(m[1:-1])
    mut_aa = m[-1]
    
    complex_mut_pos = offset + mut_pos
    print(m, complex_mut_pos)
    if complex_mut_pos in c.index_list:
        assert wt_aa == c.seq(complex_mut_pos) 
        
        to_remove_idx = list(c.index_list).index(complex_mut_pos)
        idxs_to_remove.append(to_remove_idx)
        
        print(to_remove_idx)
        
    else:
        print(m, 'not uppercase in couplingsmodel')

fixed_pos = [i for j, i in enumerate(fixed_pos) if j not in idxs_to_remove]
fixed_wt_aa = [i for j, i in enumerate(fixed_wt_aa) if j not in idxs_to_remove]

assert len(fixed_pos) == len(fixed_wt_aa)

print(len(fixed_pos))

# fix all non-mutated positions apart from AT L48, which is lowercase (not enough column coverage)
fixed_position_mask = annealer.residuelist_to_negative_sitemask(
    fixed_pos,fixed_wt_aa
)


L48L 151
L48L not uppercase in couplingsmodel
D52D 155
129
I53I 156
130
R55R 158
132
L56L 159
133
F74F 177
151
R78R 181
155
E80E 183
157
A81A 184
158
R82R 185
159
151


In [16]:
# sample at vaying temperatures
for t in [1e-2, 1e-1, 1, 2,3,5,10]:
    df_anneal_t = make_annealed_samples(t, distance_penalty=None)
    df_anneal_t.to_csv(dout+ 't_screen/' + 'df_anneal_{}.csv'.format(t))



