In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from huggingface_hub import notebook_login
import torch
import json
import re



In [2]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
mistral_checkpoint = "/kaggle/input/mistral/pytorch/7b-v0.1-hf/1"
mistral_tokenizer = AutoTokenizer.from_pretrained(mistral_checkpoint)
mistral_model = AutoModelForCausalLM.from_pretrained(
        mistral_checkpoint,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
mistral_device = next(mistral_model.parameters()).device

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
t5_checkpoint = "t5-base"
t5_tokenizer = AutoTokenizer.from_pretrained(t5_checkpoint)

Downloading config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [5]:
t5_translator = pipeline("translation_en_to_de"
                         , model = t5_checkpoint
                         , clean_up_tokenization_spaces = True)

Downloading model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [6]:
def get_gen_text(prompt, model, tokenizer, device):
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    
    generated_ids = model.generate(**model_inputs
                                   , max_length=1000
                                   , pad_token_id = tokenizer.eos_token_id
                                   , do_sample=True)
    gen_text = tokenizer.batch_decode(generated_ids)[0]
    gen_list = re.split(r'[.!?]', gen_text.replace(prompt,"").replace("<s>","").replace("\n",""))[:-1]
    
    return [x.strip() for x in gen_list]


def push_dataset(file_path, dataset_config, repo):
    dataset = load_dataset(dataset_config, data_files=file_path)
    
    dataset.push_to_hub(repo)

In [None]:
start = load_dataset(dataset_repo
                     , download_mode = "force_redownload")["train"].num_rows

In [29]:
output_path = "/kaggle/working/output.jsonl"
dataset_repo = "jaymanvirk/synthetic_text_en_de"
dataset_config = "json"
batch_size = 50
num_iters = 200
last_index = 0
threshold = 1000

for i in range(start, start+num_iters):
    prompt = f'''
                Write {batch_size} different short sentences.
                '''
    gen_text = get_gen_text(prompt
                             , mistral_model
                             , mistral_tokenizer
                             , mistral_device)

    translation = [t5_translator(f"translate English to German: {x}")[0]['translation_text'] for x in gen_text]
    
    with open(output_path, "a") as f:
        for j in range(len(gen_text)):
            tmp = {'id': (j+last_index), 'translation': {'en': gen_text[j], 'de': translation[j]}}
            f.write(json.dumps(tmp) + "\n")
    
    last_index += len(gen_text)
    
    if last_index >= threshold:
        print(f"uploading {last_index} records to HF")
        threshold += threshold
        push_dataset(output_path, dataset_config, dataset_repo)
    
    print(f"iteration: {i+1}/{num_iters} | completed: {last_index}")


iteration: 1/20 | completed: 9
iteration: 2/20 | completed: 32
iteration: 3/20 | completed: 52
iteration: 4/20 | completed: 72
iteration: 5/20 | completed: 78
uploading to HF...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

iteration: 6/20 | completed: 204
uploading to HF...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/368 [00:00<?, ?B/s]

iteration: 7/20 | completed: 257
iteration: 8/20 | completed: 376
uploading to HF...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/368 [00:00<?, ?B/s]

iteration: 9/20 | completed: 436
iteration: 10/20 | completed: 486
iteration: 11/20 | completed: 500
iteration: 12/20 | completed: 583
iteration: 13/20 | completed: 608
iteration: 14/20 | completed: 662


KeyboardInterrupt: 