In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from coconut import Coconut
import torch

In [6]:
model = AutoModelForCausalLM.from_pretrained('openai-community/gpt2')
tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens("<|start-latent|>")
tokenizer.add_tokens("<|end-latent|>")
tokenizer.add_tokens("<|latent|>")
latent_id = tokenizer.convert_tokens_to_ids("<|latent|>")
start_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
end_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")

In [7]:
# if we need new tokens, initialize their embeddings and lm heads
model.resize_token_embeddings(len(tokenizer))
embeddings = model.get_input_embeddings()
target_id = tokenizer.convert_tokens_to_ids("<<")
# initialize the new token embeddings with a known token
# it helps stablize the training
for token_id in [latent_id, start_id, end_id]:
    target_embedding = embeddings.weight.data[token_id]
    embeddings.weight.data[token_id] = target_embedding
    # The input embeddings and lm heads are tied in GPT2. So the code below is not necessary
    lm_head = model.lm_head
    lm_head.weight.data[token_id] = lm_head.weight.data[target_id]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [8]:
model = Coconut(model, latent_id, start_id, end_id, tokenizer.eos_token_id)

In [10]:
saved_weights = torch.load(
    'YOUR_PATH_TO_SAVE_THE_MODEL/gsm-coconut/checkpoint_25', map_location=torch.device(0)
)
model.load_state_dict(saved_weights, strict=False)

  saved_weights = torch.load(


<All keys matched successfully>

In [12]:
model = model.to(0)

In [13]:
model

Coconut(
  (base_causallm): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50260, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2SdpaAttention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_f

In [76]:
tokens = tokenizer("I have 7 apples and I gave Luke 3 apples and he gives Paul 2 apples. How many apples do I have?")
output = model.generate(
    torch.tensor(tokens['input_ids'], device='cuda:0').unsqueeze(0),
    torch.tensor(tokens['attention_mask'], device='cuda:0').unsqueeze(0),
)
tokenizer.decode(output[0])

'I have 7 apples and I gave Luke 3 apples and he gives Paul 2 apples. How many apples do I have?### 4'

In [56]:
output[0]

tensor([   40,   423,   767, 22514,   290,   314,  2921, 11336,   513, 22514,
          290,   339,  3607,  3362,   362, 22514,    13,  1374,   867, 22514,
          466,   314,   423,    30, 21017,   604])

In [52]:
model_cot = AutoModelForCausalLM.from_pretrained('openai-community/gpt2')
saved_weights = torch.load(
    'YOUR_PATH_TO_SAVE_THE_MODEL/gsm-cot/checkpoint_15', map_location=torch.device(0)
)
model_cot.load_state_dict(saved_weights, strict=False)
model_cot = model_cot.to(0)

  saved_weights = torch.load(


In [53]:
tokenizer_cot = AutoTokenizer.from_pretrained('openai-community/gpt2')
#tokenizer.pad_token = tokenizer.eos_token

In [78]:
tokens = tokenizer_cot("I have 7 apples and I gave Luke 3 apples and he gives Paul 2 apples. How many apples do I have?")
output = model_cot.generate(
    torch.tensor(tokens['input_ids'], device='cuda:0').unsqueeze(0),
    attention_mask=torch.tensor(tokens['attention_mask'], device='cuda:0').unsqueeze(0),
    max_new_tokens=40,
    num_beams=5,
    num_return_sequences=5,
)
print('\n\n'.join(tokenizer_cot.decode(out) for out in output))

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


I have 7 apples and I gave Luke 3 apples and he gives Paul 2 apples. How many apples do I have?
<<3+2=5>>
<<7+5=12>>
### 12<|endoftext|>

I have 7 apples and I gave Luke 3 apples and he gives Paul 2 apples. How many apples do I have?
<<3+2=5>>
<<7-5=2>>
### 2<|endoftext|>

I have 7 apples and I gave Luke 3 apples and he gives Paul 2 apples. How many apples do I have?
<<7-3-2=2>>
### 2<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>

I have 7 apples and I gave Luke 3 apples and he gives Paul 2 apples. How many apples do I have?
<<7-3=4>>
<<4-2=2>>
### 2<|endoftext|>

I have 7 apples and I gave Luke 3 apples and he gives Paul 2 apples. How many apples do I have?<<3+2=5>>
<<7-5=2>>
### 2<|endoftext|><|endoftext|>


In [68]:
output

tensor([[   40,   423,   767, 22514,   290,   314,  2921, 11336,   513, 22514,
           290,   339,  3607,  3362,   362, 22514,    13,  1374,   867, 22514,
           466,   314,   423,    30,   198, 16791,    18,    10,    17,    28,
            20,  4211,   198, 16791,    22,    10,    20,    28,  1065,  4211,
           198, 21017,  1105, 50256],
        [   40,   423,   767, 22514,   290,   314,  2921, 11336,   513, 22514,
           290,   339,  3607,  3362,   362, 22514,    13,  1374,   867, 22514,
           466,   314,   423,    30,   198, 16791,    18,    10,    17,    28,
            20,  4211,   198, 16791,    22,    12,    20,    28,    17,  4211,
           198, 21017,   362, 50256],
        [   40,   423,   767, 22514,   290,   314,  2921, 11336,   513, 22514,
           290,   339,  3607,  3362,   362, 22514,    13,  1374,   867, 22514,
           466,   314,   423,    30,   198, 16791,    22,    12,    18,    12,
            17,    28,    17,  4211,   198, 21017,   36

In [47]:
output['logits']

tensor([[[-10.4375, -10.1337, -11.9935,  ..., -14.0181, -14.0241,  -7.5947],
         [-43.2610, -48.4011, -51.0240,  ..., -50.1080, -48.0704, -42.8104],
         [-34.1602, -33.7282, -40.7938,  ..., -41.7564, -39.8034, -29.6038],
         ...,
         [-46.7081, -45.7510, -48.9980,  ..., -54.1231, -53.9510, -43.4509],
         [-48.9440, -48.5899, -51.6528,  ..., -54.6584, -55.6218, -43.4965],
         [-39.4913, -37.9772, -41.3923,  ..., -45.1985, -43.9517, -25.0955]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)