In [None]:
from time import perf_counter
import pickle

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from soynlp.tokenizer import LTokenizer
from models.modified_sample.gpt2 import GPT2ModifiedSampleForCausalLM

In [None]:
# model_path = "lexiconium/kogpt-trinity"
# revision = "punct_wrapper-related_words-minevalloss"


# device = "cuda" if torch.cuda.is_available() else "cpu"
# tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision, use_auth_token=True)
# model = AutoModelForCausalLM.from_pretrained(model_path, revision=revision, use_auth_token=True).to(device)
# model.eval()

In [None]:
model_path = "/opt/ml/outputs/checkpoint-92"

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GPT2ModifiedSampleForCausalLM.from_pretrained(model_path).to(device)
model.eval()

In [None]:
@torch.no_grad()
def genreate_from_input(
    input_text: str, max_length: int = 64, top_k: int = 30, top_p: float = 0.95, temperature: float = 1.0
):
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # Check generation time
    t = perf_counter()

    output = model.generate(
        input_ids,
        max_length=max_length,
        repetition_penalty=2.0,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        do_sample=True,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
    )
    generated_text = tokenizer.decode(output[0])

    spent = perf_counter() - t

    return generated_text, spent


In [None]:
with open("/opt/ml/data/namuwiki_filtered_cohesion.pickle", "rb") as f:
    cohesions = pickle.load(f)
    l_cohesions = {word: score[0] for word, score in cohesions.items()}
    l_tokenizer = LTokenizer(l_cohesions)

In [None]:
prompt = "스키, 설산, 내림, 남자"

input_text = f"@{prompt}@<d>\n"

# words = l_tokenizer(prompt)
# input_text = f"@{', '.join(words)}@<d>\n"

for top_p in [0.8]:
    for top_k in [16]:
        print(f"========\n{top_p=}\n{top_k=}\n========\n")
        for n in range(1, 6):
            generated_text, spent = genreate_from_input(input_text, max_length=64, top_k=top_k, top_p=top_p, temperature=1.0)
            print(f"[{n}]:\n{generated_text}\n")
            # print(f"\ntime spent: {spent:.2f} sec")


In [None]:
# prompt = "스키를 타고 눈 덮인 비탈을 내려오는 남자"

# input_text = f"@{prompt}@<usr>\n"

# generated_text, spent = genreate_from_input(input_text, max_length=64, top_k=10, top_p=0.8, temperature=1.0)
# print(f"{generated_text}\n")
# # probs.shape=torch.Size([1, 51200])