# t5 training for combined concatenated outputs (thing + property) 

refer to `t5_train_tp.py` and `guide_for_tp.md` for faster training workflow

In [1]:
# import data and load dataset
from datasets import load_from_disk
# Path to saved combined_dataset
file_path = 'combined_data'
split_datasets = load_from_disk(file_path)

In [2]:
# prepare tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-base")
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})

max_length = 64

def preprocess_function(examples):
    inputs = [ex["tag_description"] for ex in examples['translation']]
    targets = [ex["thing_property"] for ex in examples['translation']]
    # text_target sets the corresponding label to inputs
    # there is no need to create a separate 'labels'
    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=max_length, truncation=True
    )
    return model_inputs

# map method maps preprocess_function to [train, valid, test] datasets of the datasetDict
tokenized_datasets = split_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=split_datasets["train"].column_names,
)



Map:   0%|          | 0/7163 [00:00<?, ? examples/s]

Map:   0%|          | 0/983 [00:00<?, ? examples/s]

Map:   0%|          | 0/831 [00:00<?, ? examples/s]

In [3]:
import torch
import os
import subprocess

# we use the pre-trained t5-base model
from transformers import AutoModelForSeq2SeqLM
model_checkpoint = "t5-base"
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# data collator
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# evaluation 
import evaluate
metric = evaluate.load("sacrebleu")
import numpy as np

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100s in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

from transformers import Seq2SeqTrainingArguments

# load environment variables to disable GPU p2p mode for multi-gpu training without p2p mode
# not required for single-gpu training
import os
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'

args = Seq2SeqTrainingArguments(
    f"train_tp_checkpoint_1",
    evaluation_strategy="no",
    # logging_dir="tensorboard-log",
    # logging_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    auto_find_batch_size=True,
    ddp_find_unused_parameters=False,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=40,
    predict_with_generate=True,
    bf16=True,
    push_to_hub=False,
)

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

# Release GPU memory after training
del model
del trainer
torch.cuda.empty_cache()

# Check GPU memory usage
subprocess.run(['nvidia-smi'])



Step,Training Loss
500,3.5038
1000,0.8844
1500,0.6039
2000,0.5014




Sun Jun  2 17:25:32 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
| 32%   57C    P2    35W / 200W |   2105MiB / 12282MiB |     13%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:2D:00.0 Off |                  N/A |
| 43%   71C    P2    31W / 200W |   1167MiB / 12282MiB |      6%      Default |
|       

CompletedProcess(args=['nvidia-smi'], returncode=0)

: 