In [1]:
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, BartForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0')

In [3]:
model_name = "Qwen/Qwen2.5-1.5B"

llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
llm_tokenizer.pad_token = llm_tokenizer.eos_token

base_llm = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)

In [4]:
inv_model_name = "facebook/bart-base"

inv_tokenizer = AutoTokenizer.from_pretrained(inv_model_name)

inv_base_model = BartForConditionalGeneration.from_pretrained(inv_model_name, device_map=device)

In [5]:
from train.sgt_model import SGTModel

sgt = SGTModel(1536, 8, 2, 1, None, None, None, None).to(device)

sgt.load_state_dict(torch.load('/home/alex/research/stained-glass-transform-pytorch/train/checkpoints/best_sgt.pt'))

mu_init_weight is None - skipping mu_head.weight initialization
mu_init_bias is None - skipping mu_head.bias initialization
logvar_init_weight is None - skipping logvar_head.weight initialization
logvar_init_bias is None - skipping logvar_head.bias initialization


<All keys matched successfully>

In [6]:
class LLM(torch.nn.Module):
    def __init__(self, sgt, base_llm):
        super().__init__()

        self.sgt = sgt
        self.base_llm = base_llm

    def forward(self, input_ids, attention_mask, **kwargs):
        embeds = self.base_llm.model.embed_tokens(input_ids)
        embeds, _, _ = self.sgt.sample(embeds, attention_mask=attention_mask)

        return self.base_llm(inputs_embeds=embeds, attention_mask=attention_mask, output_hidden_states=True)

In [7]:
from peft import LoraConfig, get_peft_model


class InversionModel(torch.nn.Module):
    def __init__(self, model, tokenizer, input_features_d):
        super().__init__()

        self.tokenizer = tokenizer
        self.model = model

        self.proj = torch.nn.Linear(input_features_d, model.model.encoder.embed_tokens.embedding_dim)

        lora_cfg = LoraConfig(
            r=4,
            lora_alpha=16,
            lora_dropout=0.1,
            bias="none",
            task_type="SEQ_2_SEQ_LM",
            target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
        )

        self.model = get_peft_model(self.model, lora_cfg)
        self.model.print_trainable_parameters()

    def forward(self, encoder_embeds, encoder_attention_mask, labels, **kwargs):
        transformed_embeds = self.proj(encoder_embeds)

        return self.model(
            inputs_embeds=transformed_embeds, 
            attention_mask=encoder_attention_mask, 
            labels=labels,
            **kwargs
        )

    def generate(self, encoder_embeds, encoder_attention_mask, **generation_kwargs):
        transformed_embeds = self.proj(encoder_embeds)
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                inputs_embeds=transformed_embeds,
                attention_mask=encoder_attention_mask,
                **generation_kwargs
            )
        
        return generated_ids
    
    def generate_text(self, encoder_embeds, encoder_attention_mask, **generation_kwargs):
        generated_ids = self.generate(encoder_embeds, encoder_attention_mask, **generation_kwargs)
        
        generated_texts = self.tokenizer.batch_decode(
            generated_ids, 
            skip_special_tokens=True
        )
        
        return generated_texts

In [8]:
llm = LLM(sgt, base_llm).eval().to(device)
inv_model = InversionModel(inv_base_model, inv_tokenizer, 1536).eval().to(device)
inv_model.load_state_dict(torch.load('/home/alex/research/latest.pt'))

trainable params: 442,368 || all params: 139,862,784 || trainable%: 0.3163


<All keys matched successfully>

In [9]:
text = "Daniel Kids has 1000$ on his bank account"

In [10]:
llm_tokenizer(text, return_tensors='pt')

{'input_ids': tensor([[40586, 22522,   702,   220,    16,    15,    15,    15,     3,   389,
           806,  6073,  2692]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [11]:
tokenized = llm_tokenizer(text, return_tensors='pt')

outputs = llm(**tokenized.to(device))

In [12]:
generated_texts = inv_model.generate_text(
    encoder_embeds=outputs.hidden_states[1], 
    encoder_attention_mask=tokenized['attention_mask'],
    max_length=32,
    do_sample=True,
    temperature=1.0
)

In [13]:
generated_texts

['Dillon Kids have a bank account issue ']