# Testing BLOOM for Next Token Prediction 

***

## Imports

In [11]:
from transformers import BloomTokenizerFast, BloomForCausalLM
from datasets import load_dataset
#from tqdm.notebook import tqdm
from tqdm import tqdm
import torch
import os

## Use Pretrained Model

**Load Model ans Tokenizer:**

The list of available Models can be found here: https://huggingface.co/docs/transformers/model_doc/bloom

In [3]:
model_name = "bloom-560m"
tokenizer = BloomTokenizerFast.from_pretrained(f"bigscience/{model_name}", add_prefix_space=True)
model = BloomForCausalLM.from_pretrained(f"bigscience/{model_name}")

## Predict Token:

Since Bloom has not been fintuned yet, the prediction is poor as expected.

In [5]:
inputs  = tokenizer("Hello, my name is ", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
logits = outputs.logits

In [8]:
new_id = torch.argmax(logits[:, -1, :], dim=-1)
predicted_token = tokenizer.decode(new_id)

print(f"Predicted Token: {predicted_token}")

Predicted Token:  John


## Predict Tokens Continuously

In [9]:
text_input = "Hello, my name is "
#text_input = "SQL command for finding persons with names starting with F and are older than 20:"
#text_input = "JavaScript code for creating a D3 scatter plot:"
#text_input = "Python class representing people with names and age:"
#text_input = "Python code to sort an array of integers according to value in ascending order:"

In [14]:
inputs = tokenizer(text_input, return_tensors="pt")
num_tokens = 10

for i in tqdm(range(num_tokens)):

    outputs = model(**inputs, labels=inputs["input_ids"])
    logits = outputs.logits
    new_id = torch.argmax(logits[:, -1, :], dim=-1)
    inputs["input_ids"] = torch.cat((inputs["input_ids"],  torch.tensor([[new_id]])), 1)
    inputs["attention_mask"] = torch.cat((inputs["attention_mask"],  torch.tensor([[1]])), 1) 
    
text_output = tokenizer.decode(inputs["input_ids"][0])

100%|██████████| 10/10 [00:03<00:00,  2.69it/s]


In [15]:
print(text_output)

Hello, my name is  John.
I am a student at the University of
