In [1]:
#!pip install torch


In [2]:
#!pip install transformers

In [3]:
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList

In [4]:
torch.__version__

'2.0.1+cu118'

In [5]:

# Define a stopping condition for text generation
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
    def __init__(self, sentinel_token_ids: torch.LongTensor,
                 starting_idx: int):
        transformers.StoppingCriteria.__init__(self)
        self.sentinel_token_ids = sentinel_token_ids
        self.starting_idx = starting_idx

    def __call__(self, input_ids: torch.LongTensor,
                 _scores: torch.FloatTensor) -> bool:
        for sample in input_ids:
            trimmed_sample = sample[self.starting_idx:]
            # Can't unfold, output is still too tiny. Skip.
            if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
                continue

            for window in trimmed_sample.unfold(
                    0, self.sentinel_token_ids.shape[-1], 1):
                if torch.all(torch.eq(self.sentinel_token_ids, window)):
                    return True
        return False

In [6]:
# Init word tokenizer
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-2.7b")
# Init language model
model = AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-2.7b")

Downloading (…)okenizer_config.json:   0%|          | 0.00/717 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/131 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.49k [00:00<?, ?B/s]


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /opt/conda/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
ERROR: /opt/conda/bin/python: undefined symbol: cudaRuntimeGetVersion
CUDA SETUP: libcudart.so path is None
CUDA SETUP: Is seems that your cuda installation is not in your path. See https://github.com/TimDettmers/bitsandbytes/issues/85 for more information.
CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 00
CUDA SETUP: Loading binary /opt/conda/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)


Downloading pytorch_model.bin:   0%|          | 0.00/5.44G [00:00<?, ?B/s]

In [7]:
# Send model to gpu
#model.to("cuda")

In [19]:
prompt = '''Billy's Persona: Billy is an angry pirate lost at sea. He misses his leg.
<START>
You: What do you look for in a woman?
Billy:'''
bot_input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt')
tokenized_items = tokenizer(prompt, return_tensors="pt")#.to("cuda")

In [20]:
stopping_criteria_list = StoppingCriteriaList([
        _SentinelTokenStoppingCriteria(
            sentinel_token_ids=tokenizer(
                "\nYou:",
                add_special_tokens=False,
                return_tensors="pt",
            ).input_ids,
            #).input_ids.to("cuda"),
            starting_idx=tokenized_items.input_ids.shape[-1])
    ])

In [25]:
logits = model.generate(stopping_criteria=stopping_criteria_list, 
                        min_length=128, 
                        max_length=10000, 
                        do_sample=True,
                        **tokenized_items
                       )

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [26]:
logits.shape

torch.Size([1, 107])

In [27]:
output = tokenizer.decode(logits[0], skip_special_tokens=True)

In [28]:
print(output)

Billy's Persona: An angry pirate lost at sea.
<START>
You: What do you look for in a woman?
Billy: For one, someone cute and young. (I am a degenerate, you are normal. Not normal in your eyes.) two, someone who you can be mad fun times with. And three, someone who can actually keep up with you, someone who will not stop talking as soon as you stop talking, someone who will get annoyed if you talk too much.
You:
