In [1]:
from datasets import Dataset
raw_dataset = {
    "input": [
        "a b c d", "e f g h", "i j k l", "m n o p", "q r s t",
        "u v w x", "y z a b", "c d e f", "g h i j", "k l m n"
    ],
    "target": [
        "d c b a", "h g f e", "l k j i", "p o n m", "t s r q",
        "x w v u", "b a z y", "f e d c", "j i h g", "n m l k"
    ]
}
raw_dataset = Dataset.from_dict(raw_dataset)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [3]:
cfg = {
    'vocab_size': tokenizer.vocab_size,
    'max_length': 20,
    'dmodel': 128,
    'dff': 128,
    'h': 4,
    'num_layers_encoder': 4,
    'num_layers_decoder': 4,
}

In [4]:
def preprocess(examples):
    model_inputs = tokenizer(examples['input'], padding="max_length", truncation=True, max_length=cfg['max_length'])

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['target'], padding="max_length", truncation=True, max_length=cfg['max_length'])
        
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = raw_dataset.map(preprocess, batched=True)
print(tokenized_dataset[0])

Map: 100%|██████████| 10/10 [00:00<00:00, 616.00 examples/s]

{'input': 'a b c d', 'target': 'd c b a', 'input_ids': [101, 1037, 1038, 1039, 1040, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [101, 1040, 1039, 1038, 1037, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}





In [5]:
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel

config_encoder = BertConfig(
    vocab_size=cfg['vocab_size'],
    hidden_size=cfg['dmodel'],
    intermediate_size=cfg['dff'],
    num_hidden_layers=cfg['num_layers_encoder'],
    num_attention_heads=cfg['h']
)
config_decoder = BertConfig(
    vocab_size=cfg['vocab_size'],
    hidden_size=cfg['dmodel'],
    intermediate_size=cfg['dff'],
    num_hidden_layers=cfg['num_layers_encoder'],
    num_attention_heads=cfg['h'],
    is_decoder=True,
    add_cross_attention=True
)

config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

model = EncoderDecoderModel(config=config) # random

model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

In [6]:
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq
training_args = TrainingArguments(
    output_dir="./test_model",
    per_device_train_batch_size=2,
    num_train_epochs=100,
    learning_rate=5e-4,
    save_strategy="no",
    logging_strategy="epoch"
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model)
)

In [7]:
print("Starting test training...")
trainer.train()

Starting test training...


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Step,Training Loss
5,10.1685
10,9.6035
15,9.0405
20,8.4735
25,7.8943
30,7.3065
35,6.7169
40,6.1282
45,5.5382
50,4.9594


TrainOutput(global_step=500, training_loss=1.146910214871168, metrics={'train_runtime': 16.2554, 'train_samples_per_second': 61.518, 'train_steps_per_second': 30.759, 'total_flos': 135144240000.0, 'train_loss': 1.146910214871168, 'epoch': 100.0})

In [8]:
import torch

model.eval()

# 1. Prepare inputs (passing the attention mask is safer)
input_text = "a b c d"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

# 2. Generate with more explicit control
with torch.no_grad():
    generated_ids = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask, # Explicitly pass the mask
        max_new_tokens=10,                    # Generate up to 10 new tokens
        decoder_start_token_id=tokenizer.cls_token_id,
        eos_token_id=tokenizer.sep_token_id,  # Tell it when to stop
        pad_token_id=tokenizer.pad_token_id,
        num_beams=5,
        early_stopping=True
    )

# 3. DEBUG: Print the raw numbers
print(f"Input IDs:  {inputs.input_ids[0].tolist()}")
print(f"Output IDs: {generated_ids[0].tolist()}")

# 4. Final Output
decoded_output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Decoded:   '{decoded_output}'")

Input IDs:  [101, 1037, 1038, 1039, 1040, 102]
Output IDs: [101, 101, 1040, 1039, 1038, 1037, 102]
Decoded:   'd c b a'
