In [None]:
import numpy as np
from tqdm import tqdm, trange
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split


from gsm_dataset import GSMDataset, gsm_collate, gsm_prompt, sample
from biscuit import Biscuit

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
checkpoint_path = f'checkpoints/epoch_9.pth'

biscuit_model = Biscuit()
biscuit_model.model.load_state_dict(torch.load(checkpoint_path))

<All keys matched successfully>

In [13]:
dataset = GSMDataset()

train_size = int(0.9 * len(dataset))
example_size = int(0.02 * len(dataset)) # reserve some data for few shot prompting
test_size = len(dataset) - train_size - example_size

train_dataset, example_dataset, test_dataset = random_split(dataset, [train_size, example_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=gsm_collate)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, collate_fn=gsm_collate)

In [40]:
COT_MAX_LENGTH = 6
segments, keep_indices_lst = next(iter(train_loader))
examples = sample(example_dataset, num_samples=4)
prompt = gsm_prompt(examples)

softmax = nn.Softmax(dim=1)

with torch.no_grad():
    # Step 0: just process the first segment without decoding the next token
    for seg in segments[0]:
        print(seg)
    first_segment = [prompt + segment for segment in segments[0]]
    inputs = biscuit_model.tokenizer(first_segment, return_tensors="pt", padding=True).to(biscuit_model.device)
    outputs = biscuit_model.model(**inputs)
    kv_cache = outputs.past_key_values
    attn_mask = inputs.attention_mask

    # continuous CoT loop: produce CoT -> use it to predict next segment -> repeat
    for segment, keep_indices in zip(segments[1:], keep_indices_lst):
        # Step 1: drop sequences that are done
        kv_cache.batch_select_indices(keep_indices)
        attn_mask = attn_mask[keep_indices]
        batch_size = keep_indices.shape[0]
        attn_ones = torch.ones(batch_size, 1, dtype=int).to(biscuit_model.device)


        # Step 2: then autoregressively predict a continuous chain of thought sequence
        last_hidden_state = None
        k = np.random.randint(1, COT_MAX_LENGTH + 1) # the CoT sequence has a random length
        print(k)
        for i in range(k + 2):
            attn_mask = torch.cat((attn_mask, attn_ones), dim=1)
            if i == 0 or i == k + 1: # process beginning of thought or end of thought token
                seq = [biscuit_model.bot if i == 0 else biscuit_model.eot] * batch_size
                inputs = biscuit_model.tokenizer(seq, return_tensors="pt").to(biscuit_model.device)
                args = {'input_ids': inputs.input_ids}
            else: # process new continuous thought token
                args = {'inputs_embeds': last_hidden_state}

            outputs = biscuit_model.model(**args, attention_mask=attn_mask, past_key_values=kv_cache)
            last_hidden_state = outputs.hidden_states[-1][:, -1:]
            kv_cache = outputs.past_key_values

        key_cache_copy = [t.clone() for t in kv_cache.key_cache]
        value_cache_copy = [t.clone() for t in kv_cache.value_cache]

        text_output = [' ' for _ in range(batch_size)]
        next_token = text_output.copy()
        temp_mask = attn_mask.clone()
        for _ in range(50):
            inputs = biscuit_model.tokenizer(next_token, return_tensors="pt").to(biscuit_model.device)
            temp_mask = torch.cat((temp_mask, attn_ones), dim=1)
            outputs = biscuit_model.model(input_ids=inputs.input_ids, 
                                          attention_mask=temp_mask, 
                                          past_key_values=kv_cache)
            next_token = biscuit_model.tokenizer.batch_decode(torch.multinomial(softmax(outputs.logits[:, -1]), 1))
            text_output = [a + b for a, b in zip(text_output, next_token)]
        for a, b in zip(text_output, segment):
            print("model output:", a)
            print('real:', b)

        kv_cache.key_cache = key_cache_copy
        kv_cache.value_cache = value_cache_copy

        # pad on the right side so that the CoT and the new input are contiguous
        inputs = biscuit_model.tokenizer(segment, return_tensors="pt", padding=True, 
                                padding_side='right').to(biscuit_model.device)
        attn_mask = torch.cat((attn_mask, inputs.attention_mask), dim=1)
        outputs = biscuit_model.model(input_ids=inputs.input_ids, attention_mask=attn_mask, past_key_values=kv_cache)
        kv_cache = outputs.past_key_values

Question: Jimmy and Irene go shopping for clothes on a Tuesday, where senior citizens get a 10% discount on their purchases.  Jimmy picks out 3 shorts from the $15 rack.  Irene grabs 5 shirts from the $17 rack.  How much money do they give to the cashier?

Answer: Jimmy’s shorts cost 3 x $15 = $
Question: Emma has saved $230 in her bank account. She withdrew $60 to buy a new pair of shoes. The next week, she deposited twice as much money as she withdrew. How much is in her bank account now?

Answer: Emma had $230 - $60 = $
Question: Angela has a collection of 24 pieces of rare action figures. She sold off a quarter of them at the pawnshop and gave one-third of the remainder to her daughter. How many does she have left?

Answer: One-quarter of 24 action figures is 24*(1/4) = 
Question: A mailman has to deliver 48 pieces of junk mail.  There are 8 houses on the block.  2 of the houses have white mailboxes and 3 have red mailboxes.  How many pieces of junk mail will each of those houses g

In [None]:
biscuit_model.model()

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

In [10]:
print(inspect.getsource(biscuit_model.model.generate))

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        assistant_model: Optional["PreTrainedModel"] = None,
        streamer: Optional["BaseStreamer"] = None,
        negative_prompt_ids: Optional[torch.Tensor] = None,
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        use_model_defaults: Optional[bool] = None,
        custom_generate: Optional[str] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        r"""

        Generates sequences of token ids for models with a language modeling head.


        Most generation-controlling parameters are set in `gene

In [None]:
biscuit_model.tokenizer('bot').input_ids

{'input_ids': [6331], 'attention_mask': [1]}