In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from datasets import Dataset
import pandas as pd
import torch
from pathlib import Path
from dotenv import load_dotenv

load_dotenv()
PROJECT_ROOT = Path(os.getenv('PROJECT_ROOT')).resolve()
MODEL_ROOT = Path(os.getenv('MODEL_ROOT')).resolve()
DATA_ROOT = Path(os.getenv('DATA_ROOT')).expanduser().resolve()

import os

login(token=os.getenv("HF_TOKEN"))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_path = Path(str(os.getenv("MODEL_ROOT"))).resolve() / 'sft' / 'TinyLlama' / 'TinyLlama-1.1B-Chat-v1.0'
model_path = MODEL_ROOT / 'finetuned-no-prompt' / 'TinyLlama' / 'TinyLlama-1.1B-Chat-v1.0' / 'checkpoint-189'
model_path = str(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
model.warnings_issued = {}

In [None]:
dataset_path = DATA_ROOT / '.kaggle' / 'cnn_dailymail'
df = pd.read_csv(str(dataset_path / 'train.csv'))
dataset = Dataset.from_pandas(df)

In [None]:
print(dataset[0]['highlights'])
print(len(dataset))

In [None]:
# test
messages = [
    {
        "role": "user", 
        "content": "tell me about some amazing facts\n" 
    }
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

output_ids = model.generate(input_ids, max_new_tokens=400)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print(output_text)

In [None]:
def get_summary_no_prompt(content, model):
    message = [
        {
            "role": "user",
            "content": content
        }
    ]
    prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids)
    
    output_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=500,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        top_k=50,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )  
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    assistant_tag = "<|assistant|>"
    if assistant_tag in output_text:
        output_text = output_text.split(assistant_tag, 1)[1].strip()

    return output_text, output_ids.shape[-1] - input_ids.shape[-1]

In [None]:
def get_summary(content, model):
    message = [
        {
            "role": "user",
            "content": f"Summarize the following text in a TL;DR style in one sentence\n\n{content}\n"
        }
    ]
    prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids)
    
    output_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=500,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        top_k=50,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )  
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    assistant_tag = "<|assistant|>"
    if assistant_tag in output_text:
        output_text = output_text.split(assistant_tag, 1)[1].strip()

    return output_text, output_ids.shape[-1] - input_ids.shape[-1]

In [None]:
avarge_length = 0
for i in range(1000):
    content = dataset[i]['article']
    length = tokenizer.encode(content, return_tensors="pt").shape[-1]
    avarge_length += length
print(f"Average length of articles: {avarge_length / 1000}")

In [None]:
print("prompt 1")
print(dataset[0]['article'])
for i in range(1, 6):
    print(f"summary {i}")
    summary, _ = get_summary_no_prompt(dataset[0]['article'], model)
    print(summary, "\n", _)

print()
print("prompt 2")
print(dataset[1]['article'])
for i in range(1, 6):
    print(f"summary {i}")
    summary, _ = get_summary_no_prompt(dataset[1]['highlights'], model)
    print(summary, "\n", _)