In [1]:
from model import ChoiceModel
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import pandas as pd
from tqdm import tqdm
import numpy as np
import random  

np.random.seed(42)
random.seed(42)

In [2]:
def process_row(choice_model, row):
    profile = row[['person_id', 'age', 'individual_income',
                    'household_size', 'family_structure', 'vehicles', 'cypher']]
    profile = profile.to_dict()
    # additional_condition='The weather is sunny now'
    # additional_condition='The weather is rainy now'
    additional_condition=None
    response = choice_model.infer_without_context(
        profile=profile, mode='experiment',in_parallel=True,additional_condition=additional_condition)
    return response

def run_experiments(desire,num_sample,num_threads = 5,data_path=None):
    print("get test data")
    test_data_path = f'data/test/{desire}.csv'
    test_df = pd.read_csv(test_data_path, index_col=False)
    print(f"start experiemtns (desire={desire},sample_num={num_sample})")
    choice_model = ChoiceModel(
    data_dir='data', desire=desire, sample_num=num_sample, skip_init=True)
    data = []
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = [executor.submit(process_row, choice_model, row) for idx, row in test_df.iterrows()]
        for future in tqdm(as_completed(futures), total=len(futures)):
            try:
                response = future.result()
                if response is not None:
                    data.append(response)
            except Exception as e:
                print(f"Error get response: {e}")

    data_df = pd.DataFrame(data, columns=['person_id', 'profile', 'top_k', 'desire', 'city', 'cypher',
                                      'amenity_recommendation', 'amenity_llm_choice', 'amenity_final_choice',
                                      'mode_recommendation', 'mode_llm_choice', 'mode_final_choice'])
    if data_path is None:
        data_path = choice_model.log_data_path
    data_df.to_csv(data_path)
    print(f'data saved to {data_path}')
    print("=="*20)

In [3]:
run_experiments(desire='Shop',num_sample=0,num_threads = 6)

get test data
start experiemtns (desire=Shop,sample_num=0)


100%|██████████| 1000/1000 [17:08<00:00,  1.03s/it]

data saved to data/logs/0/Shop.csv



