In [1]:
!pip install -U datasets huggingface-hub

Collecting datasets
  Obtaining dependency information for datasets from https://files.pythonhosted.org/packages/e2/cf/db41e572d7ed958e8679018f8190438ef700aeb501b62da9e1eed9e4d69a/datasets-2.15.0-py3-none-any.whl.metadata
  Downloading datasets-2.15.0-py3-none-any.whl.metadata (20 kB)
Collecting huggingface-hub
  Obtaining dependency information for huggingface-hub from https://files.pythonhosted.org/packages/05/09/1945ca6ba3ad8ad6e2872ba682ce8d68c5e63c8e55458ed8ab4885709f1d/huggingface_hub-0.19.4-py3-none-any.whl.metadata
  Downloading huggingface_hub-0.19.4-py3-none-any.whl.metadata (14 kB)
Collecting pyarrow-hotfix (from datasets)
  Obtaining dependency information for pyarrow-hotfix from https://files.pythonhosted.org/packages/e4/f4/9ec2222f5f5f8ea04f66f184caafd991a39c8782e31f5b0266f101cb68ca/pyarrow_hotfix-0.6-py3-none-any.whl.metadata
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K   [90m━━━━━━━━━━━━━

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from huggingface_hub import notebook_login
import torch
import json
import re



In [3]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
mistral_checkpoint = "/kaggle/input/mistral/pytorch/7b-v0.1-hf/1"
mistral_tokenizer = AutoTokenizer.from_pretrained(mistral_checkpoint)
mistral_model = AutoModelForCausalLM.from_pretrained(
        mistral_checkpoint,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
mistral_device = next(mistral_model.parameters()).device

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
t5_checkpoint = "t5-base"
t5_tokenizer = AutoTokenizer.from_pretrained(t5_checkpoint)

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [6]:
t5_translator = pipeline("translation_en_to_de"
                         , model = t5_checkpoint
                         , clean_up_tokenization_spaces = True)

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [7]:
def get_gen_text(prompt, model, tokenizer, device):
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    
    generated_ids = model.generate(**model_inputs
                                   , max_length=1000
                                   , pad_token_id = tokenizer.eos_token_id
                                   , do_sample=True)
    gen_text = tokenizer.batch_decode(generated_ids)[0]
    gen_list = re.split(r'[.!?]', gen_text.replace(prompt,"").replace("<s>","").replace("\n",""))[:-1]
    
    return [x.strip() for x in gen_list]


def push_dataset(file_path, dataset_config, repo):
    dataset = load_dataset(dataset_config, data_files=file_path)
    
    dataset.push_to_hub(repo)

In [8]:
output_path = "/kaggle/working/output.jsonl"
dataset_repo = "jaymanvirk/synthetic_text_en_de"
dataset_config = "json"
tmp = load_dataset(dataset_repo
                     , download_mode = "force_redownload")
last_index = tmp["train"].num_rows

with open(output_path, 'w', encoding='utf-8') as json_file:
    for x in tmp['train']:
        json_line = json.dumps(x, ensure_ascii=False)
        json_file.write(json_line + '\n')

del tmp

Downloading readme:   0%|          | 0.00/368 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/25.2k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/436 [00:00<?, ? examples/s]

In [9]:
batch_size = 50
num_iters = 200
threshold = last_index + 500
end = last_index+num_iters
rng = range(last_index, end)

for i in rng:
    prompt = f'''
                Write {batch_size} different short sentences.
                '''
    gen_text = get_gen_text(prompt
                             , mistral_model
                             , mistral_tokenizer
                             , mistral_device)

    translation = [t5_translator(f"translate English to German: {x}")[0]['translation_text'] for x in gen_text]
    
    with open(output_path, "a") as f:
        for j in range(len(gen_text)):
            tmp = {'id': (j+last_index), 'translation': {'en': gen_text[j], 'de': translation[j]}}
            f.write(json.dumps(tmp) + "\n")
    
    last_index += len(gen_text)
    
    if last_index >= threshold:
        print(f"uploading {last_index} records to HF")
        threshold += threshold
        push_dataset(output_path, dataset_config, dataset_repo)
    
    print(f"iteration: {i+1}/{end} | completed: {last_index}")


iteration: 437/636 | completed: 489
iteration: 438/636 | completed: 579
iteration: 439/636 | completed: 620
iteration: 440/636 | completed: 665
iteration: 441/636 | completed: 721
iteration: 442/636 | completed: 790
iteration: 443/636 | completed: 875
uploading 949 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/368 [00:00<?, ?B/s]

iteration: 444/636 | completed: 949
iteration: 445/636 | completed: 963
iteration: 446/636 | completed: 1063
iteration: 447/636 | completed: 1181
iteration: 448/636 | completed: 1201
iteration: 449/636 | completed: 1302
iteration: 450/636 | completed: 1317
iteration: 451/636 | completed: 1329
iteration: 452/636 | completed: 1424
iteration: 453/636 | completed: 1524
iteration: 454/636 | completed: 1548
iteration: 455/636 | completed: 1560
iteration: 456/636 | completed: 1624
iteration: 457/636 | completed: 1627
iteration: 458/636 | completed: 1664
iteration: 459/636 | completed: 1674
iteration: 460/636 | completed: 1721
iteration: 461/636 | completed: 1730
iteration: 462/636 | completed: 1820
uploading 1920 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/368 [00:00<?, ?B/s]

iteration: 463/636 | completed: 1920
iteration: 464/636 | completed: 1928


KeyboardInterrupt: 