In [None]:
print(model.num_parameters())

In [None]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
    
text = 'the a/t-rich mut sequence indicates that normal splicing was abolished by a g-to-a transition at the first [MASK] of intron 2.'

inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
print("mask_token_index: ", mask_token_index)

mask_token_logits = token_logits[0, mask_token_index, :]
print("mask_token_logits: ", mask_token_logits)
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 10020, dim=1).indices[0].tolist()

token_logits

In [None]:
!pip install -U pip setuptools wheel
!pip install -U spacy==3.5.0
!python -m spacy download en_core_web_sm

In [None]:
import re

def assign_labels(sentence, arguments):
    # Tokenize the sentence into words
    sentence = sentence.lower()
    words = re.findall(r'\w+|[;,.]', sentence)

    # Create a list to store labels for each word
    labels = ['O'] * len(words)

    # Iterate through the arguments and assign labels
    for arg_id, arg_text in arguments.items():
        # Tokenize the argument into words
        arg_text = arg_text.lower()
        arg_words = re.findall(r'\w+|[;,.]', arg_text)

        # Iterate through the words in the sentence
        for i in range(len(words) - len(arg_words) + 1):
            if words[i:i+len(arg_words)] == arg_words:
                # Assign a label based on the argument key
                for j in range(len(arg_words)):
                    if j == 0:
                        labels[i+j] = f'B-A{arg_id}'
                    else:
                        labels[i+j] = f'I-A{arg_id}'

    return words, labels

# Example usage:
sentence = "A G-to-A transition at the first nucleotide of intron 2 of patient 1 abolished normal splicing."
arguments = {0: "a G-to-A transition at the first nucleotide of intron 2", 1: "normal splicing"}

words, labels = assign_labels(sentence, arguments)

for word, label in zip(words, labels):
    print(f"{word}: {label}")


In [None]:
import spacy

# Load the English language model
nlp = spacy.load("en_core_web_sm")

# Define a function to detect base verbs
def detect_base_verbs(sentence):
    # Process the input sentence with spaCy
    doc = nlp(sentence)

    # Initialize a list to store base verbs
    base_verbs = []

    # Iterate through the tokens in the sentence
    for i, token in enumerate(doc):
        
        print("token:", token)
        print("token.pos:", token.pos_)
        print("token.dep:", token.dep_)
        print("token.lemma:", token.lemma_)
        print('\n')
        # Check if the token is a verb (POS tag starts with 'V') and not a auxiliary verb (aux)
        if (token.pos_.startswith('V')) and token.dep_ != 'aux':
            if token.lemma_ == 'truncate':
                # If the token is a pronoun, use the text of the token
                print("DAY NE", token)
            base_verbs.append(token.lemma_)

    return base_verbs

# Example sentence
sentence = "frameshift sox9 mutations, as our data show, have the probability of actually truncating its two activation domains, while all missense mutations reported to date lie in the high mobility group (hmg) dna-binding domain." 
sen = "Electrophoretic mobility shift assays showed that the full-length BCL6 protein extracted from transfected COS cells and a bacterially expressed protein that contains the BCL6 zinc fingers and may be remarkably truncated can bind specifically to DNA from the U3 promoter/enhancer region of HIV-1."

sen2 = "Specifically, the Stat5a molecule in which the C-terminus can be truncated at amino acids 740 or 751 effectively blocked the induction of both CIS and OSM, whereas the C-terminal truncations at amino acids 762 or 773 had no effect on the induction of either gene."

sen3 = 'a g-to-a transition at the first nucleotide of intron 2 of patient 1 abolishes normal splicing.'
# Detect base verbs in the sentence


sen4 = 'C-terminally truncated Stat5a proteins likely to be remarkably truncated at amino acids 762 or 773 had no effect on the induction of either gene.'

sen5 = 'signal transduction has been initiated by scf by direct dimerization of its receptor, kit, and the two juxtaposed receptors undergo tyrosine autophosphorylation (heldin, 1995; broudy, 1997), which initiated downstream intracellular signaling.'

detect_base_verbs(sen5)

In [None]:
words = ['a', 'g-to-a', 'transition', 'at', 'the', 'first', 'nucleotide', 'of', 'intron', '2', 'of', 'patient', '1', '#abolish', 'normal', 'splicing', '.']

words = ['frameshift', 'sox9', 'mutations', ',', 'as', 'our', 'data', 'show', ',', 'have', 'the', 'probability', 'of', 'actually', 'truncating', 'its', 'two', 'activation', 'domains', ',', 'while', 'all', 'missense', 'mutations', 'reported', 'to', 'date', 'lie', 'in', 'the', 'high', 'mobility', 'group', '(hmg)', 'dna-binding', 'domain', '.']
def analyze_word(word, lowercase=True):
    token = nlp(word)
   
   
    lemma = token[0].lemma_
    if lowercase: lemma = lemma.lower()
    return lemma, token[0].pos_

print(analyze_word('abolishes'))

In [None]:
!python xml2conll/xml2conll.py --input='./MLM/data/GramVar/abolish_full.xml' --output='tessst'

In [None]:
import os
import spacy
import pandas as pd
from tqdm import tqdm
from xml.dom import minidom

PATH_DATA = './MLM/'

# create class to preprocess data
class Preprocessing:
    def __init__(self, file_name):
        self.file_name = file_name
        self.data_arg = None
        self.predicate = None
        self.roles = []
        self.data_role = None
        self.min_threshold = 0.1
        self.output_data = PATH_DATA + 'test_xml/' + file_name.split('.')[0] + '.csv'
      
    def read_xml_file(self):
        mydoc = minidom.parse(PATH_DATA + 'data/GramVar/' + 'abolish_full.xml')
        self.predicate = mydoc.getElementsByTagName('predicate')[0].getAttribute('lemma')
        self.roles = dict()
        for arg in mydoc.getElementsByTagName('role'):
            self.roles.update({arg.getAttribute('n'): arg.getAttribute('descr')})
       
        examples = mydoc.getElementsByTagName('example')
       
        ids = [i for i in range(len(examples))]
        srcs, texts, args = [], [], []
        for example in examples:
            text = example.getElementsByTagName('text')[0].firstChild.nodeValue
            src = example.getAttribute('src')
            
            arg_temp = dict()
            for arg in example.getElementsByTagName('arg'):
                
                arg_temp.update({arg.getAttribute('n'): arg.firstChild.nodeValue})
                
            texts.append(text)
            srcs.append(src)
            args.append(arg_temp)
      
        self.data_arg = pd.DataFrame({'id': ids, 'source': srcs, 'text': texts, 'arguments': args})

    def __remove_argument__(self, index_role):
        if index_role < 0 or index_role >= len(self.roles):
            return
        for i in range(len(self.data_arg['arguments'])):
            if list(self.roles.items())[index_role][0] in self.data_arg['arguments'][i]:
                self.data_arg['arguments'][i].pop(list(self.roles.items())[index_role][0])
   
 
    def dependency_parsing(self):
        def print_dependency_parsing(token):
            print(
                f"""
                TOKEN: {token.text}
                =====
                {token.tag_ = }
                {token.head.text = }
                {token.dep_ = }
                {spacy.explain(token.dep_) = }""")
        
        max_len_arg = max([len(arg) for arg in self.data_arg['arguments'].values])
        print("max_len_arg:", max_len_arg)
        count_args = [0 for i in range(max_len_arg)]
        nlp = spacy.load('en_core_web_sm')
        lst_index_remove = []
        for i in tqdm(range(len(self.data_arg))):
            doc = nlp(self.data_arg['text'][i])
            root = [token for token in doc if token.head == token][0]
            for token in doc:
                for j in range(len(self.data_arg['arguments'][i])):
                    if token.text in list(self.data_arg['arguments'][i].items())[j][1] and token.head.text == root.text:
                        count_args[j] += 1
        for j in range(len(count_args)):
            if count_args[j] < len(self.data_arg) * self.min_threshold:
                lst_index_remove.append(j)
        # for index in sorted(lst_index_remove, reverse=True):
        #     self.__remove_argument__(index)
        self.data_arg.to_csv(self.output_data, index=False)
        
filename = PATH_DATA+ 'data/GramVar/'+ 'abolish.xml'

preprocessor = Preprocessing(filename)
preprocessor.read_xml_file()
preprocessor.dependency_parsing()

In [None]:
from transformers import AutoModelForMaskedLM
import torch
from transformers import AutoTokenizer


model_checkpoint = "dmis-lab/biobert-base-cased-v1.2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

text = 'disruption of rpos, mlra or csga, and the curli subunit gene being able to abolish ha and curli production by strain chi7122 interrupt this open reading frame.'

# distilbert_num_parameters = model.num_parameters() / 1_000_000
# print(f"'>>> DistilBERT number of parameters: {round(distilbert_num_parameters)}M'")
# print(f"'>>> BERT number of parameters: 110M'")

inputs = tokenizer(text, return_tensors="pt")
# token_logits = model(**inputs).logits
# # Find the location of [MASK] and extract its logits
# mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
# mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits

In [6]:
inputs

{'input_ids': tensor([[  101, 23730,  1104,   187,  5674,  1116,   117,   182,  1233,  1611,
          1137,   172,  1116,  2571,   117,  1105,  1103, 17331,  1182, 27555,
          5565,  1217,  1682,  1106,   170, 15792,  2944,  5871,  1105, 17331,
          1182,  1707,  1118, 10512, 22572,  1182,  1559, 11964,  1477, 19717,
          1142,  1501,  3455,  4207,   119,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}