In [None]:
%matplotlib inline

import random
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from umi_tools import UMIClusterer


Define a function that simulates sequence amplification. Start with a list of random sequences. Each sequence gets an "intrinsic" bias for amplification. Within each group of descendent sequences, each sequence has random probability of getting duplicated at each cycle, which is biased by the intrinsic bias for amplification. At each duplication, there is a constant mutation rate. 

In [None]:
def mutated_sequence(seq, mut_rate):
  mut_seq = []
  for base in seq:
    if random.random() <= mut_rate:
      mut_seq.append(random.choice(['A', 'C', 'G', 'T']))
    else:
      mut_seq.append(base)
  return "".join(mut_seq)

def amplify_library(size, length, cycles, bias_mean, bias_sigma, dup_mean, dup_sigma, mut_rate):
  starting_seqs = ["".join(random.choices(['A', 'C', 'G', 'T'], k=length)) for i in range(size)]
  biases = np.random.normal(bias_mean, bias_sigma, size)
  amplified_seqs = [[seq] for seq in starting_seqs]

  for i in range(cycles):
    prev_seqs = amplified_seqs
    amplified_seqs = []
    for seq_group, bias in zip(prev_seqs, biases):
      amplified_group = []
      for seq in seq_group:
        amplified_group.append(seq)
        p = np.random.normal(dup_mean, dup_sigma) + bias
        if random.random() <= p:
          amplified_group.append(mutated_sequence(seq, mut_rate))
      amplified_seqs.append(amplified_group)

  amplified_counts = {seq: Counter(seq_group) for seq, seq_group in zip(starting_seqs, amplified_seqs)}
  
  return amplified_counts



Test output of function for different parameters.

In [None]:
amplified = amplify_library(size=30000, length=10, cycles=8, bias_mean=0, bias_sigma=0.05, dup_mean=0.5, dup_sigma=0.05, mut_rate=0.005)
unique_counts = [amplified[seq].total() for seq in amplified]
plt.hist(unique_counts, bins=range(max(unique_counts)))
plt.show()

Make sure `amplify_library` is outputting something that makes sense.

In [None]:
for i, seq in enumerate(amplified):
  if i < 10:
    print(f'{seq}: {amplified[seq]}')

Test converting the dictionary of `Counter`s to a single counter object with all observed sequences (summing all counter objects took forever for some reason, so turn into a list and then create a `Counter` from this list instead).

In [None]:
all_seqs = [seq for seq_group in amplified.values() for seq in seq_group.elements()]
all_seqs_counts = Counter(all_seqs) 
for i, seq in enumerate(amplified):
  if i < 10:
    print(f'{seq}: {amplified[seq][seq] :>3} {all_seqs_counts[seq] :>3}')

Define consensus function for comparing consensus with frequency.

In [None]:
# Ties are broken by the base that was present in the first sequence, which should
# be the most frequent sequence if using the output of UMIClusterer.
def get_consensus(seqs):
  consensus = []
  for bases in zip(*seqs):
    base_counter = Counter(bases)
    consensus.append(base_counter.most_common(1)[0][0])
  return "".join(consensus)

get_consensus(['ACGT', 'ACTG', 'TGCA'])

Use umi_tools to cluster the sequences and see if the output matches the expected values for different clustering parameters.

In [None]:
# Takes a dict of Counters (amplify_library output) as input, and returns the same data structure as output.
# The key for each counter in the output is either the sequence with highest frequency in the cluster (consensus=False)
# or the consensus sequence for the cluster (consensus=True)
def cluster_seqs(seqs, method, threshold, consensus=False):
  all_seqs = [seq.encode() for seq_group in seqs.values() for seq in seq_group.elements()]
  all_seqs_counts = Counter(all_seqs) 
  clusterer = UMIClusterer(cluster_method=method)
  clustered = clusterer(all_seqs_counts, threshold=threshold)
  
  cluster_counts = {}

  for cluster in clustered:
    cluster_seqs = []
    for seq in cluster:
      cluster_seqs.extend([seq.decode()] * all_seqs_counts[seq])
    cluster_counter = Counter(cluster_seqs)

    if consensus:
      cluster_counts[get_consensus(cluster_seqs)] = cluster_counter
    else:
      cluster_counts[cluster[0].decode()] = cluster_counter

  return cluster_counts

Test `cluster_seqs` on simulated data. Try with low mutation rate and long sequences first to show that `cluster_seqs` is working properly.

In [None]:
amplified = amplify_library(size=3000, length=20, cycles=8, bias_mean=0, bias_sigma=0.05, dup_mean=0.5, dup_sigma=0.05, mut_rate=0.001)
clustered = cluster_seqs(amplified, 'directional', 3)

for i, seq in enumerate(amplified):
  if i < 10:
    print(f'Original : {amplified[seq]}')
    try:
      print(f'Clustered: {clustered[seq]}')
    except KeyError:
      print(f'{seq} not in clustered seqs')


Try again with shorter sequences and higher mutation rate.

In [None]:
amplified = amplify_library(size=3000, length=12, cycles=8, bias_mean=0, bias_sigma=0.05, dup_mean=0.5, dup_sigma=0.05, mut_rate=0.01)
clustered = cluster_seqs(amplified, 'directional', 3)

for i, seq in enumerate(amplified):
  if i < 10:
    print(f'Original : {amplified[seq]}')
    try:
      print(f'Clustered: {clustered[seq]}')
    except KeyError:
      print(f'{seq} not in clustered seqs')