In [1]:
seed = 123
import sys
from pathlib import Path
from typing import *
import shutil
import time
import pandas as pd
import pickle
import torch
import numpy as np
import os
from torch.utils.data import random_split, Subset
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

sys.path.insert(0, "..")
from dataset.data import (
    RNADataset,
    ROOT_DIR,
    EasyPosAugment,
    RegionSpecNegAugment,
    InteractionSelectionPolicy,
    EasyNegAugment,
    HardPosAugment,
    HardNegAugment,
    plot_sample,
    #plot_sample2,
    seed_everything,
    MAX_RNA_SIZE,
)

import haiku as hk
import jax
import jax.numpy as jnp
import nucleotide_transformer
from nucleotide_transformer.mypretrained import get_pretrained_model
import time
random_key = jax.random.PRNGKey(0)

In [2]:
pos_width_multipliers = {20:0.5, MAX_RNA_SIZE: 0.5}
pos_height_multipliers = pos_width_multipliers
neg_width_windows = {
    (50, 1000): 0.5,
    (1000, MAX_RNA_SIZE): 0.5,
}
neg_height_windows = neg_width_windows

In [14]:
rna_rna_files_dir = os.path.join(ROOT_DIR, "dataset", "rna_rna_pairs")
processed_files_dir = os.path.join(ROOT_DIR, "dataset", "processed_files")
nt_data_dir = os.path.join(processed_files_dir, "nt_data")
embedding_dir = os.path.join(nt_data_dir, "embeddings")
metadata_dir = os.path.join(nt_data_dir, "metadata")
nt_dir =  os.path.join(ROOT_DIR, "NT_dependencies")

In [4]:
set_data = 'val' 
assert set_data in ['training', 'val', 'test']

In [5]:
meta = {}
id_couple = 0
_SUBSET_SIZE: int = 1
seed_everything(seed)
for policy in (
    EasyPosAugment(
        per_sample=1,
        interaction_selection=InteractionSelectionPolicy.LARGEST,
        width_multipliers=pos_width_multipliers,
        height_multipliers=pos_height_multipliers,
    ),
    # EasyNegAugment(
    #     per_sample=1,
    #     width_windows=neg_width_windows,
    #     height_windows=neg_height_windows,
    # ),
    # HardPosAugment(
    #     per_sample=1,
    #     interaction_selection=InteractionSelectionPolicy.RANDOM_ONE,
    #     min_width_overlap=0.3,
    #     min_height_overlap=0.3,
    #     width_multipliers=pos_width_multipliers,
    #     height_multipliers=pos_height_multipliers,
    # ),
    # HardNegAugment(
    #     per_sample=1,
    #     width_windows=neg_width_windows,
    #     height_windows=neg_height_windows,
    # ),
    # RegionSpecNegAugment(
    #     per_sample=1,
    #     width_windows=pos_width_multipliers,
    #     height_windows=neg_height_windows,
    # ),
):
    start_time = time.time()

    dataset = RNADataset(
        gene_info_path=os.path.join(processed_files_dir, "df_genes.csv"),
        interactions_path=os.path.join(
            processed_files_dir, "final_df.csv"
        ),
        subset_file=os.path.join(
            rna_rna_files_dir, "gene_pairs_val.txt"
        ),
        augment_policies=[
            policy,
        ],
    )
    for i, sample in tqdm(enumerate(dataset)):
        d = {
            'id_sample':id_couple,
            'couple':sample.couple_id,
            'gene1':sample.gene1,
            'gene2':sample.gene2,
            'x1':sample.bbox.x1,
            'x2':sample.bbox.x2,
            'y1':sample.bbox.y1,
            'y2':sample.bbox.y2,
            'interacting':sample.interacting,
            'policty':sample.policy,
            'cdna1':sample.gene1_info["cdna"][sample.bbox.x1:sample.bbox.x2],
            'cdna2':sample.gene2_info["cdna"][sample.bbox.y1:sample.bbox.y2],
            'seed_x1':sample.seed_interaction_bbox.x1,
            'seed_x2':sample.seed_interaction_bbox.x2,
            'seed_y1':sample.seed_interaction_bbox.y1,
            'seed_y2':sample.seed_interaction_bbox.y2,
            'protein_coding1':sample.gene1_info["protein_coding"],
            'protein_coding2':sample.gene2_info["protein_coding"],
            'original_length1':len(sample.gene1_info["cdna"]),
            'original_length2':len(sample.gene2_info["cdna"]),
        }
        meta[id_couple] = d
        id_couple += 1
#         if sample.interacting:
#             sequences_pos.append()
#         else:
#             sequences_pos.append()
            
        if i == 1000:
            break
        #print(sample.gene1, sample.gene2)
        continue
meta = pd.DataFrame.from_dict(meta, 'index')
meta.to_csv(os.path.join(metadata_dir, f'{set_data}.csv'), index = False)

0it [00:00, ?it/s]

### NT Embeddings

In [38]:
def retrieve_embeddings(outs, tokens, layer):
    # Retrieve embeddings
    embeddings = outs[f"embeddings_{layer}"][:, 1:, :]  # removing CLS token
    padding_mask = jnp.expand_dims(tokens[:, 1:] != tokenizer.pad_token_id, axis=-1)
    masked_embeddings = embeddings * padding_mask  # multiply by 0 pad tokens embeddings
    sequences_lengths = jnp.sum(padding_mask, axis=1)
    mean_embeddings = jnp.sum(masked_embeddings, axis=1) / sequences_lengths
    return mean_embeddings

def infer(sequences, forward_fn, tokenizer, parameters, random_key):
    tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
    tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
    tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
    # Infer
    outs = forward_fn.apply(parameters, random_key, tokens)
    return outs, tokens

def save_data_to_folder(data, labels, ids, folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    for i in range(len(data)):
        sample = data[i]
        label = labels[i]
        id_sample = ids[i]
        if label == 0:
            class_folder = os.path.join(folder_path, 'class_0')
        else:
            class_folder = os.path.join(folder_path, 'class_1')
        if not os.path.exists(class_folder):
            os.makedirs(class_folder)
        np.save(os.path.join(class_folder, f'{id_sample}.npy'), sample)

In [46]:
batch_size = 12
n_batch = int(meta.shape[0]/batch_size)
slices = np.linspace(0, meta.shape[0], n_batch, dtype = np.int64)

In [44]:
embeddings_layers_to_save = (20, 24, 28, 32)

model_name = '2B5_multi_species'

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=embeddings_layers_to_save,
    attention_maps_to_save=(),
    max_positions=1000,
    chkpt_dir = os.path.join(nt_dir, 'checkpoints')
)
forward_fn = hk.transform(forward_fn)

In [None]:
start_time = time.time()
for i in tqdm(range(len(slices)-1)):
    
    meta_slice = meta[slices[i]:slices[i+1]]
    labels = list(meta[slices[i]:slices[i+1]].interacting.values.astype(int))
    ids = list(meta[slices[i]:slices[i+1]].id_sample.values)
    
    sequences1 = list(meta_slice.cdna1.values)
    outs1, tokens1 = infer(sequences1, forward_fn, tokenizer, parameters, random_key)
    
    sequences2 = list(meta_slice.cdna2.values)
    outs2, tokens2 = infer(sequences2, forward_fn, tokenizer, parameters, random_key)
    
    for layer in embeddings_layers_to_save:
        layer_folder = os.path.join(embedding_dir, str(layer))
        if not os.path.exists(layer_folder):
            os.makedirs(layer_folder)
            
        mean_embeddings1 = retrieve_embeddings(outs1, tokens1, layer) #shape is (batch_size, 2560)
        mean_embeddings2 = retrieve_embeddings(outs2, tokens2, layer) #shape is (batch_size, 2560)
        
        #concatenate the two embeddings (check if I am doing this properly, with the rigth axis)
        embeddings = np.concatenate((mean_embeddings1, mean_embeddings2), axis=1) #shape is (2*batch_size, 5120)
        
        save_data_to_folder(embeddings, labels, ids,  os.path.join(layer_folder, set_data))
print(f"Total time to process batch: {(time.time()-start_time)/60} minutes")