In [None]:
# Make sequence logos. Similar to WebLogo, although Weblogo web interface limits number of peptides.
# Code adapted from Zheng Dai (Dai et al Bioinformatics, https://github.com/zheng-dai/MHC2-optimization/tree/main/ValidationDataAnalysis/LogoGeneration)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
#functions to make sequence logos

global_aa = list("ACDEFGHIKLMNPQRSTVWY")
#colors matching WebLogo defaults:
polars = set(['G', 'S', 'T', 'Y', 'C'])
polars_sub = set(['Q', 'N'])
basics = set(['K', 'R', 'H'])
acids = set(['D', 'E'])
#The rest are hydrophobic and black

def readimg(fname, col):
    arr = plt.imread(fname)
    newarr = np.zeros(shape=(len(arr), len(arr[0]), 4), dtype = "uint8")
    uparr = np.zeros(shape=(len(arr), len(arr[0]), 4), dtype = "uint8")
    downarr = np.zeros(shape=(len(arr), len(arr[0]), 4), dtype = "uint8")
    for x in range(len(arr)):
        for y in range(len(arr[0])):
            newarr[x][y][3] = 255 - arr[x][y][0] * 255
            newarr[x][y][0] = col[0]
            newarr[x][y][1] = col[1]
            newarr[x][y][2] = col[2]
    return newarr

imdic = {}
updic = {}
downdic = {}
for i in range(26):
    c = chr(i + ord('A'))
    col = (0,0,0)
    if c in polars:
        col = (0,192,0)
    elif c in basics:
        col = (24,32,255)
    elif c in acids:
        col = (255,0,0)
    elif c in polars_sub:
        col = (204,0,204)
        
    imdic[c] = readimg("{}_crop.png".format(c), col)
    #directory with images of letters in it.
    #from: https://github.com/zheng-dai/MHC2-optimization/blob/main/ValidationDataAnalysis/LogoGeneration/T_crop.png
    
def drawColumn(ax, seqs, hgts, shift, wide):
    i = 0
    for i,x in enumerate(range(1,10)):
        l = x-wide/2 + shift
        r = x+wide/2 + shift
        b = 0
        for c,p in seqs[i][::-1]:
            hgt = hgts[i] * p
            t = b + hgt
            if hgt == 0:
                b = t
                continue
                
            mgn = hgt * 0
            ax.imshow(imdic[c], aspect = "auto", extent = (l,r,b+mgn,t-mgn), interpolation = 'bilinear')
            b = t

    ax.set_xticks(range(1,10))
    ax.set(xlim = (0,16), ylim = (0,max(hgts)*1.1))

    
def drawLogo(seqlist, title_val):
    fig, ax = plt.subplots(1,1,figsize=(10, 6))
    w = 0.99
    ws = 0.28
    h = 0
    for i in range(-1,0):
        fractions, heights = seqlist[i+1]
        drawColumn(ax, fractions, heights, ws*(i+1), w)
        h = max(h, max(heights))
    ax.set(ylim = (0, 1.1*h), xlim = (0.5, 9.5))
    
    xtc = []
    xlab = []
    for i in range(1,10):
        for j in range(0,1):
            pos = i + j*ws
            lab = "{}".format(i)
            xtc.append(pos)
            xlab.append(lab)
    ax.set_xticks(xtc)
    ax.set_xticklabels(xlab,fontsize=34,fontname='Arial')#, rotation=270)
    
    ax.set_ylabel("Bits",fontsize=34,fontname='Arial')
    ax.tick_params(axis='y', which='major', labelsize=34)
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    
    plt.title(str(title_val))


def viz_fn(fname, weighting_kw):
    data_sub = pd.read_csv(fname)

    frac_list = []    
    height_list = []

    for j in range(10): #iterate over positions
        temp = []
        height = np.log2(20)
        for c in global_aa: 
            
            if weighting_kw==False:
                #not weighting by count
                frac = np.sum(data_sub['sequence'].str[j]==c)/len(data_sub)
            
            if weighting_kw==True:
                #multiply by count
                frac = np.sum((data_sub['sequence'].str[j]==c)*(data_sub['count']))/sum(data_sub['count'])
            
            temp.append((c,frac))
            if frac != 0:
                height += frac*np.log2(frac)
        height_list.append(height)
        frac_list.append(temp) 
        
    #order letters so most abundant on top, like in WebLogo (reorder the entries in frac_list)
    for i in range(0,10):
        frac_list[i].sort(key = lambda x: x[1],reverse=True)        

    drawLogo([(frac_list, height_list)],fname)
        
    return None

In [None]:
viz_fn('3312A3_cdhit-corrected_data.csv',True)