In [1]:
!pip install -U datasets huggingface-hub

Collecting datasets
  Obtaining dependency information for datasets from https://files.pythonhosted.org/packages/e2/cf/db41e572d7ed958e8679018f8190438ef700aeb501b62da9e1eed9e4d69a/datasets-2.15.0-py3-none-any.whl.metadata
  Downloading datasets-2.15.0-py3-none-any.whl.metadata (20 kB)
Collecting huggingface-hub
  Obtaining dependency information for huggingface-hub from https://files.pythonhosted.org/packages/05/09/1945ca6ba3ad8ad6e2872ba682ce8d68c5e63c8e55458ed8ab4885709f1d/huggingface_hub-0.19.4-py3-none-any.whl.metadata
  Downloading huggingface_hub-0.19.4-py3-none-any.whl.metadata (14 kB)
Collecting pyarrow-hotfix (from datasets)
  Obtaining dependency information for pyarrow-hotfix from https://files.pythonhosted.org/packages/e4/f4/9ec2222f5f5f8ea04f66f184caafd991a39c8782e31f5b0266f101cb68ca/pyarrow_hotfix-0.6-py3-none-any.whl.metadata
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K   [90m━━━━━━━━━━━━━

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from huggingface_hub import notebook_login
import torch
import json
import re



In [3]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
mistral_checkpoint = "/kaggle/input/mistral/pytorch/7b-v0.1-hf/1"
mistral_tokenizer = AutoTokenizer.from_pretrained(mistral_checkpoint)
mistral_model = AutoModelForCausalLM.from_pretrained(
        mistral_checkpoint,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
mistral_device = next(mistral_model.parameters()).device

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
t5_checkpoint = "t5-base"
t5_tokenizer = AutoTokenizer.from_pretrained(t5_checkpoint, model_max_length = 1024)

In [7]:
t5_translator = pipeline("translation_en_to_de"
                         , model = t5_checkpoint
                         , clean_up_tokenization_spaces = True)

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [8]:
def get_gen_text(prompt, model, tokenizer, device):
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    
    generated_ids = model.generate(**model_inputs
                                   , max_length = 1024
                                   , pad_token_id = tokenizer.eos_token_id
                                   , do_sample=True)
    gen_text = tokenizer.batch_decode(generated_ids)[0]
    gen_list = re.split(r'[.!?]', gen_text.replace(prompt,"").replace("<s>","").replace("\n",""))[:-1]
    
    return [x.strip() for x in gen_list if len(x.strip())>3]


def push_dataset(file_path, dataset_config, repo):
    dataset = load_dataset(dataset_config, data_files=file_path)
    
    dataset.push_to_hub(repo)

In [9]:
output_path = "/kaggle/working/output.jsonl"
dataset_repo = "jaymanvirk/synthetic_text_en_de"
dataset_config = "json"
tmp = load_dataset(dataset_repo
                     , download_mode = "force_redownload")
last_index = tmp["train"].num_rows

with open(output_path, 'w', encoding='utf-8') as json_file:
    for x in tmp['train']:
        json_line = json.dumps(x, ensure_ascii=False)
        json_file.write(json_line + '\n')

del tmp

Downloading readme:   0%|          | 0.00/372 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/110k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/2011 [00:00<?, ? examples/s]

In [11]:
batch_size = 75
num_iters = 200
threshold = last_index + 500
end = last_index+num_iters
rng = range(last_index, end)

for i in rng:
    prompt = f'''
                Write {batch_size} different short sentences.
                '''
    gen_text = get_gen_text(prompt
                             , mistral_model
                             , mistral_tokenizer
                             , mistral_device)

    translation = [t5_translator(f"translate English to German: {x}")[0]['translation_text'] for x in gen_text]
    
    with open(output_path, "a") as f:
        for j in range(len(gen_text)):
            tmp = {'id': (j+last_index), 'translation': {'en': gen_text[j], 'de': translation[j]}}
            f.write(json.dumps(tmp) + "\n")
    
    last_index += len(gen_text)
    
    if last_index >= threshold:
        print(f"uploading {last_index} records to HF")
        threshold += 500
        push_dataset(output_path, dataset_config, dataset_repo)
    
    print(f"iteration: {i+1}/{end} | completed: {last_index}")


iteration: 2012/2211 | completed: 2070
iteration: 2013/2211 | completed: 2132
iteration: 2014/2211 | completed: 2163
iteration: 2015/2211 | completed: 2197
iteration: 2016/2211 | completed: 2277
iteration: 2017/2211 | completed: 2352
iteration: 2018/2211 | completed: 2361
iteration: 2019/2211 | completed: 2396
iteration: 2020/2211 | completed: 2406
iteration: 2021/2211 | completed: 2452
iteration: 2022/2211 | completed: 2494
uploading 2515 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2023/2211 | completed: 2515
iteration: 2024/2211 | completed: 2525
iteration: 2025/2211 | completed: 2581
iteration: 2026/2211 | completed: 2597
iteration: 2027/2211 | completed: 2605
iteration: 2028/2211 | completed: 2613
iteration: 2029/2211 | completed: 2633
iteration: 2030/2211 | completed: 2696
iteration: 2031/2211 | completed: 2753
iteration: 2032/2211 | completed: 2761
iteration: 2033/2211 | completed: 2772
iteration: 2034/2211 | completed: 2801
iteration: 2035/2211 | completed: 2812
iteration: 2036/2211 | completed: 2877
iteration: 2037/2211 | completed: 2884
iteration: 2038/2211 | completed: 2908
iteration: 2039/2211 | completed: 2979
iteration: 2040/2211 | completed: 3004
uploading 3035 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2041/2211 | completed: 3035
iteration: 2042/2211 | completed: 3054
iteration: 2043/2211 | completed: 3088
iteration: 2044/2211 | completed: 3110
iteration: 2045/2211 | completed: 3124
iteration: 2046/2211 | completed: 3136
iteration: 2047/2211 | completed: 3160
iteration: 2048/2211 | completed: 3178
iteration: 2049/2211 | completed: 3185
iteration: 2050/2211 | completed: 3244
iteration: 2051/2211 | completed: 3271
iteration: 2052/2211 | completed: 3292
iteration: 2053/2211 | completed: 3299
iteration: 2054/2211 | completed: 3369
iteration: 2055/2211 | completed: 3393
iteration: 2056/2211 | completed: 3430
iteration: 2057/2211 | completed: 3434
iteration: 2058/2211 | completed: 3451
iteration: 2059/2211 | completed: 3464
uploading 3535 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2060/2211 | completed: 3535
iteration: 2061/2211 | completed: 3566
iteration: 2062/2211 | completed: 3576


Your input_length: 500 is bigger than 0.9 * max_length: 300. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)


iteration: 2063/2211 | completed: 3581
iteration: 2064/2211 | completed: 3584
iteration: 2065/2211 | completed: 3618
iteration: 2066/2211 | completed: 3660
iteration: 2067/2211 | completed: 3729
iteration: 2068/2211 | completed: 3736
iteration: 2069/2211 | completed: 3775
iteration: 2070/2211 | completed: 3793
iteration: 2071/2211 | completed: 3802
iteration: 2072/2211 | completed: 3822
iteration: 2073/2211 | completed: 3829
iteration: 2074/2211 | completed: 3882
iteration: 2075/2211 | completed: 3911
iteration: 2076/2211 | completed: 3936
uploading 4011 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2077/2211 | completed: 4011
iteration: 2078/2211 | completed: 4053
iteration: 2079/2211 | completed: 4061
iteration: 2080/2211 | completed: 4069
iteration: 2081/2211 | completed: 4154
iteration: 2082/2211 | completed: 4161
iteration: 2083/2211 | completed: 4175
iteration: 2084/2211 | completed: 4218
iteration: 2085/2211 | completed: 4281
iteration: 2086/2211 | completed: 4293
iteration: 2087/2211 | completed: 4318
iteration: 2088/2211 | completed: 4353
iteration: 2089/2211 | completed: 4416
iteration: 2090/2211 | completed: 4480
iteration: 2091/2211 | completed: 4489
uploading 4558 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2092/2211 | completed: 4558
iteration: 2093/2211 | completed: 4583
iteration: 2094/2211 | completed: 4593
iteration: 2095/2211 | completed: 4629
iteration: 2096/2211 | completed: 4695
iteration: 2097/2211 | completed: 4726
iteration: 2098/2211 | completed: 4738
iteration: 2099/2211 | completed: 4757
iteration: 2100/2211 | completed: 4767
iteration: 2101/2211 | completed: 4847
iteration: 2102/2211 | completed: 4856
iteration: 2103/2211 | completed: 4879
iteration: 2104/2211 | completed: 4961
iteration: 2105/2211 | completed: 4969
iteration: 2106/2211 | completed: 5005
uploading 5014 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2107/2211 | completed: 5014
iteration: 2108/2211 | completed: 5022
iteration: 2109/2211 | completed: 5035
iteration: 2110/2211 | completed: 5084
iteration: 2111/2211 | completed: 5089
iteration: 2112/2211 | completed: 5101
iteration: 2113/2211 | completed: 5107
iteration: 2114/2211 | completed: 5156
iteration: 2115/2211 | completed: 5161
iteration: 2116/2211 | completed: 5164
iteration: 2117/2211 | completed: 5289
iteration: 2118/2211 | completed: 5374
iteration: 2119/2211 | completed: 5421
iteration: 2120/2211 | completed: 5480
iteration: 2121/2211 | completed: 5489
iteration: 2122/2211 | completed: 5495
iteration: 2123/2211 | completed: 5504
uploading 5520 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2124/2211 | completed: 5520
iteration: 2125/2211 | completed: 5535
iteration: 2126/2211 | completed: 5542
iteration: 2127/2211 | completed: 5603
iteration: 2128/2211 | completed: 5651
iteration: 2129/2211 | completed: 5660
iteration: 2130/2211 | completed: 5666
iteration: 2131/2211 | completed: 5672
iteration: 2132/2211 | completed: 5726
iteration: 2133/2211 | completed: 5778
iteration: 2134/2211 | completed: 5790
iteration: 2135/2211 | completed: 5868
iteration: 2136/2211 | completed: 5898
iteration: 2137/2211 | completed: 5983
iteration: 2138/2211 | completed: 6001
uploading 6058 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2139/2211 | completed: 6058
iteration: 2140/2211 | completed: 6117
iteration: 2141/2211 | completed: 6137
iteration: 2142/2211 | completed: 6185
iteration: 2143/2211 | completed: 6186
iteration: 2144/2211 | completed: 6228
iteration: 2145/2211 | completed: 6260
iteration: 2146/2211 | completed: 6296
iteration: 2147/2211 | completed: 6305
iteration: 2148/2211 | completed: 6331
iteration: 2149/2211 | completed: 6437
iteration: 2150/2211 | completed: 6463
iteration: 2151/2211 | completed: 6477
iteration: 2152/2211 | completed: 6481
uploading 6557 records to HF


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/372 [00:00<?, ?B/s]

iteration: 2153/2211 | completed: 6557
iteration: 2154/2211 | completed: 6597
iteration: 2155/2211 | completed: 6647
iteration: 2156/2211 | completed: 6655
iteration: 2157/2211 | completed: 6715
iteration: 2158/2211 | completed: 6743
iteration: 2159/2211 | completed: 6763
iteration: 2160/2211 | completed: 6838
iteration: 2161/2211 | completed: 6872


KeyboardInterrupt: 