In [3]:
#Imports
import os
import torch
import torch.nn as nn
from collections import OrderedDict
import transformers
from transformers import BloomForCausalLM
from transformers import BloomTokenizerFast

def get_state_dict(shard_num, prefix=None):
    d = torch.load(model, f"pytorch_model_{shard_num:05d}-of-00072.bin")
    return d if prefix is None else OrderedDict((k.replace(prefix, ''), v) for k, v in d.items())

from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
from transformers.models.bloom.modeling_bloom import BloomBlock, build_alibi_tensor

model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1")
config = BloomConfig.from_pretrained(model)
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1")
device = 'cpu'

OSError: Incorrect path_or_model_id: 'BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 1536)
    (word_embeddings_layernorm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-23): 24 x BloomBlock(
        (input_layernorm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=1536, out_features=4608, bias=True)
          (dense): Linear(in_features=1536, out_features=1536, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=1536, out_features=6144, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=6144, out_features=1536, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1536, out_features=250880, bias=False)
)'. Please provide either the path to a local folder or the repo_id of a model on the Hub.

In [None]:
#3 methods to load state dictionaries into different objects. This loads only specific parts to RAM to save memory.
def load_embeddings():
    state_dict = get_state_dict(shard_num=1, prefix="word_embeddings_layernorm.")
    embeddings = nn.Embedding.from_pretrained(state_dict.pop('word_embeddings.weight'))
    lnorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=torch.bfloat16)
    lnorm.load_state_dict(state_dict)
    return embeddings.to(device), lnorm.to(device)

def load_causal_lm_head():
    linear = nn.utils.skip_init(
        nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.bfloat16)
    linear.load_state_dict(get_state_dict(shard_num=1, prefix="word_embeddings."), strict=False)
    return linear.bfloat16().to(device)

def load_block(block_obj, block_num):
    block_obj.load_state_dict(get_state_dict(shard_num=block_num + 2, prefix=f"h.{block_num}."))
    block_obj.to(device)

final_lnorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, dtype=torch.bfloat16)
final_lnorm.load_state_dict(get_state_dict(shard_num=72, prefix="ln_f."))
final_lnorm.to(device)
block = BloomBlock(config, layer_number=1).bfloat16()

In [None]:
def forward(input_ids):
    # 1. Create attention mask and position encodings
    attention_mask = torch.ones(len(input_ids)).unsqueeze(0).bfloat16().to(device)
    alibi = build_alibi_tensor(input_ids.shape[1], config.num_attention_heads,
                               torch.bfloat16).to(device)
    # 2. Load and use word embeddings
    embeddings, lnorm = load_embeddings()
    hidden_states = lnorm(embeddings(input_ids))
    del embeddings, lnorm

    # 3. Load and use the BLOOM blocks sequentially
    for block_num in range(70):
        load_block(block, block_num)
        hidden_states = block(hidden_states, attention_mask=attention_mask, alibi=alibi)[0]
        print(".", end='')
    
    hidden_states = final_lnorm(hidden_states)
    
    #4. Load and use language model head
    lm_head = load_causal_lm_head()
    logits = lm_head(hidden_states)

    # 5. Compute next token 
    return torch.argmax(logits[:, -1, :], dim=-1)

In [None]:
input_sentence = "The SQL command to extract all the users whose name starts with A is: "
input_ids = tokenizer.encode(input_sentence, return_tensors='pt').to(device)
max_tokens = 10
for i in range(max_tokens): 
    print(f"Token {i + 1} ", end='')
    new_id = forward(input_ids)
    input_ids = torch.cat([input_ids, new_id.unsqueeze(-1)], dim=-1)
    print(tokenizer.decode(new_id))

print(tokenizer.decode(input_ids.squeeze(), skip_special_tokens=True))