In [None]:
!pip install -U datasets huggingface-hub

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from huggingface_hub import notebook_login
import torch
import json
import re

In [None]:
notebook_login()

In [None]:
mistral_checkpoint = "/kaggle/input/mistral/pytorch/7b-v0.1-hf/1"
mistral_tokenizer = AutoTokenizer.from_pretrained(mistral_checkpoint)
mistral_model = AutoModelForCausalLM.from_pretrained(
        mistral_checkpoint,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
mistral_device = next(mistral_model.parameters()).device

In [None]:
t5_checkpoint = "t5-base"
t5_tokenizer = AutoTokenizer.from_pretrained(t5_checkpoint
                                             , model_max_length = 1024)

In [None]:
t5_translator = pipeline("translation_en_to_de"
                         , model = t5_checkpoint
                         , clean_up_tokenization_spaces = True)

In [None]:
def get_gen_text(prompt, model, tokenizer, device):
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    
    generated_ids = model.generate(**model_inputs
                                   , max_length = 1024
                                   , pad_token_id = tokenizer.eos_token_id
                                   , do_sample=True)
    gen_text = tokenizer.batch_decode(generated_ids)[0]
    gen_list = re.split(r'[.!?]', gen_text.replace(prompt,"").replace("<s>","").replace("\n",""))[:-1]
    
    return [x.strip() for x in gen_list if len(x.strip())>3]


def push_dataset(file_path, dataset_config, repo):
    dataset = load_dataset(dataset_config, data_files=file_path)
    
    dataset.push_to_hub(repo)

In [None]:
output_path = "/kaggle/working/output.jsonl"
dataset_repo = "jaymanvirk/synthetic_parallel_corpora"
dataset_config = "json"
tmp = load_dataset(dataset_repo
                     , download_mode = "force_redownload")
last_index = tmp["train"].num_rows

with open(output_path, 'w', encoding='utf-8') as json_file:
    for x in tmp['train']:
        json_line = json.dumps(x, ensure_ascii=False)
        json_file.write(json_line + '\n')

del tmp

In [None]:
batch_size = 50
num_iters = 200
threshold = last_index + 500
end = last_index+num_iters
rng = range(last_index, end)

for i in rng:
    prompt = f'''
                Create {batch_size} unique and diverse sentences that cover a wide range of topics, emotions, and styles.
                '''
    gen_text = get_gen_text(prompt
                             , mistral_model
                             , mistral_tokenizer
                             , mistral_device)

    translation = [t5_translator(f"translate English to German: {x}")[0]['translation_text'] for x in gen_text]
    
    with open(output_path, "a") as f:
        for j in range(len(gen_text)):
            tmp = {'id': (j+last_index), 'translation': {'en': gen_text[j], 'de': translation[j]}}
            f.write(json.dumps(tmp) + "\n")
    
    last_index += len(gen_text)
    
    if last_index >= threshold:
        print(f"uploading {last_index} records to HF")
        threshold += 500
        push_dataset(output_path, dataset_config, dataset_repo)
    
    print(f"iteration: {i+1}/{end} | completed: {last_index}")
