In [57]:
from transforna import GeneEmbeddModel,load
import yaml
import torch
mapping_dict_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json'
model = "Seq-Seq"
model_name = f"Yak-hbdx/{model}-TransfoRNA"
model_dir = f"/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/{model}/"
model_path = model_dir+"/ckpt/model_params_tcga.pt"
model_config_path = model_dir+"meta/hp_settings.yaml"
cfg = load(model_config_path)
mapping_dict = load(mapping_dict_path)
with open(model_dir+"/seq_tokens_ids_dict.yaml") as file:
    token_to_ids = yaml.load(file, Loader=yaml.FullLoader)
if 'struct' in model.lower():
    with open(model_dir+"/second_input_tokens_ids_dict.yaml") as file:
        second_input_token_to_ids = yaml.load(file, Loader=yaml.FullLoader)

    token_to_ids.update(second_input_token_to_ids)

In [58]:

#main_config = cfg#{"train_config":cfg['train_config'],"model_config":cfg['model_config']}
cfg["train_config"]["device"] = 'cpu'
cfg["mapping_dict"] = mapping_dict
model = GeneEmbeddModel(cfg)
#load state dict
model.load_state_dict(torch.load(model_path))
model.push_to_hub(model_name)
model = GeneEmbeddModel.from_pretrained(model_name)

In [59]:
from typing import List, Tuple, Dict
from transformers.tokenization_utils import PreTrainedTokenizer
import os
import numpy as np
class Tokenizer(PreTrainedTokenizer):

    model_input_names = ["input_ids"]#, "attention_mask"]
    do_upper_case: bool = True

    def __init__(
        self,
        do_upper_case: bool = True,
        model_max_length: int = 30,
        **kwargs,
    ):
        self._token_to_id = token_to_ids
        self._id_to_token = {id: token for token, id in self._token_to_id.items()}

        super().__init__(
            model_max_length=model_max_length,
            **kwargs,
        )
        self.do_upper_case = do_upper_case


    def _convert_id_to_token(self, index: int) -> str:
        return self._id_to_token.get(index, None)

    def _convert_token_to_id(self, token: str) -> int:
        return self._token_to_id.get(token, self._token_to_id.get(None))  # type: ignore[arg-type]

    def _tokenize(self, rnas: str, **kwargs):
        if self.do_upper_case:
            rnas = rnas.upper()
        return list(rnas)

    def get_vocab(self):
        return self._token_to_id.copy()

    def token_to_id(self, token: str) -> int:
        return self._token_to_id.get(token, self._token_to_id.get(None))  # type: ignore[arg-type]

    def id_to_token(self, index: int) -> str:
        return self._id_to_token.get(index, None)

    def save_vocabulary(self, save_directory: str, filename_prefix: str  = None):
        vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
        with open(vocab_file, "w") as f:
            f.write("\n".join(self.all_tokens))
        return (vocab_file,)
    
    @property
    def all_tokens(self) -> List[str]:
        return list(self.get_vocab().keys())

    @property
    def vocab_size(self) -> int:
        return len(self.all_tokens)
    
class RnaTokenizer(Tokenizer):
   
    model_input_names = ["input_ids"]#, "attention_mask"]

    def __init__(
        self,
        nmers: int = 2,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        model_max_length: int = 30,
        model_name: str = "",

        **kwargs,
    ):

        super().__init__(
            do_upper_case=do_upper_case,
            model_max_length=model_max_length,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.model_name = model_name.lower()

    def chunkstring_overlap(self, string):
        return (
            string[0 + i : self.nmers + i] for i in range(0, len(string) - self.nmers + 1, 1)
        )
    
    def _tokenize(self, rnas: str, **kwargs):
        if self.do_upper_case:
            rnas = rnas.upper()
        if self.replace_U_with_T:
            rnas = rnas.replace("U", "T")

        return list(self.chunkstring_overlap(rnas))
    
    def custom_roll(self,arr, n_shifts_per_row):
        '''
        shifts each row of a numpy array according to n_shifts_per_row
        '''
        from numpy.lib.stride_tricks import as_strided

        m = np.asarray(n_shifts_per_row)
        arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() #need `copy`
        strd_0, strd_1 = arr_roll.strides
        n = arr.shape[1]
        result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))

        return result[np.arange(arr.shape[0]), (n-m)%n]
    
    def __call__(
        self,
        rnas: str,
        return_tensors: str = "pt",
        padding: bool = "max_length",
        truncation: bool = True,
        **kwargs,
    ) -> Dict[str, List[int]]:
        seq_lens = np.array([len(rna) for rna in rnas])
        
        result =  super().__call__(
            rnas,
            return_tensors=return_tensors,
            padding=padding,
            truncation=truncation,
            **kwargs,
        )
        rna_token_ids = np.array(result["input_ids"])
        second_token_ids = np.zeros_like(rna_token_ids)

        if 'struct' in self.model_name:
            from transforna import fold_sequences
            rnas_ss = list(fold_sequences(rnas)['structure_37'].values)
            result =  super().__call__(
                rnas_ss,
                return_tensors=return_tensors,
                padding=padding,
                truncation=truncation,
                **kwargs,
            )
            second_token_ids = np.array(result["input_ids"])
        elif 'rev' in self.model_name:
            sample_token_ids_rev = rna_token_ids[:,::-1]
            n_zeros = np.count_nonzero(sample_token_ids_rev==0, axis=1)
            second_token_ids = self.custom_roll(sample_token_ids_rev, -n_zeros)
        
        elif 'seq-seq' in self.model_name:
            phase0 = rna_token_ids[:,::2]
            phase1 = rna_token_ids[:,1::2]
            #in case max_length is an odd number phase 0 will be 1 entry larger than phase 1 @ dim=1 
            if phase0.shape!= phase1.shape:
                phase1 = np.concatenate([phase1,np.zeros(phase1.shape[0])[...,np.newaxis]],axis=1)
            rna_token_ids = phase0
            second_token_ids = phase1
        else:
            #seq
            pass
            

        result['input_ids'] = torch.tensor(np.concatenate([rna_token_ids,second_token_ids,seq_lens[...,np.newaxis]],axis=1))

        return result
        

In [60]:
tokenizer = RnaTokenizer(model_max_length=29,model_name=model_name)
tokenizer.add_special_tokens({'pad_token': 'pad'})
x = tokenizer(['AACGAAGCTCGACTTTTAAGG'\
            ,'GTCCACCCCAAAGCGTAGG'])

In [61]:
x

{'input_ids': tensor([[ 5., 15.,  5., 14.,  1.,  4.,  2., 10.,  7.,  8.,  0.,  0.,  0.,  0.,
          0., 16.,  4.,  8.,  2., 15., 16., 10., 10.,  5., 13.,  0.,  0.,  0.,
          0.,  0., 21.],
        [ 9., 11., 16., 11., 12.,  5., 14.,  9.,  8.,  0.,  0.,  0.,  0.,  0.,
          0.,  1., 12., 11., 11.,  5.,  8., 15.,  7., 13.,  0.,  0.,  0.,  0.,
          0.,  0., 19.]], dtype=torch.float64)}

In [62]:
#save tokenizer and push to hub
tokenizer.push_to_hub(model_name)

CommitInfo(commit_url='https://huggingface.co/Yak-hbdx/Seq-Seq-TransfoRNA/commit/14ef8d30ec6ae0dd06304c096e0e800386ad2c21', commit_message='Upload tokenizer', commit_description='', oid='14ef8d30ec6ae0dd06304c096e0e800386ad2c21', pr_url=None, pr_revision=None, pr_num=None)

In [63]:
#load model and tokenizer
model = GeneEmbeddModel.from_pretrained(model_name)
model.eval()
tokenizer = RnaTokenizer.from_pretrained(model_name,model_name=model_name)
output = tokenizer(['AAAGTCGGAGGTTCGAAGACGATCAGATAC','TTTTCGGAACTGAGGCCATGATTAAGAGGG'])
gene_embedd, second_input_embedd, activations,attn_scores_first,attn_scores_second = model(output['input_ids'])
#get the idx of the maximum value in the gene_embedd tensor at each row
class_ids = torch.argmax(activations,dim=1).numpy()
class_labels = model.convert_ids_to_labels(class_ids)
#asset ['18S_bin-38', '18S_bin-33']
assert class_labels == ['18S_bin-38', '18S_bin-33'], print('\033[91m' + 'FAILED')
print('\033[92m' + 'PASSED')

[92mPASSED


In [64]:
#load model and tokenizer
model = GeneEmbeddModel.from_pretrained(model_name)
model.eval()
tokenizer = RnaTokenizer.from_pretrained(model_name,model_name=model_name)
output = tokenizer(['AACGAAGCTCGACTTTTAAGG','GTCCACCCCAAAGCGTAGG'])
gene_embedd, second_input_embedd, activations,attn_scores_first,attn_scores_second = model(output['input_ids'])
#get the idx of the maximum value in the gene_embedd tensor at each row
class_ids = torch.argmax(activations,dim=1).numpy()
class_labels = model.convert_ids_to_labels(class_ids)
major_class = model.convert_subclass_to_majorclass(class_labels)
assert class_labels == ['28S_bin-80', 'miR-629-3p'], print('\033[91m' + 'FAILED')
print('\033[92m' + 'PASSED')

[91mFAILED


AssertionError: None