In [64]:
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from typing import List
import pandas as pd
import numpy as np
import json

In [65]:
llm = ChatOpenAI(base_url="http://THOTH.local:1234/v1/", api_key="na", temperature=0.5)

In [66]:
class DatasetCol(BaseModel):
    name: str = Field(description="The name of the column.")
    type: str = Field(description="The type of the column. One of 'int', 'float', 'str', 'bool'.")
    description: str = Field(description="A description of the column.")
    function: str = Field(description="A python function that generates the data for this column. The function should take a single argument, the row number, and return the value for that row.")

class DatasetGen(BaseModel):
    columns: List[DatasetCol] = Field(description="The columns of the dataset.")

# Set up a parser + inject instructions into the prompt template.
parser = JsonOutputParser(pydantic_object=DatasetGen)

In [67]:
def parser_to_prompt_schema(parser):
    schema = {k: v for k, v in parser._get_schema(parser.pydantic_object).items()}

    # Remove extraneous fields.
    reduced_schema = schema
    if "title" in reduced_schema:
        del reduced_schema["title"]
    if "type" in reduced_schema:
        del reduced_schema["type"]
    # Ensure json in context is well-formed with double quotes.
    schema_str = json.dumps(reduced_schema)

    return schema_str

In [68]:
def create_dataframe(data_template, n):
    columns = data_template['columns']
    data = {}
    for column in columns:
        col_name = column['name']
        col_type = column['type']
        col_function = column['function']
        values = [eval(col_function)(i) for i in range(n)]
        if col_type == 'int':
            data[col_name] = values
        elif col_type == 'float':
            data[col_name] = values
        elif col_type == 'str':
            data[col_name] = values
        elif col_type == 'bool':
            data[col_name] = values
    df = pd.DataFrame(data)
    return df

In [69]:
template = '''
You are a helpful assistant that answers in JSON.
Generate an artificial medical dataset containing 10-20 columns each.
The subject of this dataset is: {subject}.
For each column provide a python function using numpy to generate the data.

Here's the json schema you must adhere to:
<schema>
{schema}
</schema>
'''
prompt = PromptTemplate(
    template=template,
    input_variables=["subject"],
    partial_variables={"schema": parser_to_prompt_schema(parser)},
)


In [73]:
chain = prompt | llm | parser

data_template = chain.invoke({"subject": "Risk of blindness in diabetes"})
create_dataframe(data_template, 10)

Unnamed: 0,PatientID,Age,Gender,DiabetesDuration,HbA1c,BloodPressure,Smoker,RetinopathyStatus,RiskOfBlindness,TreatmentType
0,1253,19,Male,5,5.808856,132,False,No,3,Insulin
1,3847,88,Female,17,11.995033,99,True,No,5,Oral Medication
2,4359,48,Female,15,6.575593,102,False,No,4,Insulin
3,5278,60,Male,14,9.933632,112,True,No,2,Insulin
4,3043,87,Male,6,6.608064,129,False,No,2,Oral Medication
5,3361,57,Male,9,6.530478,125,True,No,5,Insulin
6,7888,67,Male,2,10.752873,134,False,No,4,Diet Control
7,5252,55,Female,12,10.090467,114,True,No,2,Oral Medication
8,9926,80,Female,16,9.78596,121,True,No,1,Oral Medication
9,8558,30,Male,5,11.643938,128,True,No,3,Insulin
