## Quick Start

In [None]:
import sys
sys.path.append("../zo2")

from tqdm.auto import tqdm
import torch
from transformers import (
    GPT2Tokenizer,
    TrainingArguments,
    Trainer,
    pipeline,
)
from zo2 import (
    ZOConfig,
    zo2_hf_init,
)

In [None]:
# Hyperparameter
batch_size = 1
max_step = 20
device = "cuda"

In [None]:
# ZO steps
zo_config = ZOConfig(method="mezo-sgd", zo2=False)
with zo2_hf_init(zo_config):
    from transformers import OPTForCausalLM
    model = OPTForCausalLM.from_pretrained("facebook/opt-125m")
    model.zo_init(zo_config)
model.to(device)

In [None]:
# Prepare some data
B, V, T = batch_size, model.config.vocab_size, model.config.max_position_embeddings
data_batch = torch.randint(0, V, (B, T+1)).to(device)
input_ids = data_batch[:, :T]
labels = data_batch[:, 1:T+1]

In [None]:
# Training loop
for i in tqdm(range(max_step)):
    # train
    model.zo_training = True
    loss = model(input_ids=input_ids, labels=labels)
    res = "Iteration {}, loss: {}, projected grad: {}"
    tqdm.write(res.format(i, loss, model.opt.projected_grad))

    # eval
    model.zo_training = False
    loss = model(input_ids=input_ids, labels=labels)
    res = "Iteration {}, eval loss: {}"
    tqdm.write(res.format(i, loss))

In [None]:
# inference
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model.zo_training = False
model.cpu()
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
prompt = "What are we having for dinner?"
generated_text = generator(prompt)[0]["generated_text"]
print(f"Question: {prompt}\nAnswer: {generated_text}")

## Using Huggingface Trainer

In [None]:
# ZO steps
zo_config = ZOConfig(method="mezo-sgd", zo2=False)
with zo2_hf_init(zo_config):
    model = OPTForCausalLM.from_pretrained("facebook/opt-125m")
print(f"Check if zo2 init correctly: {hasattr(model, "zo_training")}")

In [None]:
# normal trainer steps
training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()