In [None]:
#Register inference for MHC-II-binding peptides
## adapted from Huisman et al eLife 2022

In [None]:
import numpy as np
import random
from scipy.special import softmax
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
#Take a list of (sequences, counts) and convert them into one-hot encodings
def encodeSequences(seqs):
    dic = {}
    v0 = np.zeros(20)
    for c, v in zip("ACDEFGHIKLMNPQRSTVWY", np.eye(20)):
        dic[c] = v
    encSeqs = []
    for seq, ct in seqs:
        # Pad entries to allow shifting to work
        encSeqs.append(( np.array([dic[c] for c in seq] + [v0]*(n_padded-(13-n_clusters+1))), ct )) 
    return encSeqs

# Updates all p1 positions
def iteration(model, seqs, beta, background):
    #p1selection = [-1,0,1,2,3,4]
    p1selection = list(range(-1,n_clusters)) #include trash cluster
    for seqinfo in seqs:
        seq, n, shift = seqinfo
        if shift != -1:
            model -= seq[shift:shift+n_padded] * n
        sums = np.sum(model, axis = 1).reshape((-1,1))
        pwm = np.log(model/sums) - background
        p = [0]
        for i in range(n_clusters):
            p.append( n * beta * np.sum( pwm * seq[i:i+n_padded] ) )
        shift = random.choices( p1selection, weights = softmax(p), k = 1 )[0]
        seqinfo[2] = shift
        if shift != -1:
            model += seq[shift:shift+n_padded] * n

# Updates all p1 positions with beta = inf (0 temperature)
def finalIteration(model, seqs, background):
    p1selection = list(range(-1,n_clusters))
    for seqinfo in seqs:
        seq, n, shift = seqinfo
        if shift != -1:
            model -= seq[shift:shift+n_padded] * n
        sums = np.sum(model, axis = 1).reshape((-1,1))
        pwm = np.log(model/sums) - background
        p = [0]
        for i in range(n_clusters):
            p.append( n * np.sum( pwm * seq[i:i+n_padded] ) )
            
        shift = p1selection[np.argmax(p)]
        seqinfo[2] = shift
        if shift != -1:
            model += seq[shift:shift+n_padded] * n
            
def reportScore(model, seqs, background):
    sums = np.sum(model, axis = 1).reshape((-1,1))
    pwm = np.log(model/sums) - background
    
    score = 0
    for seq, n, shift in seqs:
        if shift != -1:
            score += np.sum( pwm * seq[shift:shift+n_padded] ) * n
    return score
    
# seqs : list of (sequence, count) tuples. Sequences must be length 13
#
# iterations : number of iterations to carry out (not including the final iteration)
#
# betamin, betamax : the first iteration will be done with beta=betamin, and one before the final iteration
# with beta=betamax. Beta is linearly interpolated linearly in between
# Beta is the inverse of temprature, so this simulates cooling.
#
def alignSequences(seqs, iterations, betamin, betamax):
    model = np.ones((n_padded,20))
    result = []
    assignments = np.random.randint(-1,n_clusters,len(seqs))
    for (seq, n), shift in zip(encodeSequences(seqs), assignments):
        if shift != -1:
            model += seq[shift:shift+n_padded] * n
        result.append([seq, n, shift])
    background = np.log(1/20)
    
    betas = np.linspace(betamin,betamax,iterations)
    for i,beta in enumerate(betas):
        iteration(model, result, beta, background)        
        print ("Iteration {}, beta = {}, score = {}".format(i, beta, reportScore(model, result, background)))
    
    finalIteration(model, result, background)
    print ("Final iteration, score = {}".format(reportScore(model, result, background)))
    
    return result

def readInput(filename):
    seqs = []
    with open(filename, 'rt') as fin:
        for i, line in enumerate(fin):
            if i == 0: continue
            seq, ct = line.split(',')
            seqs.append((seq, int(ct)))
    return seqs


def writeOutput(filename, alignment, sequences):
    lines = ["sequence,count,P1"]
    for (_,ct,p1), (seq,_) in zip(alignment, sequences):
        lines.append("{},{},{}".format(seq,ct,p1))
    with open(filename, 'wt') as fout:
        fout.write('\n'.join(lines))

In [None]:
#set parameters
n_padded = 9 #size of kernel (dimension in model)
n_clusters = 8 #number of different clusters that can assign to

In [None]:
# HLA-DR15 Data
save_prefix = 'align-v3_NGS6_DR15'
sequences = readInput('R3rep-DR15_cdhit-corrected_data-noStop.csv')
alignment = alignSequences(sequences, 20, 0.05, 1)
fname = "peptidedisplay_{}_clusters.csv".format(save_prefix)
writeOutput(fname, alignment, sequences)
