In [None]:
%pwd

In [None]:
# !nvidia-smi

In [None]:
# pip install で 上の階層にあるlocalからインストールする see: https://stackoverflow.com/questions/15031694/installing-python-packages-from-local-file-system-folder-to-virtualenv-with-pip
# %pip install -e ../../nano-askllm/

In [None]:
%load_ext dotenv
%dotenv

In [None]:
# cache系は必ずteam storageへ
# TEAM_DATASETS_CACHE_DIR="/persistentshare/storage/team_kumagai/datasets"
TEAM_DATASETS_CACHE_DIR = "./.cache"

In [None]:
import json
import os
import sys
from datetime import datetime
import logging
import random

import numpy as np
import pandas as pd

import wandb
from huggingface_hub import login, whoami

import argilla as rg

from datasets import load_dataset, Dataset

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    AutoModel,
    set_seed,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    LlamaTokenizer,
    TrainerCallback,
)

from transformers import TrainingArguments
from trl import DPOTrainer

import torch.distributed as dist
import multiprocessing as mp

import torch
import transformers

from typing import Any
import torch

logger = logging.getLogger()
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

logger.info("start logging...!")

In [None]:
login(token=os.getenv("HF_TOKEN"))

In [None]:
run = wandb.init(
    project=os.getenv("WANDB_PROJECT"),
    entity=os.getenv("WANDB_ENTITY"),
)

In [None]:
rg.init(
    api_url=os.getenv("RG_API_URL"),
    api_key=os.getenv("RG_API_KEY"),
    workspace=os.getenv("RG_WORKSPACE"),
)

In [None]:
from nano_askllm import AskLLM

In [None]:
model_id = "Rakuten/RakutenAI-7B-instruct"
model2_id = "Rakuten/RakutenAI-7B-chat"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=TEAM_DATASETS_CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype="auto", device_map="auto", cache_dir=TEAM_DATASETS_CACHE_DIR
)
model2 = AutoModelForCausalLM.from_pretrained(
    model2_id, torch_dtype="auto", device_map="auto", cache_dir=TEAM_DATASETS_CACHE_DIR
)

In [None]:
prompt_template_prefix = "###\n"
prompt_template_postfix = """
###

Does the previous paragraph demarcated within ### and ### contain informative signal for pre-training a large-language model? An informative datapoint should be well-formatted, contain some usable knowledge of the world, and strictly NOT have any harmful, racist, sexist, etc. content.

OPTIONS: yes/no
ANSWER:"""

yes_tokens = ["yes", "Yes"]

llm = AskLLM(
    tokenizer,
    model,
    prompt_template_prefix=prompt_template_prefix,
    prompt_template_postfix=prompt_template_postfix,
    yes_tokens=yes_tokens,
    max_tokens=512,  # You can increase it up to 8192 for Mistral-7B-v0.1 based models.
)

batch_size = 2
num_ask = 5

In [None]:
datapoints = ["sample"]

scores = llm.ask(datapoints)
for score, datapoint in zip(scores.tolist(), datapoints):
    text = datapoint[:40].replace("\n", " ")
    print(f"score: {score:.4f}\ttext: {text}")

### ここから内部を取り出して同じことをみる

In [None]:
# ここで実装する
datapoint = "sample"

In [None]:
max_tokens = 512

In [None]:
encoded_tokens = tokenizer.encode(datapoint, add_special_tokens=True)
encoded_tokens

In [None]:
truncated = datapoint

In [None]:
prompt = prompt_template_prefix + truncated + prompt_template_postfix
print(prompt)

In [None]:
inputs = tokenizer([prompt], return_tensors="pt", padding=True).to(model.device)
inputs

In [None]:
max_new_tokens = 1
outputs = model.generate(
    **inputs, max_new_tokens=max_new_tokens, output_logits=True, return_dict_in_generate=True
)
outputs

In [None]:
logits = outputs.logits[0]
logits

In [None]:
probs = torch.nn.functional.softmax(logits, dim=-1)
probs

In [None]:
k = 10
for i, prob in enumerate(probs):
    tops = torch.topk(prob, k, dim=-1)
    for j, (idx, val) in enumerate(zip(tops.indices, tops.values)):
        print(f"{tokenizer.decode(idx):8s}: {val.item():.4f}")

In [None]:
yes_ids: torch.Tensor = (
    tokenizer(yes_tokens, return_tensors="pt", add_special_tokens=False)
    .to(model.device)
    .input_ids[:, 0]
)

In [None]:
yes_probs = probs[:, yes_ids]
yes_probs

In [None]:
prompt_len = sum(
    [
        len(tokenizer.encode(item, add_special_tokens=False))
        for item in [prompt_template_prefix, prompt_template_postfix]
    ]
)
prompt_len

In [None]:
scores = torch.sum(yes_probs, dim=-1)
scores