In [None]:
%load_ext autoreload
%autoreload 2
from ridge_utils.DataSequence import DataSequence
import pandas as pd
import matplotlib.pyplot as plt
from os.path import dirname
import os
from tqdm import tqdm
from neuro.features import qa_questions, feature_spaces
from neuro.data import story_names, response_utils
from neuro.features.stim_utils import load_story_wordseqs, load_story_wordseqs_huge
import neuro.config
import seaborn as sns
import numpy as np
import joblib
from collections import defaultdict
from os.path import join

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from torch.utils.data import Dataset, DataLoader
import decode
data_dir = join(neuro.config.repo_dir, 'data', 'decoding')

data_by_subject = decode.load_data_by_subject(data_dir)
data = data_by_subject['uts03']

# optimize prompts

In [None]:
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM


class PromptTuningModel(nn.Module):
    def __init__(
        self, model_name,
            # prefix_text="Repeat this word:",
            # suffix_text="\nRepeated word:",
            prefix_text='<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nDecode this message:',
            suffix_text='<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nDecoded message:',
            prompt_length=1,
    ):
        super().__init__()
        # Load tokenizer and model.
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="auto")
        self.model.eval()  # freeze model weights
        # self.model.model.tokenizer.pad_token_id = self.tokenizer.pad_token_id
        for param in self.model.parameters():
            param.requires_grad = False

        # Determine the embedding dimension from the model.
        self.embedding_dim = self.model.model.embed_tokens.embedding_dim
        self.prompt_length = prompt_length

        # Create a learnable parameter for the continuous prompt (4 tokens).
        self.prompt_embeddings = nn.Parameter(torch.randn(
            prompt_length, self.embedding_dim)).to('cuda')

        # Tokenize the fixed prefix text.
        self.prefix_text = prefix_text
        self.prefix_ids = self.tokenizer.encode(
            prefix_text, return_tensors="pt").squeeze(0).to('cuda')

        # Define suffix text and tokenize it.
        self.suffix_text = suffix_text
        self.suffix_ids = self.tokenizer.encode(
            self.suffix_text, return_tensors="pt").squeeze(0).to('cuda')

    def forward(self, target_sentence):
        """
        For a given target sentence (string), build an input that is:
            "decode this message:" tokens +
            learned continuous prompt embeddings +
            "\nDecoded message:" tokens +
            target sentence tokens.
        The loss is computed only on the target sentence portion.
        """
        # Tokenize the target sentence.
        target_ids = self.tokenizer.encode(
            target_sentence, return_tensors="pt").squeeze(0)

        # Get embeddings for the fixed prefix from the model's embedding layer.
        prefix_embeds = self.model.model.embed_tokens(
            self.prefix_ids.unsqueeze(0))  # shape: [1, prefix_len, emb_dim]

        # Get embeddings for the suffix.
        suffix_embeds = self.model.model.embed_tokens(
            self.suffix_ids.unsqueeze(0))  # shape: [1, suffix_len, emb_dim]

        # Get embeddings for the target sentence tokens.
        target_embeds = self.model.model.embed_tokens(
            target_ids.unsqueeze(0))  # shape: [1, target_len, emb_dim]

        # Expand learned prompt embeddings to batch size.
        learned_prompt = self.prompt_embeddings.unsqueeze(
            0)  # shape: [1, prompt_length, emb_dim]

        # Concatenate: [prefix embeddings] + [learned prompt] + [suffix embeddings] + [target embeddings]
        inputs_embeds = torch.cat(
            [prefix_embeds, learned_prompt, suffix_embeds, target_embeds], dim=1)

        # Build attention mask (all ones).
        attention_mask = torch.ones(inputs_embeds.shape[:-1], dtype=torch.long)

        # Create labels so that only the target sentence tokens contribute to the loss.
        prefix_len = prefix_embeds.shape[1]
        prompt_len = learned_prompt.shape[1]
        suffix_len = suffix_embeds.shape[1]
        total_len = prefix_len + prompt_len + \
            suffix_len + target_embeds.shape[1]
        labels = torch.full((1, total_len), -100, dtype=torch.long)
        labels[0, prefix_len + prompt_len + suffix_len:] = target_ids

        # Forward pass through the model using inputs_embeds.
        outputs = self.model(inputs_embeds=inputs_embeds,
                             attention_mask=attention_mask,
                             labels=labels)
        return outputs.loss, outputs.logits

# Example usage:


# adjust to your model checkpoint as needed
# model_name = "meta-llama/Meta-Llama-3-8B"
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"  # updated model checkpoint
target_sentence = " hello<|eot_id|>"  # the sentence to be decoded by the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize our prompt tuning model.
pt_model = PromptTuningModel(model_name)  # .to(device)
pt_model.prompt_embeddings = nn.Parameter(
    pt_model.prompt_embeddings.detach().clone())

# Set up an optimizer to update only the prompt embeddings.
optimizer = torch.optim.Adam([pt_model.prompt_embeddings], lr=1e-4)

# Training loop: optimize the continuous prompt so that the model’s output
# (when given "decode this message:" + learned prompt) reproduces the target sentence.
num_steps = 1000
for step in range(num_steps):
    optimizer.zero_grad()
    loss, _ = pt_model(target_sentence)
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss.item():.4f}")

        # example generation
        with torch.no_grad():
            # Build input embeddings for generation: prefix + learned prompt.
            prefix_ids = pt_model.prefix_ids.to(device)
            prefix_embeds = pt_model.model.model.embed_tokens(
                prefix_ids.unsqueeze(0))
            learned_prompt = pt_model.prompt_embeddings.unsqueeze(0)
            inputs_embeds = torch.cat([prefix_embeds, learned_prompt], dim=1)
            attention_mask = torch.ones(
                inputs_embeds.shape[:-1], dtype=torch.long).to(device)

            # Generate text (note: adjust generation parameters as needed).
            generated_ids = pt_model.model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                max_new_tokens=5
            )
            generated_text = pt_model.tokenizer.decode(
                generated_ids[0], skip_special_tokens=True)
            print(repr(generated_text))

# After training, test the learned prompt via generation.
with torch.no_grad():
    # Build input embeddings for generation: prefix + learned prompt.
    prefix_ids = pt_model.prefix_ids.to(device)
    prefix_embeds = pt_model.model.model.embed_tokens(prefix_ids.unsqueeze(0))
    learned_prompt = pt_model.prompt_embeddings.unsqueeze(0)
    inputs_embeds = torch.cat([prefix_embeds, learned_prompt], dim=1)
    attention_mask = torch.ones(
        inputs_embeds.shape[:-1], dtype=torch.long).to(device)

    # Generate text (note: adjust generation parameters as needed).
    generated_ids = pt_model.model.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        max_new_tokens=10
    )
    generated_text = pt_model.tokenizer.decode(
        generated_ids[0], skip_special_tokens=True)
    print(repr(generated_text))

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

In [None]:
messages = [
    # {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Decode this message:"},
    {"role": "assistant", "content": "Decoded message:"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)
tokenizer.batch_decode(input_ids)


prefix = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nDecode this message:'
suffix = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nDecoded message:'

In [None]:


terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = model.generate(
    input_ids,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

# E2E

In [None]:
class PromptMapper(nn.Module):
    def __init__(self, input_dim, prompt_length, hidden_size):
        super(PromptMapper, self).__init__()
        # A simple MLP that outputs a flattened prompt vector which is reshaped into (prompt_length, hidden_size)
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_size * prompt_length),
            nn.ReLU(),
            nn.Linear(hidden_size * prompt_length, hidden_size * prompt_length)
        )
        self.prompt_length = prompt_length
        self.hidden_size = hidden_size

    def forward(self, x):
        # x is of shape (batch_size, d)
        out = self.fc(x)  # (batch_size, prompt_length * hidden_size)
        out = out.view(-1, self.prompt_length, self.hidden_size)
        return out


class PromptDataset(Dataset):
    def __init__(self, vectors, sentences, tokenizer, max_length=50):
        self.vectors = vectors
        self.sentences = sentences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.vectors)

    def __getitem__(self, idx):
        vec = torch.tensor(self.vectors[idx])
        sentence = self.sentences[idx]
        # Tokenize the sentence. We use padding/truncation to a fixed length.
        tokenized = self.tokenizer(sentence,
                                   return_tensors="pt",
                                   padding="max_length",
                                   max_length=self.max_length,
                                   truncation=True)
        input_ids = tokenized.input_ids.squeeze(
            0)        # Shape: (max_length,)
        attention_mask = tokenized.attention_mask.squeeze(
            0)  # Shape: (max_length,)
        return vec.to('cuda'), input_ids.to('cuda'), attention_mask.to('cuda')

In [None]:
# data
n = 32
vectors = data['df_train'].values.astype(np.float32)
sentences = data['texts_train'].values.tolist()
vectors = vectors[:n]

# pre-trained model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B", device_map="auto",)
model.eval()  # We'll freeze the model so we only update the mapping network

# model params
d = 800
prompt_length = 6
batch_size = 32
hidden_size = model.config.hidden_size
mapper = PromptMapper(d, prompt_length, hidden_size).to('cuda')
for param in model.parameters():
    param.requires_grad = False

tokenizer.pad_token = tokenizer.eos_token
dataset = PromptDataset(vectors, sentences, tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(mapper.parameters(), lr=1e-4)

# We'll compute the loss only for the target sentence tokens.
# To do so, we create labels that are -100 for the soft prompt positions.
num_epochs = 10
max_length = max([len(s) for s in sentences])  # from our dataset

for epoch in range(num_epochs):
    for vecs, input_ids, attention_mask in tqdm(dataloader):
        batch_size = input_ids.shape[0]

        # Get the soft prompt embeddings from our mapper network.
        # soft_prompt: (batch_size, prompt_length, hidden_size)
        soft_prompt = mapper(vecs)

        # Obtain the input token embeddings from the LLaMA model.
        token_embeddings = model.get_input_embeddings()(
            input_ids)  # (batch_size, max_length, hidden_size)

        # Concatenate the soft prompt with the token embeddings.
        # New input: [soft prompt tokens] + [tokenized sentence]
        inputs_embeds = torch.cat(
            [soft_prompt, token_embeddings], dim=1)  # .to('cuda')

        # Adjust the attention mask: add ones for the soft prompt positions.
        prompt_mask = torch.ones(
            (batch_size, prompt_length), dtype=attention_mask.dtype, device=attention_mask.device)
        extended_attention_mask = torch.cat(
            [prompt_mask, attention_mask], dim=1)

        # Create labels: we ignore the soft prompt positions (set them to -100) so that loss is only computed for the sentence tokens.
        labels = torch.cat([
            torch.full((batch_size, prompt_length), -100,
                       dtype=input_ids.dtype, device=input_ids.device),
            input_ids
        ], dim=1)

        # Forward pass through the model.
        outputs = model(inputs_embeds=inputs_embeds,
                        attention_mask=extended_attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print decoded sentence examples
        if epoch % 5 == 0:
            with torch.no_grad():
                soft_prompt = mapper(vecs)
                inputs_embeds = torch.cat(
                    [soft_prompt, token_embeddings], dim=1)
                outputs = model.generate(
                    inputs_embeds=inputs_embeds,
                    attention_mask=extended_attention_mask,
                    max_length=max_length,
                    min_length=4,
                    do_sample=False)
                decoded_sentences = tokenizer.batch_decode(outputs[:, prompt_length:],
                                                           skip_special_tokens=True)
                for i in range(min(3, len(decoded_sentences))):
                    print(f"GT sentence {i+1}: {sentences[i]}")
                    print(f"Decoded sentence {i+1}: {decoded_sentences[i]}")
                    print()

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}")

    # save the model
    torch.save(mapper.state_dict(), f"mapper_epoch_{epoch+1}.pth")

In [None]:
model.generate(
    inputs_embeds=inputs_embeds,
    attention_mask=extended_attention_mask,
    max_length=max_length,
    min_length=20,
    do_sample=True)

In [None]:
# -----------------------------
# 5. Inference Example
# -----------------------------
# To use the learned mapping:
# Given a new vector, map it to a continuous prompt and generate the corresponding sentence.
mapper.eval()
with torch.no_grad():
    vecs_test = data['df_test'].values.astype(np.float32)
    sentences_test = data['texts_train'].values.tolist()
    new_vector = torch.tensor(vecs_test[0]).unsqueeze(
        0).to('cuda')  # shape (1, d)

    # new_vector = torch.tensor(np.random.randn(d).astype(
    # np.float32)).unsqueeze(0)  # shape (1, d)
    soft_prompt = mapper(new_vector)  # (1, prompt_length, hidden_size)

    # Start with an empty input (or a beginning-of-sentence token, if desired).
    # Here we assume the model will generate based on the soft prompt.
    dummy_input = torch.tensor(
        [[tokenizer.bos_token_id]], device=soft_prompt.device)
    token_embeddings = model.get_input_embeddings()(dummy_input)

    inputs_embeds = torch.cat([soft_prompt, token_embeddings], dim=1)

    prompt_mask = torch.ones(
        (1, prompt_length), dtype=torch.long, device=soft_prompt.device)
    dummy_mask = torch.ones_like(dummy_input)
    extended_attention_mask = torch.cat([prompt_mask, dummy_mask], dim=1)

    # Generate output tokens (adjust generation parameters as needed)
    generated_ids = model.generate(inputs_embeds=inputs_embeds,
                                   attention_mask=extended_attention_mask,
                                   max_length=max_length + prompt_length)

    # Decode the generated tokens (skipping the soft prompt positions)
    # Since our soft prompt is continuous, the generated text typically follows after the dummy token.
    output_text = tokenizer.decode(
        generated_ids[0][prompt_length:], skip_special_tokens=True)
    print('Original text:', sentences_test[0])
    print("Generated sentence:", output_text)

In [None]:
tokenizer.decode(generated_ids[0])