This notebook contains the code for doing a balanced 50/50 train/test splitting of a dataset using clustering.
Specifically the procedure for making these clusters is:
1. Split sequences into partitions based on the AHo positions they have gaps.
2. Run clustering on the sequences inside each partition. DBSCAN clustering is used with parameters that makes it equal to single linkage clustering with the minimum distance between clusters indicated by the `eps` values in the filenames of the fasta files in this folder.
3. Iterating through each partition the 50/50 split is made by sorting the clusters in the partition according to size. Iterating through clusters in a pairwise fashion e.g. (1, 2), (3, 4), ... (n-1, n), the larger of the pair is down-sampled to the size of the smaller and these are then assigned a random split and written to either the train or test split fasta file.

In [None]:
# Generic imports:
from __future__ import print_function
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
#import pandas as pd
import math, random, re
import time
import pickle
from Bio import SeqIO
import multiprocessing
import sys, os

In [3]:
# Amino acid alphabet:
AA_ORDER = 'ACDEFGHIKLMNPQRSTVWY-'
AA_LIST = list(AA_ORDER)
AA_DICT = {c:i for i, c in enumerate(AA_LIST)}
AA_DICT_REV = {i:c for i, c in enumerate(AA_LIST)}
AA_SET = set(AA_LIST)

In [4]:
try:
    import jellyfish
    def hamming_distance(s1, s2):
        if s1 == s2:
            return 0
        else:
            return jellyfish.hamming_distance(s1, s2)
except:
    def hamming_distance(seq1, seq2):
        '''Hamming distance between two sequences of equal length'''
        return sum(x != y for x, y in zip(seq1, seq2))

In [5]:
def calculate_pairwise_dists(x):
    N = 0
    dists = dict()
    for i, si in enumerate(x):
        for j, sj in enumerate(x[i:], start=i):
            d = jellyfish.hamming_distance(si, sj)
            if d <= DIST_CUT:
                ij = tuple(sorted([i, j]))
                dists[ij] = d
                N += 1
                assert(N <= MAX_DISTS)
    return(dists)

In [6]:
MAX_SEQS = 100000000  # Read no more than this number of sequences
MAX_LEN = 149         # Length of sequences
DIST_CUT = 15         # Only keep pairwise distance equal to or smaller than this
MAX_DISTS = int(1e9)  # Abort if more than this number of distances are found for a partition
MIN_SEQS_PER_PARTITION = 100

seq_gap_dict_filename = 'seq_gap_dict.p'
dist_dict_filename = 'dist_dict.p'

In [7]:
# Read the sequences and partition into buckets according to their gap profile:
if not os.path.isfile(seq_gap_dict_filename):
    # Read in some sequences:
    fnam = '../BCR_data/spurf_heavy_chain_AHo.fasta'
    seq_list = list()
    for i, record in enumerate(SeqIO.parse(fnam, 'fasta')):
        if i >= MAX_SEQS:
            break
        seq = str(record.seq)
        if len(seq) > MAX_LEN:
            continue
        else:
            seq += '-' * (MAX_LEN - len(seq))
            seq_list.append(seq)

    print('Input data has {} sequences.'.format(len(seq_list)))
    print('This is how a typical sequence looks:\n{}'.format(seq_list[0]))

    # Partition the sequences by gap positions:
    seq_gap_dict = dict()
    for seq in seq_list:
        gap_key = tuple(i for i, nt in enumerate(seq) if nt == '-')
        if gap_key not in seq_gap_dict:
            seq_gap_dict[gap_key] = [seq]
        else:
            seq_gap_dict[gap_key].append(seq)
    print('Found {} partitions.'.format(len(seq_gap_dict)))

    # Sort the gap keys according to number of sequences in the partition:
    gap_key_sorted = [ks[1] for ks in sorted([(len(l[1]), l[0]) for l in seq_gap_dict.items()], reverse=True, key=lambda x: x[0])]
    print('The largest partition has {} sequence.'.format(len(seq_gap_dict[gap_key_sorted[0]])))

    # Dump the results:
    with open('seq_gap_dict.p', 'wb') as fho:
        pickle.dump(seq_gap_dict, fho)
else:
    with open('seq_gap_dict.p', 'rb') as fh:
        seq_gap_dict = pickle.load(fh)
    gap_key_sorted = [ks[1] for ks in sorted([(len(l[1]), l[0]) for l in seq_gap_dict.items()], reverse=True, key=lambda x: x[0])]
    print('The largest partition has {} sequence.'.format(len(seq_gap_dict[gap_key_sorted[0]])))

Input data has 1602878 sequences.
This is how a typical sequence looks:
EVQLVES-GGGLVQPGGSLRLSCAASG-FPSNS-----YWMTWVRQAPGKGLEWVANINED---GSERYYVDSVKGRFTISRDNAKNSQYLQMNSLRAEDTAVYYCTRDVWF---------------------GFFDIWGQGTTVIVSS
Found 268 partitions.
The largest partition has 126369 sequence.


In [8]:
# Calculate all pairwise distances within a bucket
# and keep those that a smaller than or equal to DIST_CUT:
if not os.path.isfile(dist_dict_filename):  # This takes several hours and ~140 gb memory
    # Calculate pairwise distances:
    dist_dict = dict()
    for ks in gap_key_sorted:  # Start from the largest to the smallest
        dist_dict[ks] = calculate_pairwise_dists(seq_gap_dict[ks])
    # Dump results:
    with open('dist_dict.p', 'wb') as fho:
        pickle.dump(dist_dict, fho)
else:
    with open('dist_dict.p', 'rb') as fh:
        dist_dict = pickle.load(fh)

In [9]:
# Prune seq_gap_dict and dist_dict to only contain partitions
# with more than MIN_SEQS_PER_PARTITION sequences:
if MIN_SEQS_PER_PARTITION > 0:
    for ks in list(seq_gap_dict.keys()):
        if len(seq_gap_dict[ks]) < MIN_SEQS_PER_PARTITION:
            del seq_gap_dict[ks]
            del dist_dict[ks]

In [12]:
# DBSCAN clustering code:
def find_neighbors(dist_dict, seq_gap_dict, eps):
    neighbor_list_dict = dict()
    for ks in dist_dict:
        neighbor_list_dict[ks] = [set() for i in range(len(seq_gap_dict[ks]))]
        for ij in dist_dict[ks]:
            #print(ij)
            #sys.exit()
            if dist_dict[ks][ij] <= eps:
                neighbor_list_dict[ks][ij[0]].add(ij[1])
                neighbor_list_dict[ks][ij[1]].add(ij[0])
    return(neighbor_list_dict)

def expand_cluster(neighbor_list, minPts, core_neighbor, clusters, visited):
    '''
    Fully expand the newly created cluster based on the neighbor points
    of the first added core point.
    '''
    core_list = list(core_neighbor)
    # Keep looping until there are no more neighbors to the core points:
    while core_list:
        # Extract a neighbor to a core point:
        point_num = core_list.pop()
        # Skip points that are already part of a cluster:
        if visited[point_num] == 1:
            continue
        else:
            visited[point_num] = 1

        # Make a table lookup to find points
        # within 1 'eps' distance of the given point:
        neighborPts = neighbor_list[point_num]
        if len(neighborPts) >= minPts:  # Core point
            # Find new points that should be added to the search: 
            new_points = list(neighborPts - core_neighbor.intersection(neighborPts))
            # Update the core neighbor set to avoid searching
            # the same points multiple times:
            core_neighbor.update(new_points)
            # Update the list looping over:
            core_list.extend(new_points)
            # Adding the core point:
            clusters[-1][0].add(point_num)            
        else:  # Border point:
            clusters[-1][1].add(point_num)
    return(clusters)

def dbscan(neighbor_list, minPts):
    '''
    Run DBSCAN on a list of pre-computed neighbors defined by an eps distance.
    '''
    # One tuple per cluster containing two sets,
    # one for core points and one for border points: 
    clusters = list()
    npoints = len(neighbor_list)  # Total number of points
    visited = np.zeros(npoints, dtype=np.int8)  # 0/1 switch to determine if a point has been visited
    for point_num in range(npoints):
        # Skip points already visited:
        if visited[point_num] == 1:
            continue
        else:
            visited[point_num] = 1

        # Make a table lookup to find points
        # within 1 'eps' distance of the given point:
        neighborPts = neighbor_list[point_num]

        # Check if the point has enough neighbors to start a new cluster:
        if len(neighborPts) >= minPts:  # This is a core point
            clusters.append(({point_num}, set()))  # Start of a new cluster
            # Now fully expand this newly created cluster:
            clusters = expand_cluster(neighbor_list, minPts, neighborPts, clusters, visited)

    return(clusters)

In [66]:
# Making the 50/50 train/test splits and down-sample to balance them:
def train_test_split(clusters_eps):
    train = dict()
    test = dict()
    for ks in clusters_eps:
        train_ks = list()
        test_ks = list()
        N_clusters = len(clusters_eps[ks])
        # There needs to be equal number of clusters in both train and test sets:
        if N_clusters % 2 != 0:
            N_clusters -= 1
        assert(N_clusters % 2 == 0)
        cl_i = 0
        train_first = True
        while cl_i < N_clusters:
            if train_first:
                train_first = False
                train_clust = clusters_eps[ks][cl_i]
                cl_i += 1
                test_clust = clusters_eps[ks][cl_i]
                cl_i += 1
                # Downsample the train cluster to equal the test cluster:
                train_clust = random.sample(train_clust, len(test_clust))
            else:
                train_first = True
                test_clust = clusters_eps[ks][cl_i]
                cl_i += 1
                train_clust = clusters_eps[ks][cl_i]
                cl_i += 1
                # Downsample the test cluster to equal the train cluster:
                test_clust = random.sample(test_clust, len(train_clust))
            train_ks.extend(list(train_clust))
            test_ks.extend(list(test_clust))
        train[ks] = train_ks
        test[ks] = test_ks
    return(train, test)

In [67]:
# Writing the train/test splits:
def dump_train_test_data(train, test, eps, seq_gap_dict, filename_prefix='data_split'):
    train_printed = 0
    test_printed = 0
    fh_train = open('clusters/{}_trainset_eps{}.fasta'.format(filename_prefix, eps), 'w')
    fh_test = open('clusters/{}_testset_eps{}.fasta'.format(filename_prefix, eps), 'w')
    for ks in train:
        train_idxs = set(train[ks])
        test_idxs = set(test[ks])
        assert(len(train_idxs.intersection(test_idxs)) == 0)
        for idx, seq in enumerate(seq_gap_dict[ks]):
            if idx in train_idxs:
                ks_str = '-'.join(map(str, ks))
                print('>{}_{}\n{}'.format(ks_str, idx, seq), file=fh_train)
                train_printed += 1
            elif idx in test_idxs:
                ks_str = '-'.join(map(str, ks))
                print('>{}_{}\n{}'.format(ks_str, idx, seq), file=fh_test)
                test_printed += 1
    fh_train.close()
    fh_test.close()
    assert(sum([len(train[ks]) for ks in train]) == train_printed)
    assert(sum([len(test[ks]) for ks in test]) == test_printed)        

In [68]:
# Make splits with different minimum distance (eps) between clusters:
eps_list = list(range(1, DIST_CUT+1))
minPts = 1
for eps in eps_list:
    clusters_eps = dict()
    neighbor_key_list = find_neighbors(dist_dict, seq_gap_dict, eps)
    for ks in neighbor_key_list:
        neighbor_list = neighbor_key_list[ks]
        clusters_eps[ks] = dbscan(neighbor_list, minPts)
        # Sort according to cluster size,
        # and discard the "edge points":
        clusters_eps[ks] = [t[1] for t in sorted([(len(s[0]), s[0]) for s in clusters_eps[ks]], key=lambda x:x[0], reverse=True)]
    # Split in train and test:
    train, test = train_test_split(clusters_eps)
    # Write the train/test data to files:
    dump_train_test_data(train, test, eps, seq_gap_dict)