In [None]:
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForSeq2SeqLM, TrainingArguments, Trainer

### Train

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/mt0-large")
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [None]:
training_args = TrainingArguments(
    output_dir="seunggwan/bigscience/mt0-large-lora",
    learning_rate=1e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
model.save_pretrained("output")

### Inference

In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

In [None]:
model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

model.to("cuda")
model.eval()

In [None]:
inputs = tokenizer(
    "Preheat the oven to 350 degrees and place the cookie dough",
    return_tensors="pt"
)
outputs = model.generate(
    input_dis=inputs["input_ids"].to("cuda"),
    max_new_tokens=50
)

print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_token=True)[0])