# Clustering TF Binding Data Rationales

In [1]:
import os
import numpy as np
from collections import defaultdict

from sklearn.cluster import DBSCAN

In [6]:
DATASET = 'wgEncodeAwgTfbsSydhK562MaffIggrabUniPk'

DATASET_PATH = os.path.join('../rationale_results/motif/motif_occupancy', DATASET)
DIST_MATRIX_FILENAME = 'rationales_greedy_dists.txt.gz'
RATIONALES_FILENAME = 'rationales_greedy.txt'

KNOWN_MOTIFS_MAP_PATH = '../data/motif/known_motifs/map'
KNOWN_MOTIFS_MEME_PATH = '../data/motif/known_motifs/ENCODEmotif'

In [7]:
def load_motifs_map(path):
    motifs_map = {}
    with open(path, 'r') as f:
        for line in f:
            motif, loc = line.strip().split()
            motifs_map[motif] = loc
    return motifs_map

def parse_meme(filepath, replace_zeros_eps=None):
    res = []
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            data = line.split('\t')
            if len(data) != 4:  # line does not contain floats, not part of matrix
                continue
            vals = [float(x.strip()) for x in data]
            res.append(vals)
    res = np.array(res, dtype='float32')
    if replace_zeros_eps is not None:
        res[res == 0] = replace_zeros_eps
        res = res / np.linalg.norm(res, axis=1, ord=1, keepdims=True)
    return res

In [8]:
motifs_map = load_motifs_map(KNOWN_MOTIFS_MAP_PATH)
meme_path = os.path.join(KNOWN_MOTIFS_MEME_PATH, motifs_map[DATASET])
motif = parse_meme(meme_path, replace_zeros_eps=1e-6)

## Load distance matrix and rationales

In [10]:
def load_distance_matrix(filepath):
    dists = np.loadtxt(filepath, dtype=int)
    return dists

def load_rationales(filepath):
    rationales = []
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if line == '':
                continue
            # line consists of "i [rationale]"
            rationale = line.split(' ')[-1]
            rationales.append(rationale)
    return rationales

In [11]:
dist_matrix_filepath = os.path.join(DATASET_PATH, DIST_MATRIX_FILENAME)
dists = load_distance_matrix(dist_matrix_filepath)

rationales_filepath = os.path.join(DATASET_PATH, RATIONALES_FILENAME)
rationales = load_rationales(rationales_filepath)

assert(len(rationales) == dists.shape[0] == dists.shape[1])

print('Loaded %d rationales and distance matrix.' % len(rationales))

Loaded 2110 rationales and distance matrix.


## DBSCAN Clustering

In [12]:
print('Distance matrix median: ', np.median(dists))

Distance matrix median:  10.0


In [13]:
# Compute DBSCAN
db = DBSCAN(eps=2.0, min_samples=46, metric='precomputed').fit(dists)
core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
core_samples_mask[db.core_sample_indices_] = True
labels = db.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)

print('Estimated number of clusters: %d' % n_clusters_)

print(labels)

Estimated number of clusters: 2
[-1 -1 -1 ... -1  0 -1]


In [14]:
# Dump the clusters: highlight core sample rationales in each cluster, followed by random examples in cluser

core_sample_idxs = np.arange(len(core_samples_mask))[core_samples_mask]
cluster_label_to_core_sample_idxs = defaultdict(list)
for i in core_sample_idxs:
    label = labels[i]
    cluster_label_to_core_sample_idxs[label].append(i)

cluster_label_to_cluster = defaultdict(list)
for i in range(dists.shape[0]):
    label = labels[i]
    cluster_label_to_cluster[label].append(i)

cluster_label_to_rationale_freqs = defaultdict(lambda: defaultdict(int))
for i in range(dists.shape[0]):
    label = labels[i]
    rationale = rationales[i].replace('N', '-').strip('-')
    cluster_label_to_rationale_freqs[label][rationale] += 1

In [15]:
max_num_to_print_per_cluster = 15

for label, freq_dict in sorted(cluster_label_to_rationale_freqs.items(), key=lambda kv: kv[0]):
    if label == -1:
        print('--NOISE--')
    else:
        print('--Cluster %d--' % label)
    num_printed = 0
    for rationale, freq in sorted(freq_dict.items(), reverse=True, key=lambda kv: (kv[1], kv[0])):
        print('%s\t\t%d' % (rationale, freq))
        num_printed += 1
        if num_printed == max_num_to_print_per_cluster:
            break
    print()

--NOISE--
G-TGACTCAGCA--T		14
AAA-TGC----TCAGCA--A		9
AAA-TGC---GTCAGC		8
AAA-TGC----TCAGCA--T		8
TGCTGACTCA-C---T		7
TGCTGA-TCAGC---T		7
A--TGCTGA----GCA-ATT		7
GCTGACTCAGC---T		6
AAT-TGC---GTCAGC		6
A--TGC--AGTCATC		6
A--TG-TGAGTCAGC		6
TT-TGC---GTCAGC--TT		5
TGCTGACTCA-CA--A		5
TGACTCAGCA-AA		5
TGCTGA-TCA-CA-AA		4

--Cluster 0--
GCTGAGTCAT		197
ATGACTCAGC		185
GCTGAGTCA-C		83
GCTGAGTCAC		53
GCTGACTCAGCA		42
TGCTGA-TCAT		32
TGCTGAGTCA		30
TGCTGA-TCAGCA		25
GTGACTCAGCA		15
G-TGAGTCATC		15
TG-TGAGTCAT		13
TGCTG-GTCAT		12
CGCTGAGTCA		11
ATGAC-CAGCA-A		11
ATGA-TCAGCA-A		11

--Cluster 1--
TGCTGA----GCA-TTT		12
GCTGAC---GCA-TTT		8
TGCTGAC---GCA-TT		6
TGCTGAC---GCA-AA		5
TGCTGAC---GCA-AT		4
GCTGAC---GCA-TAT		3
GCTGAC---GCA-ATT		3
TGCTGAC---GCA-TA		2
TGCTGAC---GCA--ATT		1
GCTGAC---GCA-AT		1
CGCTGAC---GCA-AT		1

