# Candidate sampling (for all of train/valid/test)

## Set-up

In [None]:
DO_TRAIN = True
DO_VALID = True
DO_TEST = True

In [None]:
DATA_DIR = "data"

In [None]:
try:
    import google.colab
    from google.colab import drive
    drive.mount('/content/drive')
    FULL_DATA_DIR = f'/content/drive/My Drive/mbr-reranking/{DATA_DIR}'

    IN_COLAB = True
except:
    FULL_DATA_DIR = DATA_DIR

    IN_COLAB = False

In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
!pip install datasets transformers[sentencepiece]
!pip install sentencepiece
!pip install nltk

In [None]:
import torch
from tqdm import tqdm
import json

device = "cuda" if torch.cuda.is_available() else "cpu"

## Model

In [None]:
PRETRAINED_MODEL = "mbart"
USE_COLLATE_FN = False

In [None]:
if PRETRAINED_MODEL == "mbart":
    from transformers import MBartForConditionalGeneration, MBartTokenizer
    tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", src_lang="de_DE")
    model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
    model = model.eval()
else:
    from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoConfig
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
    model.config.max_length = 512
    model = model.eval()

In [None]:
def get_sentence_pairs(split="train", subsample_factor=1):
    sentences_src = []
    sentences_dst = []
    with open(f"{FULL_DATA_DIR}/{split}.deu", 'r', encoding='utf-8') as fp:
        for line in fp:
            sentences_src.append(line.strip())

    with open(f"{FULL_DATA_DIR}/{split}.eng", 'r', encoding='utf-8') as fp:
        for line in fp:
            sentences_dst.append(line.strip())

    if len(sentences_src) == len(sentences_dst):
        print(len(sentences_src), f"{split} pairs")
    else:
        print(f"WARNING: different number of {split} sentences: {len(sentences_src)} vs {len(sentences_dst)}")

    num_samples = len(sentences_src) // subsample_factor
    return sentences_src[:num_samples], sentences_dst[:num_samples]

In [None]:
from torch.utils.data import Sampler, DataLoader
import math

class LengthBatchSampler(Sampler):
    def __init__(self, sorted_data_with_indices, batch_size):
        self.sorted_data_with_indices = [a for (a,b) in sorted_data_with_indices]
        self.batch_size = batch_size

    def __iter__(self):
        batch = []
        i = 0
        while i < len(self.sorted_data_with_indices):
            yield self.sorted_data_with_indices[i:i+self.batch_size]
            i += self.batch_size

    def __len__(self):
        return math.ceil(len(self.sorted_data_with_indices) / self.batch_size)

from torch.utils.data import Dataset

class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


if PRETRAINED_MODEL == "mbart":
    def collator(batch):
        # Unzip the batch to separate indices and sentences
        indices, sentences = zip(*batch)

        # Tokenize the sentences. This automatically handles padding
        tokenized_inputs = tokenizer(list(sentences), padding=True, truncation=True, return_tensors="pt")

        # Return a dict containing the tokenized inputs and the original indices
        return {"indices": indices, "input_ids": tokenized_inputs['input_ids'], "attention_mask": tokenized_inputs['attention_mask']}
else:
    def collator(batch):
        # Unzip the batch to separate indices and sentences
        indices, sentences = zip(*batch)
        sentences = ["translate English to German: " + sentence for sentence in sentences]

        # Tokenize the sentences. This automatically handles padding
        tokenized_inputs = tokenizer(list(sentences), padding=True, truncation=True, return_tensors="pt")

        # Return a dict containing the tokenized inputs and the original indices
        return {"indices": indices, "input_ids": tokenized_inputs['input_ids'], "attention_mask": tokenized_inputs['attention_mask']}

## Dataset

In [None]:
def get_dataset(split="train", batch_size=1, subsample_factor=1):
    sentences_src, sentences_dst = get_sentence_pairs(split=split, subsample_factor=subsample_factor)
    sorted_sentences_with_indices = sorted(enumerate(sentences_src), key=lambda x: len(x[1]))

    # Create the dataset and sampler
    dataset = SimpleDataset(sorted_sentences_with_indices)
    # sampler = LengthBatchSampler(dataset, batch_size=batch_size)

    # Create the DataLoader
    # if USE_COLLATE_FN:
    #     data_loader = DataLoader(dataset, batch_sampler=sampler, collate_fn=collator)
    # else:
    #     data_loader = DataLoader(dataset, batch_sampler=sampler)

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    return sentences_src, data_loader

## Sampling

In [None]:
def sample_sentences(data_loader, num_sentences, num_samples=1, num_beams=1, do_sample=True, max_length=384, temperature=1.0):
    sampled_sentences = ["" for sentence in range(num_sentences) for c in range(num_samples)]
    torch.manual_seed(0)
    with torch.cuda.amp.autocast(enabled=True):
        with torch.no_grad():
            if USE_COLLATE_FN:
                for step, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    original_indices = batch['indices']

                    if PRETRAINED_MODEL == "mbart":
                        translated_tokens = model.generate(input_ids=input_ids, attention_mask=attention_mask, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"], do_sample=do_sample, num_return_sequences=num_samples, num_beams=num_beams, max_length=max_length, temperature=temperature).detach().cpu()
                        decoded = tokenizer.batch_decode(translated_tokens, skip_special_tokens=False)
                        decoded = [sentence.split('</s>')[0] for sentence in decoded]
                    else:
                        translated_tokens = model.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=do_sample, num_return_sequences=num_samples, num_beams=num_beams, max_length=max_length, temperature=temperature).detach().cpu()
                        decoded = tokenizer.batch_decode(translated_tokens, skip_special_tokens=False)

                    for index, sentence in zip(indices, decoded):
                        sampled_sentences[index] = sentence

                    torch.cuda.empty_cache()
            else:
                for step, (indices, sentences) in tqdm(enumerate(data_loader), total=len(data_loader)):

                    if PRETRAINED_MODEL == "mbart":
                        inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)
                        translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"], do_sample=do_sample, num_return_sequences=num_samples, num_beams=num_beams, eos_token_id=tokenizer.eos_token_id, max_length=max_length, temperature=temperature).detach().cpu()
                        decoded = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
                    else:
                        sentences = ["translate English to German: " + sentence for sentence in sentences]
                        inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)
                        translated_tokens = model.generate(**inputs, do_sample=do_sample, num_return_sequences=num_samples, num_beams=1, eos_token_id=tokenizer.eos_token_id, max_length=max_length, temperature=temperature).detach().cpu()
                        decoded = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)

                    for j, index in enumerate(indices):
                        sampled_sentences[num_samples*int(index):num_samples*(int(index)+1)] = decoded[num_samples*j:num_samples*(j+1)]

                    torch.cuda.empty_cache()

    return sampled_sentences

Train:

In [None]:
if DO_TRAIN:
    SAMPLES_PER_SENTENCE = 2
    BATCH_SIZE = 128 // SAMPLES_PER_SENTENCE
    sentences_src, data_loader = get_dataset("train", batch_size=BATCH_SIZE, subsample_factor=1) # subsample_factor=SAMPLES_PER_SENTENCE
    sampled_sentences = sample_sentences(data_loader, num_sentences=len(sentences_src), num_samples=SAMPLES_PER_SENTENCE)
    with open(f'{FULL_DATA_DIR}/sampled/train.eng', 'w', encoding='utf-8') as file:
        for sentence in sampled_sentences:
            file.write(sentence + '\n')
    with open(f'{FULL_DATA_DIR}/sampled/train.info.json', 'w', encoding='utf-8') as file:
        json.dump({
            "samples_per_sentence": SAMPLES_PER_SENTENCE
        }, file)

Validation:

In [None]:
if DO_VALID:

    SAMPLES_PER_SENTENCE = 30
    BATCH_SIZE = 4
    sentences_src, data_loader = get_dataset("dev", batch_size=BATCH_SIZE)
    sampled_sentences = sample_sentences(data_loader, num_sentences=len(sentences_src), num_samples=SAMPLES_PER_SENTENCE)
    with open(f'{FULL_DATA_DIR}/sampled/dev.eng', 'w', encoding='utf-8') as file:
        for sentence in sampled_sentences:
            file.write(sentence + '\n')
    sampled_sentences = sample_sentences(data_loader, num_sentences=len(sentences_src), num_samples=SAMPLES_PER_SENTENCE, temperature=0.7)
    with open(f'{FULL_DATA_DIR}/sampled/dev-cold.eng', 'w', encoding='utf-8') as file:
        for sentence in sampled_sentences:
            file.write(sentence + '\n')

    SAMPLES_PER_SENTENCE = 1
    BATCH_SIZE = 16
    sentences_src, data_loader = get_dataset("dev", batch_size=BATCH_SIZE)
    sampled_sentences = sample_sentences(data_loader, num_sentences=len(sentences_src), num_samples=SAMPLES_PER_SENTENCE, num_beams=10, do_sample=False)
    with open(f'{FULL_DATA_DIR}/beams/dev.eng', 'w', encoding='utf-8') as file:
        for sentence in sampled_sentences:
            file.write(sentence + '\n')

Test:

In [None]:
if DO_TEST:
    SAMPLES_PER_SENTENCE = 30
    BATCH_SIZE = 4
    sentences_src, data_loader = get_dataset("test", batch_size=BATCH_SIZE)
    sampled_sentences = sample_sentences(data_loader, num_sentences=len(sentences_src), num_samples=SAMPLES_PER_SENTENCE)
    with open(f'{FULL_DATA_DIR}/sampled/test.eng', 'w', encoding='utf-8') as file:
        for sentence in sampled_sentences:
            file.write(sentence + '\n')
    sampled_sentences = sample_sentences(data_loader, num_sentences=len(sentences_src), num_samples=SAMPLES_PER_SENTENCE, temperature=0.7)
    with open(f'{FULL_DATA_DIR}/sampled/test-cold.eng', 'w', encoding='utf-8') as file:
        for sentence in sampled_sentences:
            file.write(sentence + '\n')

    SAMPLES_PER_SENTENCE = 1
    BATCH_SIZE = 16
    sentences_src, data_loader = get_dataset("test", batch_size=BATCH_SIZE)
    sampled_sentences = sample_sentences(data_loader, num_sentences=len(sentences_src), num_samples=SAMPLES_PER_SENTENCE, num_beams=10, do_sample=False)
    with open(f'{FULL_DATA_DIR}/beams/test.eng', 'w', encoding='utf-8') as file:
        for sentence in sampled_sentences:
            file.write(sentence + '\n')

## End

If in Google Colab, kill the session:

In [None]:
if IN_COLAB:
    import time
    time.sleep(15)

    from google.colab import runtime
    runtime.unassign()