# Dependencies

In [None]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive')
#foldername should be changed to the name of the directory containing the file.
foldername = '/Longform-Question-Generation/'
%cd /content/drive/My\ Drive/$foldername

In [None]:
# Log into huggingface to upload models.
!pip install huggingface_hub
from huggingface_hub import notebook_login
notebook_login()

In [None]:
#Install all the necessary dependencies

!pip install transformers[sentencepiece] datasets evaluate
!pip install wandb
!pip install rouge_score
!pip install bert_score
!pip install seaborn
!pip install -U sentence-transformers
!pip install git+http://github.com/LIAAD/yake

import yake

from sentence_transformers import SentenceTransformer

import wandb, transformers, torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from urllib.parse import urlparse

import sys
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

sys.path.append('/content/drive/My Drive/{}'.format(foldername))    

#creates the 'results' and 'data' directories, if they don't exist.
if not os.path.exists('results'):
    os.makedirs('results')

if not os.path.exists('data'):
    os.makedirs('data')

import datasets
from datasets import load_from_disk, load_dataset, Dataset, load_metric,DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorWithPadding,
    DataCollatorForSeq2Seq,
    BartForConditionalGeneration,
    BartTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)

from torch.utils.data import Dataset, DataLoader

import evaluate
from transformers.trainer_callback import (EarlyStoppingCallback, 
                                           TrainerCallback)

import evaluate
import re

rouge=evaluate.load('rouge')
bertscore=evaluate.load('bertscore')

%env WANDB_LOG_MODEL=true

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers[sentencepiece]
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m38.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m63.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollectin

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

env: WANDB_LOG_MODEL=true


# AskScience

## Downloading data

In [None]:
def download_raw_data(subreddit='asks',
                      overwrite = False):  

    raw_file_name = f'./data/{subreddit}_raw_data'

    if not os.path.exists(raw_file_name) and (subreddit not in ['asks',
                                                                   'askh',
                                                                   'eli5']):
        raise Exception("File for subreddit does not exist.\
        Use 'asks', 'askh','eli5' to choose a subreddit contained in the\
        eli5 dataset. Alternatively, use scraping method of https://github.com/facebookresearch/ELI5\
        and save data at location './data/{subreddit}_raw_data'.")

    if os.path.exists(raw_file_name) and not overwrite:
        dataset=load_from_disk(raw_file_name)

    else:
        if not os.path.exists('./data'):
            os.makedir('./data')

        dataset_eli5 = load_dataset('eli5')
        
        dataset = DatasetDict()
        dataset['train'] = dataset_eli5['train_'+subreddit]
        dataset['validation'] = dataset_eli5['validation_'+subreddit]
        dataset['test'] = dataset_eli5['test_'+subreddit]
        
        dataset = dataset.flatten()

        dataset.save_to_disk(raw_file_name)
        
        with wandb.init(project='Question_Generation', 
                 entity=None, 
                 job_type='logging_data',
                 name='logging_data') as run:    
            
           
            raw_data_art=wandb.Artifact(subreddit+'_raw_data','dataset')
            raw_data_art.add_dir(raw_file_name)
            run.log_artifact(raw_data_art)

    return dataset

In [None]:
def preprocess_func(example):
    example['answers.text']=example['answers.text'][0]
    example['answers.text']=re.sub('>.*?\n',' ',example['answers.text'])
    example['answers.text']=' '.join(example['answers.text'].lower().split())

    example['answers.score']=example['answers.score'][0]

    example['title']=' '.join(example['title'].lower().split())
    example['selftext']=' '.join(example['selftext'].lower().split())

    return example

def preprocess_data(dataset):
    dataset=dataset.map(preprocess_func)
    return dataset

def log_processed_data(subreddit='asks',
                       overwrite=False,
                       min_sent_length=20):
    
    raw_file_name = f'./data/{subreddit}_raw_data'
    processed_file_name=f'./data/{subreddit}_processed_data'
    
    if os.path.exists(processed_file_name) and not overwrite:
        return load_from_disk(processed_file_name)
    
    if os.path.exists(raw_file_name) and not overwrite:
        dataset=load_from_disk(raw_file_name)
    else:
        dataset=download_raw_data(subreddit=subreddit)

    drop_cols=['subreddit','document','answers.a_id','q_id']
    
    
    ds_reduced = dataset.remove_columns(drop_cols)

    ds_reduced=preprocess_data(ds_reduced)
    
    ds_reduced = ds_reduced.filter(lambda x:\
                                    len(x['answers.text'].split())>min_sent_length)
    
    ds_reduced = ds_reduced.filter(lambda x:\
                                    'ask anything wednesday' not in x['title'])
    
    ds_reduced.save_to_disk(processed_file_name)
    
    with wandb.init(project='Question_Generation', 
                entity = None, 
                job_type = 'logging_processed_data',
                name = 'processed_data') as run:
        
        proc_data_art=wandb.Artifact(subreddit+'_processed_data',type='dataset')
        proc_data_art.add_dir(processed_file_name)
        
        run.log_artifact(proc_data_art)

    return ds_reduced

## Checking Data Leakage

In [None]:
def par_to_vec(model,sent,chunk_length=128):
    
    chunks=[sent[0+i:chunk_length+i] 
            for i in range(0,len(sent),chunk_length)]
    
    embeddings=model.encode(chunks)
    return np.sum(embeddings,axis=0,keepdims=False)

def dataset_par_to_vec(model,example,chunk_length=128):
    example['sent_vec'] = par_to_vec(model,
                                   example['answers.text'],
                                   chunk_length=chunk_length)
    return example

In [None]:
def clean_and_embed_data(sent_model_checkpoint='paraphrase-MiniLM-L6-v2',
                         subreddit='asks',
                         overwrite=False,
                         cutoff=.9,
                         min_sent_length=20):
    
    sent_model=SentenceTransformer(sent_model_checkpoint)
    
    processed_file_name=f'./data/{subreddit}_processed_data'
    cleaned_file_name=f'./data/{subreddit}_cleaned_data'

    if os.path.exists(cleaned_file_name) and not overwrite:
        return load_from_disk(cleaned_file_name)
    
    if os.path.exists(processed_file_name) and not overwrite:
        ds_reduced = load_from_disk(processed_file_name)
    else:
        ds_reduced = log_processed_data(subreddit=subreddit,
                       overwrite=overwrite,
                       min_sent_length=min_sent_length)
    
    ds_reduced_emb=ds_reduced.map(lambda x: dataset_par_to_vec(sent_model,x))
    ds_reduced_emb.set_format('torch')
    
    train_vecs=ds_reduced_emb['train']['sent_vec']
    valid_vecs=ds_reduced_emb['validation']['sent_vec']
    test_vecs=ds_reduced_emb['test']['sent_vec']

    norm_train=torch.sqrt(torch.sum(train_vecs**2,axis=1,keepdims=True))
    norm_valid=torch.sqrt(torch.sum(valid_vecs**2,axis=1,keepdims=True))
    norm_test=torch.sqrt(torch.sum(test_vecs**2,axis=1,keepdims=True))

    valid_test=torch.matmul(test_vecs/norm_test,
          torch.transpose(valid_vecs/norm_valid,0,1))

    train_test=torch.matmul(test_vecs/norm_test,
          torch.transpose(train_vecs/norm_train,0,1))

    train_valid=torch.matmul(valid_vecs/norm_valid,
          torch.transpose(train_vecs/norm_train,0,1))
    
    sim={}
    sim['train','test']=torch.where(train_test>cutoff)
    sim['valid','test']=torch.where(valid_test>cutoff)
    sim['train','valid']=torch.where(train_valid>cutoff)

    train_rem_idxs = np.concatenate((sim['train','test'][0].numpy(),
                                     sim['train','valid'][0].numpy()))
    
    train_rem_idxs = set(train_rem_idxs)

    valid_rem_idxs = np.concatenate((sim['train','test'][0].numpy(),
                                     sim['train','valid'][0].numpy()))
    
    valid_rem_idxs = set(valid_rem_idxs)

    ds_reduced_emb['train']=ds_reduced_emb['train'].filter(lambda _,idx:idx 
                                   not in train_rem_idxs,with_indices=True)

    ds_reduced_emb['validation']=ds_reduced_emb['validation'].filter(lambda _,idx:idx 
                                   not in valid_rem_idxs,with_indices=True)
    
    for split in ['train','validation','test']:
        removed_sent="your submission has been removed"

        ds_reduced_emb[split]=ds_reduced_emb[split].filter(lambda example:
                                                       removed_sent not in example['answers.text'])
    
    ds_reduced_emb.save_to_disk(cleaned_file_name)

    with wandb.init(project='Question_Generation', 
            entity = None, 
            job_type = 'logging_cleaned_data',
            name = 'cleaned_data') as run:
    
        cleaned_data_art=wandb.Artifact(subreddit+'_cleaned_data',type='dataset')
        cleaned_data_art.add_dir(cleaned_file_name)
        
        run.log_artifact(cleaned_data_art)

    return ds_reduced_emb

## Adding keywords

In [None]:
def keybert_top3(text):
    kw_model = KeyBERT()
    keywords = kw_model.extract_keywords(text,
                                     keyphrase_ngram_range=(1, 3),
                                     use_mmr=True,
                                     diversity=0.9,
                                     stop_words='english')

    return ', '.join([keywords[i][0] for
                      i in range(min(3,len(keywords)))]) + ' ====== '

def yake_top3(text):
    language = "en"
    max_ngram_size = 3
    deduplication_threshold = .7
    deduplication_algo = 'seqm'
    windowSize = 1
    numOfKeywords = 20

    custom_kw_extractor = yake.KeywordExtractor(lan=language, 
                                                n=max_ngram_size,
                                                dedupLim=deduplication_threshold,
                                                dedupFunc=deduplication_algo,
                                                windowsSize=windowSize,
                                                top=numOfKeywords,
                                                features=None)
    keywords = custom_kw_extractor.extract_keywords(text)

    prefix=', '.join([keywords[i][0] for
                      i in range(min(3,len(keywords)))]) + ' ====== '
    return prefix + text

def add_keywords_to_dataset(keyword_model,example):
    example['answers.text']=keyword_model(example['answers.text'])
    return example


def add_keywords(keyword_model,
                 subreddit='asks',
                 overwrite=False,
                 sent_model_checkpoint='paraphrase-MiniLM-L6-v2',
                 cutoff=.9,
                 min_sent_length=20):
    keyword_model_name=keyword_model.__name__
    cleaned_file_name=f'./data/{subreddit}_cleaned_data'
    keyword_file_name=f'./data/{keyword_model_name}_{subreddit}_cleaned_data'

    if os.path.exists(keyword_file_name) and not overwrite:
        return load_from_disk(keyword_file_name)
    
    if os.path.exists(cleaned_file_name) and not overwrite:
        cleaned_dataset=load_from_disk(cleaned_file_name)
    
    else:
        cleaned_dataset= clean_and_embed_data(sent_model_checkpoint=sent_model_checkpoint,
                         subreddit=subreddit,
                         overwrite=overwrite,
                         cutoff=cutoff,
                         min_sent_length=min_sent_length)

    dataset_with_keywords=cleaned_dataset.map(lambda x: 
                                              add_keywords_to_dataset(keyword_model,x))

    dataset_with_keywords.save_to_disk(keyword_file_name)
    
    return dataset_with_keywords

# Full Fine-tuning

## Training functions

In [None]:
class Compute_Metrics:
    def __init__(self,tokenizer):
        self.tokenizer=tokenizer
    
    def compute_metrics(self,eval_pred):
        predictions,labels=eval_pred
        decoded_preds=self.tokenizer.batch_decode(predictions)
        labels=np.where(labels!=-100,labels,self.tokenizer.pad_token_id)
        decoded_labels=self.tokenizer.batch_decode(labels,skip_special_tokens=True)
        
        result={}

        result['bertscore'] = bertscore.compute(predictions=decoded_preds,
                                            references=decoded_labels,
                                            lang='en')
        
        result['rouge'] = rouge.compute(predictions=decoded_preds,
                                references=decoded_labels)

        output={}
        for k in result:
            for met in result[k]:
                if met!='hashcode':
                    output[k+'_'+met]=np.mean(result[k][met])

        return output

class Tokenizer_Wrapper:
    def __init__(self,tokenizer,max_length=512):
        self.tok=tokenizer
        self.max_length=max_length
    
    def tokenizer_func(self,examples):
        model_inputs = self.tok(examples["answers.text"],
                                 max_length=self.max_length,
                                 truncation=True)

        labels = self.tok(examples["title"],
                                max_length=self.max_length,
                                truncation=True)

        model_inputs["labels"] = labels["input_ids"]
        
        return model_inputs

In [None]:
class TrainModel:
    
    def __init__(self,
                 checkpoint,
                 subreddit='asks',
                 keyword_model=None,
                 tok=None,
                 fp16=True,
                 sent_model_checkpoint='paraphrase-MiniLM-L6-v2',
                 overwrite=False,
                 cutoff=.9,
                 min_sent_length=20):
        
        self.keyword_model=keyword_model

        if keyword_model:
            keyword_model_name=keyword_model.__name__
            self.data_name=keyword_model_name+'_'+subreddit
        
        else:
            self.data_name=subreddit
        
        if os.path.exists(f'./data/{self.data_name}_cleaned_data') and not overwrite:
            self.dataset=load_from_disk(f'./data/{self.data_name}_cleaned_data')
        
        elif not keyword_model:

            self.dataset=clean_and_embed_data(sent_model_checkpoint=sent_model_checkpoint,
                                              subreddit=subreddit,
                                              overwrite=overwrite,
                                              cutoff=cutoff,
                                              min_sent_length=min_sent_length)
        
        else:
            self.dataset=add_keywords(keyword_model,
                                      subreddit=subreddit,
                                      overwrite=overwrite,
                                      sent_model_checkpoint=sent_model_checkpoint,
                                      cutoff=cutoff,
                                      min_sent_length=min_sent_length)
        
        
        self.model_name = checkpoint.split('/')[1]

        if 'flan' in self.model_name.lower():
            self.fp16=False
        else:
            self.fp16=fp16

        self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        
        if tok is None:
            self.tok = AutoTokenizer.from_pretrained(checkpoint)
        else:
            self.tok = tok

        self.data_collator = DataCollatorForSeq2Seq(self.tok, 
                                                    self.model)
    
    def prepare_data(self,
                     max_length=512,
                     prompt=None,
                     overwrite=False):
    
        if os.path.exists(f'./data/{self.data_name}/{self.model_name}') and not overwrite:
            self.tok_datasets=load_from_disk(f'./data/{self.data_name}/{self.model_name}')
            return

        if prompt:
            print('adding new prompt')
            def add_prompt(example):
                example['answers.text']=prompt+example['answers.text']
                return example
            
            self.dataset=self.dataset.map(add_prompt)

        tok_func=Tokenizer_Wrapper(self.tok,max_length).tokenizer_func
        self.tok_datasets = self.dataset.map(tok_func,batched=True)

        keep_columns=['input_ids','attention_mask','labels']
        drop_cols=[col for col in list(self.dataset['train'].features) \
                   if col not in keep_columns]

        self.tok_datasets = self.tok_datasets.remove_columns(drop_cols)
        self.tok_datasets.save_to_disk(f'./data/{self.data_name}/{self.model_name}')

        with wandb.init(project='Question_Generation', 
                entity = None, 
                job_type = 'logging_tokenized_data',
                name = 'tok_'+self.model_name+'_'+self.data_name) as run:

            tok_data_art=wandb.Artifact(self.data_name+'_'+self.model_name,type='dataset')
            tok_data_art.add_dir(f'./data/{self.data_name}/{self.model_name}')
    
            run.log_artifact(tok_data_art)
        
    def train_model(self,
                    batch_size=4,
                    num_epochs=8,
                    eval_strat='epoch',
                    lr=5.5e-5,
                    weight_decay=0.01,
                    save_limit=3,
                    gradient_accumulation_steps=2,
                    early_stopping_patience=3):
        
        with wandb.init(project='Question_Generation', 
                 entity=None, 
                 job_type='training',
                 name='train_'+self.model_name+'_'+self.data_name) as run:

            logging_steps=len(self.tok_datasets['train'])//(2*batch_size)

            args=Seq2SeqTrainingArguments(
                output_dir= "./models/"+self.model_name+'_'+self.data_name+'_qg',
                evaluation_strategy=eval_strat,
                save_strategy=eval_strat,
                learning_rate=lr,
                per_device_train_batch_size=batch_size,
                per_device_eval_batch_size=batch_size,
                weight_decay=weight_decay,
                save_total_limit=save_limit,
                num_train_epochs=num_epochs,
                predict_with_generate=True,
                logging_steps=logging_steps,
                fp16=True if (device!='cpu' and self.fp16) else False,
                logging_dir=self.model_name+ '_'+self.data_name+'_qg/logs',
                report_to='wandb',
                metric_for_best_model='bertscore_f1',
                load_best_model_at_end=True,
                gradient_accumulation_steps=gradient_accumulation_steps
                )
            
            trainer = Seq2SeqTrainer(
                self.model,
                args,
                train_dataset=self.tok_datasets['train'],
                eval_dataset=self.tok_datasets['validation'],
                data_collator=self.data_collator,
                tokenizer=self.tok,
                callbacks=[EarlyStoppingCallback(\
                                                 early_stopping_patience=\
                                                 early_stopping_patience)],
                compute_metrics=Compute_Metrics(self.tok).compute_metrics
                )
            
            trainer.train()

            outputs=trainer.evaluate()
            trainer.save_model("./models/"+self.model_name+'_'+self.data_name+'_qg')
            
            run.log({self.model_name+'_'+self.data_name+"_Performance-data": wandb.Table(dataframe=pd.DataFrame(outputs, index=["Performance"]))})
            self.model.push_to_hub('dhmeltzer/'+self.model_name+'_'+self.data_name+'_qg')
            self.tok.push_to_hub('dhmeltzer/'+self.model_name+'_'+self.data_name+'_qg')
            
            trained_model_art=wandb.Artifact(self.model_name+'_'+self.data_name+'_qg',type='model')
            trained_model_art.metadata={"hub_id":'dhmeltzer/'+self.model_name+'_'+self.data_name+'_qg'}

In [None]:
def complete_train(checkpoint,
                 subreddit='asks',
                 keyword_model=None,
                 tok=None,
                 fp16=True,
                 sent_model_checkpoint='paraphrase-MiniLM-L6-v2',
                 overwrite=False,
                 cutoff=.9,
                 min_sent_length=20,
                 batch_size=4,
                 num_epochs=8,
                 eval_strat='epoch',
                 lr=5.5e-5,
                 weight_decay=0.01,
                 save_limit=3,
                 gradient_accumulation_steps=2,
                 early_stopping_patience=3,
                 prompt='',
                 max_length=512):


                 

    train_model=TrainModel(checkpoint=checkpoint,
                 subreddit=subreddit,
                 keyword_model=keyword_model,
                 tok=tok,
                 fp16=fp16,
                 sent_model_checkpoint=sent_model_checkpoint,
                 overwrite=overwrite,
                 cutoff=cutoff,
                 min_sent_length=min_sent_length)

    train_model.prepare_data(prompt=prompt,
                             overwrite=overwrite,
                             max_length=max_length)
    
    train_model.train_model(batch_size=batch_size,
                                     num_epochs=num_epochs,
                                     eval_strat=eval_strat,
                                     lr=lr,
                                     weight_decay=weight_decay,
                                     save_limit=save_limit,
                                     gradient_accumulation_steps=gradient_accumulation_steps,
                                     early_stopping_patience=early_stopping_patience,
)

## Prepare Data

In [None]:
clean_and_embed_data(overwrite=False)
add_keywords(yake_top3,overwrite=False)

# Examples

## Bart-Tiny (sshleifer)

In [None]:
complete_train(checkpoint='sshleifer/bart-tiny-random',
               datadict=train_asks,
               dataset=ds_reduced_emb,
               data_name='askscience',
               batch_size=32,
               num_epochs=8,
               early_stopping_patience=10)

In [None]:
complete_train(checkpoint='sshleifer/bart-tiny-random',
               batch_size=32,
               num_epochs=8,
               early_stopping_patience=10,
               keyword_model=yake_top3)



Map:   0%|          | 0/2060 [00:00<?, ? examples/s]



Saving the dataset (0/1 shards):   0%|          | 0/125323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2060 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4058 [00:00<?, ? examples/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Adding directory to artifact (./data/yake_top3_asks/bart-tiny-random)... Done. 0.7s


[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m. Use [1m`wandb login --relogin`[0m to force relogin


You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,9.0065,7.701253,0.686878,0.799051,0.738577,0.028482,6.3e-05,0.02836,0.028359
2,6.9659,6.929829,0.760325,0.80111,0.78009,0.010461,0.0,0.010414,0.010475
2,6.6238,6.771379,0.842492,0.804162,0.82278,0.004914,0.0,0.004898,0.004919
4,6.3919,6.502172,0.842492,0.804162,0.82278,0.004914,0.0,0.004898,0.004919
4,6.1575,6.334644,0.842492,0.804162,0.82278,0.004914,0.0,0.004898,0.004919
6,6.0231,6.255635,0.685347,0.822304,0.747479,0.052252,0.000249,0.050667,0.050748
6,5.9576,6.219351,0.685176,0.822322,0.7474,0.052292,0.000249,0.050694,0.050763
7,5.9303,6.20875,0.685176,0.822322,0.7474,0.052292,0.000249,0.050694,0.050763


Downloading (…)lve/main/config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.43G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/5.33M [00:00<?, ?B/s]

VBox(children=(Label(value='8.297 MB of 8.318 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.997536…

0,1
eval/bertscore_f1,▁▄███▂▂▂█
eval/bertscore_precision,▁▄███▁▁▁█
eval/bertscore_recall,▁▂▃▃▃███▃
eval/loss,█▄▄▂▂▁▁▁▄
eval/rouge_rouge1,▄▂▁▁▁███▁
eval/rouge_rouge2,▃▁▁▁▁███▁
eval/rouge_rougeL,▅▂▁▁▁███▁
eval/rouge_rougeLsum,▅▂▁▁▁███▁
eval/runtime,█▁▁▁▂▁▁▁▂
eval/samples_per_second,▁▇██▇▇█▇▆

0,1
eval/bertscore_f1,0.82278
eval/bertscore_precision,0.84249
eval/bertscore_recall,0.80416
eval/loss,6.77138
eval/rouge_rouge1,0.00491
eval/rouge_rouge2,0.0
eval/rouge_rougeL,0.0049
eval/rouge_rougeLsum,0.00492
eval/runtime,20.9179
eval/samples_per_second,98.48


## Tinier Bart

In [None]:
complete_train(checkpoint='sshleifer/tinier_bart',
               datadict=train_asks,
               dataset=ds_reduced_emb,
               data_name='askscience',
               batch_size=64,num_epochs=8,
               early_stopping_patience=10)

In [None]:
complete_train(checkpoint='sshleifer/tinier_bart',
               batch_size=64,
               num_epochs=8,
               early_stopping_patience=10,
               keyword_model=yake_top3)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/5.29M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]



Map:   0%|          | 0/2060 [00:00<?, ? examples/s]



Saving the dataset (0/1 shards):   0%|          | 0/125323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2060 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4058 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/yake_top3_asks/tinier_bart)... Done. 1.0s


You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,9.8699,9.002449,0.760325,0.80111,0.78009,0.010461,0.0,0.010414,0.010475
2,8.2638,7.818701,0.760325,0.80111,0.78009,0.010461,0.0,0.010414,0.010475
2,7.3246,7.231643,0.760325,0.80111,0.78009,0.010461,0.0,0.010414,0.010475
4,6.8919,7.004455,0.760325,0.80111,0.78009,0.010461,0.0,0.010414,0.010475
4,6.7248,6.926886,0.760325,0.80111,0.78009,0.010461,0.0,0.010414,0.010475
6,6.6603,6.888157,0.760325,0.80111,0.78009,0.010461,0.0,0.010414,0.010475


## Bart-Small

In [None]:
complete_train(checkpoint='lucadiliello/bart-small',
               datadict=train_asks,
               dataset=ds_reduced_emb,
               data_name='askscience',
               batch_size=32,
               num_epochs=8,
               early_stopping_patience=10)

In [None]:
complete_train(checkpoint='lucadiliello/bart-small',
               batch_size=32,
               num_epochs=8,
               early_stopping_patience=10,
               keyword_model=yake_top3)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.71k [00:00<?, ?B/s]

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.


Downloading pytorch_model.bin:   0%|          | 0.00/282M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/999k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/957 [00:00<?, ?B/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,3.4185,3.588686,0.868306,0.851131,0.859318,0.126295,0.029457,0.111196,0.111144
2,3.1541,3.512452,0.869045,0.852333,0.860309,0.135885,0.032381,0.119183,0.119131
2,3.0248,3.463767,0.869233,0.853472,0.860973,0.138932,0.034255,0.122218,0.122211
4,2.9305,3.450398,0.869713,0.854169,0.861576,0.14317,0.036415,0.125137,0.125179
4,2.8565,3.439936,0.869005,0.854354,0.86131,0.142302,0.036331,0.124865,0.124931


Downloading (…)lve/main/config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.43G [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,3.4185,3.588686,0.868306,0.851131,0.859318,0.126295,0.029457,0.111196,0.111144
2,3.1541,3.512452,0.869045,0.852333,0.860309,0.135885,0.032381,0.119183,0.119131
2,3.0248,3.463767,0.869233,0.853472,0.860973,0.138932,0.034255,0.122218,0.122211
4,2.9305,3.450398,0.869713,0.854169,0.861576,0.14317,0.036415,0.125137,0.125179
4,2.8565,3.439936,0.869005,0.854354,0.86131,0.142302,0.036331,0.124865,0.124931
6,2.7989,3.437783,0.868972,0.854499,0.861382,0.14307,0.035432,0.125328,0.12546
6,2.753,3.431431,0.869023,0.85467,0.86148,0.144419,0.036106,0.126519,0.126641
7,2.7204,3.435036,0.868945,0.854524,0.861368,0.143095,0.035884,0.125416,0.125558


Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/282M [00:00<?, ?B/s]

VBox(children=(Label(value='272.060 MB of 272.060 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
eval/bertscore_f1,▁▄▆█▇▇█▇█
eval/bertscore_precision,▁▅▆█▄▄▅▄█
eval/bertscore_recall,▁▃▆▇▇███▇
eval/loss,█▅▂▂▁▁▁▁▂
eval/rouge_rouge1,▁▅▆█▇▇█▇█
eval/rouge_rouge2,▁▄▆██▇█▇█
eval/rouge_rougeL,▁▅▆▇▇▇█▇▇
eval/rouge_rougeLsum,▁▅▆▇▇▇██▇
eval/runtime,█▁▁▁▁▁▁▁▁
eval/samples_per_second,▁████████

0,1
eval/bertscore_f1,0.86158
eval/bertscore_precision,0.86971
eval/bertscore_recall,0.85417
eval/loss,3.4504
eval/rouge_rouge1,0.14317
eval/rouge_rouge2,0.03641
eval/rouge_rougeL,0.12514
eval/rouge_rougeLsum,0.12518
eval/runtime,58.761
eval/samples_per_second,35.057


## Bart-Base

In [None]:
complete_train(checkpoint='facebook/bart-base',
               datadict=train_asks,
               dataset=ds_reduced_emb,
               data_name='askscience',
               num_epochs=3,
               batch_size=16)

In [None]:
complete_train(checkpoint='facebook/bart-base',
               batch_size=16,
               num_epochs=3,
               early_stopping_patience=10,
               keyword_model=yake_top3)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/558M [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/125323 [00:00<?, ? examples/s]

Map:   0%|          | 0/2060 [00:00<?, ? examples/s]

Map:   0%|          | 0/4058 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/125323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2060 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4058 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Adding directory to artifact (./data/yake_top3_asks/bart-base)... Done. 0.5s


You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,2.8595,2.91852,0.871112,0.856896,0.863637,0.151118,0.042164,0.132975,0.132894
2,2.5382,2.85686,0.871586,0.857815,0.864331,0.153839,0.044256,0.135257,0.135275
2,2.3755,2.837399,0.87129,0.858589,0.864579,0.159937,0.046725,0.139604,0.13962


pytorch_model.bin:   0%|          | 0.00/558M [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

VBox(children=(Label(value='535.341 MB of 535.341 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
eval/bertscore_f1,▁▆██
eval/bertscore_precision,▁█▄▄
eval/bertscore_recall,▁▅██
eval/loss,█▃▁▁
eval/rouge_rouge1,▁▃██
eval/rouge_rouge2,▁▄██
eval/rouge_rougeL,▁▃██
eval/rouge_rougeLsum,▁▃██
eval/runtime,▇▁▅█
eval/samples_per_second,▂█▄▁

0,1
eval/bertscore_f1,0.86458
eval/bertscore_precision,0.87129
eval/bertscore_recall,0.85859
eval/loss,2.8374
eval/rouge_rouge1,0.15994
eval/rouge_rouge2,0.04673
eval/rouge_rougeL,0.1396
eval/rouge_rougeLsum,0.13962
eval/runtime,77.3535
eval/samples_per_second,26.631


## Bart-large

In [None]:
complete_train(checkpoint='facebook/bart-large',
               datadict=train_asks,
               dataset=ds_reduced_emb,
               data_name='askscience',
               num_epochs=3,
               batch_size=32)

In [None]:
complete_train(checkpoint='facebook/bart-large',
               batch_size=32,
               num_epochs=3,
               early_stopping_patience=10,
               keyword_model=yake_top3)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/125323 [00:00<?, ? examples/s]

Map:   0%|          | 0/2060 [00:00<?, ? examples/s]

Map:   0%|          | 0/4058 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/125323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2060 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4058 [00:00<?, ? examples/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Adding directory to artifact (./data/yake_top3_asks/bart-large)... Done. 0.5s


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m. Use [1m`wandb login --relogin`[0m to force relogin


You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,2.6036,2.714948,0.872522,0.860253,0.866114,0.162669,0.047312,0.14141,0.141508
2,2.2601,2.662288,0.873183,0.861605,0.867141,0.171188,0.051995,0.148899,0.148928


Downloading (…)lve/main/config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.43G [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,2.6036,2.714948,0.872522,0.860253,0.866114,0.162669,0.047312,0.14141,0.141508
2,2.2601,2.662288,0.873183,0.861605,0.867141,0.171188,0.051995,0.148899,0.148928
2,2.0619,2.644391,0.872302,0.861903,0.866821,0.173584,0.053659,0.151757,0.15174


Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

0,1
eval/bertscore_f1,▁█▆█
eval/bertscore_precision,▃█▁█
eval/bertscore_recall,▁▇█▇
eval/loss,█▃▁▃
eval/rouge_rouge1,▁▆█▆
eval/rouge_rouge2,▁▆█▆
eval/rouge_rougeL,▁▆█▆
eval/rouge_rougeLsum,▁▆█▆
eval/runtime,█▁▁▂
eval/samples_per_second,▁██▇

0,1
eval/bertscore_f1,0.86714
eval/bertscore_precision,0.87318
eval/bertscore_recall,0.8616
eval/loss,2.66229
eval/rouge_rouge1,0.17119
eval/rouge_rouge2,0.05199
eval/rouge_rougeL,0.1489
eval/rouge_rougeLsum,0.14893
eval/runtime,103.613
eval/samples_per_second,19.882


## pegasus-arxiv

In [None]:
checkpoint = 'google/pegasus-arxiv'

complete_train(checkpoint=checkpoint,
               dataset=ds_reduced_emb,
               data_name='askscience',
               num_epochs=3,
               batch_size=16,
               overwrite=True,
               max_length=512)

In [None]:
checkpoint = 'google/pegasus-arxiv'

complete_train(checkpoint=checkpoint,
               num_epochs=3,
               batch_size=16,
               max_length=512,
               keyword_model=yake_top3,
               overwrite=False)

#complete_train(checkpoint=checkpoint,
#               num_epochs=3,
#               batch_size=2,
#               max_length=512,
#               keyword_model=yake_top3,
#               overwrite=False,
#               gradient_accumulation_steps=16)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.28G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/88.0 [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/1.91M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


You're using a PegasusTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,2.9426,3.07416,0.795657,0.832977,0.811984,0.029721,0.006427,0.026947,0.026925
2,2.7112,3.003621,0.80964,0.832948,0.818789,0.031059,0.006763,0.027257,0.027239
2,2.6348,2.987781,0.813815,0.832461,0.820632,0.030811,0.006816,0.026882,0.026888


Downloading (…)lve/main/config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.43G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/2.28G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

spiece.model:   0%|          | 0.00/1.91M [00:00<?, ?B/s]

VBox(children=(Label(value='2186.140 MB of 2186.140 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.…

0,1
eval/bertscore_f1,▁▇██
eval/bertscore_precision,▁▆██
eval/bertscore_recall,██▁▁
eval/loss,█▂▁▁
eval/rouge_rouge1,▁█▇▇
eval/rouge_rouge2,▁▇██
eval/rouge_rougeL,▂█▁▁
eval/rouge_rougeLsum,▂█▁▁
eval/runtime,█▃▁▁
eval/samples_per_second,▁▆██

0,1
eval/bertscore_f1,0.82063
eval/bertscore_precision,0.81382
eval/bertscore_recall,0.83246
eval/loss,2.98778
eval/rouge_rouge1,0.03081
eval/rouge_rouge2,0.00682
eval/rouge_rougeL,0.02688
eval/rouge_rougeLsum,0.02689
eval/runtime,1242.6788
eval/samples_per_second,1.658


In [None]:
from google.colab import runtime
import time
time.sleep(120)

runtime.unassign()

## flan-T5 small

In [None]:
prompt = 'generate a question: '

complete_train(checkpoint='google/flan-t5-small',           
               dataset=ds_reduced_emb,
               data_name='askscience',
               prompt=prompt,
               num_epochs=3,
               batch_size=16,
               overwrite=True)

In [None]:
prompt = 'generate a question: '

complete_train(checkpoint='google/flan-t5-small',
               batch_size=16,
               num_epochs=3,
               early_stopping_patience=10,
               keyword_model=yake_top3,
               prompt=prompt,
               overwrite=True)



Saving the dataset (0/1 shards):   0%|          | 0/125395 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2105 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4060 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/asks_processed_data)... Done. 0.7s


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…



Map:   0%|          | 0/2105 [00:00<?, ? examples/s]



Filter:   0%|          | 0/2105 [00:00<?, ? examples/s]



Filter:   0%|          | 0/2060 [00:00<?, ? examples/s]



Saving the dataset (0/1 shards):   0%|          | 0/125323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2060 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4058 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/asks_cleaned_data)... Done. 41.6s


Map:   0%|          | 0/125323 [00:00<?, ? examples/s]

Map:   0%|          | 0/2060 [00:00<?, ? examples/s]

Map:   0%|          | 0/4058 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/125323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2060 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4058 [00:00<?, ? examples/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/308M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

adding new prompt


Map:   0%|          | 0/125323 [00:00<?, ? examples/s]

Map:   0%|          | 0/2060 [00:00<?, ? examples/s]

Map:   0%|          | 0/4058 [00:00<?, ? examples/s]

Map:   0%|          | 0/125323 [00:00<?, ? examples/s]

Map:   0%|          | 0/2060 [00:00<?, ? examples/s]

Map:   0%|          | 0/4058 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/125323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2060 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4058 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/yake_top3_asks/flan-t5-small)... Done. 0.6s


You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,3.0245,3.081095,0.867996,0.85197,0.859516,0.126682,0.025157,0.112192,0.112023
2,2.9028,3.040322,0.869057,0.853485,0.860824,0.133014,0.027752,0.116161,0.116115


Downloading (…)lve/main/config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.43G [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,3.0245,3.081095,0.867996,0.85197,0.859516,0.126682,0.025157,0.112192,0.112023
2,2.9028,3.040322,0.869057,0.853485,0.860824,0.133014,0.027752,0.116161,0.116115


In [None]:
from google.colab import runtime
runtime.unassign()

## flan-t5-base

In [None]:
prompt = 'generate a question: '

complete_train(checkpoint='google/flan-t5-base',           
               dataset=ds_reduced_emb,
               data_name='askscience',
               prompt=prompt,
               num_epochs=3,
               batch_size=16,
               overwrite=False)

In [None]:
prompt = 'generate a question: '

complete_train(checkpoint='google/flan-t5-base',
               batch_size=16,
               num_epochs=3,
               early_stopping_patience=10,
               keyword_model=yake_top3,
               prompt=prompt)

Downloading pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,2.6421,2.70389,0.871438,0.856479,0.863491,0.143566,0.035767,0.125887,0.126157
2,2.4992,2.663896,0.871501,0.858005,0.864309,0.152529,0.038694,0.133777,0.134023


Downloading (…)lve/main/config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.43G [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Bertscore Precision,Bertscore Recall,Bertscore F1,Rouge Rouge1,Rouge Rouge2,Rouge Rougel,Rouge Rougelsum
0,2.6421,2.70389,0.871438,0.856479,0.863491,0.143566,0.035767,0.125887,0.126157
2,2.4992,2.663896,0.871501,0.858005,0.864309,0.152529,0.038694,0.133777,0.134023
2,2.4453,2.651978,0.871331,0.85822,0.864346,0.154404,0.039334,0.135742,0.135887


pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

VBox(children=(Label(value='947.607 MB of 947.607 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
eval/bertscore_f1,▁███
eval/bertscore_precision,▅█▁▁
eval/bertscore_recall,▁▇██
eval/loss,█▃▁▁
eval/rouge_rouge1,▁▇██
eval/rouge_rouge2,▁▇██
eval/rouge_rougeL,▁▇██
eval/rouge_rougeLsum,▁▇██
eval/runtime,█▁▁▂
eval/samples_per_second,▁██▇

0,1
eval/bertscore_f1,0.86435
eval/bertscore_precision,0.87133
eval/bertscore_recall,0.85822
eval/loss,2.65198
eval/rouge_rouge1,0.1544
eval/rouge_rouge2,0.03933
eval/rouge_rougeL,0.13574
eval/rouge_rougeLsum,0.13589
eval/runtime,91.7501
eval/samples_per_second,22.452


# One-shot Generation

## FLAN-T5-XXL

In [None]:
import requests
from getpass import getpass

secret_hf = getpass('Enter your Huggingface key: ')


API_URL = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
headers = {"Authorization": f"Bearer {secret_hf}"}

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.json()

In [None]:
asks_ds=load_from_disk('./data/asks_cleaned_data')

In [None]:
asks_ds_yake_top3=load_from_disk('./data/yake_top3_asks_cleaned_data')

In [None]:
asks_ds_filtered=DatasetDict()
for key in asks_ds:
    asks_ds_filtered[key]=asks_ds[key].filter(lambda example: 'ask me anything' not in example['title'] 
                             and 'megathread' not in example['title']
                             and 'ama series' not in example['title'] )

Filter:   0%|          | 0/125323 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2060 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4058 [00:00<?, ? examples/s]

In [None]:
asks_ds_yake_top3_filtered=DatasetDict()
for key in asks_ds_yake_top3:
    asks_ds_yake_top3_filtered[key]=asks_ds_yake_top3[key].filter(lambda example: 'ask me anything' not in example['title'] 
                             and 'megathread' not in example['title']
                             and 'ama series' not in example['title'] )

Filter:   0%|          | 0/125323 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2060 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4058 [00:00<?, ? examples/s]

In [None]:
#asks_ds_filtered.save_to_disk('./data/asks_filtered_data')
asks_ds_filtered=load_from_disk('./data/asks_filtered_data')

In [None]:
#asks_ds_yake_top3_filtered.save_to_disk('./data/asks_yaketop3_filtered_data')
asks_ds_yake_top3_filtered=load_from_disk('./data/asks_yaketop3_filtered_data')

In [None]:
flan_t5_xxl_preds = []
i=0
while len(flan_t5_xxl_preds)<len(asks_ds_filtered['validation']['answers.text']):
    i=len(flan_t5_xxl_preds)
    if i%20==0:
        print(i)

    answer=asks_ds_filtered['validation'][i]['answers.text']
    question = query({'inputs':f"generate a question: {answer}"[:1000]})
    question = question[0]['generated_text']

    flan_t5_xxl_preds.append(question)

In [None]:
flan_t5_xxl_preds

In [None]:
flanT5_preds_V1=pd.DataFrame([],columns=['answer','title','prediction'])
flanT5_preds_V1['answer']=asks_ds_filtered['validation']['answers.text']
flanT5_preds_V1['title']=asks_ds_filtered['validation']['title']
flanT5_preds_V1['prediction']=flan_t5_xxl_preds
flanT5_preds_V1.to_csv('./results/flanT5_xxl_V1')

In [None]:
bertscore_flant5xxl=bertscore.compute(predictions=flanT5_preds_V1['prediction'],
                  references=flanT5_preds_V1['title'],
                  lang='en')

In [None]:
import pickle

In [None]:
for key,value in bertscore_flant5xxl.items():
    if 'hash' not in key:
        print((key,np.mean(value)))

('precision', 0.8774760780230194)
('recall', 0.8578001879614817)
('f1', 0.8673236250699633)


In [None]:
bertscore_flant5xxl

In [None]:
rouge_flant5xxl=rouge.compute(predictions=flanT5_preds_V1['prediction'],
                  references=flanT5_preds_V1['title'])

In [None]:
flan_t5_results={}
flan_t5_results['bertscore']=bertscore_flant5xxl
flan_t5_results['rouge']=rouge_flant5xxl

f=open("./results/flan_t5_xxl_results.pkl",'wb')
pickle.dump(flan_t5_results,f)
f.close()

In [None]:
for key,value in rouge_flant5xxl.items():
    print((key,np.mean(value)))

('rouge1', 0.18364659593165306)
('rouge2', 0.04928913518083794)
('rougeL', 0.16304952530114208)
('rougeLsum', 0.1630151221519669)


## GPT-3

In [None]:
asks_ds_filtered=load_from_disk('./data/asks_filtered_data')

In [None]:
sum(map(lambda x:len(x.split()),
        asks_ds_filtered['validation']['answers.text']))*4/3*.0004*1/1000

0.1613850666666667

In [None]:
sum(map(lambda x:len(x.split()),
        asks_ds_filtered['validation']['title']))*4/3*.0004*1/1000

0.015553066666666665

In [None]:
sum(map(lambda x:len(x.split()),
        asks_ds_filtered['validation']['answers.text']))*4/3*.02*1/1000

8.069253333333334

In [None]:
sum(map(lambda x:len(x.split()),
        asks_ds_filtered['validation']['title']))*4/3*.02*1/1000

0.7776533333333333

In [None]:
sum(map(lambda x:len(x.split()),
        asks_ds_filtered['validation']['answers.text']))*4/3*.03*1/1000

12.10388

In [None]:
sum(map(lambda x:len(x.split()),
        asks_ds_filtered['validation']['title']))*4/3*.06*1/1000

2.3329599999999995

In [None]:
!pip install openai

import openai
import os

openai.api_key = getpass('Enter your OpenAI key: ')

In [None]:
GPT_results={}

In [None]:
from tqdm import tqdm

In [None]:
model_engine = "text-ada-001"
max_tokens = 50

GPT_results[model_engine]=[]

for answer in tqdm(asks_ds_filtered['validation']['answers.text']):

    prompt = f"generate a question: {answer}"

    response = openai.Completion.create(
        engine=model_engine,
        prompt=prompt[:2000],
        max_tokens=max_tokens,
    )

    question = response.choices[0].text.strip()
    GPT_results[model_engine].append(question)

100%|██████████| 2012/2012 [12:07<00:00,  2.77it/s]


In [None]:
model_engine = "gpt-3.5-turbo"
max_tokens = 50

GPT_results[model_engine]=[]

for answer in tqdm(asks_ds_filtered['validation']['answers.text']):

    prompt = f"generate a question: {answer}"

    response=openai.ChatCompletion.create(
        model=model_engine,
        messages=[
            {"role": "system", "content": "You are a helpful assistant that generates questions from text."},
            {"role": "user", "content": prompt},
        ])

    question=response['choices'][0]['message']['content']
    
    GPT_results[model_engine].append(question)

100%|██████████| 2012/2012 [55:47<00:00,  1.66s/it]


In [None]:
GPT_results["gpt-3.5-turbo"]

In [None]:
GPT_results.keys()

dict_keys(['text-ada-001', 'gpt-3.5-turbo'])

In [None]:
df_GPT_results=pd.DataFrame.from_dict(GPT_results)
df_GPT_results['answer']=asks_ds_filtered['validation']['answers.text']
df_GPT_results['title']=asks_ds_filtered['validation']['title']
df_GPT_results.to_csv('./results/GPT_results')

In [None]:
bertscore_GPT={}
rouge_GPT={}

for key in GPT_results:

    bertscore_GPT[key]=bertscore.compute(predictions=df_GPT_results[key],
                  references=df_GPT_results['title'],
                  lang='en')

    rouge_GPT[key]=rouge.compute(predictions=df_GPT_results[key],
                  references=df_GPT_results['title'])

In [None]:
print(np.mean(GPT_metrics['bertscore']['text-ada-001']['precision']))
print(np.mean(GPT_metrics['bertscore']['text-ada-001']['recall']))
np.mean(GPT_metrics['bertscore']['text-ada-001']['f1'])

0.6210231705463193
0.625972832292022


0.6232945738623203

In [None]:
print(np.mean(GPT_metrics['bertscore']['gpt-3.5-turbo']['precision']))
print(np.mean(GPT_metrics['bertscore']['gpt-3.5-turbo']['recall']))
np.mean(GPT_metrics['bertscore']['gpt-3.5-turbo']['f1'])

0.8615672170345636
0.8672474204131196


0.8642142173013915

In [None]:
print(np.mean(GPT_metrics['rouge']['text-ada-001']['rouge1']))
print(np.mean(GPT_metrics['rouge']['text-ada-001']['rouge2']))
print(np.mean(GPT_metrics['rouge']['text-ada-001']['rougeL']))
print(np.mean(GPT_metrics['rouge']['text-ada-001']['rougeLsum']))

0.09752219365997786
0.017561066699303712
0.08274926574950985
0.08376101127857803


In [None]:
print(np.mean(GPT_metrics['rouge']['gpt-3.5-turbo']['rouge1']))
print(np.mean(GPT_metrics['rouge']['gpt-3.5-turbo']['rouge2']))
print(np.mean(GPT_metrics['rouge']['gpt-3.5-turbo']['rougeL']))
print(np.mean(GPT_metrics['rouge']['gpt-3.5-turbo']['rougeLsum']))

0.18025551303753928
0.03973726618066363
0.1476850246547054
0.14762872575740563


In [None]:
print(np.mean(GPT_metrics['rouge']['gpt-3.5-turbo']['precision']))
print(np.mean(GPT_metrics['rouge']['gpt-3.5-turbo']['recall']))
np.mean(GPT_metrics['rouge']['gpt-3.5-turbo']['f1'])

In [None]:
GPT_metrics={}
GPT_metrics['bertscore']=bertscore_GPT
GPT_metrics['rouge']=rouge_GPT

f=open("./results/GPT_metrics.pkl",'wb')
pickle.dump(GPT_metrics,f)
f.close()

# LORA

In [None]:
!pip install peft

In [None]:
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType
from torch.utils.data import DataLoader
from transformers import default_data_collator, get_linear_schedule_with_warmup
from tqdm import tqdm

checkpoint='lucadiliello/bart-small'
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM,
                         inference_mode=False,
                         r=8, 
                         lora_alpha=32,
                         lora_dropout=0.1)


model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = get_peft_model(model, peft_config)

dataset=load_from_disk('./data/askscience/bart-small')

keep_cols=['input_ids','attention_mask','labels']
drop_cols=[col for col in dataset['train'].features if 
           col not in keep_cols]

In [None]:
dataset=dataset.remove_columns(drop_cols)

In [None]:
dataset.set_format('pt')


max_length = 128
lr = 1e-3
num_epochs = 3
batch_size = 32

collate_fn=DataCollatorForSeq2Seq(tokenizer, model=model)

train_dataloader = DataLoader(
    dataset['train'],
    shuffle=True, 
    collate_fn=collate_fn,
    #collate_fn=DataCollatorForSeq2Seq,
    batch_size=batch_size,
    pin_memory=True
)
eval_dataloader = DataLoader(dataset['validation'],
                             collate_fn=default_data_collator,
                             batch_size=batch_size,
                             pin_memory=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [None]:
# training and evaluation
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        #print(batch)
        #batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch.to(device))
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        #batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch.to(device))
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")