# Setup

In [1]:
import numpy as np
import pandas as pd
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer, AutoModelForCausalLM, BartForConditionalGeneration, BartTokenizer, T5ForConditionalGeneration
from collections import OrderedDict
from torch.utils.data import Dataset
from os import walk, path
from tqdm import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
pretrained_model = ''
fields = []
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu"
)
generated_sequence = None
MAX_LEN = 700
model = None
model_id = None
# Load pre-trained model (weights)
# model_name = 'EleutherAI/gpt-neo-125M'
model_name = 'google/t5-v1_1-base'
# model_name = 'facebook/bart-base'
# model = T5ForConditionalGeneration.from_pretrained(
#     model_name,
#     output_attentions=True,
#     return_dict=True
# )
model_name = 'gpt2'
# tokenizer = AutoTokenizer.from_pretrained(model_name)

# tokenizer = BartTokenizer.from_pretrained(model_name)
# model = BartForConditionalGeneration.from_pretrained(
#     model_name,
#     output_attentions=True,
#     return_dict=True
# )


tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True)
tokenizer.add_special_tokens({
    'pad_token': '<PAD>',
    'bos_token': '<BOS>',
    'eos_token': '<EOS>',
    'sep_token': '<SEP>',
    'additional_special_tokens': ['<SEPO>']
})

model.resize_token_embeddings(len(tokenizer))
checkpoint_name = 'checkpoint_gpt2_splitted_500_entities-epoch=02-val_loss=1.16.ckpt'
checkpoint = torch.load('../checkpoints/' + checkpoint_name, map_location='cpu')
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if k[:6] == 'model.':
        name = k[6:]
    else:
        name = k
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
# # torch.save(self.model_3.state_dict(), '/content/drive/MyDrive/Muteffstage/Checkpoints/only-mutation-epoch=09.ckpt')


# # # model 2
# # model_3.load_state_dict(
# # torch.load('/content/drive/MyDrive/Tesi Polimi/GEO-metadata-translator-master/Checkpoints/checkpoint_2_hs_at_an-epoch=17-val_loss=0.156.ckpt')
# # )

model.eval()
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50262, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )


# inference entities classic


In [7]:
with torch.no_grad():
  filenames = next(walk('../data/trainingdata_v3/dev/'), (None, None, []))[2]
  filenames_txt = [filename for filename in filenames if 'txt' in filename]
  filenames_txt.sort()
  for filename_txt in tqdm(filenames_txt):
    error_log = ''
    df_ann = pd.read_csv('../data/trainingdata_v3/dev/' + filename_txt[:-4] + '.ann', sep ='\t', names=['entity-event-context', 'classification-type', 'value'])
    df_entities = pd.DataFrame()
    if len(df_ann.apply(lambda row: row['value'].lower() if row['entity-event-context'][0]== 'T' else '', axis=1)) > 0:
      df_entities['entities'] = df_ann.apply(lambda row: row['value'].lower() if row['entity-event-context'][0]== 'T' else '', axis=1)
      df_entities = df_entities[df_entities['entities'] != '']
      entities_trg = df_entities.drop_duplicates(['entities']).loc[:, 'entities'].tolist()
    else:
      entities_trg = []
    # print(entities_trg)
    with open('../data/trainingdata_v3/dev/'+filename_txt) as text_file:
      original_text = ''.join(text_file.readlines())
    # print(original_text)
    # replace space new lines and tabs with spaces the drop duplicate spaces
    text = original_text.replace('\t', ' ').replace('\n', ' ')
    text_filtered = ''
    for char_index, char in enumerate(text):
        if char_index < (len(text)-1):
            if not(text[char_index + 1] == ' ' and char == ' '):
                text_filtered += char
    text_filtered = text_filtered.lower() # uncase the text
    # print(text_filtered)
    attribute = 'medications: '
    # input_text = bos + text_filtered + sep + attribute
    input_text_ids = tokenizer.encode(
        text_filtered,
        return_tensors='pt'
    )

    # If a text is longer than 900 tokens I create slices to divide it
    text_slices = []
    start_slice = 0
    max_lenght = 500
    for end_slice in range(max_lenght, input_text_ids.shape[1], max_lenght):
      text_slices.append([start_slice, end_slice])
      start_slice = end_slice
    text_slices.append([start_slice, input_text_ids.shape[1]])
    # print(text_slices)

    # gpt2 small can only manage 1000 tokens, I divide the text in more parts and
    # give them as different input and the aggregate the results
    entity_list = []
    for text_slice in text_slices:

      input_ids = torch.cat(
          (
              torch.tensor([[tokenizer.bos_token_id]]),
              input_text_ids[:, text_slice[0]:text_slice[1]],
              torch.tensor([[tokenizer.sep_token_id]]),
              tokenizer.encode(attribute, return_tensors='pt')
          ),
          dim=-1
      )

      input_ids = input_ids.to(device)
      # print('inputs ids length:', input_ids.shape[1])

      results = tokenizer.decode(model.generate(
        input_ids,
        max_length=max_lenght + 250,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        # num_beanum_beams=2,
        # top_k=0,
        # top_p=0,
        # do_sample=False,
        # repetition_penalty=3.,
        # length_penalty=0.1,
        # early_stopping=True,
      )[0])
      # print(results)
      results = results.split('<SEP>')[1]
      # here I postprocessing the results droping duplicates and filtering empty results
      print(results)
      results = results.split('medications: ')[1]
      
      if '<s>' in results:
        results = results.split('<s>')[1].split('<EOS>')[0]
      if '<EOS>' in results:
        results = results.split('<EOS>')[0]
      else:
        results = ','.join(results.split(',')[:-1])
      # else:
      #   results = results.split('<EOS>')[0]
      # print(results)
      results = results.split(',')
      # print(results)
      results = list(set(results))
      results = [result.strip() for result in results if result.strip() != '']
      # print('results without duplicates:', results)
      entity_list += results

    # for each entity founded with the model I search the positions in the text
    ann_text = ''
    results_with_pos =[]
    index = 0

    for entity in entity_list:   
      for punctuation in [',', '.', ' ', '\n', '\t', ':', ';', '(', '[', '{']:
        search_start = 0
        # if entity == 'ativan':
        #   print(f"entity : {entity} with pucntuation: {punctuation}  was found with start point {original_text.lower().find(punctuation + entity, search_start)}")
        while original_text.lower().find(punctuation + entity, search_start) != -1:
          start = original_text.lower().find(punctuation + entity, search_start) + 1
          end = start + len(entity)
          search_start = end
          pos = [start, end]
          results_with_pos.append([entity, pos])
          ann_text += 'T'+ str(index) + '\t' + 'Drug' + ' ' + str(start) + ' ' + str(end) + '\t' + entity + '\n'
          index += 1


    error_log += 'target entities: ' + ','.join(entities_trg) + '\n'
    error_log += 'predicted entities: ' + ','.join(list(set(list(zip(*results_with_pos))[0]))) + '\n' if len(results_with_pos)>0 else 'predicted entities: \n'
    error_log += '\n'
    if len(results_with_pos) > 0:
      for result in list(set(list(zip(*results_with_pos))[0])):
        if result not in entities_trg:
          error_log += filename_txt[:-4] + '\t' +\
            'wrong prediction: ' + result + '\n'
    
    for entity_trg in entities_trg:
      if len(results_with_pos) > 0: 
        if entity_trg not in list(zip(*results_with_pos))[0]:
          error_log += filename_txt[:-4] + '\t' +\
            'missing prediction: ' + entity_trg +'\n'
    error_log += '********************original_text*******************' + '\n'
    error_log += original_text + '\n'

    for text_slice_index, text_slice in enumerate(text_slices):
      # print(tokenizer.decode(input_text_ids[:, text_slice[0]:text_slice[1]][0]))
      error_log += '*********** text_slice_' + str(text_slice_index) + ' ***********\n' 
      error_log += tokenizer.decode(input_text_ids[:, text_slice[0]:text_slice[1]][0]) + '\n'

    if not path.exists('../data/trainingdata_v3/error_logs/'):
      os.mkdir('../data/trainingdata_v3/error_logs/')

    log_dir_path = '../data/trainingdata_v3/error_logs/'+ checkpoint_name + '/'
    if not path.exists(log_dir_path):
      os.mkdir(log_dir_path)
    
    with open(log_dir_path + filename_txt[:-4] + '_error_log' + '.txt' , 'w') as file:
      file.write(error_log)

    inference_dir_path = '../data/trainingdata_v3/'+'inference/'
    # print(ann_text)
    if not path.exists(inference_dir_path):
      os.mkdir(inference_dir_path)
    with open(inference_dir_path + filename_txt[:-4] + '.ann' , 'w') as ann_file:
      ann_file.write(ann_text)

  4%|▍         | 2/50 [00:00<00:06,  7.39it/s]

<BOS> record date: 2106-02-12 campbell orthopedic associates 4 madera circle omak, ga 28172 habib valenzuela, m.d. valdez, harlan jr. 845-41-54-4 february 12, 2106 har is a 43 year old 6' 214 pound gentleman who is referred for consultation by dr. harlan oneil. about a week ago he slipped on the driveway at home and sustained an injury to his left ankle. he was seen at tri-city hospital and was told he had a fracture. he was placed in an air splint and advised to be partial weight bearing, and he is using a cane. he is here for routine follow-up. past medical history is notable for no ankle injuries previously. he has a history of diabetes and sleep apnea. he takes prozac, cardizem, glucophage and amaryl. he is also followed by dr. harold nutter for an arrhythmia. he does not smoke. he drinks minimally. he is a set designer at columbia pictures. on examination today he has slight tenderness of the left ankle about four fingerbreadths above the malleolus. the malleolus is non-tender med

  6%|▌         | 3/50 [00:00<00:13,  3.43it/s]

<BOS> 2. sleep disordered breathing: as evidenced during prior polysomnographic evaluations, mostly of obstructive and or mixed hypopnea. the patient appears largely refractory to a trial of cpap therapy, particularly in so far as he demonstrates associated claustrophobic symptoms in association with it's usage, despite relatively modest cpap water pressures (6 cm). in addition, he has tried various nasal cpap face mask, including the mallinckrodt "breeze" supportive head gear with "nasal pillows" and with limited success. one might consider repeating a polysomnographic evaluation in the future, and if so, utilizing a potential trial of bipap titration, which may help to improve claustrophobic symptoms, but the patient will still be left with the issues referable to "tangled tubing at night" and issues referable to nasal face mask usage, as noted above. 3. relative difficulties in sleep reinitiation and maintenance: the patient describes at least 2-4 early morning awakenings with diffi

  8%|▊         | 4/50 [00:00<00:12,  3.66it/s]

<BOS>ine 25 qhs, mirtazapine 45 qd meds on transfer: please see green sheets medications: asa, lipitor 20, lopressor 50 bid, folate, norvasc 5 qd; lithium, 300 bid; depakote 500 bid; sonata 10 mg qhs, doxylamine 25 qhs, mirtazapine 45 qd allergies: nkda family history: family h/o cad social history: no etoh, no tob, no illicits review of systems: per hpi allergies: nkda family history: family h/o cad social history: no etoh, no tob, no illicits review of systems: per hpi ccu course + plan: 1) cards a. rhythm - on night of admission patient was started on an esmolol drip as well as amio bloused and rhythm converted to nsr. esmolol drip as well as amio was stopped and bb was escalated and patient has remained in nsr. i. ramp up lopressor as tolerated by bp b. pump - patient has remained euvolemic and had a echo with ef 84% and aortic stenosis c. ischemia - was stented x 2 to the prox rca lesion and was on integrilin x 24hrs prior. he was started on plavix. i. cont plavix, lopressor, lisi

 12%|█▏        | 6/50 [00:01<00:12,  3.39it/s]

<BOS>5 11.5-14.5 % 12/13/11 09:51 other hematology esr 14 0-17 mm/hr 12/13/11 19:51 urinalysis negative therapeutic drugs therapeutic drug monitoring lithium <0.10l 0.50-1.50 mmol/l 12/13/11 20:40 <0.10(l) 12/13/11 20:40 mri brain: acute/subacute infarcts (dwi bright/adc dark/flair bright) in l cerebellum (punctate) and r precentral gyrus (small elliptical area). ct-a head/neck: has aberrant origin of r vert from cca, both aca's come off of l carotid, also w/ bilateral fetal pca's and likely congenitally small vertebrobasilar vessels. no significant focal stenoses or atheromatous calcifications. ekg: pending mri l/s spine: negative impression: 49yorhm w/ pmh signif for cad, dm, bipolar disorder, afib, s/p recent cardiac cath who presents w/ lue weakness 10d prior to admission resolving after 4-5 days, and rle weakness?/sensory deficit? imaging reveals r precentral gyrus small infarct, l cerebellar infarct, no significant vessel stenoses. neuro exam w/ brisker reflexes on r, equivocal r

 14%|█▍        | 7/50 [00:01<00:12,  3.44it/s]

<BOS> record date: 2084-12-27 jar night-float admission note internal medicine patient: mae paul mrn: 1005708 date of admission: 12/27/2084 renal attending: dr. john cc: cough, sob, failed z-pk x2 hpi: pt is a 73yo female with esrd on hs s/p kidney transplant x2. she has had a cough for over the past week. no fever; no sputum production. also reports 1 day of sore throat, now resolved and nasal congestion forh te past week. her grandson (age 4) has had a cough/head cold. her daughter is a surgeon and has prescribed her a z-pk x 2 which she has failed and yesterday gave her a dose of levaquin 500mg on 12/26/2084. last hd on 12/22; missed on 12/25. on presentation to the ed her vs were: t=97.9 p=118 bp=173/93 sat=91%ra. saturations decreased to 91%on 4l in setting of sbp to 221/119. she was treated with iv lopressor 2.5 x2 and lopressor 12.5 po. saturation improved to 95%4l. renal team evaluated her and she will have dialysis this am. she was admitted to floor for further evaluation. ros

 16%|█▌        | 8/50 [00:02<00:13,  3.15it/s]

<BOS>. elevated serum cholesterol on atorva 10 mg ldl <100. fasting lipids done today. benign prostatic hypertrophy significant sx from bph, helped some by cardura, but retention and turp 3/86 by fonseca. residual prostate 1 to 2+ w/o nodules. currently on proscar w sx benefit. health maintenance screening flex sig done 7/84 neg by xian. quinlivan screening flex sig 10/30 neg x tics. stool cards given. skin lesion many actinic lesions. bce excised 12/89 at st francis. note that brother has had melanoma. dr. tomlin excised ssc from scalp 3/93-->reexcision. 7/94 will see him again.. allergies * nkda past medical history stopped smoking about 2061. no etoh. no known med allergies. family history mother died age 90. father died 37, pneumonia. two full brothers, one with dm, who is s/p thr, 1 a&w. two half brothers and one half sister, all a&w. one of his brothers had had melanoma. social history insurance claims examiner with desk job in new marlborough. married with 4 children. currently 




KeyboardInterrupt: 

In [None]:
# spitted with spacy

with torch.no_grad():
  filenames = next(walk('../data/trainingdata_v3/dev/'), (None, None, []))[2]
  filenames_txt = [filename for filename in filenames if 'txt' in filename]
  filenames_txt.sort()
  for filename_txt in tqdm(filenames_txt):
    error_log = ''
    df_ann = pd.read_csv('../data/trainingdata_v3/dev/' + filename_txt[:-4] + '.ann', sep ='\t', names=['entity-event-context', 'classification-type', 'value'])
    df_entities = pd.DataFrame()
    if len(df_ann.apply(lambda row: row['value'].lower() if row['entity-event-context'][0]== 'T' else '', axis=1)) > 0:
      df_entities['entities'] = df_ann.apply(lambda row: row['value'].lower() if row['entity-event-context'][0]== 'T' else '', axis=1)
      df_entities = df_entities[df_entities['entities'] != '']
      entities_trg = df_entities.drop_duplicates(['entities']).loc[:, 'entities'].tolist()
    else:
      entities_trg = []
    # print(entities_trg)
    with open('../data/trainingdata_v3/dev/'+filename_txt) as text_file:
      text = ''.join(text_file.readlines())
    
    doc = nlp(text)
    max_lenght = 750
    text_slices = []
    sent_index = 0
    sents = [text_preprocess(sent.text) for sent in doc.sents]
    text_slice = ''
    for sent in sents:
        if len(tokenizer.encode(text_slice + sent)) > max_lenght:
            text_slices.append(text_slice)
            text_slice = sent
        else:
            text_slice += sent
    text_slices.append(text_slice)

    # gpt2 small can only manage 1000 tokens, I divide the text in more parts and
    # give them as different input and the aggregate the results
    attribute = 'medications: '
    entity_list = []
    for text_slice in text_slices:

      input_ids = torch.cat(
          (
              torch.tensor([[tokenizer.bos_token_id]]),
              tokenizer.encode(text[:len(text_slice)], return_tensors='pt'),
              torch.tensor([[tokenizer.sep_token_id]]),
              tokenizer.encode(attribute, return_tensors='pt')
          ),
          dim=-1
      )

      input_ids = input_ids.to(device)
      # print('inputs ids length:', input_ids.shape[1])

      results = tokenizer.decode(model.generate(
        input_ids,
        max_length=max_lenght + 250,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        # num_beams=2, 
        # early_stopping=True,
        top_k=1
      )[0])
      print(results) 
      results = results.split('<SEP>')[1]
      
      # here I postprocessing the results droping duplicates and filtering empty results
      results = results.split('medications: ')[1]
      # print(results)
      if '<s>' in results:
        results = results.split('<s>')[1].split('<EOS>')[0]
      if '<EOS>' in results:
        results = results.split('<EOS>')[0]
      else:
        results = ','.join(results.split(',')[:-1])
      # else:
      #   results = results.split('<EOS>')[0]
      # print(results)
      results = results.split(',')
      # print(results)
      results = list(set(results))
      results = [result.strip() for result in results if result.strip() != '']
      # print('results without duplicates:', results)
      entity_list += results

    # for each entity founded with the model I search the positions in the text
    ann_text = ''
    results_with_pos =[]
    index = 0

    for entity in entity_list:   
      for punctuation in [',', '.', ' ', '\n', '\t', ':', ';', '(', '[', '{']:
        search_start = 0
        # if entity == 'ativan':
        #   print(f"entity : {entity} with pucntuation: {punctuation}  was found with start point {original_text.lower().find(punctuation + entity, search_start)}")
        while original_text.lower().find(punctuation + entity, search_start) != -1:
          start = original_text.lower().find(punctuation + entity, search_start) + 1
          end = start + len(entity)
          search_start = end
          pos = [start, end]
          results_with_pos.append([entity, pos])
          ann_text += 'T'+ str(index) + '\t' + 'Drug' + ' ' + str(start) + ' ' + str(end) + '\t' + entity + '\n'
          index += 1


    error_log += 'target entities: ' + ','.join(entities_trg) + '\n'
    error_log += 'predicted entities: ' + ','.join(list(set(list(zip(*results_with_pos))[0]))) + '\n' if len(results_with_pos)>0 else 'predicted entities: \n'
    error_log += '\n'
    if len(results_with_pos) > 0:
      for result in list(set(list(zip(*results_with_pos))[0])):
        if result not in entities_trg:
          error_log += filename_txt[:-4] + '\t' +\
            'wrong prediction: ' + result + '\n'
    
    for entity_trg in entities_trg:
      if len(results_with_pos) > 0: 
        if entity_trg not in list(zip(*results_with_pos))[0]:
          error_log += filename_txt[:-4] + '\t' +\
            'missing prediction: ' + entity_trg +'\n'
    error_log += '********************original_text*******************' + '\n'
    error_log += text + '\n'

    for text_slice_index, text_slice in enumerate(text_slices):
      # print(tokenizer.decode(input_text_ids[:, text_slice[0]:text_slice[1]][0]))
      error_log += '*********** text_slice_' + str(text_slice_index) + ' ***********\n' 
      error_log += text_slice + '\n'

    if not path.exists('../data/trainingdata_v3/error_logs/'):
      os.mkdir('../data/trainingdata_v3/error_logs/')

    log_dir_path = '../data/trainingdata_v3/error_logs/'+ checkpoint_name + '/'
    if not path.exists(log_dir_path):
      os.mkdir(log_dir_path)
    
    with open(log_dir_path + filename_txt[:-4] + '_error_log' + '.txt' , 'w') as file:
      file.write(error_log)

    inference_dir_path = '../data/trainingdata_v3/'+'inference/'
    # print(ann_text)
    if not path.exists(inference_dir_path):
      os.mkdir(inference_dir_path)
    with open(inference_dir_path + filename_txt[:-4] + '.ann' , 'w') as ann_file:
      ann_file.write(ann_text)

# inference spacy entities


In [None]:
# splitted with spacy
def text_preprocess(text):
    for iter in range(10):
      text = text.replace('\n\n', '\n').replace('\n ', '\n').replace(' \n', '\n').replace('\t', ' ')
    text_filtered = ''
    for char_index, char in enumerate(text):
        if char_index < (len(text)-1):
            if not(text[char_index + 1] == ' ' and char == ' '):
                text_filtered += char
    return text_filtered.lower()

In [None]:
# context inference

with torch.no_grad():
  filenames = next(walk('../data/trainingdata_v3/dev/'), (None, None, []))[2]
  filenames_txt = [filename for filename in filenames if 'txt' in filename]
  filenames_ann = [filename for filename in filenames if 'ann' in filename]
  filenames_txt.sort()
  filenames_ann.sort()
  df_dataset = pd.read_csv('../data/trainingdata_v3/datasets/dev_dataset_gpt2_disposition.tsv', sep ='\t')
  
  for filename_txt, filename_ann in tqdm(zip(filenames_txt, filenames_ann)):
    a_index = 0
    # print(filename_txt)
    
    # print(original_text)
    # df_ann = pd.read_csv('../data/trainingdata_v3/dev/' + filename_ann, sep='\t', names=['entity-event-context', 'classification-type', 'value'])
    if len(df_dataset.loc[df_dataset['filename'] == filename_ann[:-4]]) > 0:
      df_ann = df_dataset.loc[df_dataset['filename'] == filename_ann[:-4]]
    else:
      df_ann = pd.DataFrame()
    ann_text = ''
    t_index = 0
    for index, row in df_ann.iterrows():
      text = row['text']

      attribute = 'medication: ' + row['value'] + '<SEPO>' + 'disposition:'
      # input_text = bos + text_filtered + sep + attribute
      input_text_ids = tokenizer.encode(
          text,
          return_tensors='pt',
          truncation=True,
          max_length=950
      )

      input_ids = torch.cat(
          (
              torch.tensor([[tokenizer.bos_token_id]]),
              input_text_ids,
              torch.tensor([[tokenizer.sep_token_id]]),
              tokenizer.encode(attribute, return_tensors='pt')
          ),
          dim=-1
      )

      input_ids = input_ids.to(device)
      # print('inputs ids length:', input_ids.shape[1])

      results = tokenizer.decode(model.generate(
        input_ids,
        max_length=1020,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        # num_beams=2, 
        # early_stopping=True,
        top_k=1
      )[0])
      # print(row['disposition-type'])
      results = results.split('<SEP>')[1]
      # print('model prediction:', results)
      ctxs = ['action', 'actor', 'negation', 'temporality', 'certainty']
      for ctx in ctxs:
        assert ctx in results, f'ctx: {ctx} is not in results'
      ann_text += 'T'+ str(t_index) + '\t' + 'Drug' + ' ' + str(row['start']) + ' ' + str(row['end']) + '\t' + str(row['value']) + '\n'
      ann_text += 'E'+ str(t_index) + '\t' + 'Disposition' + ':' + 'T'+ str(t_index) +'\n'
      for ctx in ctxs:
        prediction = results.split(ctx + ': ')[1].split('<SEPO>')[0].split('<EOS>')[0]
        ann_text += 'A' + str(a_index) + '\t' + ctx[0].upper() + ctx[1:] + ' ' + 'E'+ str(t_index) + ' ' + prediction +'\n'
        a_index += 1
      
      t_index += 1

    inference_dir_path = '../data//trainingdata_v3/'+'inference-context/'
    # print(ann_text)
    if not path.exists(inference_dir_path):
      os.mkdir(inference_dir_path)
    with open(inference_dir_path + filename_txt[:-4] + '.ann' , 'w') as ann_file:
      ann_file.write(ann_text)
    

In [None]:
# disposition inference

with torch.no_grad():
  filenames = next(walk('../data/trainingdata_v3/dev/'), (None, None, []))[2]
  filenames_txt = [filename for filename in filenames if 'txt' in filename]
  filenames_ann = [filename for filename in filenames if 'ann' in filename]
  filenames_txt.sort()
  filenames_ann.sort()
  df_dataset = pd.read_csv('../data/trainingdata_v3/datasets/dev_dataset_gpt2_disposition_only.tsv', sep ='\t')
  
  for filename_txt, filename_ann in tqdm(zip(filenames_txt, filenames_ann)):
    a_index = 0
    # print(filename_txt)
    
    # print(original_text)
    # df_ann = pd.read_csv('../data/trainingdata_v3/dev/' + filename_ann, sep='\t', names=['entity-event-context', 'classification-type', 'value'])
    if len(df_dataset.loc[df_dataset['filename'] == filename_ann[:-4]]) > 0:
      df_ann = df_dataset.loc[df_dataset['filename'] == filename_ann[:-4]]
    else:
      df_ann = pd.DataFrame()
    ann_text = ''
    t_index = 0
    for index, row in df_ann.iterrows():
      text = row['text']

      attribute = 'medication: ' + row['value'] + '<SEPO>' + 'disposition: '
      # input_text = bos + text_filtered + sep + attribute
      input_text_ids = tokenizer.encode(
          text,
          return_tensors='pt',
          truncation=True,
          max_length=950
      )
      print(input_text_ids.shape)
      input_ids = torch.cat(
          (
              torch.tensor([[tokenizer.bos_token_id]]),
              input_text_ids,
              torch.tensor([[tokenizer.sep_token_id]]),
              tokenizer.encode(attribute, return_tensors='pt')
          ),
          dim=-1
      )

      input_ids = input_ids.to(device)
      # print('inputs ids length:', input_ids.shape[1])

      results = tokenizer.decode(model.generate(
        input_ids,
        max_length=1020,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        # num_beams=2, 
        # early_stopping=True,
        top_k=1
      )[0])
      # print(results)
      results = results.split('<SEP>')[1]
      # print('model prediction:', results)
      prediction = results.split('disposition' + ': ')[1].split('<SEPO>')[0].split('<EOS>')[0].strip()
      # if prediction == 'Undetermined':
        # print('undeterminated trovato')
      ann_text += 'T'+ str(t_index) + '\t' + 'Drug' + ' ' + str(row['start']) + ' ' + str(row['end']) + '\t' + str(row['value']) + '\n'
      ann_text += 'E'+ str(t_index) + '\t' + prediction + ':' + 'T'+ str(t_index) +'\n'
      
      t_index += 1

    inference_dir_path = '../data/trainingdata_v3/'+'inference-disposition/'
    # print(ann_text)
    if not path.exists(inference_dir_path):
      os.mkdir(inference_dir_path)
    with open(inference_dir_path + filename_txt[:-4] + '.ann' , 'w') as ann_file:
      ann_file.write(ann_text)
    

# T5 entities


In [24]:
#inference entities


filenames = next(walk('../data/trainingdata_v3/dev'), (None, None, []))[2]
filenames_txt = [filename for filename in filenames if 'txt' in filename]
filenames_txt.sort()
# print(filenames_txt)
for filename_txt in tqdm(filenames_txt):
  error_log = ''
  df_ann = pd.read_csv('../data/trainingdata_v3/dev/' + filename_txt[:-4] + '.ann', sep ='\t', names=['entity-event-context', 'classification-type', 'value'])
  df_entities = pd.DataFrame()
  if len(df_ann.apply(lambda row: row['value'].lower() if row['entity-event-context'][0]== 'T' else '', axis=1)) > 0:
    df_entities['entities'] = df_ann.apply(lambda row: row['value'].lower() if row['entity-event-context'][0]== 'T' else '', axis=1)
    df_entities = df_entities[df_entities['entities'] != '']
    entities_trg = df_entities.drop_duplicates(['entities']).loc[:, 'entities'].tolist()
  else:
    entities_trg = []
  with open('../data/trainingdata_v3/dev/'+filename_txt) as text_file:
    original_text = ''.join(text_file.readlines())
  # print(original_text)
  # replace space new lines and tabs with spaces the drop duplicate spaces
  text = original_text.replace('\t', ' ').replace('\n', ' ')
  text_filtered = ''
  for char_index, char in enumerate(text):
      if char_index < (len(text)-1):
          if not(text[char_index + 1] == ' ' and char == ' '):
              text_filtered += char
  text_filtered = text_filtered.lower() # uncase the text
  # print(text_filtered)
  attribute = 'medications: '
  # input_text = bos + text_filtered + sep + attribute
  input_text_ids = tokenizer.encode(
      text_filtered,
      return_tensors='pt',
      truncation=True,
      max_length=750,
      add_special_tokens=True
  )

  # If a text is longer than 900 tokens I create slices to divide it
  text_slices = []
  start_slice = 0
  max_lenght = 550
  for end_slice in range(max_lenght, input_text_ids.shape[1], max_lenght):
    text_slices.append(tokenizer.decode(input_text_ids[0, start_slice:end_slice], skip_special_tokens=True))
    start_slice = end_slice
  text_slices.append(tokenizer.decode(input_text_ids[0,start_slice:], skip_special_tokens=True))

  # print(text_slices)
  # break
  # gpt2 small can only manage 1000 tokens, I divide the text in more parts and
  # give them as different input and the aggregate the results
  entity_list = []
  for text_slice in text_slices:

    # input_ids = torch.cat(
    #     (
    #         tokenizer(attribute + tokenizer.decode(input_text_ids[0, text_slice[0]:text_slice[1]]), return_tensors='pt',add_special_tokens=True,)['input_ids'],
            
    #     ),
    #     dim=-1
    # )
    # print('inputs ids length:', input_ids.shape[1])

    results = model.predict(
      attribute + text_filtered,
      top_k=0
      beam=1,
      num_beams=1
    )[0]
    # results = results.split('<SEP>')[1]
    # here I postprocessing the results droping duplicates and filtering empty results
    # print(results)
    # print('dopo parsing:', results)
    results = results.split(',')
    # print(results)
    results = list(set(results))
    results = [result.strip() for result in results if result.strip() != '']
    # print('results without duplicates:', results)
    entity_list += results

  ann_text = ''
  results_with_pos =[]
  index = 0

  for entity in entity_list:   
    for punctuation in [',', '.', ' ', '\n', '\t', ':', ';', '(', '[', '{']:
      search_start = 0
      # if entity == 'ativan':
      #   print(f"entity : {entity} with pucntuation: {punctuation}  was found with start point {original_text.lower().find(punctuation + entity, search_start)}")
      while original_text.lower().find(punctuation + entity, search_start) != -1:
        start = original_text.lower().find(punctuation + entity, search_start) + 1
        end = start + len(entity)
        search_start = end
        pos = [start, end]
        results_with_pos.append([entity, pos])
        ann_text += 'T'+ str(index) + '\t' + 'Drug' + ' ' + str(start) + ' ' + str(end) + '\t' + entity + '\n'
        index += 1

  # print(results_with_pos)
  # print(list(zip(*results_with_pos)))
  error_log += 'target entities: ' + ','.join(entities_trg) + '\n'
  error_log += 'predicted entities: ' + ','.join(list(set(list(zip(*results_with_pos))[0]))) + '\n' if len(results_with_pos)>0 else 'predicted entities: \n'
  error_log += '\n'
  if len(results_with_pos) > 0:
    for result in list(set(list(zip(*results_with_pos))[0])):
      if result not in entities_trg:
        error_log += filename_txt[:-4] + '\t' +\
          'wrong prediction: ' + result + '\n'
  
  for entity_trg in entities_trg:
    if len(results_with_pos) > 0: 
      if entity_trg not in list(zip(*results_with_pos))[0]:
        error_log += filename_txt[:-4] + '\t' +\
          'missing prediction: ' + entity_trg +'\n'
  error_log += '********************original_text*******************' + '\n'
  error_log += original_text + '\n'

  # for text_slice_index, text_slice in enumerate(text_slices):
    # print(tokenizer.decode(input_text_ids[:, text_slice[0]:text_slice[1]][0]))
    # error_log += '*********** text_slice_' + str(text_slice_index) + ' ***********\n' 
    # error_log += tokenizer.decode(input_text_ids[:, text_slice[0]:text_slice[1]][0]) + '\n'

  if not path.exists('../data/trainingdata_v3/error_logs/'):
    os.mkdir('../data/trainingdata_v3/error_logs/')

  log_dir_path = '../data/trainingdata_v3/error_logs/'+ checkpoint_name + '/'
  if not path.exists(log_dir_path):
    os.mkdir(log_dir_path)
  
  with open(log_dir_path + filename_txt[:-4] + '_error_log' + '.txt' , 'w') as file:
    file.write(error_log)

  inference_dir_path = '../data/trainingdata_v3/'+'inference/'
  # print(ann_text)
  if not path.exists(inference_dir_path):
    os.mkdir(inference_dir_path)
  with open(inference_dir_path + filename_txt[:-4] + '.ann' , 'w') as ann_file:
    ann_file.write(ann_text)

  
  

100%|██████████| 50/50 [20:19<00:00, 24.38s/it]


# inference con simple t5

In [5]:
from simplet5 import SimpleT5
from os import walk, path
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer

In [2]:
model = SimpleT5()
checkpoint_name = 'simplet5-epoch-6-train-loss-0.2724-val-loss-0.1477'
model.load_model("t5","../checkpoints/"+checkpoint_name, use_gpu=True)

In [6]:
tokenizer = AutoTokenizer.from_pretrained('t5-large')
# model = T5ForConditionalGeneration.from_pretrained('../checkpoints/simplet5-epoch-4-train-loss-0.5191-val-loss-0.342')
# model.to(device)


Downloading: 100%|██████████| 1.17k/1.17k [00:00<00:00, 373kB/s]
Downloading: 100%|██████████| 773k/773k [00:00<00:00, 1.09MB/s] 
Downloading: 100%|██████████| 1.32M/1.32M [00:00<00:00, 1.94MB/s]


In [7]:
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu"
)

In [12]:
from os import walk, path
from tqdm import tqdm
import pandas as pd

In [15]:
#inference entities


filenames = next(walk('../data/trainingdata_v3/dev'), (None, None, []))[2]
filenames_txt = [filename for filename in filenames if 'txt' in filename]
filenames_txt.sort()
# print(filenames_txt)
for filename_txt in tqdm(filenames_txt):
  error_log = ''
  df_ann = pd.read_csv('../data/trainingdata_v3/dev/' + filename_txt[:-4] + '.ann', sep ='\t', names=['entity-event-context', 'classification-type', 'value'])
  df_entities = pd.DataFrame()
  if len(df_ann.apply(lambda row: row['value'].lower() if row['entity-event-context'][0]== 'T' else '', axis=1)) > 0:
    df_entities['entities'] = df_ann.apply(lambda row: row['value'].lower() if row['entity-event-context'][0]== 'T' else '', axis=1)
    df_entities = df_entities[df_entities['entities'] != '']
    entities_trg = df_entities.drop_duplicates(['entities']).loc[:, 'entities'].tolist()
  else:
    entities_trg = []
  with open('../data/trainingdata_v3/dev/'+filename_txt) as text_file:
    original_text = ''.join(text_file.readlines())
  # print(original_text)
  # replace space new lines and tabs with spaces the drop duplicate spaces
  text = original_text.replace('\t', ' ').replace('\n', ' ')
  text_filtered = ''
  for char_index, char in enumerate(text):
      if char_index < (len(text)-1):
          if not(text[char_index + 1] == ' ' and char == ' '):
              text_filtered += char
  text_filtered = text_filtered.lower() # uncase the text
  # print(text_filtered)
  attribute = 'medications: '
  # input_text = bos + text_filtered + sep + attribute
  input_text_ids = tokenizer.encode(
      text_filtered,
      return_tensors='pt',
      add_special_tokens=True
  )

  # If a text is longer than 900 tokens I create slices to divide it
  text_slices = []
  start_slice = 0
  max_lenght = 250
  for end_slice in range(max_lenght, input_text_ids.shape[1], max_lenght):
    text_slices.append(tokenizer.decode(input_text_ids[0, start_slice:end_slice], skip_special_tokens=True))
    start_slice = end_slice
  text_slices.append(tokenizer.decode(input_text_ids[0,start_slice:], skip_special_tokens=True))

  # print(text_slices)
  # break
  # gpt2 small can only manage 1000 tokens, I divide the text in more parts and
  # give them as different input and the aggregate the results
  entity_list = []
  for text_slice in text_slices:

    # input_ids = torch.cat(
    #     (
    #         tokenizer(attribute + tokenizer.decode(input_text_ids[0, text_slice[0]:text_slice[1]]), return_tensors='pt',add_special_tokens=True,)['input_ids'],
            
    #     ),
    #     dim=-1
    # )
    # print('inputs ids length:', input_ids.shape[1])

    results = model.predict(
      attribute + text_slice,
      num_beams=2,
      top_k=0,
      top_p=1,
      do_sample=False,
      repetition_penalty=0.5,
      length_penalty=0.4,
      early_stopping=True,
    )[0]

    # print(results)
    # inp_ids = tokenizer.encode(
    #   attribute + text_filtered,
    #   return_tensors="pt",
    #   add_special_tokens=True,
      
    # )
    # inp_ids = inp_ids.to(device)

    # results_ids = model.generate(
    #   inp_ids,
    #   max_length=250,
      # num_beams=2,
      # top_k=50,
      # top_p=0.95,
      # do_sample=True,
      # repetition_penalty=2.5,
      # length_penalty=1.0,
      # early_stopping=True,
    # )

    # results = tokenizer.decode(
    #   results_ids[0],
    #   skip_special_tokens=True,
    #   clean_up_tokenization_spaces=True,
    # )


    # results = results.split('<SEP>')[1]
    # here I postprocessing the results droping duplicates and filtering empty results
    # print(results)
    # print('dopo parsing:', results)
    results = results.split(',')
    # print(results)
    results = list(set(results))
    results = [result.strip() for result in results if result.strip() != '']
    print('results without duplicates:', results)
    entity_list += results

  ann_text = ''
  results_with_pos =[]
  index = 0

  for entity in entity_list:   
    for punctuation in [',', '.', ' ', '\n', '\t', ':', ';', '(', '[', '{']:
      search_start = 0
      # if entity == 'ativan':
      #   print(f"entity : {entity} with pucntuation: {punctuation}  was found with start point {original_text.lower().find(punctuation + entity, search_start)}")
      while original_text.lower().find(punctuation + entity, search_start) != -1:
        start = original_text.lower().find(punctuation + entity, search_start) + 1
        end = start + len(entity)
        search_start = end
        pos = [start, end]
        results_with_pos.append([entity, pos])
        ann_text += 'T'+ str(index) + '\t' + 'Drug' + ' ' + str(start) + ' ' + str(end) + '\t' + entity + '\n'
        index += 1

  # print(results_with_pos)
  # print(list(zip(*results_with_pos)))
  error_log += 'target entities: ' + ','.join(entities_trg) + '\n'
  error_log += 'predicted entities: ' + ','.join(list(set(list(zip(*results_with_pos))[0]))) + '\n' if len(results_with_pos)>0 else 'predicted entities: \n'
  error_log += '\n'
  if len(results_with_pos) > 0:
    for result in list(set(list(zip(*results_with_pos))[0])):
      if result not in entities_trg:
        error_log += filename_txt[:-4] + '\t' +\
          'wrong prediction: ' + result + '\n'
  
  for entity_trg in entities_trg:
    if len(results_with_pos) > 0: 
      if entity_trg not in list(zip(*results_with_pos))[0]:
        error_log += filename_txt[:-4] + '\t' +\
          'missing prediction: ' + entity_trg +'\n'
  error_log += '********************original_text*******************' + '\n'
  error_log += original_text + '\n'

  # for text_slice_index, text_slice in enumerate(text_slices):
    # print(tokenizer.decode(input_text_ids[:, text_slice[0]:text_slice[1]][0]))
    # error_log += '*********** text_slice_' + str(text_slice_index) + ' ***********\n' 
    # error_log += tokenizer.decode(input_text_ids[:, text_slice[0]:text_slice[1]][0]) + '\n'

  if not path.exists('../data/trainingdata_v3/error_logs/'):
    os.mkdir('../data/trainingdata_v3/error_logs/')

  log_dir_path = '../data/trainingdata_v3/error_logs/'+ checkpoint_name + '/'
  if not path.exists(log_dir_path):
    os.mkdir(log_dir_path)
  
  with open(log_dir_path + filename_txt[:-4] + '_error_log' + '.txt' , 'w') as file:
    file.write(error_log)

  inference_dir_path = '../data/trainingdata_v3/'+'inference/'
  # print(ann_text)
  if not path.exists(inference_dir_path):
    os.mkdir(inference_dir_path)
  with open(inference_dir_path + filename_txt[:-4] + '.ann' , 'w') as ann_file:
    ann_file.write(ann_text)

  
  

  0%|          | 0/50 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 3.82 GiB total capacity; 2.89 GiB already allocated; 16.25 MiB free; 2.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF