In [None]:
!pip install datasets
!pip install transformers
!pip install sentencepiece
!pip install openai

In [1]:
import re
import torch

from datasets import load_dataset
from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, AdamW, get_linear_schedule_with_warmup

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
dataset = load_dataset('rungalileo/20_Newsgroups_Fixed')
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'text', 'label'],
        num_rows: 11314
    })
    test: Dataset({
        features: ['id', 'text', 'label'],
        num_rows: 7532
    })
})

In [4]:
train_ds = dataset['train']
test_ds = dataset['test']

Let's look at a few examples from the Newsgroup 20 dataset

In [5]:
train_ds[:3]

{'id': [0, 1, 2],
 'text': ['I was wondering if anyone out there could enlighten me on this car I saw\nthe other day. It was a 2-door sports car, looked to be from the late 60s/\nearly 70s. It was called a Bricklin. The doors were really small. In addition,\nthe front bumper was separate from the rest of the body. This is \nall I know. If anyone can tellme a model name, engine specs, years\nof production, where this car is made, history, or whatever info you\nhave on this funky looking car, please e-mail.',
  "A fair number of brave souls who upgraded their SI clock oscillator have\nshared their experiences for this poll. Please send a brief message detailing\nyour experiences with the procedure. Top speed attained, CPU rated speed,\nadd on cards and adapters, heat sinks, hour of usage per day, floppy disk\nfunctionality with 800 and 1.4 m floppies are especially requested.\n\nI will be summarizing in the next two days, so please add to the network\nknowledge base if you have done the 

Let's clean the text of the dataset before we use it We lowercase all text and remove special characters.

In [6]:
def clean_dataset(example):
  if example['text']:
    example['text'] =' '.join(example['text'].splitlines())
    example['text'] = re.sub(r'[^\w\s]', '', example['text'])
    example['text'] = example['text'].strip()
  return example

train_ds = train_ds.map(clean_dataset)

Map:   0%|          | 0/11314 [00:00<?, ? examples/s]

In [7]:
train_ds['text'][:3]

['I was wondering if anyone out there could enlighten me on this car I saw the other day It was a 2door sports car looked to be from the late 60s early 70s It was called a Bricklin The doors were really small In addition the front bumper was separate from the rest of the body This is  all I know If anyone can tellme a model name engine specs years of production where this car is made history or whatever info you have on this funky looking car please email',
 'A fair number of brave souls who upgraded their SI clock oscillator have shared their experiences for this poll Please send a brief message detailing your experiences with the procedure Top speed attained CPU rated speed add on cards and adapters heat sinks hour of usage per day floppy disk functionality with 800 and 14 m floppies are especially requested  I will be summarizing in the next two days so please add to the network knowledge base if you have done the clock upgrade and havent answered this poll Thanks',
 'well folks my 

Let's try to take an off the shelf finetuned model from HF hub and try tagging our documents.

In [8]:
# Model parameters
from transformers import (
    Text2TextGenerationPipeline,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)


class KeyphraseGenerationPipeline(Text2TextGenerationPipeline):
    def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
        super().__init__(
            model=AutoModelForSeq2SeqLM.from_pretrained(model),
            tokenizer=AutoTokenizer.from_pretrained(model),
            *args,
            **kwargs
        )
        self.keyphrase_sep_token = keyphrase_sep_token

    def postprocess(self, model_outputs):
        results = super().postprocess(
            model_outputs=model_outputs
        )
        return [[keyphrase.strip() for keyphrase in result.get("generated_text").split(self.keyphrase_sep_token) if keyphrase != ""] for result in results]


In [9]:
model_name = "ml6team/keyphrase-generation-t5-small-inspec"
generator = KeyphraseGenerationPipeline(model=model_name)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

In [10]:
for i in range(5):
  text = train_ds['text'][i]
  keyphrases = generator(text)
  print("text: ", text)
  print("keyphrases: ", keyphrases)

text:  I was wondering if anyone out there could enlighten me on this car I saw the other day It was a 2door sports car looked to be from the late 60s early 70s It was called a Bricklin The doors were really small In addition the front bumper was separate from the rest of the body This is  all I know If anyone can tellme a model name engine specs years of production where this car is made history or whatever info you have on this funky looking car please email
keyphrases:  [['car', '2door sports car', 'bricklin', 'engine specs years of production']]
text:  A fair number of brave souls who upgraded their SI clock oscillator have shared their experiences for this poll Please send a brief message detailing your experiences with the procedure Top speed attained CPU rated speed add on cards and adapters heat sinks hour of usage per day floppy disk functionality with 800 and 14 m floppies are especially requested  I will be summarizing in the next two days so please add to the network knowle

Improve performance by fine-tuning a model

In [4]:
import openai, os
openai.api_key = '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'
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)  # for exponential backoff
from joblib import Parallel, delayed
from tqdm import tqdm
import ast, random
import pandas as pd

In [25]:
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
    return openai.Completion.create(**kwargs)

In [None]:
prompt = '''Generate 5 keywords to describe the concept of this text, separate them with comma: {} =>'''
api_response = Parallel(n_jobs=15, backend='multiprocessing')(delayed(completion_with_backoff)(model="text-davinci-003", prompt=prompt.format(s)) \
                                                                    for s in train_ds['text'][:100])
api_response

In [36]:
keywords = [[x.choices[0].text.strip()] for x in api_response]
keywords[:5]

[['Bricklin, 2-door sports car, late 60s,'],
 ['Clock Upgrade, Top Speed, CPU Rated Speed, Heat Sinks, Floppy'],
 ['MacPlus, Powerbook, Display, Disk, Hellcats'],
 ['Weitek, addressphone, chip, information, number'],

In [38]:
text = train_ds[:100]['text']
data = {'text': text, 'api_response': api_response, 'keywords': keywords}
df = pd.DataFrame(data)
df.head()

Unnamed: 0,text,api_response,keywords
0,I was wondering if anyone out there could enli...,"{'id': 'cmpl-7zfPvZ6wFzJNUgryuP1OGrJXVBbAg', '...","[Bricklin, 2-door sports car, late 60s,]"
1,A fair number of brave souls who upgraded thei...,"{'id': 'cmpl-7zfPvzvNuTb7ukPTvBACmMEQYC9zY', '...","[Clock Upgrade, Top Speed, CPU Rated Speed, He..."
2,well folks my mac plus finally gave up the gho...,"{'id': 'cmpl-7zfPv2ZQqN27WPGZ7FuWaiDTLwvNN', '...","[MacPlus, Powerbook, Display, Disk, Hellcats]"
3,Do you have Weiteks addressphone number Id li...,"{'id': 'cmpl-7zfPvuBgpTtSY3uuQe2qdZ89fVEgd', '...","[Weitek, addressphone, chip, information, number]"
4,From article C5owCBn3pworldstdcom by tombakerw...,"{'id': 'cmpl-7zfPvBxDlGRRUSjdngh99KMLFiwK5', '...","[warning system, software bugs, known bugs, er..."


In [39]:
df.to_csv("generated-dataset", index=False)

In [None]:
# generate test data
prompt = '''Generate 5 keywords to describe the concept of this text, separate them with comma: {} =>'''
api_response = Parallel(n_jobs=15, backend='multiprocessing')(delayed(completion_with_backoff)(model="text-davinci-003", prompt=prompt.format(s)) \
                                                                    for s in train_ds['text'][200:300])
api_response

In [27]:
keywords = [[x.choices[0].text.strip()] for x in api_response]
keywords[:5]

[['guns, firearm, consultation, 1991, March'],
 ['GM, Diesel, Rings, 10W-40, Warranty'],
 ['Benjamin Netanyahu, CNN, Larry King Live, Charismatic, Artic'],
 ['0:1\n\nApology, Confusing, Moderator, Violence,'],
 ['Clipper Chip, Key Management, Session Key, Key Choice, Communication']]

Generate a test dataset

In [28]:
text = train_ds[200:300]['text']
data = {'text': text, 'api_response': api_response, 'keywords': keywords}
df_test = pd.DataFrame(data)
df_test.head()

Unnamed: 0,text,api_response,keywords
0,\nI first read and consulted rec.guns in the s...,"{'id': 'cmpl-80GVp4388hUmZS893cuMyJtzZd6Mg', '...","[guns, firearm, consultation, 1991, March]"
1,\n\nSeveral years ago GM was having trouble wi...,"{'id': 'cmpl-80GVpKpHm9uZCXGO3hIUmR06hZ6fF', '...","[GM, Diesel, Rings, 10W-40, Warranty]"
2,\n Great interview with Benjamin Netanyahu o...,"{'id': 'cmpl-80GVp6KV3BqPO0Kl9AnS4ut8srwkZ', '...","[Benjamin Netanyahu, CNN, Larry King Live, Cha..."
3,I apologize if this article is slightly confus...,"{'id': 'cmpl-80GVqhWDjVkTHsE7o97yVmdj2wsee', '...","[0:1\n\nApology, Confusing, Moderator, Violence,]"
4,\nOh? Hellman said ``each user will get to cho...,"{'id': 'cmpl-80GVpocuXyVCQtCDQdAxkFKqZYdP6', '...","[Clipper Chip, Key Management, Session Key, Ke..."


In [29]:
df_test.to_csv("generated-dataset-test.csv", index=False)

Read Data

In [12]:
df = pd.read_csv("generated-dataset")
df = df.dropna()
df.head()

Unnamed: 0,text,api_response,keywords
0,I was wondering if anyone out there could enli...,"{\n ""id"": ""cmpl-7zfPvZ6wFzJNUgryuP1OGrJXVBbAg...","['Bricklin, 2-door sports car, late 60s,']"
1,A fair number of brave souls who upgraded thei...,"{\n ""id"": ""cmpl-7zfPvzvNuTb7ukPTvBACmMEQYC9zY...","['Clock Upgrade, Top Speed, CPU Rated Speed, H..."
2,well folks my mac plus finally gave up the gho...,"{\n ""id"": ""cmpl-7zfPv2ZQqN27WPGZ7FuWaiDTLwvNN...","['MacPlus, Powerbook, Display, Disk, Hellcats']"
3,Do you have Weiteks addressphone number Id li...,"{\n ""id"": ""cmpl-7zfPvuBgpTtSY3uuQe2qdZ89fVEgd...","['Weitek, addressphone, chip, information, num..."
4,From article C5owCBn3pworldstdcom by tombakerw...,"{\n ""id"": ""cmpl-7zfPvBxDlGRRUSjdngh99KMLFiwK5...","['warning system, software bugs, known bugs, e..."


In [13]:
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [14]:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4, eps=1e-8)



In [15]:
model.train()

epochs = 10

for epoch in range(epochs):
  print ("epoch ",epoch)
  for idx, row in df.iterrows():
    input = str(row['text'])
    keywords = ast.literal_eval(row['keywords'])
    output = ' '.join(keywords)
    input_sent = "tag: "+ input + '</s>'
    ouput_sent = output + '</s>'
    # print(ouput_sent)

    tokenized_inp = tokenizer(input_sent,  max_length=140, pad_to_max_length=True,return_tensors="pt")
    tokenized_output = tokenizer(ouput_sent, max_length=100, pad_to_max_length=True,return_tensors="pt")

    input_ids  = tokenized_inp["input_ids"].to(device)
    attention_mask = tokenized_inp["attention_mask"].to(device)

    lm_labels= tokenized_output["input_ids"].to(device)
    decoder_attention_mask=  tokenized_output["attention_mask"].to(device)


    # the forward function automatically creates the correct decoder_input_ids
    output = model(input_ids=input_ids, labels=lm_labels,decoder_attention_mask=decoder_attention_mask,attention_mask=attention_mask)
    loss = output[0]

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
  print("loss: ", loss)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


epoch  0




epoch  1
epoch  2
epoch  3
epoch  4
epoch  5
epoch  6
epoch  7
epoch  8
epoch  9


In [27]:
model.save_pretrained("t5-base-keyphrase-gen")

In [28]:
tokenizer.save_pretrained("t5-base-tokenizer")

('t5-base-tokenizer/tokenizer_config.json',
 't5-base-tokenizer/special_tokens_map.json',
 't5-base-tokenizer/spiece.model',
 't5-base-tokenizer/added_tokens.json')

Evaluation

In [42]:
## Inference on a single example

In [16]:
test_sent = 'tag: A fair number of brave souls who upgraded their SI clock oscillator have shared their experiences for this poll Please send a brief message detailing your experiences with the procedure Top speed attained CPU rated speed add on cards and adapters heat sinks hour of usage per day floppy disk functionality with 800 and 14 m floppies are especially requested  I will be summarizing in the next two days so please add to the network knowledge base if you have done the clock upgrade and havent answered this poll Thanks </s>'
test_tokenized = tokenizer(test_sent, return_tensors="pt")

test_input_ids  = test_tokenized["input_ids"].to(device)
test_attention_mask = test_tokenized["attention_mask"].to(device)

model.eval()
beam_outputs = model.generate(
    input_ids=test_input_ids,attention_mask=test_attention_mask,
    max_length=10,
    early_stopping=True,
    num_beams=10,
    num_return_sequences=3,
    no_repeat_ngram_size=2
)

for beam_output in beam_outputs:
    sent = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
    print ("sent: ", sent)



sent:  Clock Upgrade, Top Speed, CPU Rated
sent:  Clock upgrade, Top Speed, CPU Rated
sent:  Clock upgrade, Top speed, CPU Rated


In [22]:
df_test = pd.read_csv("generated-dataset-test.csv")
df_test = df_test.dropna()
df_test.head()

Unnamed: 0,text,api_response,keywords
0,\nI first read and consulted rec.guns in the s...,"{\n ""id"": ""cmpl-80GVp4388hUmZS893cuMyJtzZd6Mg...","['guns, firearm, consultation, 1991, March']"
1,\n\nSeveral years ago GM was having trouble wi...,"{\n ""id"": ""cmpl-80GVpKpHm9uZCXGO3hIUmR06hZ6fF...","['GM, Diesel, Rings, 10W-40, Warranty']"
2,\n Great interview with Benjamin Netanyahu o...,"{\n ""id"": ""cmpl-80GVp6KV3BqPO0Kl9AnS4ut8srwkZ...","['Benjamin Netanyahu, CNN, Larry King Live, Ch..."
3,I apologize if this article is slightly confus...,"{\n ""id"": ""cmpl-80GVqhWDjVkTHsE7o97yVmdj2wsee...","['0:1\n\nApology, Confusing, Moderator, Violen..."
4,\nOh? Hellman said ``each user will get to cho...,"{\n ""id"": ""cmpl-80GVpocuXyVCQtCDQdAxkFKqZYdP6...","['Clipper Chip, Key Management, Session Key, K..."


In [25]:
model.eval()

t5_generated = []
for idx, row in df_test.iterrows():
  test_sent = "tag:" + str(row['text']) + "</s>"
  test_tokenized = tokenizer(test_sent, return_tensors="pt")

  test_input_ids  = test_tokenized["input_ids"].to(device)
  test_attention_mask = test_tokenized["attention_mask"].to(device)


  beam_outputs = model.generate(
    input_ids=test_input_ids,attention_mask=test_attention_mask,
    max_length=15,
    early_stopping=True,
    num_beams=5,
    num_return_sequences=1,
    no_repeat_ngram_size=2
)

  for beam_output in beam_outputs:
      sent = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
      print ("sent: ", sent)
      t5_generated.append(sent)

sent:  Rec.guns, 1991, read, consulted, purchased
sent:  GM, 5.7, diesel, 10W-40, warranty,
sent:  Benjamin Netanyahu, CNN, Larry King Live,
sent:  "Reality, Illusion, Purpose, God,
sent:  Key, Clipper Chip, Protocol, Key-Management,
sent:  Windows, Communication Software, Interrupts, Cache,
sent:  MX-6, 90, gasket, tail light, known problem
sent:  FLYERS, NHL, Hall of Fame, Goalie,
sent:  Energy recovery, 'perpetual motion', cost, time
sent:  beliefs, language, culture, truth, belief, argument.
sent:  Annotation, Documents, Xt, Oclock,
sent:  Gun Control, America, Crime, Murder,
sent:  Olivetti Quaderno, Sound Digitisation,
sent:  Video, IIci, Color, Max
sent:  Mono sodium glutamate,
sent:  Castrol 20W50, Car, Truck, Bus,
sent:  Apple, System 7, Software, Hardware, Microsoft
sent:  Essential, Non-essential, Supplement
sent:  American taxpayers put up at least 30% of the money, "private
sent:  comp.graphics, c.s.amiga,
sent:  blind, seeing eye dog, trust, faith, defiance, logic
sent:  

In [26]:
df_test['t5-keywords'] = t5_generated
df_test.to_csv("generated-dataset-test-t5.csv", index=False)

In [5]:
df_test = pd.read_csv("generated-dataset-test-t5.csv")
df_test.head()

Unnamed: 0,text,api_response,keywords,t5-keywords
0,\nI first read and consulted rec.guns in the s...,"{\n ""id"": ""cmpl-80GVp4388hUmZS893cuMyJtzZd6Mg...","['guns, firearm, consultation, 1991, March']","Rec.guns, 1991, read, consulted, purchased"
1,\n\nSeveral years ago GM was having trouble wi...,"{\n ""id"": ""cmpl-80GVpKpHm9uZCXGO3hIUmR06hZ6fF...","['GM, Diesel, Rings, 10W-40, Warranty']","GM, 5.7, diesel, 10W-40, warranty,"
2,\n Great interview with Benjamin Netanyahu o...,"{\n ""id"": ""cmpl-80GVp6KV3BqPO0Kl9AnS4ut8srwkZ...","['Benjamin Netanyahu, CNN, Larry King Live, Ch...","Benjamin Netanyahu, CNN, Larry King Live,"
3,I apologize if this article is slightly confus...,"{\n ""id"": ""cmpl-80GVqhWDjVkTHsE7o97yVmdj2wsee...","['0:1\n\nApology, Confusing, Moderator, Violen...","""Reality, Illusion, Purpose, God,"
4,\nOh? Hellman said ``each user will get to cho...,"{\n ""id"": ""cmpl-80GVpocuXyVCQtCDQdAxkFKqZYdP6...","['Clipper Chip, Key Management, Session Key, K...","Key, Clipper Chip, Protocol, Key-Management,"


In [76]:
# evaluate exact match scores
avg_match = 0

for idx, row in df_test.iterrows():
    labels = ast.literal_eval(row['keywords'])
    preds = row['t5-keywords'].split(',')
    labels = labels[0].split(',')

    clean_pred = []
    clean_label = []

    for p in preds:
      if len(p):
        p = ''.join(letter for letter in p if letter.isalnum())
        p = p.lower()
        clean_pred.append(p)

    for l in labels:
      if len(l):
        l = ''.join(letter for letter in l if letter.isalnum())
        l = l.lower()
        clean_label.append(l)

    matches = set(clean_pred).intersection(set(clean_label))
    avg_match+= (len(matches)/ len(labels))
print(avg_match/len(df_test))

0.24419642857142854


In [78]:
!pip install evaluate
!pip install bert_score

Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: bert_score
Successfully installed bert_score-0.3.13


In [80]:
from evaluate import load
bertscore = load("bertscore")

In [90]:
all_preds = []
all_labels = []

for idx, row in df_test.iterrows():
  labels = ast.literal_eval(row['keywords'])
  labels = ' '.join(labels)
  pred = str(row['t5-keywords'])

  all_preds.append(pred)
  all_labels.append(labels)

In [92]:
scores = bertscore.compute(predictions=all_preds, references=all_labels, lang="en")

In [96]:
print("Avg precision: ", sum(scores['precision'])/len(scores['precision']))
print("Avg recall: ", sum(scores['recall'])/len(scores['recall']))
print("Avg F1: ", sum(scores['f1'])/len(scores['f1']))

Avg precision:  0.8818408648173014
Avg recall:  0.8660932344694933
Avg F1:  0.8735948223620653
