In [None]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split

from gsm_dataset import GSMDataset, gsm_collate, gsm_prompt, sample
from biscuit import Biscuit

In [None]:
biscuit_model = Biscuit()
checkpoint_path = None # Add checkpoint path here
biscuit_model.latent_trunk.load_state_dict(torch.load(checkpoint_path))

In [None]:
biscuit_model.bot_embedding = torch.load('BOT_EMBEDDING_PATH')
biscuit_model.eot_embedding = torch.load('EOT_EMBEDDING_PATH')

In [None]:
dataset = GSMDataset()

train_size = int(0.9 * len(dataset))
example_size = int(0.02 * len(dataset)) # reserve some data for few shot prompting
test_size = len(dataset) - train_size - example_size

train_dataset, example_dataset, test_dataset = random_split(dataset, [train_size, example_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=gsm_collate)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, collate_fn=gsm_collate)

In [None]:
COT_MAX_LENGTH = 6
segments, keep_indices_lst = next(iter(train_loader))
examples = sample(example_dataset, num_samples=4)
prompt = gsm_prompt(examples)

softmax = nn.Softmax(dim=1)

with torch.no_grad():
    # Step 0: just process the first segment without decoding the next token
    for seg in segments[0]:
        print(seg)
    first_segment = [prompt + segment for segment in segments[0]]
    inputs = biscuit_model.tokenizer(first_segment, return_tensors="pt", padding=True).to(biscuit_model.device)
    outputs = biscuit_model.token_trunk(**inputs)
    kv_cache = outputs.past_key_values
    attn_mask = inputs.attention_mask

    # continuous CoT loop: produce CoT -> use it to predict next segment -> repeat
    for segment, keep_indices in zip(segments[1:], keep_indices_lst):
        # Step 1: drop sequences that are done
        kv_cache.batch_select_indices(keep_indices)
        attn_mask = attn_mask[keep_indices]
        batch_size = keep_indices.shape[0]
        attn_ones = torch.ones(batch_size, 1, dtype=int).to(biscuit_model.device)


        # Step 2: then autoregressively predict a continuous chain of thought sequence
        last_hidden_state = None
        k = np.random.randint(1, COT_MAX_LENGTH + 1) # the CoT sequence has a random length
        print("num_latents: ", k)
        text_output = [[] for _ in range(batch_size)]
        for i in range(k + 2):
            attn_mask = torch.cat((attn_mask, attn_ones), dim=1)
            if i == 0 or i == k + 1: # process beginning of thought or end of thought token
                inp = biscuit_model.bot_embedding if i == 0 else biscuit_model.eot_embedding
                outputs = biscuit_model.token_trunk(inputs_embeds=inp.repeat(batch_size, 1, 1), attention_mask=attn_mask, 
                                               past_key_values=kv_cache)
            else: # process new continuous thought token
                outputs = biscuit_model.latent_trunk(inputs_embeds=last_hidden_state, attention_mask=attn_mask, 
                                            past_key_values=kv_cache)
            last_hidden_state = outputs.hidden_states[-1][:, -1:]
            next_token = biscuit_model.tokenizer.batch_decode(torch.multinomial(softmax(outputs.logits[:, -1]), 1))
            text_output = [a + [b] for a, b in zip(text_output, next_token)]
            kv_cache = outputs.past_key_values
        for i, a in enumerate(text_output):
            print(f"latent {i}", a)

        key_cache_copy = [t.clone() for t in kv_cache.key_cache]
        value_cache_copy = [t.clone() for t in kv_cache.value_cache]

        text_output = [' ' for _ in range(batch_size)]
        next_token = text_output.copy()
        temp_mask = attn_mask.clone()
        for _ in range(50):
            inputs = biscuit_model.tokenizer(next_token, return_tensors="pt").to(biscuit_model.device)
            temp_mask = torch.cat((temp_mask, attn_ones), dim=1)
            outputs = biscuit_model.token_trunk(input_ids=inputs.input_ids, 
                                          attention_mask=temp_mask, 
                                          past_key_values=kv_cache)
            next_token = biscuit_model.tokenizer.batch_decode(torch.multinomial(softmax(outputs.logits[:, -1]), 1))
            text_output = [a + b for a, b in zip(text_output, next_token)]
        for i, a, b in zip(range(len(text_output)), text_output, segment):
            print(i)
            print("model output:", a)
            print('real:', b)

        kv_cache.key_cache = key_cache_copy
        kv_cache.value_cache = value_cache_copy

        # pad on the right side so that the CoT and the new input are contiguous
        inputs = biscuit_model.tokenizer(segment, return_tensors="pt", padding=True, 
                                padding_side='right').to(biscuit_model.device)
        attn_mask = torch.cat((attn_mask, inputs.attention_mask), dim=1)
        outputs = biscuit_model.token_trunk(input_ids=inputs.input_ids, attention_mask=attn_mask, past_key_values=kv_cache)
        kv_cache = outputs.past_key_values

In [None]:
logits = biscuit_model.token_trunk.lm_head(biscuit_model.eot_embedding)
biscuit_model.tokenizer.batch_decode(torch.multinomial(softmax(logits), 40))