# Setup

In [37]:
import os
import json
import random

import numpy as np
import pandas as pd

from IPython.display import display

In [26]:
SEQS_AND_ANNOTATIONS_CHUNKS_DIR = '/cs/phd/nadavb/cafa_project/data/seqs_and_annotations'

N_SPECIAL_TOKENS = 2
PAD_TOKEN = 0
MASK_TOKEN = 1

PROT_LETTERS = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')

MAX_SEQ_LEN = 1000
N_ANNOTATIONS = 5000

# Set relevant annotations

In [29]:
annotation_counts = pd.read_csv('/cs/phd/nadavb/cafa_project/data/unique_annotations_counts.csv', \
        index_col = 0, squeeze = True, names = ['count'])
display(annotation_counts)

10019    14077131
4205      4645778
2706      3605141
4396      2874126
24685     2413000
           ...   
44088           1
1871            1
11143           1
43873           1
11461           1
Name: count, Length: 26032, dtype: int64

In [34]:
assert len(annotation_counts) >= N_ANNOTATIONS
used_annotation_counts = annotation_counts.iloc[:N_ANNOTATIONS]
print('The rarest used annotation appears %d times.' % used_annotation_counts.min())

unique_annotations = list(used_annotation_counts.index)

The rarest used annotation appears 477 times.


# Encode functions

In [88]:
aa_to_token = {aa: i + N_SPECIAL_TOKENS for i, aa in enumerate(PROT_LETTERS)}
annotation_to_index = {annotation: i for i, annotation in enumerate(unique_annotations)}
unique_annotations_set = set(unique_annotations)

def encode_seqs(batch_seqs, max_len = MAX_SEQ_LEN):
    
    encoded_seqs = np.full((len(batch_seqs), max_len), PAD_TOKEN, dtype = np.int8)
    
    for i, seq in enumerate(batch_seqs):
        assert len(seq) <= max_len
        encoded_seqs[i, :len(seq)] = [aa_to_token[aa] for aa in seq]
    
    return encoded_seqs

def encode_annotations(batch_annotations):
    
    encoded_annotations = np.zeros((len(batch_annotations), len(unique_annotations)), dtype = np.int8)
    
    for i, annotations in enumerate(encoded_annotations):
        for annotation in annotations:
            if annotation in annotation_to_index:
                encoded_annotations[i, annotation_to_index[annotation]] = 1
                
    return encoded_annotations

def generate_batches(batch_size = 32, min_encoded_annotations_per_seq = 2, max_consecutive_samples_per_file = 10000):
    
    chunk_file_names = os.listdir(SEQS_AND_ANNOTATIONS_CHUNKS_DIR)
    current_chunk = None
    
    while True:
        
        print('Starting an iteration on the entire dataset.')
        random.shuffle(chunk_file_names)
        
        for chunk_file_name in chunk_file_names:
            
            print('Loading %s.' % chunk_file_name)
            new_chunk = pd.read_csv(os.path.join(SEQS_AND_ANNOTATIONS_CHUNKS_DIR, chunk_file_name))
            new_chunk['complete_go_annotation_indices'] = new_chunk['complete_go_annotation_indices'].apply(json.loads)
            
            seq_len_mask = (new_chunk['seq'].str.len() <= MAX_SEQ_LEN)
            print('Filtering out %d of %d sequences, due to too long (>%d) length.' % ((~seq_len_mask).sum(), \
                    len(new_chunk), MAX_SEQ_LEN))
            new_chunk = new_chunk[seq_len_mask]
            
            sufficient_annotations_mask = new_chunk['complete_go_annotation_indices'].apply(lambda annotations: \
                    len(set(annotations) & unique_annotations_set) >= min_encoded_annotations_per_seq)
            print('Filtering out %d of %d sequences, due to insufficient (at least %d) encoded annotations.' % \
                    ((~sufficient_annotations_mask).sum(), len(new_chunk), min_encoded_annotations_per_seq))
            new_chunk = new_chunk[sufficient_annotations_mask]
            
            sample_size = min(max_consecutive_samples_per_file, len(new_chunk))
            print('Sampling %d of %d sequences.' % (sample_size, len(new_chunk)))
            new_chunk = new_chunk.sample(sample_size)
            
            if current_chunk is None:
                current_chunk = new_chunk
            else:
                current_chunk = pd.concat([current_chunk, new_chunk]) 
                
            while len(current_chunk) >= batch_size:
                
                batch_data = current_chunk.iloc[:batch_size]
                current_chunk = current_chunk.iloc[batch_size:]
                
                batch_encoded_seqs = encode_seqs(batch_data['seq'])
                batch_encoded_annotations = encode_annotations(batch_data['complete_go_annotation_indices'])
                
                yield batch_encoded_seqs, batch_encoded_annotations

# Encode target seqs

In [94]:
display(target_seqs)

Unnamed: 0,taxa_id,cafa_id,uniprot_name,seq,complete_go_annotation_indices,seq_len
49385,10090,T100900015051,TITIN_MOUSE,MTTQAPMFTQPLQSVVVLEGSTATFEAHVSGSPVPEVSWFRDGQVI...,[],35213
31041,9606,T96060017620,TITIN_HUMAN,MTTQAPTFTQPLQSVVVLEGSTATFEAHISGFPVPEVSWFRDGQVI...,[],34350
97089,6239,T62390003241,TTN1_CAEEL,MEGNEKKGGGLPPTQQRHLNIDTTVGGSISQPVSPSMSYSTDRETV...,"[135, 2818, 3561, 4178, 4205, 6231, 10256, 106...",18562
86722,7227,T72270003175,TITIN_DROME,MQRQNPNPYQQQNQQHQQVQQFSSQEYSHSSQEQHQEQRISRTEQH...,[],18141
24127,9606,T96060010706,MUC16_HUMAN,MLKPSGLPGSSSPTRSLMTGSRSTKATPEMDSGLTGATLSPKTSTG...,"[1541, 1542, 1545, 1573, 1696, 1749, 2001, 200...",14507
...,...,...,...,...,...,...
752,4577,T45770000753,UC22_MAIZE,IFFEV,[],5
30735,9606,T96060017314,TDB01_HUMAN,GTGG,[],4
31923,9606,T96060018502,TUFT_HUMAN,TKPR,[],4
57851,9823,T98230001354,TRH_PIG,QHP,[],3


In [120]:
TARGET_SEQS_CSV_FILE_PATH = '/cs/phd/nadavb/cafa_project/data/target_seqs_expanded_annotations.csv.gz'

def generate_target_seqs_batches(batch_size = 32, min_max_len = 1000):
    
    target_seqs = pd.read_csv(TARGET_SEQS_CSV_FILE_PATH)
    target_seqs['seq_len'] = target_seqs['seq'].str.len()
    target_seqs.sort_values('seq_len', inplace = True, ascending = False)
    
    while len(target_seqs) > 0:
        
        batch_seqs = target_seqs.iloc[:batch_size]
        target_seqs = target_seqs.iloc[batch_size:]
    
        batch_encoded_seqs = encode_seqs(batch_seqs['seq'], max_len = max(min_max_len, batch_seqs['seq_len'].max()))
        batch_encoded_annotations = encode_annotations(batch_seqs['complete_go_annotation_indices'])

        yield batch_seqs.index.values, batch_seqs['cafa_id'].values, batch_encoded_seqs, batch_encoded_annotations