### Fine tune FairSeq Transformer pretrained on wmt19-de-en

In [7]:
# from datasets import load_dataset
# from transformers import FSMTForConditionalGeneration, FSMTTokenizer, get_scheduler
# import torch
# from torch.utils.data import DataLoader
# from torch.utils.data.distributed import DistributedSampler
# from torch.optim import AdamW
# import evaluate
# import torch_xla.core.xla_model as xm
# import torch_xla.debug.metrics as met
# import torch_xla.distributed.parallel_loader as pl
# import torch_xla.distributed.xla_multiprocessing as xmp
# import time


In [3]:
# Get pretrained tokenizer and model
# mname = "facebook/wmt19-de-en"
# tokenizer = FSMTTokenizer.from_pretrained(mname)
# model = FSMTForConditionalGeneration.from_pretrained(mname)
# optimizer = AdamW(model.parameters(), lr=5e-5)

In [4]:
# Example
# input = "Maschinelles Lernen ist großartig, oder?"
# input_ids = tokenizer.encode(input, return_tensors="pt", max_length=128, truncation=True)
# outputs = model.generate(input_ids)
# decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print(decoded) # "Machine learning is great, isn't it?"



Machine learning is great, isn't it?


In [5]:
# source_lang = "de"
# target_lang = "en"

# def preprocess(examples):
#   inputs = [ex[source_lang] for ex in examples["translation"]]
#   targets = [ex[target_lang] for ex in examples["translation"]]
#   return tokenizer(text=inputs, text_target=targets, padding="max_length", truncation=True)

In [6]:
# Get and preprocess the data
# ds = load_dataset("news_commentary", "de-en", split="train")
# ds = ds.map(preprocess, batched=True)
# ds = ds.remove_columns(["id", "translation"])
# ds = ds.train_test_split(test_size=0.2)
# ds.set_format("torch")



In [7]:
# small_train_dataset = ds["train"].shuffle(seed=42).select(range(1000))
# small_eval_dataset = ds["test"].shuffle(seed=42).select(range(1000))
# train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
# eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)

In [8]:
# num_epochs = 3
# num_training_steps = num_epochs * len(train_dataloader)
# lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [9]:
# device = torch.device("cpu")
# model = model.to(device)

In [None]:
# model.train()
# for epoch in range(num_epochs):
#   print('epoch {} begin'.format(epoch), flush=True)
#   for x, batch in enumerate(train_dataloader):
#     batch = {k: v.to(device) for k, v in batch.items()}
#     outputs = model(**batch)
#     loss = outputs.loss
#     loss.backward()

#     optimizer.step()
#     lr_scheduler.step()
#     optimizer.zero_grad()
#     if x % 100 == 0:
#       print('epoch {}: Step = {} Loss={:.5f}'.format(epoch, x, loss.item()), flush=True)
#   print('epoch {} end'.format(epoch), flush=True)

In [None]:

# metric = evaluate.load("accuracy")
# model.eval()
# for batch in eval_dataloader:
#   batch = {k: v.to(device) for k, v in batch.items()}
#   with torch.no_grad():
#     outputs = model(**batch)

#   logits = outputs.logits
#   predictions = torch.argmax(logits, dim=-1)
#   metric.add_batch(predictions=predictions, references=batch["labels"])

# metric.compute()

In [7]:
from datasets import load_dataset
from transformers import FSMTForConditionalGeneration, FSMTTokenizer, get_scheduler
import torch
from torch.optim import AdamW
import evaluate
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import time

model_name = "facebook/wmt19-de-en"
SERIAL_EXEC = xmp.MpSerialExecutor()
WRAPPED_MODEL = xmp.MpModelWrapper(FSMTForConditionalGeneration.from_pretrained(model_name))
tokenizer = FSMTTokenizer.from_pretrained(model_name)

def finetune():
  torch.manual_seed(1)

  def get_dataset():
    def preprocess(examples):
      inputs = [ex[FLAGS['source_lang']] for ex in examples["translation"]]
      targets = [ex[FLAGS['target_lang']] for ex in examples["translation"]]
      return tokenizer(text=inputs, text_target=targets, padding="max_length", truncation=True)

    ds = load_dataset("news_commentary", "de-en", split="train")
    ds = ds.map(preprocess, batched=True)
    ds = ds.remove_columns(["id", "translation"])
    ds = ds.train_test_split(test_size=0.2)
    ds.set_format("torch")
    return ds["train"].shuffle(seed=42).select(range(1000)), ds["test"].shuffle(seed=42).select(range(1000))

  # Using the serial executor avoids multiple processes to
  # download the same data.
  small_train_dataset, small_test_dataset = SERIAL_EXEC.run(get_dataset)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
      small_train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True)
  train_loader = torch.utils.data.DataLoader(
      small_train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)
  test_loader = torch.utils.data.DataLoader(
      small_test_dataset,
      batch_size=FLAGS['batch_size'],
      shuffle=False,
      num_workers=FLAGS['num_workers'],
      drop_last=True)
  
  # Scale learning rate to world size
  lr = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get optimizer, scheduler and model
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  optimizer = AdamW(model.parameters(), lr=lr)
  lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=FLAGS['num_epochs'] * len(train_loader))

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, batch in enumerate(loader):
      optimizer.zero_grad()
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()
      xm.optimizer_step(optimizer)
      lr_scheduler.step()
      tracker.add(FLAGS['batch_size'])
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), x, loss.item(), tracker.rate(),
            tracker.global_rate(), time.asctime()), flush=True)
  
  def test_loop_fn(loader):
    metric = evaluate.load("sacrebleu")
    model.eval()
    for batch in loader:
      batch = {k: v.to(device) for k, v in batch.items()}
      with torch.no_grad():
        outputs = model(**batch)

      logits = outputs.logits
      predictions = torch.argmax(logits, dim=-1)
      
      decoded_preds = [pred.strip() for pred in tokenizer.batch_decode(predictions, skip_special_tokens=True)]
      decoded_labels = [[label.strip()] for label in tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)]
      metric.add_batch(predictions=decoded_preds, references=decoded_labels)

    eval_metric = metric.compute()
    print('[xla:{}] Bleu={:.5f} Time={}'.format(
            xm.get_ordinal(), eval_metric["score"], time.asctime()), flush=True)
  
  # Train and eval loops
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    xm.master_print("Started training epoch {}".format(epoch))
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    xm.master_print("Evaluate epoch {}".format(epoch))
    para_loader = pl.ParallelLoader(test_loader, [device])
    test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

In [5]:
def mp_fn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  finetune()