In [1]:
# Run all the cells up to the Experiments Heading to load the algorithm and sample data
# The subsequent cell generates the designs that are needed to run the experiments, which will
# be dumped into the "outputs" folder.

In [2]:
import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import gzip

from copy import deepcopy
from matplotlib import pyplot as plt
import Levenshtein as lev

import torch
from torch.nn import parallel

<h3>Import Data</h3>

In [3]:
with open("./data.pkl", 'rb') as fin:
    mhc1_data, mhc2_data = pickle.load(fin)

In [4]:
# Script for processing data

def makeProductDistribution(population, dataset):
    _,alleles,_ = dataset
    intmap = {}
    for i,allele in enumerate(alleles):
        intmap[allele] = i
    default = len(alleles)
    population = [(frozenset([intmap.get(x[1], default) for x in genotype]),p) for genotype, p in population]
    diploidDistribution = {}
    for i,(haplotype1,p1) in tqdm(enumerate(population), position = 0, leave = True):
        if haplotype1 not in diploidDistribution:
            diploidDistribution[haplotype1] = 0
        diploidDistribution[haplotype1] += p1 ** 2
        for (haplotype2,p2) in population[:i]:
            diplotype = haplotype1.union(haplotype2)
            if diplotype not in diploidDistribution:
                diploidDistribution[diplotype] = 0
            diploidDistribution[diplotype] += 2 * p1 * p2
    return diploidDistribution
    
def makeDiploidDistribution(populations, dataset):
    default = len(dataset[1])
    distributions = [makeProductDistribution(population, dataset) for population in populations]
    for distribution in distributions[1:]:
        for genotype in distribution:
            if genotype not in distributions[0]:
                distributions[0][genotype] = distribution[genotype]
            else:
                distributions[0][genotype] += distribution[genotype]
    distribution = distributions[0]
    indexes = []
    weights = []
    for genotype in distribution:
        weights.append( distribution[genotype]/3 )
        genotype = tuple(genotype)
        if len(genotype) != 6:
            genotype = genotype + ( (default,) * (6-len(genotype)) )
        indexes.append(genotype)
    indexes = np.array(indexes, dtype = np.int16)
    weights = np.array(weights)
    return indexes, weights

def trimDataset(data, filtered):
    data, alleles, peptides = data
    filtered = set(filtered)
    indexes = []
    peptides2 = []
    for i, peptide in enumerate(peptides):
        if peptide in filtered:
            indexes.append(i)
            peptides2.append(peptide)
    data = data[indexes]
    return data, alleles, peptides2
    
def reformatData(data):
    credence_data, binary_data, peptides, population = data
    
    credence_data = trimDataset(credence_data, peptides)
    credence_distribution = makeDiploidDistribution(population, credence_data)
    credence_data = torch.tensor(1-credence_data[0], dtype = torch.float32), credence_data[2],\
        credence_distribution[0], torch.tensor(credence_distribution[1], dtype = torch.float32)
    
    binary_distribution = makeDiploidDistribution(population, binary_data)
    binary_data = torch.tensor(1-binary_data[0], dtype = torch.float32), binary_data[2],\
        binary_distribution[0], torch.tensor(binary_distribution[1], dtype = torch.float32)
    
    print (credence_data[0].size(), binary_data[0].size())
    return credence_data, binary_data

inputs1_credences, inputs1_binarized = reformatData(mhc1_data)
inputs2_credences, inputs2_binarized = reformatData(mhc2_data)

779it [00:00, 1287.24it/s]
1200it [00:01, 651.71it/s]
440it [00:00, 3434.99it/s]
779it [00:00, 1260.70it/s]
1200it [00:02, 599.11it/s]
440it [00:00, 787.40it/s]
537it [00:00, 2761.30it/s]

torch.Size([1043, 298]) torch.Size([1100, 234])



920it [00:01, 799.18it/s] 
502it [00:00, 3279.87it/s]
537it [00:00, 1616.17it/s]
920it [00:01, 891.37it/s] 
502it [00:00, 1323.38it/s]


torch.Size([3934, 281]) torch.Size([4195, 281])


<h3>Greedy Algorithm</h3>

In [5]:
# Utility functions

# f(x) = min(n, x)
def getThresholdUtility(n):
    return torch.tensor(np.arange(0,n+1,1),dtype = torch.float32)

def getMarginalImprovement(utility):
    return torch.cat( (utility[1:] - utility[:-1], torch.zeros(1)) )

In [6]:
#candidates: [candidate, 1 - pMHC hit probability]
#columnIndex: [diplotype, allele in diplotype]
#columnWeights: [diplotype]
#distributions: [dummy, diplotype, distribution]
#marginalImprovement: [improvement (shifting from i to i+1, so last entry should be 0)]

def evaluateCandidates(candidates, columnIndex, columnWeights, distributions, marginalImprovement):
    probabilityOfHit = 1 - torch.prod(candidates[:, columnIndex], dim = 2).unsqueeze(2)
    shiftedMass = distributions * probabilityOfHit
    improvement = torch.sum( shiftedMass * marginalImprovement, dim = 2)
    weightedImprovement = torch.sum( improvement * columnWeights, dim = 1)
    return weightedImprovement

def updateDistribution(newRow, columnIndex, distributions):
    probabilityOfMiss = torch.prod(newRow[columnIndex], dim = 1).reshape(1, -1, 1)
    shiftedMass = distributions * (1-probabilityOfMiss)
    
    convolution = distributions * probabilityOfMiss
    convolution[:,:,1:] += shiftedMass[:,:,:-1]
    convolution[:,:,-1] += shiftedMass[:,:,-1]
    return convolution

def evaluateDesign(candidates, seqs, columnIndex, columnWeights, design, utility):
    distributions = torch.zeros( (1, len(columnIndex), len(utility)) )
    distributions[:, :, 0] = 1
    
    seqToIndex = {}
    for i, seq in enumerate(seqs):
        seqToIndex[seq] = i
        
    for seq in design:
        row = seqToIndex[seq]
        distributions = updateDistribution(candidates[row], columnIndex, distributions)
        
    scores = torch.sum( distributions * utility.view(1,1,-1), dim = 2 ).reshape(-1)
    return torch.sum(scores * columnWeights).numpy()

class evaluateCandidatesModule(torch.nn.Module):
    def __init__(self, columnIndex, columnWeights, marginalImprovement, device):
        super(evaluateCandidatesModule, self).__init__()
        self.device = device
        self.columnIndex = columnIndex
        self.columnWeights = columnWeights.cuda(self.device)
        self.marginalImprovement = marginalImprovement.cuda(self.device)
        
    def updateDistributions(self, distributions):
        self.distributions = distributions.cuda(self.device)
        
    def forward(self, candidates):
        probabilityOfHit = 1 - torch.prod(candidates[:, self.columnIndex], dim = 2).unsqueeze(2)
        shiftedMass = self.distributions * probabilityOfHit
        improvement = torch.sum( shiftedMass * self.marginalImprovement, dim = 2)
        weightedImprovement = torch.sum( improvement * self.columnWeights, dim = 1)
        return weightedImprovement.cpu()

In [7]:
def greedySelectionMulticore(candidates,
                             seqs,
                             columnIndex,
                             columnWeights,
                             designSize,
                             marginalImprovement,
                             threshold,
                             batchSize,
                             devices):
    
    # Set up modules on different devices
    modules = [evaluateCandidatesModule(columnIndex, columnWeights, marginalImprovement, device)
               for device in devices]
    
    # Distribute the computation between devices
    numRows = candidates.shape[0]
    numVertical = (numRows//(len(devices) * batchSize))
    sliceSize = (numRows//(len(devices) * numVertical)) + 1
    slices = []
    z = 0
    for _ in range(numVertical):
        singleSlice = []
        if z*sliceSize >= numRows: break
        for device in devices:
            if z == numVertical * len(devices) - 1:
                singleSlice.append( candidates[z*sliceSize:].cuda(device) )
            else:
                singleSlice.append( candidates[z*sliceSize:(z+1)*sliceSize].cuda(device) )
            z += 1
        slices.append(singleSlice)
    
    # Initialize selected set and score
    selectedSet = []
    score = 0
    selectable = np.ones(numRows)
    
    # Initialize coverage distributions
    distributions = torch.zeros( (1, len(columnIndex), len(marginalImprovement)) )
    distributions[:, :, 0] = 1
    
    numberOfSlices = len(slices)
    pbarDescription = "Sequence added: None, Objective: 0.00000, Delta: 0.00000, Iteration: {}/{}".format(
        "{}", numberOfSlices)
    with tqdm(range(designSize), position = 0, leave = True) as pbar:
        for _ in pbar:
            # Update distributions in modules
            for module in modules:
                module.updateDistributions(distributions)

            # Compute marginal utilities
            allImprovements = []
            # We need to batch the following vector operations due to space limitations
            for sliceIndex, singleSlice in enumerate(slices):
                improvements = parallel.parallel_apply(modules, singleSlice)
                allImprovements.append(torch.cat(improvements))
                pbar.set_description(pbarDescription.format(sliceIndex+1))
            allImprovements = torch.cat(allImprovements).numpy()

            # Argmax
            selection = np.argmax(allImprovements * selectable)

            # Add best sequence
            selectedSeq = seqs[selection]
            selectedSet.append(selectedSeq)

            # Update score
            delta = allImprovements[selection]
            score += delta

            pbarDescription = "Sequence added: {}, Objective: {:.5f}, Delta: {:.5f}, Iteration: {}/{}".format(
                selectedSeq, score, delta, "{}", numberOfSlices)

            # Update distributions for next round
            distributions = updateDistribution(candidates[selection], columnIndex, distributions)

            # Remove invalid candidates from consideration
            for i, seq in enumerate(seqs):
                if lev.distance(seq, selectedSeq) <= threshold:
                    selectable[i] = 0
    
    torch.cuda.empty_cache()
    return selectedSet

In [8]:
# Number of GPUs available
list(range(torch.cuda.device_count()))

[0, 1, 2, 3, 4, 5, 6, 7]

<h3>Experiments</h3>

In [52]:
# The arguments to greedySelectionMulticore are as follows:
# 1. A tuple consisting of credences for individual allele binding, a list of peptides, a list of diplotypes
#   where each diplotype is given as a list of alleles, and a set of weights for each diplotype
#     We provide "inputs#_credences" and "inputs#_binarized" as possible inputs
#     which correspond to the credences we derived (see Section 3.2) and credences from Liu et al. respectively.
# 2. The number of peptides in the vaccine design
# 3. The marginal improvement of the utility as an array. The ith entry of the array should
#   contain the marginal improvement for going from i to i+1
# 4. The Levenshtein distance threshold. Peptides that are within this threshold of the peptides
#   that have already been selected will not be considered for inclusion
# 5. The batch size. This is required because of memory limitations on the GPU. The larger this value the better.
#   It seems values between 10-50 work relatively well
# 6. The list of devices to use. list(range(torch.cuda.device_count()))) should enumerate all available devices

In [None]:
# Generate designs using both credences that we derived and using 0-1 binarized credences that
# match with those used in Liu et al.
# Designs will be dumped in ./outputs/

for threshold in range(3, 21):
    # MHC Class 1, using credences
    x = greedySelectionMulticore(*inputs1_credences,
               151,
               getMarginalImprovement( getThresholdUtility(threshold) ),
               3,
               20,
               list(range(torch.cuda.device_count())))
    with open("./outputs/mhc1_credences_threshold_{}.pkl".format(threshold), 'wb') as fout:
        pickle.dump(x, fout)
    
    # MHC Class 2, using credences
    x = greedySelectionMulticore(*inputs2_credences,
               151,
               getMarginalImprovement( getThresholdUtility(threshold) ),
               5,
               20,
               list(range(torch.cuda.device_count())))
    with open("./outputs/mhc2_credences_threshold_{}.pkl".format(threshold), 'wb') as fout:
        pickle.dump(x, fout)
        
    # MHC Class 1, using binarized values
    x = greedySelectionMulticore(*inputs1_binarized,
               151,
               getMarginalImprovement( getThresholdUtility(threshold) ),
               3,
               20,
               list(range(torch.cuda.device_count())))
    with open("./outputs/mhc1_binarized_threshold_{}.pkl".format(threshold), 'wb') as fout:
        pickle.dump(x, fout)
    
    # MHC Class 2, using binarized values
    x = greedySelectionMulticore(*inputs2_binarized,
               151,
               getMarginalImprovement( getThresholdUtility(threshold) ),
               5,
               20,
               list(range(torch.cuda.device_count())))
    with open("./outputs/mhc2_binarized_threshold_{}.pkl".format(threshold), 'wb') as fout:
        pickle.dump(x, fout)