# Requirements

In [None]:
!pip install datasets transformers sentencepiece torch tqdm

In [None]:
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets, load_metric
from transformers import AutoTokenizer, AutoModelForConditionalGeneration, AdamW, get_scheduler
from functools import partial
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.cuda import is_available
from torch import device
from tqdm.auto import tqdm
import json
import random
import os

# Utils

In [None]:
def train_test_split(dataset, test_ratio = 0.1, train_size = None, test_size = None):
    train, test = [data for _, data in dataset.train_test_split(test_ratio).items()]
    train = slice_if_available(train, train_size)
    test = slice_if_available(test, test_size)
    return train, test

In [None]:
def slice_if_available(data, size):
    if size:
        if size < len(data):
            data = Dataset.from_dict(data[:size])
        else:
            raise Exception(f"'size' must be smaller than len(data). size = {size}, len(data) = {len(data)}")

    return data

In [None]:
def get_dataloader(dataset, generate = False, max_length = 32, batch_size = 32, shuffle = True):
    dataset = dataset.map(partial(tokenize_seq2seq, generate = generate, max_length = max_length), batched = True)
    dataset = dataset.remove_columns([col for col in dataset.column_names if col not in ['input_ids', 'attention_mask', 'labels']])
    dataset.set_format("torch")
    dataloader = DataLoader(dataset, batch_size, shuffle)
    return dataloader

In [None]:
def tokenize_seq2seq(example, generate = False, max_length = 32, padding = "max_length", truncation = True):
    task_prefix = "paraphrase: " if "mt5" in architecture else ""
    if generate:
        tokenizer.padding_side = "left"
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.padding_side = "right"
        tokenizer.pad_token = "<pad>"

    lower_map = {ord(u'I'): u'ı', ord(u'İ'): u'i'}

    for i in range(len(example["src"])):
        example["src"][i] = task_prefix + example["src"][i]
        if "bart" in architecture:
            example["src"][i] = example["src"][i].translate(lower_map).lower()
            example["tgt"][i] = example["tgt"][i].translate(lower_map).lower()

    encoded_input = tokenizer(example["src"], padding = padding, truncation = truncation, max_length = max_length)
    input_ids, attention_mask = encoded_input.input_ids, encoded_input.attention_mask

    labels = tokenizer(example["tgt"], padding = padding, truncation = truncation, max_length = max_length, return_tensors = "pt").input_ids

    if not generate:
        labels[labels == tokenizer.pad_token_id] = CrossEntropyLoss().ignore_index

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels.tolist()}

In [None]:
def train(dataloader, num_epochs = 2):
    model.train()
    progress_bar = tqdm(range(num_training_steps // num_train_loops))

    for _ in range(num_epochs):
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            progress_bar.update(1)

    progress_bar.close()

In [None]:
def test(dataloader):
    inputs, candidates, references = generate(dataloader)
    bleu_candidates, bleu_references = get_bleu_inputs(candidates, references)
    bleu = load_metric("bleu")
    return inputs, candidates, references, bleu.compute(predictions = bleu_candidates, references = bleu_references)

In [None]:
def generate(dataloader, max_length = 128, top_k = 50, top_p = 0.95, num_return_sequences = 1):
    inputs, candidates, references = [], [], []

    progress_bar = tqdm(range(len(dataloader)))

    for batch in dataloader:
        outputs = model.generate(
                input_ids = batch["input_ids"].to(device),
                attention_mask = batch["attention_mask"].to(device),
                do_sample = True,
                max_length = max_length,
                top_k = top_k,
                top_p = top_p,
                early_stopping = True,
                num_return_sequences = num_return_sequences)

        inputs += [tokenizer.batch_decode(batch["input_ids"], skip_special_tokens = True, clean_up_tokenization_spaces = True)]
        candidates += [tokenizer.batch_decode(outputs, skip_special_tokens = True, clean_up_tokenization_spaces = True)]
        references += [tokenizer.batch_decode(batch["labels"], skip_special_tokens = True, clean_up_tokenization_spaces = True)]

        progress_bar.update(1)

    progress_bar.close()

    return inputs, candidates, references

In [None]:
def get_bleu_inputs(candidates, references):
    bleu_candidates, bleu_references = [], []

    for candidate_batch, reference_batch in zip(candidates, references):
        for candidate, reference in zip(candidate_batch, reference_batch):
            bleu_candidates += [candidate.lower().split()]
            bleu_references += [[reference.lower().split()]]

    return bleu_candidates, bleu_references

In [None]:
def generate_paraphrase(src, max_length = 128, num_return_sequences = 5, num_beams = 5):
  if "mt5" in architecture:
    src = 'paraphrase: ' + src

  tokenized = tokenizer.encode_plus(src, return_tensors='pt')
  tokenized.to(device)

  outputs = model.generate(tokenized['input_ids'], max_length = max_length, num_return_sequences = num_return_sequences, num_beams = num_beams)

  return tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Arguments

In [None]:
dataset = 'name_of_dataset'
architecture = "model_architecture" # mt5-base, mt5-small, bart
model_checkpoint = 'model_checkpoint' # google/mt5-base, mukayese/bart-base-turkish-sum
root = '/path/to/project'
starting_epoch = 0
num_train_loops = 2
num_epochs = 1
lr = 1e-4

# Preprocessing

In [None]:
data = load_dataset("csv", data_files = {"train": f"{root}/path/to/train_dataset.csv",
                                         "test": f"{root}/path/to/test_dataset.csv"})

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
train_dataloader = get_dataloader(data["train"])
test_dataloader = get_dataloader(data["test"], generate = True, shuffle = False)

In [None]:
device = device("cuda") if is_available() else device("cpu")

In [None]:
model_checkpoint = f'{root}/models/{architecture}-{dataset}/model-{starting_epoch}' if starting_epoch else model_checkpoint

In [None]:
model = AutoModelForConditionalGeneration.from_pretrained(model_checkpoint).to(device)

In [None]:
optimizer = AdamW(model.parameters(), lr = lr)

In [None]:
num_training_steps = num_train_loops * num_epochs * len(train_dataloader)

In [None]:
prefix_length = len("paraphrase: ") if "mt5" in architecture else 0
results = []
for i in range(num_train_loops):
  train(train_dataloader, num_epochs = num_epochs)

  foldername = f'{root}/models/{architecture}-{dataset}/model-{starting_epoch + num_epochs*(i+1)}'

  model.save_pretrained(foldername)
  result = test(test_dataloader)
  results.append(result)
  os.makedirs(foldername, exist_ok=True)
  with open(f'{foldername}/result.json', 'w', encoding='utf-8') as f:
    sample = random.choice(list(zip(result[0], result[1], result[2])))
    data = {'score': result[-1],
            'data': [{'i': sample[0][i][prefix_length:], 'c': sample[1][i], 'r': sample[-1][i]} for i in range(len(sample[0]))]}
    json.dump(data, f, ensure_ascii=False, indent=4)

# Inference

In [None]:
generate_paraphrase("Erkek işini yapmak için çocuk yollayamazsın.")