In [254]:
import sys
sys.path.append('..')
from datasets import load_dataset
import numpy as np
from itertools import groupby
from collections import defaultdict
from connect import OpenAI

In [253]:
!pip install openai

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting openai
  Downloading openai-0.26.4.tar.gz (55 kB)
[K     |████████████████████████████████| 55 kB 833 kB/s eta 0:00:01
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone
Building wheels for collected packages: openai
  Building wheel for openai (PEP 517) ... [?25ldone
[?25h  Created wheel for openai: filename=openai-0.26.4-py3-none-any.whl size=67722 sha256=274370cd366b6f9063c1b0d2bcb6ec7a65718399cd3227130635dc67b33b4bc0
  Stored in directory: /tmp/pip-ephem-wheel-cache-z__q8tjb/wheels/2b/d8/4e/268f029bd3277c1dd9e8781a0e0296e0a63822665bfa2429fc
Successfully built openai
Installing collected packages: openai
Successfully installed openai-0.26.4


In [5]:
conlpp_dataset = load_dataset("conllpp")

Downloading builder script:   0%|          | 0.00/8.73k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/3.35k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.70k [00:00<?, ?B/s]

Downloading and preparing dataset conllpp/conllpp to /root/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/650k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/163k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/141k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]

Dataset conllpp downloaded and prepared to /root/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
train_ds, validation_ds, test_ds = conlpp_dataset.values()

In [295]:
def preprocess_example(example):
    """
    This function recieves as input a dictionary containing the text and the ground truth annotation
    of named entities and return the pair of inputs in a format suitable to use in a prompt.
    """
    ner_tags_dict = {0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-ORG',  4: 'I-ORG', 5: 'B-LOC', 6: 'I-LOC', 7: 'B-MISC', 8:'I-MISC'}
    text = example['tokens']
    
    ner_tags = example['ner_tags']
    ner_tags = [ner_tags_dict[el].split('-')[-1] for el in ner_tags]
    gb_ner = groupby(ner_tags)
    groups = []
    
    for k, v in gb_ner:
        groups.append(list(v))
    
    ner_output_dict = defaultdict(list)
    
    offset = 0
    for group in groups:
        corresponding_text = text[offset:offset+len(group)]
        offset+=len(group)
        key = group[0]
        if key != 'O':
            ner_output_dict[key].append(' '.join(corresponding_text))
    text = ' '.join(text)
    return text, ner_output_dict

In [308]:
def compose_prompt(target , n_shot_examples):
    """
    This function recieves a target element and n examples that serve as context for the assignment
    and outputs a prompt. The examples are a zipped object with the first element
    being the text and the second element being a dict with ner tags and the corresponding parts of the text
    """
    explanation = f'Retrieve the people, organizations and locations mentioned in the text below\n'
    
    examples = ''
    
    for text, ner_tags in n_shot_examples:
        examples+= f'text:{text}\n'
        examples+= f'Named Entities: '
        for k1, k2 in zip(["PER", "ORG", "LOC"], ["People", "Organizations", "Locations"]):
            v = ner_tags.get(k1)
            if v:
                v = ', '.join(v)
                examples+=f'{k2} - {v}; '
        examples+='\n'
                
    
    completion_target = f'text:{" ".join(target["tokens"])}\nNamed Entities:'
    return explanation+examples+completion_target

def choose_examples(dataset, n_examples= 4, min_tokens = 10, min_ner=1):
    
    filtered_data = list(filter(lambda x: len(x['tokens']) >= min_tokens and len(np.nonzero(x['ner_tags'])[0]) > min_ner, dataset))
    examples = []
    selected_idx = np.random.choice(np.arange(len(filtered_data)), size=n_examples, replace=False)
    for i in selected_idx:
        examples.append(filtered_data[i])
    return examples

def generate_prompt(data, n_shot, min_tokens, min_ner):
    selected_examples = choose_examples(data, n_examples=n_shot+1, min_tokens=min_tokens, min_ner=min_ner)
    examples = selected_examples[:-1]
    target = selected_examples[-1]

    texts, ner_outputs = [], []
    for ex in examples:
        text, ner_output = preprocess_example(ex)
        texts.append(text)
        ner_outputs.append(ner_output)
    n_shot_examples = zip(texts, ner_outputs)

    prompt = compose_prompt(target, n_shot_examples)
    return prompt

In [309]:
prompt = generate_prompt(train_ds, 3, 30, 5)

In [311]:
print(prompt)

Retrieve the people, organizations and locations mentioned in the text below
text:The Government Housing Bank will issue bonds worth three billion baht and the metropolitan Waterworks Authority will issue bonds worth 730 million , an investment banker at Siam Commercial Bank told Reuters .
Named Entities: Organizations - Government Housing Bank, Waterworks Authority, Siam Commercial Bank, Reuters; 
text:Tight bowling from Glamorgan off-spinner Robert Croft helped England to restrict Pakistan to 225 for five in their 50 overs in the first one-day international at Old Trafford on Thursday .
Named Entities: People - Robert Croft; Organizations - Glamorgan; Locations - England, Pakistan, Old Trafford; 
text:He said Stallone , best known for the " Rocky " and " Rambo " movies , left the set of " Copland , " which is filming in New York and New Jersey , to be with Flavin for the birth .
Named Entities: People - Stallone, Flavin; Locations - New York, New Jersey; 
text:India blocked the Compr

In [320]:
oi = OpenAI(api_key='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', model='ada')

In [270]:
train_ds[50]

{'id': '50',
 'tokens': ['Opel',
  'AG',
  'together',
  'with',
  'General',
  'Motors',
  'came',
  'in',
  'second',
  'place',
  'with',
  '49,269',
  'registrations',
  ',',
  '16.4',
  'percent',
  'of',
  'the',
  'overall',
  'figure',
  '.'],
 'pos_tags': [22,
  22,
  30,
  15,
  22,
  23,
  38,
  15,
  16,
  21,
  15,
  11,
  24,
  6,
  11,
  21,
  15,
  12,
  16,
  21,
  7],
 'chunk_tags': [11,
  12,
  3,
  13,
  11,
  12,
  21,
  13,
  11,
  12,
  13,
  11,
  12,
  0,
  11,
  12,
  13,
  11,
  12,
  12,
  0],
 'ner_tags': [3, 4, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

In [313]:
response = oi.ask(prompt, max_tokens=20)

In [271]:
print(prompt)

Retrieve the people, organizations and locations mentioned in the text below
text:News Corp said British newspaper operating profits rose 10 percent for the year , as higher cover prices at The Sun and The Times and higher advertising volumes offset increased newsprint costs .
Named Entities: Organizations - News Corp The Sun The Times; 
text:Gente said Ducruet , a keen racing driver , met Houteman during a race in Belgium and photographers had been on their trail ever since .
Named Entities: People - Ducruet Houteman; Organizations - Gente; Locations - Belgium; 
text:Defending champions Ajax Amsterdam were defeated 2-0 loss away to Heerenveen on Saturday .
Named Entities: Organizations - Ajax Amsterdam Heerenveen; 
text:" Roy agreed a new deal before last night 's game against Everton and we are delighted , " said United manager Alex Ferguson on Thursday .
Named Entities: People - Roy Alex Ferguson; Organizations - Everton United; 
text:Czech Republic 's Havel to tour Brazil in Septem

In [316]:
 response['choices'][0]['text']

' Organizations - Comprehensive Test Ban Treaty, Conference on Disarmament; Locations - India, Geneva;'

In [337]:
parse_response(response)

{'Organizations': ['Comprehensive Test Ban Treaty',
  'Conference on Disarmament'],
 'Locations': ['India', 'Geneva']}

In [321]:
[s.split('-')[-1].split(',') for s in response['choices'][0]['text'].split(';')]

[[' Comprehensive Test Ban Treaty', ' Conference on Disarmament'],
 [' India', ' Geneva'],
 ['']]