In [117]:
%load_ext autoreload
%autoreload 2

import pickle
import pandas as pd
import os
import openai
import numpy as np
import ipdb
import re
from tqdm import tqdm
import matplotlib.pyplot as plt

from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
import spacy
import scipy
import sklearn
from sklearn import feature_extraction

openai.api_key= os.environ['OPENAI_KEY']

from data_utils import *
from gpt3_utils import *
from eval_utils import *

from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer

import torch

import copy

pd.set_option('max_rows',500,'max_colwidth',1000)
pd.options.display.float_format = "{:,.2f}".format

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
os.environ['OPENAI_KEY'] = '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'

## Loading Full Training and Dev Datasets with Entities, Filled out Prompts and Empty Prompts 

In [53]:
train_filename = '../data/bc5cdr_disease.train.processed.tsv'
dev_filename = '../data/bc5cdr_disease.dev.processed.tsv'

In [54]:
train_df = pd.read_csv(train_filename,sep='\t')
dev_df = pd.read_csv(dev_filename,sep='\t')

### Retrieving Random Subsets for Training and Evaluation 

In [55]:
training_subset_size = 100
dev_subset_size = 200

In [56]:
train_df = train_df.sample(training_subset_size,random_state=np.random.RandomState(42))
dev_df = dev_df.sample(dev_subset_size,random_state=np.random.RandomState(42))

### Testing GPT-3 on random, random stratified, KNN and KNN stratified prompts without Entity Filtering 

In [57]:
logit_bias = 10
sep_logit_bias = 10
new_line_logit_bias = 10

In [58]:
engine='ada'
few_shot_size = 5

In [93]:
def get_embedding(sent, model, tokenizer, mode='cls'):
    if mode == 'sbert':
        embedding = model.encode(sent)
    else:
        input_dict = tokenizer(sent, return_tensors='pt').to('cuda')
        embedding = model(**input_dict)['last_hidden_state'].cpu().numpy()[0]

        if mode == 'cls':
            embedding = embedding[0]
        elif mode == 'avg':
            embedding = np.mean(embedding,axis=0)

    return embedding

def get_embeddings(sents, model, tokenizer, mode='cls'):

    embeddings = []

    with torch.no_grad():
        for sent in tqdm.tqdm(sents):

            embedding = get_embedding(sent, model, tokenizer, mode=mode)
            embeddings.append(embedding)
    
    embeddings = np.array(embeddings)
    norm_embeddings = embeddings.T/np.linalg.norm(embeddings,axis=1)
    
    return norm_embeddings.T

In [90]:
def run_ner_gpt_from_training_set(train, 
                              dev, 
                              engine, 
                              prompt_size,
                              logit_bias, 
                              sep_logit_bias, 
                              new_line_logit_bias, 
                              max_tokens, 
                              sampling_strategy='random', 
                              stratified=False):
    
    dev = get_prompts_from_df(train, dev, prompt_size, sampling_strategy, stratified)
    pred_df = run_gpt3_on_df(engine, dev, dev.test_ready_prompts, max_tokens, ',', logit_bias, sep_logit_bias, new_line_logit_bias)
    
    return pred_df

def get_prompts_from_df(train, dev, prompt_size, sampling_strategy, stratified):
        
    if sampling_strategy == 'random':
        dev = get_random_prompts(train, dev, prompt_size, stratified)
    elif sampling_strategy == 'bm25':
        dev = get_bm25_knn_prompts(train, dev, prompt_size, sampling_strategy, stratified)
    else:
        dev = get_bert_knn_prompts(train, dev, prompt_size, sampling_strategy, stratified)
        
    return dev

def get_random_prompts(train, dev, prompt_size, stratified):
    if not stratified:
        prompt_samples = train_df.sample(prompt_size, random_state=np.random.RandomState(42))
        prompt_samples = prompt_samples.prompts.values
    else:
        chosen_inds = []
        sorted_available_num_entites = np.sort(train.num_entities.unique())
        
        for i in range(prompt_size):
            chosen_inds.append(np.random.RandomState(42).permutation(train_df[train_df.num_entities == sorted_available_num_entites[i%prompt_size]].index)[0])
            
        prompt_samples = train_df.loc[chosen_inds, 'prompts'].values

    prompt = '\n\n'.join(prompt_samples)
    
    dev['prompt_samples'] = '\n\n'.join(prompt_samples)
    dev['test_ready_prompts'] = [prompt +'\n\n'+empty_prompt for empty_prompt in dev.empty_prompts]
    
    return dev

def get_bm25_knn_prompts(train, dev, prompt_size, sampling_strategy, stratified):
    
    bm25 = BM25()
    bm25.fit(train.sents)
    
    knn_prompt_samples = []
    knn_prompts = []

    for i,row in dev.iterrows():
        test_sent = row['sents']

        sims = bm25.transform(test_sent, train.sents)
        sorted_sims = np.argsort(sims)[::-1]
        
        assert sims[sorted_sims[0]] >= sims[sorted_sims[-1]], print(sims[sorted_sims[0]], sims[sorted_sims[-1]])
        selected_ids = sorted_sims[:prompt_size]

        selected_prompts = train.prompts.values[selected_ids]

        empty_prompt = row['empty_prompts']

        knn_prompt_samples.append((selected_prompts, sims[selected_ids]))
        knn_prompts.append('\n\n'.join(selected_prompts)+'\n\n'+empty_prompt)
        
    dev['prompt_samples'] = knn_prompt_samples
    dev['test_ready_prompts'] = knn_prompts
    
    return dev

def get_bert_knn_prompts(train, dev, prompt_size, sampling_strategy, stratified):
    
    bert_parameters = {
        'scibert':('allenai/scibert_scivocab_uncased', 'cls'),
        'sbert':('sentence-transformers/paraphrase-mpnet-base-v2', 'sbert'),
        'roberta':('roberta-large','avg'),
        'pubmed_bert':('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract','cls')
    }
    
    bert_model = bert_parameters[sampling_strategy][0]
    mode = bert_parameters[sampling_strategy][1]

    if sampling_strategy == 'sbert':
        model = SentenceTransformer(bert_model)
        tokenizer = None
    else:
        model = AutoModel.from_pretrained(bert_model).to('cuda')
        tokenizer = AutoTokenizer.from_pretrained(bert_model)

    train_embeddings = get_embeddings(train.sents.values, model, tokenizer, mode)

    knn_prompt_samples = []
    knn_prompts = []

    with torch.no_grad():

        for i,row in dev.iterrows():
            test_sent = row['sents']

            sent_emb = get_embedding(test_sent, model, tokenizer, mode=mode)
            sent_emb = sent_emb/np.linalg.norm(sent_emb)

            sims = train_embeddings.dot(sent_emb)

            sorted_sims = np.argsort(sims)[::-1]

            assert sims[sorted_sims[0]] > sims[sorted_sims[-1]]
            selected_ids = sorted_sims[:prompt_size]

            selected_prompts = train.prompts.values[selected_ids]

            empty_prompt = row['empty_prompts']

            knn_prompt_samples.append((selected_prompts, sims[selected_ids]))
            knn_prompts.append('\n\n'.join(selected_prompts)+'\n\n'+empty_prompt)
        
        dev['prompt_samples'] = knn_prompt_samples
        dev['test_ready_prompts'] = knn_prompts
        
    return dev

In [106]:
engine

'ada'

In [101]:
experiments = [('random',False),
               ('random',True),
               ('bm25',False),
               ('sbert',False),
               ('roberta',False),
               ('scibert',False)
              ]

outputs = []

for strategy, stratified in experiments:
    
    outputs.append(run_ner_gpt_from_training_set(train_df, 
                              dev_df.copy(), 
                              engine, 
                              few_shot_size,
                              logit_bias, 
                              sep_logit_bias, 
                              new_line_logit_bias,
                              30, 
                              sampling_strategy=strategy, 
                              stratified=stratified))

200it [00:28,  6.95it/s]
200it [00:34,  5.76it/s]
200it [00:27,  7.18it/s]
100%|██████████| 100/100 [00:00<00:00, 112.63it/s]
200it [00:26,  7.45it/s]
Some weights of the model checkpoint at roberta-large were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 100/100 [00:01<00:00, 85.86it/s]
200it [00:26,  7.44it/s]
Some weights of the model checkpoint at allenai/scibert_scivocab_uncas

In [121]:
for output, exp in zip(outputs, experiments):
    print(exp)
    df = output.copy()

    df = create_bio_preds(df, 'predictions')
    f1, precision, recall = conlleval_eval(df.ner_seq,df.bio_preds)

    b_true = [s.replace('I','B') for s in df.ner_seq]
    b_pred = [s.replace('I','B') for s in df.bio_preds]

    f1, precision, recall = conlleval_eval(b_true,b_pred)

('random', False)
processed 4878 tokens with 183 phrases; found: 283 phrases; correct: 113.
accuracy:  92.50%; (non-O)
accuracy:  92.50%; precision:  39.93%; recall:  61.75%; FB1:  48.50%
                X: precision:  39.93%; recall:  61.75%; FB1:  48.50%  283


Unnamed: 0,F1,Precision,Recall
0,48.5,39.93,61.75


processed 4878 tokens with 279 phrases; found: 428 phrases; correct: 179.
accuracy:  92.85%; (non-O)
accuracy:  92.85%; precision:  41.82%; recall:  64.16%; FB1:  50.64%
                X: precision:  41.82%; recall:  64.16%; FB1:  50.64%  428


Unnamed: 0,F1,Precision,Recall
0,50.64,41.82,64.16


('random', True)
processed 4878 tokens with 183 phrases; found: 357 phrases; correct: 102.
accuracy:  88.89%; (non-O)
accuracy:  88.89%; precision:  28.57%; recall:  55.74%; FB1:  37.78%
                X: precision:  28.57%; recall:  55.74%; FB1:  37.78%  357


Unnamed: 0,F1,Precision,Recall
0,37.78,28.57,55.74


processed 4878 tokens with 279 phrases; found: 622 phrases; correct: 189.
accuracy:  89.28%; (non-O)
accuracy:  89.28%; precision:  30.39%; recall:  67.74%; FB1:  41.95%
                X: precision:  30.39%; recall:  67.74%; FB1:  41.95%  622


Unnamed: 0,F1,Precision,Recall
0,41.95,30.39,67.74


('bm25', False)
processed 4878 tokens with 183 phrases; found: 206 phrases; correct: 89.
accuracy:  93.42%; (non-O)
accuracy:  93.42%; precision:  43.20%; recall:  48.63%; FB1:  45.76%
                X: precision:  43.20%; recall:  48.63%; FB1:  45.76%  206


Unnamed: 0,F1,Precision,Recall
0,45.76,43.2,48.63


processed 4878 tokens with 279 phrases; found: 331 phrases; correct: 151.
accuracy:  93.69%; (non-O)
accuracy:  93.69%; precision:  45.62%; recall:  54.12%; FB1:  49.51%
                X: precision:  45.62%; recall:  54.12%; FB1:  49.51%  331


Unnamed: 0,F1,Precision,Recall
0,49.51,45.62,54.12


('sbert', False)
processed 4878 tokens with 183 phrases; found: 196 phrases; correct: 96.
accuracy:  94.05%; (non-O)
accuracy:  94.05%; precision:  48.98%; recall:  52.46%; FB1:  50.66%
                X: precision:  48.98%; recall:  52.46%; FB1:  50.66%  196


Unnamed: 0,F1,Precision,Recall
0,50.66,48.98,52.46


processed 4878 tokens with 279 phrases; found: 304 phrases; correct: 155.
accuracy:  94.40%; (non-O)
accuracy:  94.40%; precision:  50.99%; recall:  55.56%; FB1:  53.17%
                X: precision:  50.99%; recall:  55.56%; FB1:  53.17%  304


Unnamed: 0,F1,Precision,Recall
0,53.17,50.99,55.56


('roberta', False)
processed 4878 tokens with 183 phrases; found: 195 phrases; correct: 92.
accuracy:  93.99%; (non-O)
accuracy:  93.99%; precision:  47.18%; recall:  50.27%; FB1:  48.68%
                X: precision:  47.18%; recall:  50.27%; FB1:  48.68%  195


Unnamed: 0,F1,Precision,Recall
0,48.68,47.18,50.27


processed 4878 tokens with 279 phrases; found: 294 phrases; correct: 148.
accuracy:  94.32%; (non-O)
accuracy:  94.32%; precision:  50.34%; recall:  53.05%; FB1:  51.66%
                X: precision:  50.34%; recall:  53.05%; FB1:  51.66%  294


Unnamed: 0,F1,Precision,Recall
0,51.66,50.34,53.05


('scibert', False)
processed 4878 tokens with 183 phrases; found: 175 phrases; correct: 85.
accuracy:  94.14%; (non-O)
accuracy:  94.14%; precision:  48.57%; recall:  46.45%; FB1:  47.49%
                X: precision:  48.57%; recall:  46.45%; FB1:  47.49%  175


Unnamed: 0,F1,Precision,Recall
0,47.49,48.57,46.45


processed 4878 tokens with 279 phrases; found: 272 phrases; correct: 139.
accuracy:  94.40%; (non-O)
accuracy:  94.40%; precision:  51.10%; recall:  49.82%; FB1:  50.45%
                X: precision:  51.10%; recall:  49.82%; FB1:  50.45%  272


Unnamed: 0,F1,Precision,Recall
0,50.45,51.1,49.82
