In [204]:
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from typing import List
import pandas as pd
import numpy as np
import json

In [205]:
llm = ChatOpenAI(base_url="http://THOTH.local:1234/v1/", api_key="na", temperature=0.5)

In [206]:
class DatasetCol(BaseModel):
    name: str = Field(description="The name of the column.")
    type: str = Field(description="The type of the column. One of 'int', 'float', 'str', 'bool'.")
    description: str = Field(description="A description of the column.")
    dependent: bool = Field(description="Whether this column is dependent on other columns.")
    function: str = Field(description="A python function that generates the data for this column. The function takes parameters and returns the value for the column in that row. The function takes the row as a paramter with each column accessbile by name (Eg: row[\"name\"]). Independent variables cannot use any other variables in the dataset in their function including themselves.")

class DatasetGen(BaseModel):
    columns: List[DatasetCol] = Field(description="The columns of the dataset.")

# Set up a parser + inject instructions into the prompt template.
parser = JsonOutputParser(pydantic_object=DatasetGen)

In [207]:
def parser_to_prompt_schema(parser):
    schema = {k: v for k, v in parser._get_schema(parser.pydantic_object).items()}

    # Remove extraneous fields.
    reduced_schema = schema
    if "title" in reduced_schema:
        del reduced_schema["title"]
    if "type" in reduced_schema:
        del reduced_schema["type"]
    # Ensure json in context is well-formed with double quotes.
    schema_str = json.dumps(reduced_schema)

    return schema_str

In [208]:
def try_eval(fn_str, *params):
    try:
        return eval(fn_str)(*params)
    except:
        return None
def create_dataframe(data_template, n):
    columns = data_template['columns']
    data = {}
    independent_vars = {}
    # Generate values for independent variables
    for column in columns:
        if not column['dependent']:
            col_name = column['name']
            col_type = column['type']
            col_function = column['function']
            values = [try_eval(col_function, i) for i in range(n)]
            data[col_name] = values
    # Calculate values for dependent variables
    for column in columns:
        if column['dependent']:
            col_name = column['name']
            col_type = column['type']
            col_function = column['function']
            values = [eval(col_function)(pd.DataFrame(data).to_dict('records')[i]) for i in range(n)]
            data[col_name] = values

    df = pd.DataFrame(data)
    return df


In [209]:
template = '''
You are a helpful assistant that answers in JSON.
Generate an artificial medical research dataset containing 20-30 columns each.
The subject of this dataset is: {subject}.
For each column provide a python function using numpy to generate the data.

Here's the json schema you must adhere to:
<schema>
{schema}
</schema>
'''
prompt = PromptTemplate(
    template=template,
    input_variables=["subject"],
    partial_variables={"schema": parser_to_prompt_schema(parser)},
)


In [210]:
chain = prompt | llm | parser
data_template = chain.invoke({"subject": "blindness in diabetes"})

pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', None)
display(pd.DataFrame(data_template["columns"]))

Unnamed: 0,name,type,description,dependent,function
0,patient_id,str,Unique identifier for each patient.,False,lambda row: str(row['patient_id'])
1,age,int,Age of the patient in years.,False,"lambda row: np.random.randint(18, 100)"
2,gender,str,Gender of the patient (Male or Female).,False,lambda row: 'Male' if np.random.rand() > 0.5 else 'Female'
3,diabetes_duration,int,Duration of diabetes in years.,False,"lambda row: np.random.randint(1, 20)"
4,blood_pressure,float,Blood pressure of the patient (mmHg).,False,"lambda row: np.random.uniform(low=80.0, high=140.0)"
5,fasting_blood_glucose,float,Fasting blood glucose level (mg/dL).,False,"lambda row: np.random.uniform(low=70.0, high=120.0)"
6,body_mass_index,float,Body mass index of the patient (kg/m^2).,False,"lambda row: np.random.uniform(low=18.5, high=30.0)"
7,smoker,bool,Whether the patient is a smoker or not.,False,lambda row: bool(np.random.randint(2))
8,blindness,bool,Whether the patient has blindness due to diabetes or not.,True,lambda row: (row['fasting_blood_glucose'] > 126.0) and (row['blood_pressure'] > 135.0)
9,retinopathy,bool,Whether the patient has retinopathy due to diabetes or not.,True,lambda row: (row['fasting_blood_glucose'] > 126.0) and (row['blood_pressure'] > 135.0)


In [211]:
create_dataframe(data_template, 100)

Unnamed: 0,patient_id,age,gender,diabetes_duration,blood_pressure,fasting_blood_glucose,body_mass_index,smoker,hypertension,hyperlipidemia,...,family_history_diabetes,physical_activity,diet_control,medication_adherence,health_insurance,blindness,retinopathy,neuropathy,nephropathy,foot_problems
0,,47,Male,7,89.016368,104.4436,28.403275,False,False,True,...,False,False,False,True,True,False,False,False,False,False
1,,85,Male,7,105.487086,73.908716,22.080301,True,True,False,...,False,False,True,True,True,False,False,False,False,False
2,,84,Female,11,119.003098,107.711623,19.373628,False,False,False,...,True,False,False,False,False,False,False,False,False,False
3,,68,Male,1,125.072254,96.350197,20.190769,True,False,False,...,True,False,True,False,False,False,False,False,False,False
4,,31,Female,5,88.188785,111.851569,27.448994,True,True,False,...,False,False,True,True,True,False,False,False,False,False
5,,87,Female,9,135.371076,84.067005,24.513361,False,True,True,...,False,False,True,True,True,False,False,False,False,False
6,,88,Male,11,85.849524,75.605231,24.413076,False,False,False,...,False,True,True,False,False,False,False,False,False,False
7,,35,Female,3,126.680059,90.648062,26.430279,True,True,True,...,False,False,False,True,True,False,False,False,False,False
8,,68,Male,18,100.672366,113.065373,23.243536,True,False,True,...,False,True,True,True,False,False,False,False,False,False
9,,68,Female,12,81.33626,79.368728,28.056396,True,False,True,...,False,False,False,False,True,False,False,False,False,False
