In [None]:
%load_ext autoreload
%autoreload 2
from ridge_utils.DataSequence import DataSequence
import pandas as pd
import matplotlib.pyplot as plt
from os.path import dirname
import os
from tqdm import tqdm
from neuro.features import qa_questions, feature_spaces
from neuro.data import story_names, response_utils
from neuro.features.stim_utils import load_story_wordseqs, load_story_wordseqs_huge
import neuro.config
import seaborn as sns
import numpy as np
import joblib
from collections import defaultdict
from os.path import join

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from torch.utils.data import Dataset, DataLoader
import decode
data_dir = join(neuro.config.repo_dir, 'data', 'decoding')

In [10]:
class PromptMapper(nn.Module):
    def __init__(self, input_dim, prompt_length, hidden_size):
        super(PromptMapper, self).__init__()
        # A simple MLP that outputs a flattened prompt vector which is reshaped into (prompt_length, hidden_size)
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_size * prompt_length),
            nn.ReLU(),
            nn.Linear(hidden_size * prompt_length, hidden_size * prompt_length)
        )
        self.prompt_length = prompt_length
        self.hidden_size = hidden_size

    def forward(self, x):
        # x is of shape (batch_size, d)
        out = self.fc(x)  # (batch_size, prompt_length * hidden_size)
        out = out.view(-1, self.prompt_length, self.hidden_size)
        return out


class PromptDataset(Dataset):
    def __init__(self, vectors, sentences, tokenizer, max_length=50):
        self.vectors = vectors
        self.sentences = sentences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.vectors)

    def __getitem__(self, idx):
        vec = torch.tensor(self.vectors[idx])
        sentence = self.sentences[idx]
        # Tokenize the sentence. We use padding/truncation to a fixed length.
        tokenized = self.tokenizer(sentence,
                                   return_tensors="pt",
                                   padding="max_length",
                                   max_length=self.max_length,
                                   truncation=True)
        input_ids = tokenized.input_ids.squeeze(
            0)        # Shape: (max_length,)
        attention_mask = tokenized.attention_mask.squeeze(
            0)  # Shape: (max_length,)
        return vec.to('cuda'), input_ids.to('cuda'), attention_mask.to('cuda')

In [None]:
# data
n = 32
data_by_subject = decode.load_data_by_subject(data_dir)
data = data_by_subject['uts03']
vectors = data['df_train'].values.astype(np.float32)
sentences = data['texts_train'].values.tolist()
vectors = vectors[:n]

# pre-trained model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B", device_map="auto",)
model.eval()  # We'll freeze the model so we only update the mapping network

# model params
d = 800
prompt_length = 6
batch_size = 32
hidden_size = model.config.hidden_size
mapper = PromptMapper(d, prompt_length, hidden_size).to('cuda')
for param in model.parameters():
    param.requires_grad = False

tokenizer.pad_token = tokenizer.eos_token
dataset = PromptDataset(vectors, sentences, tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(mapper.parameters(), lr=1e-4)

# We'll compute the loss only for the target sentence tokens.
# To do so, we create labels that are -100 for the soft prompt positions.
num_epochs = 10
max_length = max([len(s) for s in sentences])  # from our dataset

for epoch in range(num_epochs):
    for vecs, input_ids, attention_mask in tqdm(dataloader):
        batch_size = input_ids.shape[0]

        # Get the soft prompt embeddings from our mapper network.
        # soft_prompt: (batch_size, prompt_length, hidden_size)
        soft_prompt = mapper(vecs)

        # Obtain the input token embeddings from the LLaMA model.
        token_embeddings = model.get_input_embeddings()(
            input_ids)  # (batch_size, max_length, hidden_size)

        # Concatenate the soft prompt with the token embeddings.
        # New input: [soft prompt tokens] + [tokenized sentence]
        inputs_embeds = torch.cat(
            [soft_prompt, token_embeddings], dim=1)  # .to('cuda')

        # Adjust the attention mask: add ones for the soft prompt positions.
        prompt_mask = torch.ones(
            (batch_size, prompt_length), dtype=attention_mask.dtype, device=attention_mask.device)
        extended_attention_mask = torch.cat(
            [prompt_mask, attention_mask], dim=1)

        # Create labels: we ignore the soft prompt positions (set them to -100) so that loss is only computed for the sentence tokens.
        labels = torch.cat([
            torch.full((batch_size, prompt_length), -100,
                       dtype=input_ids.dtype, device=input_ids.device),
            input_ids
        ], dim=1)

        # Forward pass through the model.
        outputs = model(inputs_embeds=inputs_embeds,
                        attention_mask=extended_attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print decoded sentence examples
        if epoch % 5 == 0:
            with torch.no_grad():
                soft_prompt = mapper(vecs)
                inputs_embeds = torch.cat(
                    [soft_prompt, token_embeddings], dim=1)
                outputs = model.generate(
                    inputs_embeds=inputs_embeds,
                    attention_mask=extended_attention_mask,
                    max_length=max_length,
                    min_length=4,
                    do_sample=False)
                decoded_sentences = tokenizer.batch_decode(outputs[:, prompt_length:],
                                                           skip_special_tokens=True)
                for i in range(min(3, len(decoded_sentences))):
                    print(f"GT sentence {i+1}: {sentences[i]}")
                    print(f"Decoded sentence {i+1}: {decoded_sentences[i]}")
                    print()

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}")

    # save the model
    torch.save(mapper.state_dict(), f"mapper_epoch_{epoch+1}.pth")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
model.generate(
    inputs_embeds=inputs_embeds,
    attention_mask=extended_attention_mask,
    max_length=max_length,
    min_length=20,
    do_sample=True)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


tensor([[128001],
        [128001],
        [128001],
        [128001],
        [128001],
        [128001],
        [128001],
        [128001],
        [128001],
        [128001],
        [128001],
        [128001]], device='cuda:0')

In [None]:
# -----------------------------
# 5. Inference Example
# -----------------------------
# To use the learned mapping:
# Given a new vector, map it to a continuous prompt and generate the corresponding sentence.
mapper.eval()
with torch.no_grad():
    vecs_test = data['df_test'].values.astype(np.float32)
    sentences_test = data['texts_train'].values.tolist()
    new_vector = torch.tensor(vecs_test[0]).unsqueeze(
        0).to('cuda')  # shape (1, d)

    # new_vector = torch.tensor(np.random.randn(d).astype(
    # np.float32)).unsqueeze(0)  # shape (1, d)
    soft_prompt = mapper(new_vector)  # (1, prompt_length, hidden_size)

    # Start with an empty input (or a beginning-of-sentence token, if desired).
    # Here we assume the model will generate based on the soft prompt.
    dummy_input = torch.tensor(
        [[tokenizer.bos_token_id]], device=soft_prompt.device)
    token_embeddings = model.get_input_embeddings()(dummy_input)

    inputs_embeds = torch.cat([soft_prompt, token_embeddings], dim=1)

    prompt_mask = torch.ones(
        (1, prompt_length), dtype=torch.long, device=soft_prompt.device)
    dummy_mask = torch.ones_like(dummy_input)
    extended_attention_mask = torch.cat([prompt_mask, dummy_mask], dim=1)

    # Generate output tokens (adjust generation parameters as needed)
    generated_ids = model.generate(inputs_embeds=inputs_embeds,
                                   attention_mask=extended_attention_mask,
                                   max_length=max_length + prompt_length)

    # Decode the generated tokens (skipping the soft prompt positions)
    # Since our soft prompt is continuous, the generated text typically follows after the dummy token.
    output_text = tokenizer.decode(
        generated_ids[0][prompt_length:], skip_special_tokens=True)
    print('Original text:', sentences_test[0])
    print("Generated sentence:", output_text)

In [None]:
tokenizer.decode(generated_ids[0])