In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
import torch

In [2]:
torch.cuda.empty_cache()

In [3]:
model = AutoModelForCausalLM.from_pretrained("apple/OpenELM-270M-Instruct", trust_remote_code=True).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Dolphin-Llama2-7B-GPTQ")

model.eval()
tokenizer.pad_token = tokenizer.eos_token

In [4]:
def filter_logits(logits, attention_mask):
    last_unpadded_indices = attention_mask.sum(dim=1) - 1

    filtered_logits = []

    for i in range(logits.size(0)):
        filtered_logits.append(logits[i, :last_unpadded_indices[i] + 1, :])
    
    max_length = max(logit.size(0) for logit in filtered_logits)
    padded_logits = torch.zeros((logits.size(0), max_length, logits.size(2)), device=logits.device)
    
    for i, logit in enumerate(filtered_logits):
        padded_logits[i, :logit.size(0), :] = logit
    
    return padded_logits

In [5]:
def generate(model, tokenizer, context_ids, attention_mask, unique_token_ids, n=50, temperature=0.1, repetition_penalty=1.1):
    with torch.inference_mode(), torch.cuda.amp.autocast():
        for _ in range(n):
            logits = model(input_ids=context_ids, attention_mask=attention_mask).logits
            filtered_logits = filter_logits(logits, attention_mask)

            last_logits = filtered_logits[torch.arange(filtered_logits.size(0)), attention_mask.sum(dim=1) - 1, :]

            mask = torch.full_like(last_logits, -float('inf'))
            mask[:, unique_token_ids] = 0
            next_token_logits = last_logits + mask

            next_token_probs = F.softmax(next_token_logits / temperature, dim=-1)

            for i in range(context_ids.shape[1]):
                next_token_probs[:, context_ids[:, i]] /= repetition_penalty

            flattened_probs = next_token_probs.view(-1)
            max_prob_index = torch.argmax(flattened_probs)

            vocab_size = next_token_probs.shape[1]
            token_index = max_prob_index % vocab_size

            next_token = token_index.unsqueeze(0)

            insert_positions = torch.sum(attention_mask, dim=1).unsqueeze(-1)

            new_context_ids = []
            new_attention_mask = []

            for batch_idx in range(context_ids.size(0)):
                position = insert_positions[batch_idx].item()
                
                new_sequence = torch.cat((
                    context_ids[batch_idx, :position],
                    next_token,
                    context_ids[batch_idx, position+1:]
                ), dim=0)
                new_context_ids.append(new_sequence)
                
                new_mask = attention_mask[batch_idx].clone()
                new_mask[position] = 1
                new_attention_mask.append(new_mask)

            context_ids = torch.stack(new_context_ids)
            attention_mask = torch.stack(new_attention_mask)

            if '"' in tokenizer.decode(next_token):
                break

    return context_ids

In [6]:
needle = '''human said, "george's birthday is on the 31st of may". agent replied, "okay, i'll remember that".
human said, "when is george's birthday?" agent replied, "on the'''

haystack = '''human said, "hello". agent replied, "hi, how are you?".
human said, "when is george's birthday?" agent replied, "on the'''

context_texts = [haystack] * 99 + [needle]

encoded_output = tokenizer(context_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=200)
encoded_output = {k: v.to("cuda") for k, v in encoded_output.items()}

context_ids = encoded_output['input_ids']
attention_mask = encoded_output['attention_mask']

concatenated_ids = torch.cat([context_ids], dim=0)
unique_token_ids = list(set(concatenated_ids.view(-1).tolist()))

pad_token_id = tokenizer.pad_token_id

In [7]:
with torch.inference_mode():
    output = generate(model, tokenizer, context_ids, attention_mask, unique_token_ids, 25)
print(tokenizer.decode(output[0], skip_special_tokens=True))

human said, "hello". agent replied, "hi, how are you?".
human said, "when is george's birthday?" agent replied, "on the 31st of may".
