In [1]:
import os
import random
from argparse import ArgumentParser
import logging

import torch
from trl import SFTConfig, SFTTrainer

from lima_dataset import load_lima_dataset, tokenize_text, format_prompt_func, EOT_TOKEN
from utils import (
    read_yaml,
    get_model_config,
    get_tokenizer_config,
    get_split_config,
    get_dataset_config,
    get_trainer_config,
    get_generation_config,
    get_generation_samples,
    get_lora_config,
    _handle_seed,
    DEVICE,
)
from model import (
    tokenize_text,
    load_model,
    load_tokenizer,
    load_pretrained_base_llama2_model,
    load_lora_model,
    generate,
    compute_metrics,
)

In [2]:
config = read_yaml("./configs/generate_config_llama.yaml")

In [3]:
tokenizer_name, tokenizer_path, tokenizer_config = get_tokenizer_config(config)
tokenizer = load_tokenizer(
    tokenizer_name=tokenizer_name,
    tokenizer_path=tokenizer_path,
    tokenizer_config=tokenizer_config,
)
tokenizer_name, tokenizer_path, tokenizer_config

('llama2',
 'meta-llama/Llama-2-7b-hf',
 {'special_token_kwargs': {'pad_token': 'eos_token',
   'additional_tokens': ['EOT_TOKEN']}})

In [4]:
# tokenizer.pad_token_id

In [5]:
model_name, model_path, base_model_path, model_config = get_model_config(config)
model_config

{'force_download': False,
 'device_map': 'cuda:0',
 'bnb_config': {'load_in_4bit': True,
  'bnb_4bit_quant_type': 'nf4',
  'bnb_4bit_compute_dtype': 'float16',
  'bnb_4bit_use_double_quant': False}}

In [6]:
model_config['pad_token_id'] = tokenizer.pad_token_id
model_config['tokenizer_length'] = len(tokenizer)

In [7]:
# base_model = load_pretrained_base_llama2_model(
#     base_model_path, **model_config
# )

In [8]:
model = load_model(
    model_string=model_name,
    model_path=model_path,
    base_model_path=base_model_path,
    model_config=model_config,
)
# base_model.config.pad_token_id = tokenizer.pad_token_id
# base_model.resize_token_embeddings(len(tokenizer))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [9]:
model.config.pad_token_id

2

In [10]:
# from peft import PeftModel
# model = PeftModel.from_pretrained(
#     base_model, "/home/hmankodi/instruct_tuning/TrainingLogs/checkpoint-468"
# )
# model

In [11]:
generation_config = get_generation_config(config)
generation_config

{'max_length': 2048,
 'top_p': 0.9,
 'temperature': 0.7,
 'num_beams': 1,
 'top_k': None,
 'do_sample': True,
 'repetition_penalty': 1.2}

In [12]:
from model import generate

In [17]:
# prompt = "I'm writing a NeurIPS paper about a new model architecture for processing and generating long texts. Here are some facts about the paper:\n* The main trick is to replace some of the attention heads with an exponential moving average, where the decay rate is learned for each head. We call this architecture ExeMA.\n* On language modeling, the perplexity difference between our model and a vanilla transformer is negligible, but that's because next-token prediction is almost always a local task, so perplexity won't be sensitive enough to detect any improvements in long-range understanding.\n* However, on the SCROLLS benchmark, our model improves by 10% over the baseline.\n* We also have a new metric for measuring coherence in generated text (CoGnaTe), where our model generates text that is 43% more coherent than the baseline.\nHelp me write the paper's introduction."
prompt = "Plan a day trip in Tokyo. The spots need to be within walking distance to each other."
# prompt = "What medicine should I take when I get a cold?"
# prompt = f"{prompt}{EOT_TOKEN}"
outs = generate(
    model,
    tokenizer,
    prompt_samples=prompt,
    generation_config=generation_config,
    use_encode=True,
)
model.config.pad_token_id

2

In [18]:
print(outs[0])

Plan a day trip in Tokyo. The spots need to be within walking distance to each other. [EOT] Here's an itinerary:

* Start your morning by visiting the Meiji Shrine, which is located at 1-chomei, Minato City and surrounded by trees of the Imperial Palace Garden. Then walk around the shrines and learn about Japanese history and culture from the signboards on display.
* From there you can take a bus or train to reach Asakusa district where you will find one of Japan’s most famous temples – Sensoji Temple. This temple has been standing for more than 200 years and features a large red gate that was built as part of its restoration work after World War II. If you are lucky enough to come here during cherry blossom season then make sure not miss out viewing some beautiful sakura flowers! There are also many stalls offering traditional food like okonomiyaki (a savory pancake) so don’t forget try them before leaving this place;
* After exploring all these attractions head over towards Akihabara