In [2]:
import json
import os
import random
from tqdm import tqdm
from mlx_lm import load, generate

def load_model_and_tokenizer():
    print("Loading model and tokenizer...")
    model, tokenizer = load("mlx-community/gemma-2-27b-it-4bit")
    return model, tokenizer

def load_sample_data(directory):
    sample_data = []
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            with open(os.path.join(directory, filename), 'r') as f:
                sample_data.append(json.load(f))
    return sample_data

def generate_bl_sample(model, tokenizer, sample_data, max_tokens=2048):
    instruction = f"""
    Generate a new JSON object for a Bill of Lading (BL) based on the following structure and examples:
    {json.dumps(random.choice(sample_data), indent=2)}

    Ensure all fields are filled with realistic and varied data. 
    Modify values, names, and details to create a unique BL while maintaining the overall structure.
    The output should be a valid JSON object.
    """

    prompt = f'<s>[INST] {instruction} [/INST]\n'
    
    generated_text = generate(
        model, 
        tokenizer, 
        prompt=prompt, 
        max_tokens=max_tokens,
        temp=0.7,
        top_p=0.95,
        verbose=False
    )

    generated_text = generated_text.replace(prompt, "").strip()
    
    # Find the first occurrence of '{' and the last occurrence of '}'
    start = generated_text.find('{')
    end = generated_text.rfind('}')
    
    if start != -1 and end != -1 and start < end:
        json_str = generated_text[start:end+1]
        try:
            bl_data = json.loads(json_str)
            return bl_data
        except json.JSONDecodeError:
            print("Failed to parse JSON. Retrying...")
    else:
        print("No valid JSON structure found. Retrying...")
    return None

def create_bl_samples(model, tokenizer, sample_data_dir, output_dir, num_samples=10000):
    sample_data = load_sample_data(sample_data_dir)
    
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Generating {num_samples} BL samples...")
    for i in tqdm(range(num_samples), desc="Generating samples"):
        retry_count = 0
        while retry_count < 5:  # Limit the number of retries
            bl_data = generate_bl_sample(model, tokenizer, sample_data)
            if bl_data:
                filename = f"bl_sample_{i+1}.json"
                with open(os.path.join(output_dir, filename), 'w') as f:
                    json.dump(bl_data, f, indent=2)
                break
            retry_count += 1
        if retry_count == 5:
            print(f"Failed to generate valid JSON for sample {i+1} after 5 attempts. Skipping...")

if __name__ == "__main__":
    model, tokenizer = load_model_and_tokenizer()
    sample_data_dir = "./si/"  # Directory containing the sample JSON files
    output_dir = "./bl_sample/"
    create_bl_samples(model, tokenizer, sample_data_dir, output_dir, num_samples=10000)
    print("Sample generation completed.")

Loading model and tokenizer...


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

Generating 10000 BL samples...


Generating samples:   6%|▌         | 584/10000 [15:36:41<265:35:01, 101.54s/it]