In [None]:
# Install nightly PyTorch/XLA
# %pip install --user https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly-cp38-cp38-linux_x86_64.whl 'torch_xla[tpuvm] @ https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl'
# %pip install datasets transformers sacremoses evaluate sklearn sacrebleu

In [1]:
%env PJRT_DEVICE=TPU

env: PJRT_DEVICE=TPU


### Fine tune FairSeq Transformer pretrained on wmt19-de-en

In [2]:
# Define Parameters
FLAGS = {}
FLAGS['batch_size'] = 4
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 5e-5
FLAGS['num_epochs'] = 3
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False
FLAGS['source_lang'] = "de"
FLAGS['target_lang'] = "en"
FLAGS['metrics_debug'] = False

In [3]:
from datasets import load_dataset
from transformers import FSMTForConditionalGeneration, FSMTTokenizer, get_scheduler
import torch
from torch.optim import AdamW
import evaluate
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import time

model_name = "facebook/wmt19-de-en"
SERIAL_EXEC = xmp.MpSerialExecutor()
WRAPPED_MODEL = xmp.MpModelWrapper(FSMTForConditionalGeneration.from_pretrained(model_name))
tokenizer = FSMTTokenizer.from_pretrained(model_name)

def finetune():
  torch.manual_seed(1)

  def get_dataset():
    def preprocess(examples):
      inputs = [ex[FLAGS['source_lang']] for ex in examples["translation"]]
      targets = [ex[FLAGS['target_lang']] for ex in examples["translation"]]
      return tokenizer(text=inputs, text_target=targets, padding="max_length", truncation=True)

    ds = load_dataset("news_commentary", "de-en", split="train")
    ds = ds.map(preprocess, batched=True)
    ds = ds.remove_columns(["id", "translation"])
    ds = ds.train_test_split(test_size=0.2)
    ds.set_format("torch")
    return ds["train"].shuffle(seed=42).select(range(1000)), ds["test"].shuffle(seed=42).select(range(1000))

  # Using the serial executor avoids multiple processes to
  # download the same data.
  small_train_dataset, small_test_dataset = SERIAL_EXEC.run(get_dataset)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
      small_train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True)
  train_loader = torch.utils.data.DataLoader(
      small_train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)
  test_loader = torch.utils.data.DataLoader(
      small_test_dataset,
      batch_size=FLAGS['batch_size'],
      shuffle=False,
      num_workers=FLAGS['num_workers'],
      drop_last=True)
  
  # Scale learning rate to world size
  lr = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get optimizer, scheduler and model
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  optimizer = AdamW(model.parameters(), lr=lr)
  lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=FLAGS['num_epochs'] * len(train_loader))

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, batch in enumerate(loader):
      optimizer.zero_grad()
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()
      xm.optimizer_step(optimizer)
      lr_scheduler.step()
      tracker.add(FLAGS['batch_size'])
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), x, loss.item(), tracker.rate(),
            tracker.global_rate(), time.asctime()), flush=True)
  
  def test_loop_fn(loader):
    metric = evaluate.load("sacrebleu")
    model.eval()
    for batch in loader:
      batch = {k: v.to(device) for k, v in batch.items()}
      with torch.no_grad():
        outputs = model(**batch)

      logits = outputs.logits
      predictions = torch.argmax(logits, dim=-1)
      
      decoded_preds = [pred.strip() for pred in tokenizer.batch_decode(predictions, skip_special_tokens=True)]
      decoded_labels = [[label.strip()] for label in tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)]
      metric.add_batch(predictions=decoded_preds, references=decoded_labels)

    eval_metric = metric.compute()
    print('[xla:{}] Bleu={:.5f} Time={}'.format(
            xm.get_ordinal(), eval_metric["score"], time.asctime()), flush=True)
  
  # Train and eval loops
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    xm.master_print("Started training epoch {}".format(epoch))
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    xm.master_print("Evaluate epoch {}".format(epoch))
    para_loader = pl.ParallelLoader(test_loader, [device])
    test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

In [4]:
def mp_fn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  finetune()

In [5]:
xmp.spawn(mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')



  0%|          | 0/224 [00:00<?, ?ba/s]



Started training epoch 1
[xla:1](0) Loss=12.07927 Rate=0.16 GlobalRate=0.16 Time=Tue Nov 22 19:02:52 2022
[xla:3](0) Loss=12.36526 Rate=0.16 GlobalRate=0.16 Time=Tue Nov 22 19:02:53 2022
[xla:2](0) Loss=12.23764 Rate=0.15 GlobalRate=0.15 Time=Tue Nov 22 19:02:56 2022
[xla:0](0) Loss=12.01942 Rate=0.15 GlobalRate=0.15 Time=Tue Nov 22 19:02:59 2022
[xla:2](20) Loss=3.34164 Rate=0.10 GlobalRate=0.07 Time=Tue Nov 22 19:23:02 2022
[xla:3](20) Loss=3.29067 Rate=0.10 GlobalRate=0.07 Time=Tue Nov 22 19:23:02 2022
[xla:1](20) Loss=3.37073 Rate=0.10 GlobalRate=0.07 Time=Tue Nov 22 19:23:02 2022
[xla:0](20) Loss=3.30566 Rate=0.10 GlobalRate=0.07 Time=Tue Nov 22 19:23:02 2022
[xla:2](40) Loss=0.20429 Rate=4.92 GlobalRate=0.13 Time=Tue Nov 22 19:23:12 2022
[xla:0](40) Loss=0.24094 Rate=4.92 GlobalRate=0.13 Time=Tue Nov 22 19:23:12 2022
[xla:1](40) Loss=0.12618 Rate=4.91 GlobalRate=0.13 Time=Tue Nov 22 19:23:12 2022
[xla:3](40) Loss=0.25264 Rate=4.90 GlobalRate=0.13 Time=Tue Nov 22 19:23:12 2022
[xl