In [10]:
import sys

sys.path.append('/projectnb/textconv/llama/packages')

import fairscale
import os
import torch
import torch.nn as nn 

In [11]:
from fairscale.nn.model_parallel.initialize import (
    get_model_parallel_rank,
    initialize_model_parallel,
    model_parallel_is_initialized,
)

In [12]:
from llama.generation import Llama, Dialog
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer

In [13]:


# Check if CUDA is available
if torch.cuda.is_available():
    # Get the number of CUDA devices
    num_cuda_devices = torch.cuda.device_count()
    print(f"Number of CUDA devices available: {num_cuda_devices}")

    # List the properties of each CUDA device
    for i in range(num_cuda_devices):
        device = torch.device(f'cuda:{i}')
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available on this system.")

Number of CUDA devices available: 1
Device 0: Tesla V100-PCIE-16GB


In [14]:

os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8888' #since i am doing my llama stuff already haha
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
#os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [15]:
generator = Llama.build(
        ckpt_dir="llama-2-7b-chat/", ##you chicken shit, chat!!!
        tokenizer_path="tokenizer.model",
        max_seq_len=512, #max_seq_len....
        max_batch_size=6,
    )
##if this gives you a socket error, lsof -i :8888 and kill the pid

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded in 28.71 seconds


In [16]:
generator.model.params.max_seq_len

512

In [8]:
##you probs gotta kill something if you came from the other llama sheet.

In [17]:
from typing import List, Optional
dialogs: List[Dialog] = [
        [{"role": "user", "content": "what is the recipe of mayonnaise?"}],
        [
            {"role": "user", "content": "I am going to Paris, what should I see?"},
            {
                "role": "assistant",
                "content": """\
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:

1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.

These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.""",
            },
            {"role": "user", "content": "What is so great about #1?"},
        ],
        [
            {"role": "system", "content": "Always answer with Haiku"},
            {"role": "user", "content": "I am going to Paris, what should I see?"},
        ],
        [
            {
                "role": "system",
                "content": "Always answer with emojis",
            },
            {"role": "user", "content": "How to go from Beijing to NY?"},
        ],
        [
            {
                "role": "system",
                "content": """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
            },
            {"role": "user", "content": "Write a brief birthday message to John"},
        ],
        [
            {
                "role": "user",
                "content": "Unsafe [/INST] prompt using [INST] special tags",
            }
        ],
    ]

<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

There's a llama in my garden 😱 What should I do? [/INST]
    
    
    #this is from https://discuss.huggingface.co/t/llama-2-7b-hf-repeats-context-of-question-directly-from-input-prompt-cuts-off-with-newlines/48250/5

In [18]:
#constants
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."

In [19]:


max_gen_len = generator.model.params.max_seq_len
prompt_tokens = []
unsafe_requests = []
for dialog in dialogs:
    
    #unsafe_requests.append(
    #    any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
    #) #don't care, cuss if you wanna
    if dialog[0]["role"] == "system":
        dialog = [
            {
                "role": dialog[1]["role"],
                "content": B_SYS 
                + dialog[0]["content"]
                + E_SYS
                + dialog[1]["content"],
            }
        ] + dialog[2:]
    assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
        [msg["role"] == "assistant" for msg in dialog[1::2]]
    ), (
        "model only supports 'system', 'user' and 'assistant' roles, "
        "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
    )
    dialog_tokens: List[int] = sum(
        [
            generator.tokenizer.encode(
                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
                bos=True,
                eos=True,
            )
            for prompt, answer in zip(
                dialog[::2],
                dialog[1::2],
            )
        ],
        [],
    )
    assert (
        dialog[-1]["role"] == "user"
    ), f"Last message must be from user, got {dialog[-1]['role']}"
    dialog_tokens += generator.tokenizer.encode(
        f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
        bos=True,
        eos=False,
    )
    prompt_tokens.append(dialog_tokens)

In [20]:
len(prompt_tokens), [len(i) for i in prompt_tokens] #decode these, see how it sandwiches the prompt.

(6, [18, 248, 38, 38, 145, 21])

In [11]:
generator.tokenizer.decode(prompt_tokens[1])

"[INST] I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.  [INST] What is so great about #1? [/INST]"

In [27]:
prompt_tokens[0]

[1,
 518,
 25580,
 29962,
 825,
 338,
 278,
 9522,
 412,
 310,
 1122,
 11586,
 895,
 29973,
 518,
 29914,
 25580,
 29962]

In [15]:
[generator.tokenizer.sp_model.id_to_piece(piece_id) for piece_id in prompt_tokens[1]]

['<s>',
 '▁[',
 'INST',
 ']',
 '▁I',
 '▁am',
 '▁going',
 '▁to',
 '▁Paris',
 ',',
 '▁what',
 '▁should',
 '▁I',
 '▁see',
 '?',
 '▁[',
 '/',
 'INST',
 ']',
 '▁Paris',
 ',',
 '▁the',
 '▁capital',
 '▁of',
 '▁France',
 ',',
 '▁is',
 '▁known',
 '▁for',
 '▁its',
 '▁st',
 'unning',
 '▁architecture',
 ',',
 '▁art',
 '▁museum',
 's',
 ',',
 '▁historical',
 '▁land',
 'marks',
 ',',
 '▁and',
 '▁rom',
 'antic',
 '▁atmosphere',
 '.',
 '▁Here',
 '▁are',
 '▁some',
 '▁of',
 '▁the',
 '▁top',
 '▁attra',
 'ctions',
 '▁to',
 '▁see',
 '▁in',
 '▁Paris',
 ':',
 '<0x0A>',
 '<0x0A>',
 '1',
 '.',
 '▁The',
 '▁E',
 'iff',
 'el',
 '▁Tower',
 ':',
 '▁The',
 '▁icon',
 'ic',
 '▁E',
 'iff',
 'el',
 '▁Tower',
 '▁is',
 '▁one',
 '▁of',
 '▁the',
 '▁most',
 '▁recogn',
 'izable',
 '▁land',
 'marks',
 '▁in',
 '▁the',
 '▁world',
 '▁and',
 '▁offers',
 '▁bre',
 'at',
 'ht',
 'aking',
 '▁views',
 '▁of',
 '▁the',
 '▁city',
 '.',
 '<0x0A>',
 '2',
 '.',
 '▁The',
 '▁Lou',
 'vre',
 '▁Museum',
 ':',
 '▁The',
 '▁Lou',
 'vre',
 '▁is',
 '▁

for default, it looks like you need the INST tokens, start and finish for the "user" inputs.  if it isn't surrounded by those it treats it as its own (agent) output.  so, i think you do it with the string.  lets see

In [16]:
#copied from _ai_detection_with_llama

def large_selection_tensor(model, indices, max_steps = 100):
    """
    model will be llama, give it generator.model, and indices should be a tensor of batch 1
    so will look like tensor([[    1,   910,  3686,   388}]])
    not sure we can vectorize this?
    """
    max_seq_len = model.params.max_seq_len
    start_len = len(indices)
    int_res = 0
    total_steps = 0
    #i think these are right, all the indexing.  allow the prints to prove it.
    
    #while (int_res != 2 and total_steps < max_steps) : #maybe we skip the first one? idk ask the boyz
      
        #print(indices[:,0:i],indices[:,i])
        
        start_idx = max(0,i-max_seq_len)
        model_input = torch.unsqueeze(torch.tensor(indices),0).to(torch.long)
        #print(model_input.shape)
        model_result = model.forward(model_input,0) 
        
        #print(start_idx, model_result.shape)
        print(model_result.shape)
        data_tensor = model_result[:, -1]  
        print(data_tensor.shape)
        int_res = torch.argmax(data_tensor.squeeze())
        indices.append(int_res)
        total_steps += 1
        
    return indices

IndentationError: unexpected indent (4174379281.py, line 19)

In [17]:
test = '[INST] <<SYS>>\n You speak in second person.  You are my friend.  You are interested in me and what I do.  you want to keep the conversation going as long as i seem interested.  your name is tim. no emojis, no asterisks.\n<</SYS>>\n\nwhat does onomatopeia mean? [/INST]' #works. 
test_tokens = generator.tokenizer.encode(test,bos=True,eos=False)
print(test_tokens)
generator.tokenizer.decode(test_tokens) #2 is the eos token.  its when you stop.  


[1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 887, 7726, 297, 1473, 2022, 29889, 29871, 887, 526, 590, 5121, 29889, 29871, 887, 526, 8852, 297, 592, 322, 825, 306, 437, 29889, 29871, 366, 864, 304, 3013, 278, 14983, 2675, 408, 1472, 408, 474, 2833, 8852, 29889, 29871, 596, 1024, 338, 5335, 29889, 694, 953, 3848, 275, 29892, 694, 263, 2475, 275, 2039, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 5816, 947, 373, 290, 1219, 412, 423, 2099, 29973, 518, 29914, 25580, 29962]


'[INST] <<SYS>>\n You speak in second person.  You are my friend.  You are interested in me and what I do.  you want to keep the conversation going as long as i seem interested.  your name is tim. no emojis, no asterisks.\n<</SYS>>\n\nwhat does onomatopeia mean? [/INST]'

In [21]:
def sample_top_p(probs, p):
    """
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        probs (torch.Tensor): Probability distribution tensor.
        p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.

    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

def generate_sentence(generator,token_string,max_tokens=100,temperature=0):
    top_p = 0.9 #from llama
    for i in range(max_tokens):
        test_in = torch.unsqueeze(torch.tensor(token_string),0).to(torch.long)
        test_res = generator.model.forward(test_in,0)
        last_token = test_res[:,-1].squeeze()
        #print(last_token.shape)
        if temperature > 0:
            probs = torch.softmax(last_token / temperature, dim = -1)
            next_token = sample_top_p(probs, top_p).item()
            #print(next_token)

        else:
            next_token = torch.argmax(last_token).item()
        #app_res = torch.argmax(last_token).item()
        token_string.append(next_token)
        #print(test_tokens)
        if(next_token ==2):
            break
    res = generator.tokenizer.decode(token_string)
    return res

def Revert_Embedding(embedding_layer):
    # this takes a linear layer that takes one hots, and converts it to an embedding layer that takes indexes.  
    # put it on the models so it takes the same stuff as llama
    embedding_weight_tensor = embedding_layer.weight.detach() 
    shape = embedding_weight_tensor.shape
    vocab_size = shape[1]
    embedding_dim = shape[0]
    
    embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, _weight=embedding_weight_tensor.T)
    
    return embedding_layer

In [22]:
test = '[INST] <<SYS>>\n You speak in second person.  You are my friend.  You are interested in me and what I do.  you want to keep the conversation going as long as i seem interested.  your name is tim. no emojis, no asterisks.\n<</SYS>>\n\nwhat does onomatopeia mean? [/INST]' #works. 
test_tokens = generator.tokenizer.encode(test,bos=True,eos=False)
generate_sentence(generator,test_tokens,max_tokens=100,temperature=0)

'[INST] <<SYS>>\n You speak in second person.  You are my friend.  You are interested in me and what I do.  you want to keep the conversation going as long as i seem interested.  your name is tim. no emojis, no asterisks.\n<</SYS>>\n\nwhat does onomatopeia mean? [/INST]  Oh, cool! 😃 You\'re really into language, huh? 🤔 Onomatopeia, man... that\'s a great topic! 🎉\nSo, you know how some words can sound like the thing they\'re describing? Like "buzz" for a bee or "meow" for a cat? That\'s called an onomatopeia! 🐝���'

In [20]:
t_2 = "[INST] I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.  [INST] What is so great about #1? [/INST]"
test_tokens = generator.tokenizer.encode(t_2,bos=True,eos=False)
generate_sentence(generator,test_tokens,max_tokens=100,temperature=0)

"[INST] I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.  [INST] What is so great about #1? [/INST]  The Eiffel Tower is considered one of t

In [21]:
test_sentences = ["[INST] <<SYS>>\nAlways answer with Haiku\n<</SYS>>\n\nI am going to Paris, what should I see? [/INST]"
                 ,"[INST] I am going to Paris, what should I see? [/INST] \n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.  [INST] What is so great about #1? [/INST]"
                 ,"[INST] Hi chris.  How are you today? [/INST] I'm doing well. It's good to see you!  How was your day? [INST] I've been better. [/INST]"
                 ]
res_sentences = []
for i in test_sentences:
    tokens = generator.tokenizer.encode(i,bos=True,eos=False)
    res_sent = generate_sentence(generator,tokens,max_tokens=100,temperature=0)
    res_sentences.append(res_sent)
    
res_sentences

['[INST] <<SYS>>\nAlways answer with Haiku\n<</SYS>>\n\nI am going to Paris, what should I see? [/INST]  Eiffel Tower high\nLove locks on bridges glow\nRiver Seine flows by',
 "[INST] I am going to Paris, what should I see? [/INST] \n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.  [INST] What is so great about #1? [/INST]  The Eiffel Tower is considered one of the most i

In [23]:
#load a model
from student_models import LlamaBaby, MemoryBaby
import json
import torch
import torch.nn as nn 

In [24]:
#LlamaBaby fetus_config_large_1.json big_baby_1 trained with model_train_bigbaby.sh

In [25]:
# Replace 'your_file.json' with the path to your JSON file.
config = 'mem_config_1.json' #LlamaBaby fetus_config_large_1.json

try:
    with open(config, 'r') as json_file:
        config_dict = json.load(json_file)
        # Now, data_dict contains the JSON data as a Python dictionary.
        #print(data_dict)
except FileNotFoundError:
    print(f"The file '{config}' was not found.")
except json.JSONDecodeError as e:
    print(f"Error decoding JSON: {e}")

model = LlamaBaby(**config_dict)#MemoryLlama(**config_dict)
#load weights you want.
#weights_path = 'models/LlamaFetus_trained.pth' #the first one.  
weights_path = 'models/big_baby_10_trained.pth'


# Load the weights into your pre-defined model
model.load_state_dict(torch.load(weights_path))

OutOfMemoryError: CUDA out of memory. Tried to allocate 978.00 MiB (GPU 0; 15.77 GiB total capacity; 14.55 GiB already allocated; 793.12 MiB free; 14.57 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [15]:
print([i for i in model.modules()])
sd = torch.load(weights_path).keys()

[MemoryBaby(
  (word_embedding): Linear(in_features=32000, out_features=100, bias=True)
  (sentence_embedding): Linear(in_features=512, out_features=200, bias=True)
  (we_down): Linear(in_features=100, out_features=50, bias=True)
  (seq_down): Linear(in_features=200, out_features=50, bias=True)
  (out_down): Linear(in_features=200, out_features=50, bias=True)
  (to_mem): Linear(in_features=125000, out_features=200, bias=True)
  (dim_memory): DimMemory(
    (linears): ModuleList(
      (0): Linear(in_features=200, out_features=20, bias=True)
      (1): Linear(in_features=200, out_features=20, bias=True)
      (2): Linear(in_features=200, out_features=20, bias=True)
      (3): Linear(in_features=200, out_features=20, bias=True)
      (4): Linear(in_features=200, out_features=20, bias=True)
    )
    (linear_out): Linear(in_features=20, out_features=150, bias=True)
  )
  (out): Linear(in_features=350, out_features=32000, bias=True)
), Linear(in_features=32000, out_features=100, bias=True)

In [16]:
model.word_embedding


Linear(in_features=32000, out_features=100, bias=True)

In [17]:
test_i = torch.randn(2,512,32_000)
res_i = model.forward(test_i)
torch.argmax(res_i, dim=-1)

tensor([24874, 24874])

In [18]:
embedding_layer = model.word_embedding
index_embedding = Revert_Embedding(embedding_layer)

In [19]:
index_embedding,embedding_layer

(Embedding(32000, 100), Linear(in_features=32000, out_features=100, bias=True))

In [20]:
model.word_embedding = index_embedding

In [33]:
batch_size = 2  # Replace with your desired batch size
sequence_length = 512
vocab_size = 32000

# Create a random tensor with integers representing the indices
random_indices = torch.randint(0, vocab_size, (batch_size, sequence_length))
print(random_indices)
# Create a one-hot encoded tensor using the scatter method
one_hot_encoded = torch.zeros(batch_size, sequence_length, vocab_size)
one_hot_encoded.scatter_(2, random_indices.unsqueeze(2), 1)
print(one_hot_encoded.shape)

tensor([[13021, 25593, 16561,  ..., 28478, 25387,   757],
        [19154,  4540, 27802,  ..., 18915,  8087, 25797]])
torch.Size([2, 512, 32000])


In [34]:
idx_input = torch.argmax(one_hot_encoded, axis=-1) #fetus model takes a tensor, generator takes a list.  

In [35]:
idx_input.shape

torch.Size([2, 512])

In [36]:
model(idx_input)

tensor([[ 0.0304,  0.0125, -0.0559,  ...,  0.0503, -0.0341,  0.0794],
        [ 0.0303,  0.0126, -0.0560,  ...,  0.0503, -0.0342,  0.0794]],
       grad_fn=<AddmmBackward0>)

In [37]:
torch.argmax(model(idx_input), dim=-1)

tensor([24874, 24874])

In [27]:
idx_input[0]

tensor([19973,  3694,   140, 12737,  3060, 29965, 23193,  3293,  1321,  2493,
        16239, 24376, 31604, 24137, 10259,  2898, 30211, 29238,  2965,  8059,
        31514, 28645, 20094, 19042, 31256, 13264, 12771, 27525,  9089, 26329,
        16461, 21398,  7926, 26984, 30109, 10448,  1094, 26764, 14185,  9400,
        12472,   337, 20145, 24089, 17917, 16942,  8211, 28092,  8255,   356,
        10665, 29264, 31889,  6077,    50, 14650, 15671, 28255, 15574, 23188,
         1954, 23314, 24578, 22272, 25612, 22839, 19900, 26285,  2383,  8498,
         2700, 18496, 23588,  5036, 11067,  5975, 31663, 19172, 26037,  1877,
        11842, 18865, 24749, 23574,  7760,  8951,  7060,   345,  5866, 27506,
         4404, 24789,  5200, 17799, 28047, 23880, 19791, 26798,   745,  8222,
          678, 14529, 19846, 23874, 25631, 28115, 14881,  8587,  5909, 20764,
        16336,  6318, 10249,  7472,  8744,  6309,  8787, 19301, 27888, 20666,
        31941,  9937, 21252, 28722,  5343, 29080, 12102, 12483, 

In [30]:
def fetus_generate_sentence(model,token_string,max_tokens=100,temperature=0.0):
    top_p = 0.9 #from llama
    for i in range(max_tokens):
        test_in = torch.unsqueeze(torch.tensor(token_string),0).to(torch.long)
        #print(test_in.shape)
        print(test_in.shape)
        test_in_len = test_in.shape[1]
        if test_in_len < 512:
            pad_size = 512 - test_in_len 
            
            padding = torch.zeros((1, pad_size), dtype=test_in.dtype)
            test_in = torch.cat((test_in, padding), dim=1)
        #print(test_in.shape, i)  
        test_res = model.forward(test_in)
        print(test_res.shape)
        
        last_token = test_res.squeeze()
        #print(last_token.shape)
        #print(last_token.shape)
        if temperature > 0.0:
            probs = torch.softmax(last_token / temperature, dim = -1)
            next_token = sample_top_p(probs, top_p).item()
            #print(next_token)

        else:
            next_token = torch.argmax(last_token).item()
            print(next_token,last_token[next_token])
        #app_res = torch.argmax(last_token).item()
        token_string.append(next_token)
        #print(test_tokens)
        if(next_token ==2):
            break
    res = generator.tokenizer.decode(token_string)
    return res, token_string

In [31]:
this_test = prompt_tokens[0][:18]

res, string = fetus_generate_sentence(model,this_test,max_tokens=10,temperature=0)

torch.Size([1, 18])


RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (1) for operand 0 and no ellipsis was given

In [29]:
res

'[INST] what is the recipe of mayonnaise? [/INST] materials materials materials materials materials materials materials materials materials materials'

In [30]:
prompt_tokens[0][:18]

[1,
 518,
 25580,
 29962,
 825,
 338,
 278,
 9522,
 412,
 310,
 1122,
 11586,
 895,
 29973,
 518,
 29914,
 25580,
 29962]