In [1]:
from google.colab import drive
drive.mount('/content/drive')
!pip3 install torch
!pip3 install transformers

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
from transformers import BertTokenizer, BertModel, BertForTokenClassification, BertConfig
from transformers import TransfoXLTokenizer, TransfoXLModel, TransfoXLConfig
from keras.preprocessing.sequence import pad_sequences
import sys
import numpy as np
import pandas as pd
import itertools
from operator import itemgetter 

Using TensorFlow backend.


In [0]:
def read_data(filepath):
  print("reading ",filepath)
  genes = []
  labels = []
  df = pd.read_csv(filepath,usecols=[1,2],sep="\t",header=None,skiprows=1)
  print(df.head())
  for entry in df.itertuples():
    kmer_list = [kmer.strip("\'") for kmer in entry[1][1:-1].split(", ")]
    label_list = list(map(float, entry[2][1:-1].split(", ")))
    genes.append(kmer_list)
    labels.append(label_list)
  return genes, labels
	

def tokenize_samples(genes):
  k= len(genes[0][0])
  if k==2:
    kmer_filepath = '/content/drive/My Drive/fourmersXL.txt'
  elif k==6:
    kmer_filepath = '/content/drive/My Drive/hexamersXL.txt'

  tokenizer=TransfoXLTokenizer(vocab_file=kmer_filepath)
  print("TOKENIZER LENGTH", len(tokenizer))
  seq_ids = [tokenizer.convert_tokens_to_ids(gene) for gene in genes]
  return seq_ids

In [4]:
genes,labels = read_data('/content/drive/My Drive/all_samples_6-mer_test.txt')
gene_ids = tokenize_samples(genes)
print(len(gene_ids))
print("Finished making data")

reading  /content/drive/My Drive/all_samples_6-mer_test.txt
                                                   1                                                  2
0  ['GGGAGG', 'GGGGCA', 'CGGGCT', 'ATAAAC', 'GCTC...  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
1  ['TCTGCC', 'CGGCTC', 'CCCAGC', 'GCCCCC', 'GGGC...  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
2  ['GGGGCC', 'GCCGCT', 'CTGGCC', 'CGCGTG', 'GGGC...  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
3  ['CGACGC', 'CGACAA', 'CTTTGC', 'GATGGA', 'GTTT...  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
4  ['GGGCCG', 'CTCTTG', 'CCCGGC', 'GTGGCG', 'ACTC...  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
TOKENIZER LENGTH 4098
3027
Finished making data


In [0]:
from transformers import TransfoXLTokenizer, TransfoXLModel, TransfoXLConfig
class Model(torch.nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.config = TransfoXLConfig(vocab_size_or_config_json_file='/content/drive/My Drive/classification_XL_configuration.json')
    self.config.vocab_size=204098
    self.config.output_attentions=True
    self.model = TransfoXLModel(self.config)
    self.out_layer = torch.nn.Linear(self.model.d_model, 2)
  def forward(self, input_ids, mems=None):
    hidden_states, mems, atts = self.model(input_ids, mems)
    #preds = self.out_layer(hidden_states[0]).squeeze(0)
    first_layer_att = atts[0].squeeze(0).detach().cpu().numpy()
    second_layer_att = atts[1].squeeze(0).detach().cpu().numpy()
    return mems, first_layer_att, second_layer_att

device = torch.device('cuda')
#building model
model = Model()
model.to(device)
f="/content/drive/My Drive/xl_classification_6mer_sd_ftMKII.pt"
model.load_state_dict(torch.load(f))
window_size=1012
mem_length = 1012

In [0]:
#only returns the indices in the seq (have to be converted to actual hexamers afterwards)
def most_attended_pos_per_head_in_window(atts):
  indices_per_head=[]
  for i in range(6):
    #only those that arent attending to mems
    head_atts = atts[i][:,mem_length:]
    #print(np.shape(head_atts))
    atts_summed_per_col = np.sum(head_atts,axis=1)
    highest_indices = np.argsort(atts_summed_per_col)[-20:]
    indices_per_head.append(highest_indices)
  return indices_per_head

#this one actually gets hexamers
def most_attended_hexamers_per_head_in_window(atts,gene):
  hexamers_per_head=[]
  for i in range(6):
    #only those that arent attending to mems
    head_atts = atts[i][:,mem_length:]
    #print(np.shape(head_atts))
    atts_summed_per_col = np.sum(head_atts,axis=1)
    highest_indices = np.argsort(atts_summed_per_col)[-20:]
    attended_hexamers = []
    for attended_index in highest_indices:
      attended_hexamers.append(gene[attended_index])
    hexamers_per_head.append(attended_hexamers)
  return hexamers_per_head
    

def make_dict_with_all_hexamers():
  d={}
  for hexamer in itertools.product(['A','T','C','G'], repeat=6):
    d[''.join(hexamer)]=0
  return d

def add_hexamers_close_to_ss(labels,genes,dic):
  one_indices = np.where(np.array(labels)==1)[0]
  for i in one_indices:
    dic[genes[i-1]]+=1
    dic[genes[i]]+=1
    dic[genes[i+1]]+=1


#want to both compute how often each hexamer is paid attention to as well as how often each hexamer is at or adjacent to an actual ss. Then we can correlate these

def get_most_attended_hexamers(genes,tok_genes,labels):
  model.eval()
  num_samples=len(genes)
  first_layer_hexamer_freqs_per_head=[make_dict_with_all_hexamers() for i in range(6)]
  second_layer_hexamer_freqs_per_head=[make_dict_with_all_hexamers() for i in range(6)]
  hexamers_near_ss=make_dict_with_all_hexamers()
  for i in range(num_samples):
    if(i%100==0):
      print(i)
      #print(first_layer_hexamer_freqs_per_head)
      #print(second_layer_hexamer_freqs_per_head)
      #print(hexamers_near_ss)
    mems = None
    curr_tok_gene = tok_genes[i]
    curr_gene = genes[i]
    curr_labels = labels[i]
    add_hexamers_close_to_ss(curr_labels,curr_gene,hexamers_near_ss)
    #looping over all windows
    for w in range(0, len(curr_gene), window_size):
      toks = curr_tok_gene[w:w+window_size]
      #ignore tiny windows(leads to dimensionality issues)
      if(len(toks)<2):
        continue
      window_input_ids = torch.tensor(toks).unsqueeze(0).cuda()
      window_labels = curr_labels[w:w+window_size]
      window_hexamers = curr_gene[w:w+window_size]
      with torch.no_grad():
        mems, first_layer_att, second_layer_att = model(window_input_ids, mems)
      first_layer_hexamers = most_attended_hexamers_per_head_in_window(first_layer_att,window_hexamers)
      second_layer_hexamers = most_attended_hexamers_per_head_in_window(second_layer_att,window_hexamers)
      for i in range(6):
        curr_head1 = first_layer_hexamer_freqs_per_head[i]
        curr_hexamers1 = first_layer_hexamers[i]
        for h in curr_hexamers1:
          curr_head1[h]+=1
        curr_head2 = second_layer_hexamer_freqs_per_head[i]
        curr_hexamers2 = second_layer_hexamers[i]
        for h in curr_hexamers2:
          curr_head2[h]+=1
  return first_layer_hexamer_freqs_per_head, second_layer_hexamer_freqs_per_head, hexamers_near_ss


In [8]:

first_layer, second_layer, hexamers_near_ss=get_most_attended_hexamers(genes[:1800],gene_ids[:1800],labels[:1800])

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700


In [9]:
first_layer_sum=0
second_layer_sum=0
nearby_sum=0
for el in first_layer[0]:
  first_layer_sum+=first_layer[0][el]
  second_layer_sum+=second_layer[0][el]
  nearby_sum+=hexamers_near_ss[el]
print(first_layer_sum)
print(second_layer_sum)
print(nearby_sum)

topn_first = dict(sorted(first_layer[0].items(), key = itemgetter(1), reverse = True)[:20])
topn_second = dict(sorted(second_layer[5].items(), key = itemgetter(1), reverse = True)[:20])
topn_labels = dict(sorted(hexamers_near_ss.items(), key = itemgetter(1), reverse = True)[:20])
print("Top N hexamers (first layer) " + str(topn_first))
print("Top N hexamers(second layer) " + str(topn_second))
print("Most significant irl " + str(topn_labels))


157458
157458
147369
Top N hexamers (first layer) {'TTGAGA': 2060, 'CTCAAA': 1999, 'TGTGTT': 1937, 'TCTCAA': 1841, 'GGTGGA': 1762, 'TTCTCT': 1574, 'TCTCCC': 1445, 'CATGAG': 1259, 'CTTCCT': 1235, 'CTGCCC': 1171, 'TTCCCC': 1153, 'CTCCAC': 1071, 'ACGGGG': 920, 'CCGGCC': 898, 'GAGGGA': 859, 'TTGTCA': 794, 'CCCCTT': 789, 'GTTTGG': 786, 'CCCCGC': 782, 'CTCTGC': 767}
Top N hexamers(second layer) {'CCCAGG': 427, 'AAAAAA': 375, 'TACAGG': 369, 'CAGGTG': 319, 'GGTGGG': 313, 'TTTTTT': 291, 'GCCAGG': 284, 'GGCAGG': 250, 'AGGAGG': 249, 'CTGAGG': 247, 'GGCTGG': 232, 'CCCCAG': 226, 'CTCAGG': 224, 'CCAGGC': 221, 'TCCAGG': 218, 'CCTGGG': 212, 'CAGGCT': 211, 'GCTGGG': 211, 'GGGTGG': 209, 'GGTGTG': 208}
Most significant irl {'GGTGAG': 921, 'CAGGTG': 610, 'AGGTGA': 532, 'GTGAGT': 530, 'CCAGGT': 406, 'CCCAGG': 395, 'CCCCAG': 356, 'GGTAAG': 349, 'CAGGTA': 343, 'AAAAAA': 316, 'GTGAGG': 301, 'CTGCAG': 297, 'GCAGGT': 295, 'AGGTAA': 273, 'TGAGTG': 267, 'CCACAG': 261, 'TCCAGG': 260, 'TGCAGG': 257, 'GGTGGG': 257, 

In [168]:
#Correlation Visualization
import matplotlib.pyplot as plt
import seaborn as sns
near_label_values=list(hexamers_near_ss.values())
first_layer_values=list(first_layer[5].values())
#second_layer_values=list(second_layer[0].values())
print(np.corrcoef(near_label_values,first_layer_values))
f=sns.scatterplot(near_label_values,first_layer_values)
f.set_title("Attention Weights vs Proximity to Splice Sites")
f.set_xlabel("Counts of Hexamer Adjacent to SS")
f.set_ylabel("Counts of Hexamer Given High Importance")
plt.show()

TypeError: ignored

In [41]:


normalized_label_counts = {k: round(v / 147369,4) for k, v in hexamers_near_ss.items()}
top5_labels= dict(sorted(normalized_label_counts.items(), key = itemgetter(1), reverse = True)[:5])
print("Most common hexamers at/adjacent to splice site " + str(top5_labels))
for i in range(6):
  num_AGGs = sum([v for k, v in first_layer[i].items() if 'AGGT'in k])
  normalized_layer1 = {k: round(v / 157458,4) for k, v in first_layer[i].items()}
  normalized_layer2 = {k: round(v / 157458,4) for k, v in second_layer[i].items()}
  print("Head ",i)
  topn_first = dict(sorted(normalized_layer1.items(), key = itemgetter(1), reverse = True)[:5])
  topn_second = dict(sorted(normalized_layer2.items(), key = itemgetter(1), reverse = True)[:5])
  print("Most attended to hexamers (first layer)    \t" + str(topn_first))
  print("Most attended to hexamers (second layer)   \t" + str(topn_second))
  print(num_AGGs/157458)

Most common hexamers at/adjacent to splice site {'GGTGAG': 0.0062, 'CAGGTG': 0.0041, 'AGGTGA': 0.0036, 'GTGAGT': 0.0036, 'CCAGGT': 0.0028}
Head  0
Most attended to hexamers (first layer)    	{'TTGAGA': 0.0131, 'CTCAAA': 0.0127, 'TGTGTT': 0.0123, 'TCTCAA': 0.0117, 'GGTGGA': 0.0112}
Most attended to hexamers (second layer)   	{'TTTTTT': 0.0079, 'AAAAAA': 0.0053, 'ATTTTT': 0.0023, 'CCTCCC': 0.0018, 'TTTTTG': 0.0016}
0.0031436954616469153
Head  1
Most attended to hexamers (first layer)    	{'TGGGAG': 0.0127, 'TGTAAT': 0.0108, 'GGAGGG': 0.0108, 'CCTGTA': 0.0106, 'CACACA': 0.0095}
Most attended to hexamers (second layer)   	{'TTTTTT': 0.0077, 'AAAAAA': 0.0032, 'ATTTTT': 0.0022, 'CCCAGG': 0.002, 'AGGCTG': 0.0016}
0.002680079767303027
Head  2
Most attended to hexamers (first layer)    	{'TTTTTT': 0.1187, 'CTGCCT': 0.0295, 'TTTGTT': 0.019, 'AATACA': 0.0141, 'TTTCAC': 0.0137}
Most attended to hexamers (second layer)   	{'AAAAAA': 0.0068, 'TTTTTT': 0.0049, 'CAGGAG': 0.0022, 'CAGCCT': 0.0019, 'CCC

In [54]:
!pip install biopython

Collecting biopython
[?25l  Downloading https://files.pythonhosted.org/packages/96/01/7e5858a1e54bd0bd0d179cd74654740f07e86fb921a43dd20fb8beabe69d/biopython-1.75-cp36-cp36m-manylinux1_x86_64.whl (2.3MB)
[K     |▏                               | 10kB 24.5MB/s eta 0:00:01[K     |▎                               | 20kB 30.4MB/s eta 0:00:01[K     |▍                               | 30kB 35.7MB/s eta 0:00:01[K     |▋                               | 40kB 39.6MB/s eta 0:00:01[K     |▊                               | 51kB 40.8MB/s eta 0:00:01[K     |▉                               | 61kB 42.7MB/s eta 0:00:01[K     |█                               | 71kB 34.2MB/s eta 0:00:01[K     |█▏                              | 81kB 35.6MB/s eta 0:00:01[K     |█▎                              | 92kB 37.1MB/s eta 0:00:01[K     |█▍                              | 102kB 35.4MB/s eta 0:00:01[K     |█▋                              | 112kB 35.4MB/s eta 0:00:01[K     |█▊                       

In [15]:
#PWM analysis
from Bio import motifs
from Bio.Seq import Seq

layer_of_interest = hexamers_near_ss
seq_instances = []
for el in layer_of_interest:
  seq_instances.extend([Seq(el) for i in range(layer_of_interest[el]//40)])
print(len(seq_instances))
m = motifs.create(seq_instances)
print(m.pwm)
print(m.consensus)
m.weblogo("/content/drive/My Drive/testmotif.png")

1883
        0      1      2      3      4      5
A:   0.18   0.18   0.19   0.20   0.20   0.20
C:   0.29   0.28   0.27   0.27   0.27   0.26
G:   0.31   0.33   0.33   0.34   0.34   0.37
T:   0.22   0.21   0.20   0.19   0.19   0.17

GGGGGG


Third Head of First Layer convincingly looks for all Ts

Fourth/Fifth Look for all Gs