In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, random_split
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM, IntervalStrategy


import wandb

In [2]:

torch.manual_seed(42)
#modelName = "./results/checkpoint-22428"
# modelName = "EleutherAI/gpt-neo-125M"
#modelName = "EleutherAI/gpt-neo-1.3B"
modelName = "EleutherAI/gpt-neo-2.7B"
#modelName = "EleutherAI/gpt-j-6B"# , revision="sharded")
#modelName = "EleutherAI/gpt-neox-20b"


tokenizer = AutoTokenizer.from_pretrained(modelName, bos_token='<|startoftext|>',
                                          eos_token='<|endoftext|>', pad_token='<|pad|>')
model = AutoModelForCausalLM.from_pretrained(modelName).cuda()
model.resize_token_embeddings(len(tokenizer))



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Embedding(50259, 768)

In [None]:
runName = modelName + 'dialogSum'
wandb.init(name=runName, project='gptNeo_dialogSum_ScaleStudy', )

In [4]:
trainFiles = pd.read_json('../11_dialogsum/dialogsum/DialogSum_Data/dialogsum.train.jsonl', lines=True)

In [5]:


texts = trainFiles['dialogue'] + "\nSUMMARY: \n" + trainFiles['summary']

from datasets import load_metric
metric = load_metric("rouge")

In [6]:

max_length = 1024


class DialogDataset(Dataset):
    def __init__(self, txt_list, tokenizer, max_length):
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        for txt in txt_list:
            encodings_dict = tokenizer('<|startoftext|>' + txt + '<|endoftext|>', truncation=True,
                                       max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]



In [7]:
from datasets import load_metric
metric = load_metric("rouge")

def compute_metrics(pred):
    #print(pred)
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    
    orig_label_str = label_str
    try:
        pred_str = [oneStr[oneStr.find('SUMMARY:')+9:] for oneStr in pred_str]
        label_str = [oneStr[oneStr.find('SUMMARY:')+9:] for oneStr in label_str]
    except:
        pred_str = pred_str
        label_str = label_str
        
        
    
    rouge_output = metric.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge1","rouge2","rougeL"]
    )

    return {
        #"rouge1_precision": round(rouge_output['rouge1'].mid.precision, 4),
        #"rouge1_recall": round(rouge_output['rouge1'].mid.recall, 4),
        "rouge1_fmeasure": round(rouge_output['rouge1'].mid.fmeasure, 4),
        #"rouge2_precision": round(rouge_output['rouge2'].mid.precision, 4),
        #"rouge2_recall": round(rouge_output['rouge2'].mid.recall, 4),
        "rouge2_fmeasure": round(rouge_output['rouge2'].mid.fmeasure, 4),
        #"rougeL_precision": round(rouge_output['rougeL'].mid.precision, 4),
        #"rougeL_recall": round(rouge_output['rougeL'].mid.recall, 4),
        "rougeL_fmeasure": round(rouge_output['rougeL'].mid.fmeasure, 4),
    }

In [8]:
 def preprocess_logits_for_metrics(logits, labels):
            if isinstance(logits, tuple):
                # Depending on the model and config, logits may contain extra tensors,
                # like past_key_values, but logits always come first
                logits = logits[0]
            return logits.argmax(dim=-1)

In [9]:

dataset = DialogDataset(texts, tokenizer, max_length=max_length)



evalLength = 10
testFiles = pd.read_json('../11_dialogsum/dialogsum/DialogSum_Data/dialogsum.test.jsonl', lines=True)
dialogueOnly = testFiles[:evalLength]['dialogue'] + "\nSUMMARY: \n" 
realSummaries = testFiles[:evalLength]['summary1']


testDataset = DialogDataset(dialogueOnly, tokenizer, max_length=max_length)

In [None]:
batch_size = 1
train_size = int(0.99 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

training_args = TrainingArguments(output_dir='./results',
                                  num_train_epochs=5, 
                                  logging_steps=5000,
                                  save_strategy="epoch", 
                                  save_total_limit=5,
                                  per_device_train_batch_size=batch_size, 
                                  per_device_eval_batch_size=batch_size,
                                  warmup_steps=100, 
                                  weight_decay=0.01,
                                  evaluation_strategy="epoch",
                                  fp16=True, #formerly true
                                  #eval_steps=1000,
                                  deepspeed='./dsconfig.json', #uncomment for 2.7B and other deepspeed runs
                                  logging_dir='./logs')


trainer = Trainer(model=model, 
                  args=training_args, 
                  train_dataset=train_dataset,        
                  eval_dataset=val_dataset, 
                  preprocess_logits_for_metrics = preprocess_logits_for_metrics,
                  compute_metrics=compute_metrics,
                  data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                                              'attention_mask': torch.stack([f[1] for f in data]),
                                                              'labels': torch.stack([f[0] for f in data])})




In [None]:
wandb.watch(model, log='all')

import numpy as np


In [None]:
trainer.train()

In [None]:
torch.cuda.empty_cache()

In [None]:

table_rows = []






logs = {}


sampleSumm = dict()
#r1 = np.zeros((len(dialogueOnly)))
#r2 = np.zeros((len(dialogueOnly)))
#rL = np.zeros((len(dialogueOnly)))
#rLsum = np.zeros((len(dialogueOnly)))




for file in range(len(dialogueOnly)):
    text = dialogueOnly[file]
    tokenized_text = tokenizer(text, return_tensors="pt").input_ids.cuda()
    summary_candidate = model.generate(tokenized_text, do_sample=True, top_k=50,
                                bos_token='<|startoftext|>',
                                eos_token='<|endoftext|>', pad_token='<|pad|>',
                                max_length=1024, top_p=0.95, temperature=1.9)
    summary_candidate_decoded = tokenizer.decode(summary_candidate[0], skip_special_tokens=True)

    sampleSumm['prompt'] = text
    sampleSumm['summary candidate'] = summary_candidate_decoded[int(len(text)):]
    sampleSumm['real summary'] = realSummaries[file]
    print(sampleSumm)

    table_rows.append([list(r) for r in zip([sampleSumm['prompt']], [sampleSumm['summary candidate']],[sampleSumm['real summary']])][0])
    #print(table_rows)
    tablename = 'sampleSummaries'
    logs.update({tablename:wandb.Table(
        columns=['query', 'response', 'actual summary'],
        rows=table_rows)})


    #metrics = metric.compute(predictions=[summary_candidate_decoded[int(len(text)):]], references=[realSummaries.loc[file]])

    #r1[file] = metrics['rouge1'][0][2]
    #r2[file] = metrics['rouge2'][0][2]
    #rL[file] = metrics['rougeL'][0][2]
    #rLsum[file] = metrics['rougeLsum'][0][2]
    #wandb.log({'epoch': epoch,'eval/ROUGE1':np.mean(r1),
    #'eval/ROUGE2' = np.mean(r2),
    #'eval/ROUGEL' = np.mean(rL),
    #'eval/ROUGELsum' = np.mean(rLsum)}


In [None]:
generated = tokenizer("<|startoftext|>", return_tensors="pt").input_ids.cuda()
sample_outputs = model.generate(generated, do_sample=True, top_k=50,
                                bos_token='<|startoftext|>',
                                eos_token='<|endoftext|>', pad_token='<|pad|>',
                                max_length=300, top_p=0.95, temperature=1.9, num_return_sequences=20)
for i, sample_output in enumerate(sample_outputs):
    print("\n\n")
    print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

In [None]:

logs['sampleSummaries']
