In [2]:
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
import re
from tqdm import tqdm
from datasets import load_dataset
import pandas as pd

from spoiler_generation.utils.dataset_class import Dataset

llama_path = "path/to/your/llama/model"
model = LlamaForCausalLM.from_pretrained(llama_path, load_in_8bit=True)
tokenizer = LlamaTokenizer.from_pretrained(llama_path)
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})

In [None]:
dataset = Dataset(pd.DataFrame(load_dataset(
    "MateuszW/spoiler_generation",
    data_files={"train": "clickbait_spoiling_data/train.jsonl"},
)['train']))
clickbaits = dataset.df["postText"].tolist()

In [5]:
clickbaits

['Wes Welker Wanted Dinner With Tom Brady, But Patriots QB Had Better Idea',
 'NASA sets date for full recovery of ozone hole',
 "This is what makes employees happy -- and it's not their paycheck",
 'Passion is overrated —\xa07 work habits you need instead',
 "The perfect way to cook rice so that it's perfectly fluffy and NEVER sticks to the pan",
 'What happens if your new AirPods get lost or stolen, will Apple do anything?',
 'The Reason Why Gabor Kiraly Wears THOSE Trackie Bottoms',
 'You’ll Never Believe What This Family Saw in the Sky Outside Their House in Finland.',
 'Should you drink Red Wine?',
 'Hot Sauce Taste Test: Find out which we named number 1',
 'Analysis: This may be the most brutal number in the CBO report',
 '#TeenMom2 star @PBandJenelley_1 reveals the sex of her second child through social media',
 "You're probably missing out on this major way to save money",
 "Target's $20 million answer to transgender bathroom boycott",
 "China invited a reporter to hit their ne

In [None]:
len(clickbaits)

In [None]:
generated_questions = []
sen_ques = re.compile(r"Question:\n(.*)\n?", re.MULTILINE)
j = 0
for i in tqdm(range(len(clickbaits) // 50)):
    clickbaits_batch = clickbaits[i * 50 : (i + 1) * 50]
    input_ids = tokenizer([PROMPT.format(question=c) for c in clickbaits_batch], padding=True, return_tensors="pt")
    with torch.inference_mode():
        generated_ids = model.generate(**input_ids, max_new_tokens=30)

    questions_batch = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    for question in tqdm(questions_batch):
        try:
            generated = next(sen_ques.finditer(question)).groups(0)[0]
        except StopIteration:
            generated = clickbaits[j]
        j += 1
        generated_questions.append(generated)

In [None]:
len(generated_questions)

In [None]:
import pandas as pd

pd.DataFrame(generated_questions, columns=["generated_questions"]).to_csv("data/spoiler_generation/vicuna/train_questions.csv", index=False)