In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration
import torch

# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

# Add special tokens
special_tokens_dict = {'additional_special_tokens': ['<persona>', '<query>', '<answer>', '<eos>']}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))  # Resize for new tokens

# Input text
persona = "<persona> I like pizza. I'm a developer."
query = "<query> What's your favorite food?"
answer = "<answer> I love pizza! <eos>"

# 1. ENCODER INPUT
encoder_input = tokenizer(persona, return_tensors='pt', padding=True, truncation=True)

# 2. DECODER INPUT
decoder_input_text = f"{query} {answer}"
decoder_inputs = tokenizer(decoder_input_text, return_tensors='pt', padding=True, truncation=True)

# 3. LABELS — same as decoder_input_ids but with masked query part
labels = decoder_inputs['input_ids'].clone()
answer_token_id = tokenizer.convert_tokens_to_ids('<answer>')

# Mask everything before and including <answer> token
for i in range(labels.size(0)):
    idx = (labels[i] == answer_token_id).nonzero(as_tuple=True)[0]
    if idx.numel() > 0:
        labels[i, :idx + 1] = -100
    else:
        labels[i, :] = -100  # no <answer> found

# Check the output
print("\n[ENCODER INPUT IDS]")
print(encoder_input['input_ids'])

print("\n[DECODER INPUT IDS]")
print(decoder_inputs['input_ids'])

print("\n[LABELS]")
print(labels)


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.



[ENCODER INPUT IDS]
tensor([[    0, 50265,    38,   101,  9366,     4,    38,   437,    10,  6596,
             4,     2]])

[DECODER INPUT IDS]
tensor([[    0, 50266,   653,    18,   110,  2674,   689,   116,  1437, 50267,
            38,   657,  9366,   328,  1437, 50268,     2]])

[LABELS]
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
            38,   657,  9366,   328,  1437, 50268,     2]])


### Multi turn Dilouge

In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration
import torch

# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

# Add special tokens
special_tokens_dict = {'additional_special_tokens': ['<persona>', '<query>', '<answer>', '<eos>']}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

# === Multi-turn dialogue ===
persona = "<persona> I like pizza. I'm a developer."

conversation = (
    "<query> What's your favorite food? <answer> I love pizza! "
    "<query> What do you do? <answer> I'm a developer. "
    "<query> How was your day? <answer> It was great! <eos>"
)

# Encoder input
encoder_inputs = tokenizer(persona, return_tensors='pt', padding=True, truncation=True)

# Decoder input (teacher forcing)
decoder_inputs = tokenizer(conversation, return_tensors='pt', padding=True, truncation=True)

# === Construct labels by masking query tokens ===
labels = decoder_inputs['input_ids'].clone()

# Get token IDs for special tokens
query_token_id = tokenizer.convert_tokens_to_ids("<query>")
answer_token_id = tokenizer.convert_tokens_to_ids("<answer>")

# Mask all <query>...<answer> spans in the labels
for i in range(labels.size(0)):
    tokens = labels[i]
    mask = torch.ones_like(tokens) * -100  # default: ignore everything
    j = 0

    while j < len(tokens):
        if tokens[j] == answer_token_id:
            # Keep answer and all tokens until next <query> or <eos>
            a_start = j
            next_query = (tokens[j+1:] == query_token_id).nonzero(as_tuple=True)
            next_eos = (tokens[j+1:] == tokenizer.eos_token_id).nonzero(as_tuple=True)

            if next_query[0].numel() > 0:
                a_end = next_query[0][0].item() + j + 1
            elif next_eos[0].numel() > 0:
                a_end = next_eos[0][0].item() + j + 1
            else:
                a_end = len(tokens)

            # Copy answer span into label
            mask[a_start:a_end] = tokens[a_start:a_end]
            j = a_end
        else:
            j += 1

    labels[i] = mask


# Print for debug
print("\nENCODER INPUT IDs:\n", encoder_inputs['input_ids'])
print("\nDECODER INPUT IDs:\n", decoder_inputs['input_ids'])
print(tokenizer.decode(decoder_inputs['input_ids'].squeeze()))
print("\nLABELS (Masked Queries):\n", labels)


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.



ENCODER INPUT IDs:
 tensor([[    0, 50265,    38,   101,  9366,     4,    38,   437,    10,  6596,
             4,     2]])

DECODER INPUT IDs:
 tensor([[    0, 50266,   653,    18,   110,  2674,   689,   116,  1437, 50267,
            38,   657,  9366,   328,  1437, 50266,   653,   109,    47,   109,
           116,  1437, 50267,    38,   437,    10,  6596,     4,  1437, 50266,
          1336,    21,   110,   183,   116,  1437, 50267,    85,    21,   372,
           328,  1437, 50268,     2]])
<s> <query>  What's your favorite food?  <answer>  I love pizza!  <query>  What do you do?  <answer>  I'm a developer.  <query>  How was your day?  <answer>  It was great!  <eos> </s>

LABELS (Masked Queries):
 tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 50267,
            38,   657,  9366,   328,  1437,  -100,  -100,  -100,  -100,  -100,
          -100,  -100, 50267,    38,   437,    10,  6596,     4,  1437,  -100,
          -100,  -100,  -100,  -100,  -100,  -100, 