In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def print_param_precision(model):
  dtypes = {}
  for _, p in model.named_parameters():
      dtype = p.dtype
      if dtype not in dtypes:
          dtypes[dtype] = 0
      dtypes[dtype] += p.numel()
  total = 0
  for k, v in dtypes.items():
      total += v
  for k, v in dtypes.items():
      print(f"{k}, {v / 10**6:.4f} M, {v / total*100:.2f} %")

def print_parameters(model):
  # Count the total parameters
  total_params = sum(p.numel() for p in model.parameters())
  print(f"Total parameters: {total_params/10**6:.4f} M")

In [None]:
merged_model = "merged_bloom-1b1"
device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
#ft_model = AutoModelForCausalLM.from_pretrained(merged_model, torch_dtype=torch.float16, device_map=device_map)
ft_model = AutoModelForCausalLM.from_pretrained(merged_model, device_map=device_map)
ft_tokenizer = AutoTokenizer.from_pretrained(merged_model,device_map=device_map)
print(f"{device_map} Memory Used: {ft_model.get_memory_footprint() / 1024**2:.4f} MB")
print("\nParameters:")
print_parameters(ft_model)
print("\nData types:")
print_param_precision(ft_model)

In [None]:
#mytask="CREATE TABLE trip (bus_stop VARCHAR, duration INTEGER), list all the bus stops from which a trip of duration below 100 started."
mytask="CREATE TABLE book (Title VARCHAR, Writer VARCHAR). What are the titles of the books whose writer is not Dennis Lee?"
prompt = f"""
# Instruction:
Use the context below to produce the result
# context:
{mytask}
# result:
"""

In [None]:
input_id1 = ft_tokenizer.encode(prompt, return_tensors="pt").to(device_map)
attention_mask1 = torch.ones(input_id1.shape, dtype=torch.long).to(device_map)
print(f"--------------------------------------")
print(f"Prompt:\n{prompt}")
print(f"--------------------------------------")

print(f"Fine-tuned Model Result :\n")
output_ft = ft_model.generate(input_ids=input_id1, do_sample=True, max_new_tokens=100, top_p=0.9,temperature=0.5,attention_mask=attention_mask1)
print(f"{ft_tokenizer.batch_decode(output_ft.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")
print(f"--------------------------------------")