In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import torch

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="mps"
)

def generate_mongo_query(query):
    prompt = """<|system|>You convert natural language to MongoDB queries.
<|user|>Convert: Find users who are 25 years old
<|assistant|>{ "collection": "users", "query": { "age": 25 } }
<|user|>Convert: Find products with price over $100
<|assistant|>{ "collection": "products", "query": { "price": { "$gt": 100 } } }
<|user|>Convert: """ + query + """
<|assistant|>"""

    inputs = tokenizer(prompt, return_tensors="pt").to("mps")

    outputs = model.generate(
        inputs.input_ids,
        max_length=200,
        temperature=0.1,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    try:
        query_start = generated_text.find("{", generated_text.rfind("Convert:"))
        query_text = generated_text[query_start:].strip()
        parsed_query = json.loads(query_text)
        return f"db.{parsed_query['collection']}.find({json.dumps(parsed_query['query'], indent=2)})"
    except Exception as e:
        print(f"Error: {str(e)}\nGenerated: {generated_text}")
        return None

# Test
query = "Find users where age is less than 35"
print(generate_mongo_query(query))



db.users.find({
  "age": {
    "$lt": 35
  }
})


In [3]:
query = "Find users where age is less than 35 and more than 25"
print(generate_mongo_query(query))

db.users.find({
  "age": {
    "$lte": 35,
    "$gte": 25
  }
})


In [4]:
query = "Find user where the number of hours worked is more than 50"
print(generate_mongo_query(query))

db.users.find({
  "hoursWorked": {
    "$gt": 50
  }
})


In [5]:
query = "Find user who has worked the most number of hours"
print(generate_mongo_query(query))

db.users.find({
  "hours_worked": {
    "$max": "$hours_worked"
  }
})


In [6]:
query = "Find user where the number of hours worked is more than 50"
print(generate_mongo_query(query))

db.users.find({
  "hoursWorked": {
    "$gt": 50
  }
})
