In [1]:
from typing import Tuple
import json
import time
import sys
import gc
import os

import torch

from non_parallel_llama.model import ModelArgs, Transformer
from non_parallel_llama.tokenizer import Tokenizer
from non_parallel_llama import Llama

In [2]:
torch.cuda.empty_cache()
gc.collect()

0

In [3]:
CKPT_PATH      = 'checkpoints/7B-chat'
MODEL_PATH     = f'{CKPT_PATH}/llama-2-7b-chat'
TOKENIZER_PATH = f'{CKPT_PATH}/tokenizer.model'

In [4]:
def build(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -> Llama:
    
    checkpoints = sorted([path for path in os.listdir(ckpt_dir)
                          if '.pth' in path])
    assert (
        world_size == len(checkpoints)
    ), f'Loading a checkpoint for MP = {len(checkpoints)} but world_size is {world_size}'
    
    torch.cuda.set_device(local_rank)
    
    ckpt_path  = f'{ckpt_dir}/{checkpoints[local_rank]}'
    ckpt       = torch.load(ckpt_path, map_location = 'cpu')
    params     = json.loads(open(f'{ckpt_dir}/params.json', 'r').read())
    model_args = ModelArgs(max_seq_len = 1024, max_batch_size = 32, **params)
    tokenizer  = Tokenizer(model_path = tokenizer_path)
    
    model_args.vocab_size = tokenizer.n_words
    model                 = Transformer(model_args)
    generator             = Llama(model, tokenizer)
    
    return generator


def main(ckpt_dir: str, tokenizer_path: str, 
         temperature: float = 0.8, top_p:float = 0.95):
    
    generator = build(ckpt_dir, tokenizer_path, 0, 1)
    prompt    = input("Enter prompt : ")
    
    while True:
        if prompt == '<end>': break
        
        dialog = [[{"role"    : "user",
                    "content" : prompt}]]
        
        print(f'Q. {prompt}')
        
        results = generator.chat_completion(
                        dialog, max_gen_len = None, 
                        temperature = temperature, top_p = top_p
                    )
        
        for result in results:
            print(f'A. {result["generation"]["content"]}')
            print("="*30, '\n')
        prompts = input("Enter prompt : ")

In [None]:
main(MODEL_PATH, TOKENIZER_PATH)



Enter prompt :  Do you know Pigeon?


Q. Do you know Pigeon?
