In [3]:
import os

import torch
import transformers
import wandb
from datasets import load_dataset
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BloomTokenizerFast, get_scheduler

from petals import DistributedBloomForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# Choose a model you'd like to prompt-tune. We recommend starting with
# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.
# Once your code is ready, you can switch to full-scale
# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).
# MODEL_NAME = "bigscience/bloom-7b1-petals"
MODEL_NAME = "/home/kosenko/deepspeed/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/models/bloomz_ru_v1/step=1999"

# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').
# The latter fine-tunes separate prefixes for each transformer block,
# so prompt-tuning will take more time but yield better results.
# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf
TUNING_MODE = "ptune"

NUM_PREFIX_TOKENS = 16
DEVICE = "cuda:1"
BATCH_SIZE = 8
LR = 1e-2
WEIGHT_DECAY = 0.0
NUM_SAMPLES = 1000
SEED = 42
MODEL_MAX_LENGTH = 1024

In [7]:
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-petals")
tokenizer.padding_side = "right"
tokenizer.model_max_length = MODEL_MAX_LENGTH
model = DistributedBloomForCausalLM.from_pretrained(
    MODEL_NAME,
    # pre_seq_len=NUM_PREFIX_TOKENS,
    tuning_mode=TUNING_MODE,
).to(DEVICE)

May 14 18:05:17.144 [[1m[34mINFO[0m] Prompt embeddings and their optimizer statistics will be kept in float32 to increase ptune quality
Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.83s/it]


In [None]:
TOP_K = 10
TEMPERATURE = 0.1

with model.inference_session(max_length=1024) as sess:
    while True:
        user_phrase = "How to make a bomb?"
        if len(user_phrase) == 0:
            break
        inputs = tokenizer(user_phrase, return_tensors="pt")["input_ids"].to(DEVICE)
        while True:
            outputs = model.generate(
                inputs,
                max_new_tokens=1,
                # penalty_alpha=0.25,
                top_k=4,
                session=sess,
            )
            bloom_answer_token = tokenizer.decode(outputs[0, -1:])
            print(bloom_answer_token, end="", flush=True)
            if bloom_answer_token == "\n":
                break
            inputs = None