In [1]:
import numpy as np

In [2]:
sample_size = 10

In [3]:
O1 = np.random.normal(0, 1, (sample_size, 3))
A2 = np.random.randint(1, 4, sample_size) 

In [11]:
O1.shape

(10, 3)

In [15]:
A2

array([1, 1, 2, 3, 3, 1, 1, 2, 1, 1])

In [14]:
O1

array([[ 0.13827105, -1.13850132, -0.36779692],
       [-0.48521878, -2.77699117, -0.34939154],
       [-2.0250776 ,  0.60206749,  0.31531961],
       [-0.30072282, -2.18131407, -0.86849781],
       [-1.36589492, -0.97907894, -1.79591771],
       [ 1.54883484,  0.89622265, -2.0370887 ],
       [ 0.83768825, -0.18950545, -0.85639641],
       [-0.66255101, -0.74248416,  1.18513487],
       [-1.27920179, -0.5664314 , -0.97207485],
       [ 2.09161351, -0.94204533,  0.11204933]])

In [16]:

# Extract corresponding columns from O1 using A2 as indices
O1[np.arange(sample_size), A2 - 1]**2

# array([ 0.13827105, -0.48521878,  0.60206749, -0.86849781, -1.79591771,
#         1.54883484,  0.83768825, -0.74248416, -1.27920179,  2.09161351])

array([0.01911888, 0.23543726, 0.36248526, 0.75428845, 3.22532042,
       2.39888936, 0.7017216 , 0.55128273, 1.63635722, 4.37484709])

In [19]:
(-0.48521878)**2

0.2354372644646884

In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
from scipy.stats import norm
from collections import Counter

import pdb
# # Set the seed for reproducibility
# seed = 12345
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# np.random.seed(seed)

import warnings

# Suppress all warnings
warnings.filterwarnings('ignore')

# pip install rpy2

import rpy2.robjects as ro
from rpy2.robjects import numpy2ri

# Activate automatic conversion of numpy objects to R objects
numpy2ri.activate()




def batches(N, batch_size, seed=0):
    np.random.seed(seed)
    indices = np.arange(N)
    np.random.shuffle(indices)
    for start_idx in range(0, N, batch_size):
        batch_indices = indices[start_idx:start_idx+batch_size]
        yield torch.tensor(batch_indices, dtype=torch.long)


def extract_and_prepare_data(A1_tensor_test, A2_tensor_test, A1, A2, d1_star, d2_star):

    # Helper function to convert tensors to numpy arrays
    def to_numpy(data):
        if isinstance(data, torch.Tensor):
            return data.cpu().numpy() if data.is_cuda else data.numpy()
        elif isinstance(data, np.ndarray):
            return data
        else:
            raise TypeError("The input must be a PyTorch tensor or a NumPy array")

    new_row = {
        'Behavioral_A1': to_numpy(A1_tensor_test).tolist(),
        'Behavioral_A2': to_numpy(A2_tensor_test).tolist(),
        'Predicted_A1': to_numpy(A1).tolist(),
        'Predicted_A2': to_numpy(A2).tolist(),
        'Optimal_A1': to_numpy(d1_star).tolist(),
        'Optimal_A2': to_numpy(d2_star).tolist()
    }

    return new_row





class NNClass(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_networks, dropout_rate):
        super(NNClass, self).__init__()
        self.networks = nn.ModuleList()
        for _ in range(num_networks):
            network = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ELU(alpha=0.4),  # Using ELU instead of ReLU; best result 0.2, 0.13
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ELU(alpha=0.4),
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_dim, output_dim),
                nn.BatchNorm1d(output_dim),
            )
            self.networks.append(network)

    def forward(self, x):
        outputs = []
        for network in self.networks:
            outputs.append(network(x))
        return outputs

    def he_initializer(self):
        for network in self.networks:
            for layer in network:
                if isinstance(layer, nn.Linear):
                    nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                    nn.init.constant_(layer.bias, 0)  # Biases can be initialized to zero

    def reset_weights(self):
        for network in self.networks:
            for layer in network:
                if isinstance(layer, nn.Linear):
                    nn.init.constant_(layer.weight, 0.1) # best 0.1
                    nn.init.constant_(layer.bias, 0.0)







## train V validation loss

def plot_simulation_surLoss_losses_in_grid(selected_indices, losses_dict, cols=3):
    # Calculate the number of rows needed based on the number of selected indices and desired number of columns
    rows = len(selected_indices) // cols + (len(selected_indices) % cols > 0)

    # Create a figure and a set of subplots
    fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 4*rows))  # Adjust figure size as needed
    fig.suptitle(f'Training and Validation Loss for Selected Simulations @ n_epoch = {n_epoch}')

    # Flatten the axes array for easy indexing, in case of a single row or column
    axes = axes.flatten()

    for i, idx in enumerate(selected_indices):
        train_loss, val_loss = losses_dict[idx]

        # Plot on the ith subplot
        axes[i].plot(train_loss, label='Training')
        axes[i].plot(val_loss, label='Validation')
        axes[i].set_title(f'Simulation {idx}')
        axes[i].set_xlabel('Epochs')
        axes[i].set_ylabel('Loss')
        axes[i].legend()

    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust the layout to make room for the suptitle
    plt.show()



def transform_Y(Y1, Y2):

    # Identify the minimum value among Y1 and Y2, only if they are negative
    min_negative_Y = min(min(Y1.min(), 0), min(Y2.min(), 0))

    # If the minimum value is negative, adjust Y1 and Y2 by adding the absolute
    # value of min_negative_Y plus 1 to ensure all values are non-negative
    if min_negative_Y < 0:
        Y1_trans = Y1 - min_negative_Y + 1
        Y2_trans = Y2 - min_negative_Y + 1
    else:
        # If there are no negative values, no adjustment is needed
        Y1_trans = Y1
        Y2_trans = Y2

    return Y1_trans, Y2_trans


def A_sim(matrix_pi, stage):
    N, K = matrix_pi.shape  # sample size and treatment options
    if N <= 1 or K <= 1:
        raise ValueError("Sample size or treatment options are insufficient!")
    if np.min(matrix_pi) < 0:
        raise ValueError("Treatment probabilities should not be negative!")

    # Normalize probabilities to add up to 1 and simulate treatment A for each row
    pis = matrix_pi.sum(axis=1)
    probs = np.divide(matrix_pi, pis[:, np.newaxis])
    A = np.array([np.random.choice(np.arange(K), p=probs[i,]) for i in range(N)])
    if stage == 1:
        col_names = ['pi_10', 'pi_11', 'pi_12']
    else:
        col_names = ['pi_20', 'pi_21', 'pi_22']
    probs_df = pd.DataFrame(probs, columns=col_names)
    return {'A': A, 'probs': probs_df}




def generate_data(sample_size, setting, replication_seed):
    np.random.seed(replication_seed)

    print("DGP: ", setting)

    # Simulate baseline covariates
    N = sample_size
    x1, x2, x3, x4, x5 = np.random.normal(size=(5, N))
    O1 = np.column_stack((x1, x2, x3, x4, x5))

    Z1 = np.random.normal(0, 1, sample_size)
    Z2 = np.random.normal(0, 1, sample_size)

    if noiseless:
        Z1 = 0
        Z2 = 0

    # Stage 1 data simulation
    pi_10 = np.ones(N)
    pi_11 = np.exp(0.5 - 0.5 * x3)
    pi_12 = np.exp(0.5 * x4)
    matrix_pi1 = np.column_stack((pi_10, pi_11, pi_12))
    result1 = A_sim(matrix_pi1, stage=1)
    A1, probs1 = result1['A'], result1['probs']
    A1 +=1


    # Optimal g1.opt corrected
    g1_opt = ((x1 > -1).astype(float) * ((x2 > -0.5).astype(float) + (x2 > 0.5).astype(float))).astype(float) + 1


    # Stage 1 outcome R1

    Y1 = np.exp(1.5 - np.abs(1.5 * x1 + 2) * (A1 - g1_opt)**2) + Z1

    # Stage 2 data simulation
    pi_20 = np.ones(N)
    pi_21 = np.exp(0.2 * Y1 - 1)
    pi_22 = np.exp(0.5 * x4)
    matrix_pi2 = np.column_stack((pi_20, pi_21, pi_22))
    result2 = A_sim(matrix_pi2, stage=2)
    A2, probs2 = result2['A'], result2['probs']
    A2 +=1

    Y1_opt = np.exp(1.5) + Z1

    g2_opt = (x3 > -1).astype(float) * ((Y1_opt > 0.5).astype(float) + (Y1_opt > 3).astype(float)) + 1



    # Stage 2 outcome R2
    Y2 = np.exp(1.26 - np.abs(1.5 * x3 - 2) * (A2 - g2_opt)**2) + Z2


    # Extract propensity for both stages
    pi_10, pi_11, pi_12 = probs1['pi_10'], probs1['pi_11'], probs1['pi_12']
    pi_20, pi_21, pi_22 = probs2['pi_20'], probs2['pi_21'], probs2['pi_22']

    # Dummy O2
    O2 = np.zeros(sample_size)

    Y1_trans, Y2_trans = transform_Y(Y1, Y2)



    return {
        'dgp': setting,
        'O1': torch.tensor(O1, dtype=torch.float32, device=device),
        'A1': torch.tensor(A1, dtype=torch.float32, device=device),
        'Z1': torch.tensor(Z1, dtype=torch.float32, device=device),
        'Y1': torch.tensor(Y1, dtype=torch.float32, device=device),
        'O2': torch.tensor(O2, dtype=torch.float32, device=device),
        'A2': torch.tensor(A2, dtype=torch.float32, device=device),
        'Z2': torch.tensor(Z2, dtype=torch.float32, device=device),
        'Y2': torch.tensor(Y2, dtype=torch.float32, device=device),
        'pi_10': torch.tensor(pi_10, dtype=torch.float32, device=device),
        'pi_11': torch.tensor(pi_11, dtype=torch.float32, device=device),
        'pi_12': torch.tensor(pi_12, dtype=torch.float32, device=device),
        'pi_20': torch.tensor(pi_20, dtype=torch.float32, device=device),
        'pi_21': torch.tensor(pi_21, dtype=torch.float32, device=device),
        'pi_22': torch.tensor(pi_22, dtype=torch.float32, device=device),
        'Y1_trans': torch.tensor(Y1_trans, dtype=torch.float32, device=device),
        'Y2_trans': torch.tensor(Y2_trans, dtype=torch.float32, device=device),
        'g1_opt': torch.tensor(g1_opt, dtype=torch.float32, device=device),
        'g2_opt': torch.tensor(g2_opt, dtype=torch.float32, device=device),
    }


# Preprocess
def preprocess_data(generated_data, sample_size, setting, replication_seed, run='train', training_validation_prop = 0.8):

    # Extract data from the generated_data dictionary
    O1_tensor, O2_tensor = [generated_data[key] for key in ['O1', 'O2']]
    A1_tensor, A2_tensor = [generated_data[key] for key in ['A1', 'A2']]
    g1_opt, g2_opt = [generated_data[key] for key in ['g1_opt', 'g2_opt']]


    if run == 'test':
        Y1_tensor, Y2_tensor = [generated_data[key] for key in ['Y1', 'Y2']]
    else:
        Y1_tensor, Y2_tensor = [generated_data[key] for key in ['Y1_trans', 'Y2_trans']]

    Y1_tensor_opt = torch.exp(torch.tensor(1.5)) + generated_data['Z1']
    
    if tree_type:
        g1_opt = ((O1_tensor[:, 0] > -1).float() * ((O1_tensor[:, 1] > -0.5).float() + (O1_tensor[:, 1] > 0.5).float())) + 1

        g2_opt = ((O1_tensor[:, 2] > -1).float() * ((Y1_tensor_opt > 0.5).float() + (Y1_tensor_opt > 3).float())) + 1
    else:

        g1_opt = ((O1_tensor[:, 0] > -0.5).float() * (1 + (O1_tensor[:, 0] - O1_tensor[:, 1] > 0).float())) + 1

        g2_opt = ((O1_tensor[:, 2] > 0).float() + ((O1_tensor[:, 2] + Y1_tensor_opt > 2.5).float())) + 1

    Y2_tensor = torch.exp(torch.tensor(1.26) - torch.abs(torch.tensor(1.5) * O1_tensor[:, 2] - 2) * (A2_tensor - g2_opt)**2) + generated_data["Z2"]


    # Probabilities
    pi_tensors = [generated_data[key] for key in ['pi_10', 'pi_11', 'pi_12', 'pi_20', 'pi_21', 'pi_22']]
    pi_tensor_stack = torch.stack(pi_tensors)

    # Adjusting A1 and A2 indices
    A1_indices = (A1_tensor - 1).long().unsqueeze(0)  # A1 actions, Subtract 1 to match index values (0, 1, 2)
    A2_indices = (A2_tensor - 1 + 3).long().unsqueeze(0)   # A2 actions, Add +3 to match index values (3, 4, 5) for A2, with added dimension

    # Gathering probabilities based on actions
    # Since we're gathering along the first dimension, we need to use the dim=0 parameter in torch.gather
    P_A1_given_H1_tensor = torch.gather(pi_tensor_stack, dim=0, index=A1_indices).squeeze(0)  # Remove the added dimension after gathering
    P_A2_given_H2_tensor = torch.gather(pi_tensor_stack, dim=0, index=A2_indices).squeeze(0)  # Remove the added dimension after gathering

    # Calculate Ci tensor
    Ci_tensor = (Y1_tensor + Y2_tensor) / (P_A1_given_H1_tensor * P_A2_given_H2_tensor)

    # Input preparation
    input_stage1 = O1_tensor

    if setting == 'tao':
        input_stage2 = torch.cat([O1_tensor, A1_tensor.unsqueeze(1), Y1_tensor.unsqueeze(1)], dim=1)
    else:
        input_stage2 = torch.cat([O1_tensor, A1_tensor.unsqueeze(1), Y1_tensor.unsqueeze(1), O2_tensor.unsqueeze(1)], dim=1)



    if run == 'test':
        return input_stage1, input_stage2, Ci_tensor

    # print("input_stage2: ", input_stage2[:5, :])

    # Splitting data into training and validation sets
    train_size = int(training_validation_prop * sample_size)
    train_tensors = [tensor[:train_size] for tensor in [input_stage1, input_stage2, Ci_tensor, Y1_tensor, Y2_tensor, A1_tensor, A2_tensor]]
    val_tensors = [tensor[train_size:] for tensor in [input_stage1, input_stage2, Ci_tensor, Y1_tensor, Y2_tensor, A1_tensor, A2_tensor]]

    return tuple(train_tensors), tuple(val_tensors), tuple([input_stage1, input_stage2, Ci_tensor, Y1_tensor, Y2_tensor, A1_tensor, A2_tensor, pi_tensor_stack, g1_opt, g2_opt])


def adaptive_contrast_tao(all_data, contrast):

    train_input_stage1, train_input_stage2, train_Ci, train_Y1, train_Y2, train_A1, train_A2, pi_tensor_stack, g1_opt, g2_opt = all_data

    A1 = train_A1.numpy()
    probs1 = pi_tensor_stack.T[:, :3].numpy()

    A2 = train_A2.numpy()
    probs2 = pi_tensor_stack.T[:, 3:].numpy()

    R1 = train_Y1.numpy()
    R2 = train_Y2.numpy()

    g1_opt = g1_opt.numpy()
    g2_opt = g2_opt.numpy()



    if setting == 'tao':

        # Convert tensor to numpy if not already numpy array
        train_input_np = train_input_stage1.numpy()

        x1 = train_input_np[:, 0]
        x2 = train_input_np[:, 1]
        x3 = train_input_np[:, 2]
        x4 = train_input_np[:, 3]
        x5 = train_input_np[:, 4]


        # Load the R script containing the function
        ro.r('source("ACWL_tao.R")')

        # Call the R function
        results = ro.globalenv['train_ACWL'](x1, x2, x3, x4, x5, A1, probs1, A2, probs2, R1, R2, g1_opt, g2_opt, contrast, method= f_model)


    elif setting == 'linear':

        # Convert tensor to numpy if not already numpy array
        train_input_np = train_input_stage2.numpy()

        x1 = train_input_np[:, 0]
        x2 = train_input_np[:, 1]
        O2 = train_input_np[:, 4]


        # Load the R script containing the function
        ro.r('source("ACWL_linear.R")')

        # Call the R function
        results = ro.globalenv['train_ACWL'](x1, x2, O2, A1, probs1, A2, probs2, R1, R2, g1_opt, g2_opt, contrast, method= f_model)

    # Extract results
    select2 = results.rx2('select2')[0]
    select1 = results.rx2('select1')[0]
    selects = results.rx2('selects')[0]

    return select2, select1, selects



def prepare_stage2_test_input(generated_test_data, A1):
    # Extract tensors from the generated_test_data dictionary
    O1_tensor_test, O2_tensor_test = [generated_test_data[key] for key in ['O1', 'O2']]
    Z1_tensor_test = generated_test_data['Z1']

    # Optimal policy conditions for Stage 1
    if tree_type:
        g1_opt_conditions = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5).float() + (O1_tensor_test[:, 1] > 0.5).float())) + 1
    else:
        g1_opt_conditions = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0).float())) + 1


    # Assuming g1_opt_conditions gives the optimal action, we use it to compute Y1_pred
    if noiseless:
        Z1_tensor_test = 0

    Y1_pred = torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (A1 - g1_opt_conditions)**2) + Z1_tensor_test

        
    test_input_stage2 = torch.cat([O1_tensor_test, A1.unsqueeze(1), Y1_pred.unsqueeze(1)], dim=1)
    #print("Y1_pred [min, max, mean]: ", [torch.min(Y1_pred), torch.max(Y1_pred), torch.mean(Y1_pred)], torch.min(torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (A1 - g1_opt_conditions)**2)),  torch.min(Z1_tensor_test))
    # Calculate the required quantities
    Y1_stats = [torch.min(Y1_pred), torch.max(Y1_pred), torch.mean(Y1_pred)]
    exp_calculation = torch.min(torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (A1 - g1_opt_conditions)**2))
    #Z1_min = torch.min(Z1_tensor_test)

    # Construct the message string
    stats_message = f"Y1_pred [min, max, mean]: {Y1_stats}"

    # Use tqdm.write() to print the stats
    tqdm.write(stats_message)




    return test_input_stage2, Y1_pred



def prepare_Y2_pred(generated_test_data, A1, A2):
    O1_tensor_test = generated_test_data['O1']
    Z1_tensor_test = generated_test_data['Z1']
    Z2_tensor_test = generated_test_data['Z2']  # Assuming Z2 is a key in your dictionary


    if tree_type:
        Y1_pred = torch.exp(torch.tensor(1.5)) + Z1_tensor_test

        g2_opt_conditions = ((O1_tensor_test[:, 2] > -1).float() * ((Y1_pred > 0.5).float() + (Y1_pred > 3).float())) + 1

    else:
        Y1_pred = torch.exp(torch.tensor(1.5)) #torch.exp(1.5) #+ Z1_tensor_test

        g2_opt_conditions = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + Y1_pred > 2.5).float())) + 1

    # Assuming g2_opt_conditions gives the optimal action, we use it to compute Y2_pred
    if noiseless:
        Z2_tensor_test = 0

    Y2_pred = torch.exp(1.26 - torch.abs(1.5 * O1_tensor_test[:, 2] - 2) * (A2 - g2_opt_conditions)**2) + Z2_tensor_test

    # print("Y2_pred [min, max, mean]: ", [torch.min(Y2_pred), torch.max(Y2_pred), torch.mean(Y2_pred)] )
    stats_message = f"Y2_pred [min, max, mean]: [{torch.min(Y2_pred)}, {torch.max(Y2_pred)}, {torch.mean(Y2_pred)}]"
    tqdm.write(stats_message)

    return Y2_pred



# def compute_optimal_policy(generated_test_data, A1, A2):
def compute_optimal_policy(generated_test_data):
    O1_tensor_test = generated_test_data['O1']
    Z1_tensor_test = generated_test_data['Z1']

    Y1_pred = torch.exp(torch.tensor(1.5)) + Z1_tensor_test
    if tree_type:
        d1_star = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5).float() + (O1_tensor_test[:, 1] > 0.5).float())) + 1
        d2_star = ((O1_tensor_test[:, 2] > -1).float() * ((Y1_pred > 0.5).float() + (Y1_pred > 3).float())) + 1
    else:
        d1_star = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0).float())) + 1
        d2_star = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + Y1_pred > 2.5).float())) + 1

    return d1_star, d2_star



def calculate_optimal_policy_values(d1_star, d2_star, generated_test_data):

    # Extract necessary tensors from the generated_test_data dictionary
    O1_tensor_test, O2_tensor_test, Z1_tensor_test, Z2_tensor_test = [generated_test_data[key] for key in ['O1', 'O2', 'Z1', 'Z2']]


    if tree_type:
        g1_opt_conditions = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5).float() + (O1_tensor_test[:, 1] > 0.5).float())) + 1
    else:
        g1_opt_conditions = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0).float())) + 1


    # Calculate Y1_test_opt and Y2_test_opt using the determined g1_opt and g2_opt
    Y1_test_opt = torch.exp(torch.tensor(1.5) - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (d1_star - g1_opt_conditions)**2) + Z1_tensor_test

    if tree_type:
        g2_opt_conditions = ((O1_tensor_test[:, 2] > -1).float() * ((Y1_test_opt > 0.5).float() + (Y1_test_opt > 3).float())) + 1
    else:
        g2_opt_conditions = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + Y1_test_opt > 2.5).float())) + 1

    Y2_test_opt = torch.exp(torch.tensor(1.26) - torch.abs(1.5 * O1_tensor_test[:, 2] - 2) * (d2_star - g2_opt_conditions)**2) + Z2_tensor_test

    # the following is simplified work; above is just for clarity
    # Y1_test_opt = torch.exp(torch.tensor(1.5)) + Z1_tensor_test
    # Y2_test_opt = torch.exp(torch.tensor(1.26)) + Z2_tensor_test
    test1 = torch.abs(1.5 * O1_tensor_test[:, 2] - 2) * (d2_star - g2_opt_conditions)**2

    # Aggregate the values into a tuple
    values_opt = (d1_star, O1_tensor_test, Z1_tensor_test, d2_star, Z2_tensor_test)

    return Y1_test_opt, Y2_test_opt, values_opt



def calculate_policy_values(d1_star, d2_star, generated_test_data, Y1_pred, Y2_pred, V_replications):
    # Optimal policy value calculation
    Y1_test_opt, Y2_test_opt, values_opt = calculate_optimal_policy_values(d1_star, d2_star, generated_test_data)
    V_d1_d2_opt = torch.mean(Y1_test_opt + Y2_test_opt).cpu().item()  # Calculate the mean value and convert to Python scalar
    V_replications["V_replications_M1_optimal"].append(V_d1_d2_opt)  # Append to the list for optimal policy values

    # Behavioral policy value calculation
    V_d1_d2 = torch.mean(generated_test_data['Y1'] + generated_test_data['Y2']).cpu().item()  # Calculate the mean value and convert to Python scalar
    V_replications["V_replications_M1_behavioral"].append(V_d1_d2)  # Append to the list for behavioral policy values

    # Current approach value calculation
    V_replications["V_replications_M1_pred"].append(torch.mean(Y1_pred + Y2_pred).item())  # Append the mean value as a Python scalar to the list for current approach values

    return V_replications, values_opt





def eval_DTR(sample_size, V_replications, num_replications, nn_stage1, nn_stage2, df, params):


    # Calculate V using the best model parameters on the test data # Calculate V using the best model parameters on the test data
    generated_test_data = generate_data(sample_size, params['setting'], replication_seed = num_replications)
    # generated_test_data = generate_data_test(sample_size, params['setting'], replication_seed = num_replications)

    # Preprocess data
    test_input_stage1, test_input_stage2, Ci_tensor = preprocess_data(generated_test_data, sample_size, params['setting'], replication_seed=num_replications, run='test')
    # test_input_stage1, test_input_stage2, Ci_tensor = preprocess_data_test(generated_test_data, params['setting'], replication_seed=num_replications, run='test')

    A1_tensor_test, A2_tensor_test = [generated_test_data[key] for key in ['A1', 'A2']]


    # optimal policy
    d1_star, d2_star =  compute_optimal_policy(generated_test_data)


    # Calculate test outputs for all networks in stage 1
    # Perform forward pass
    test_input_np = test_input_stage1.numpy()
    x1 = test_input_np[:, 0]
    x2 = test_input_np[:, 1]
    x3 = test_input_np[:, 2]
    x4 = test_input_np[:, 3]
    x5 = test_input_np[:, 4]


    # Load the R script containing the function
    ro.r('source("ACWL_tao.R")')

    # Call the R function
    results = ro.globalenv['test_ACWL'](x1, x2, x3, x4, x5, d1_star.numpy(), d2_star.numpy(), noiseless, method= f_model)

    # Extract results

    select2_test = results.rx2('select2')[0]
    select1_test = results.rx2('select1')[0]
    selects_test = results.rx2('selects')[0]

    # TODO: FIX THESE TO GET EXACTLY SAME ACCURACY AS WE GET IN PYTHON
    print(f"TEST: Select1: {select1_test}, Select2: {select2_test}, Selects: {selects_test}")


    # Extracting each component of the results and convert them to tensors
    Y1_pred_R = torch.tensor(np.array(results.rx2('R1.a1')), dtype=torch.float32)
    Y2_pred_R = torch.tensor(np.array(results.rx2('R2.a1')), dtype=torch.float32)



    # TODO: FIX THESE TO GET EXACTLY SAME ACCURACY AS WE GET IN PYTHON

    Y1_stats_R = [torch.min(Y1_pred_R), torch.max(Y1_pred_R), torch.mean(Y1_pred_R)]
    message = f"Y1_pred_R [min, max, mean]: {Y1_stats_R}"
    tqdm.write(message)
    message = f"Y2_pred_R [min, max, mean]: [{torch.min(Y2_pred_R)}, {torch.max(Y2_pred_R)}, {torch.mean(Y2_pred_R)}]"
    tqdm.write(message)


    # torch.mean(Y1_pred + Y2_pred): 4.660262107849121
    message = f'torch.mean(Y1_pred_R + Y2_pred_R): {torch.mean(Y1_pred_R + Y2_pred_R)} \n'
    tqdm.write(message)



    A1 = torch.tensor(np.array(results.rx2('g1.a1')), dtype=torch.float32)
    A2 = torch.tensor(np.array(results.rx2('g2.a1')), dtype=torch.float32)

    test_input_stage2, Y1_pred = prepare_stage2_test_input(generated_test_data, A1)
    Y2_pred =  prepare_Y2_pred(generated_test_data, A1, A2)



    # optimal policy
    # d1_star, d2_star =  compute_optimal_policy(generated_test_data, A1, A2)

    # Append to DataFrame
    new_row = {
        'Behavioral_A1': A1_tensor_test.cpu().numpy().tolist(),
        'Behavioral_A2': A2_tensor_test.cpu().numpy().tolist(),
        'Predicted_A1': A1.cpu().numpy().tolist(),
        'Predicted_A2':  A2.cpu().numpy().tolist(),
        'Optimal_A1': d1_star.cpu().numpy().tolist(),
        'Optimal_A2': d2_star.cpu().numpy().tolist()
        }

    # new_row = extract_and_prepare_data(A1_tensor_test, A2_tensor_test, A1, A2, d1_star, d2_star)

    df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)

    message = f'torch.mean(Y1_pred + Y2_pred): {torch.mean(Y1_pred + Y2_pred)} \n\n'
    #message = f'torch.mean(Y1_pred + Y2_pred): {torch.mean(Y1_pred + Y2_pred).item() if isinstance(Y1_pred, torch.Tensor) else np.mean(Y1_pred + Y2_pred)}'
    tqdm.write(message)

    V_replications, values_opt = calculate_policy_values(d1_star, d2_star, generated_test_data, Y1_pred, Y2_pred, V_replications)


    return V_replications, df, values_opt


def simulations(surrogate_num, sample_size, num_replications, V_replications, params):
    columns = ['Behavioral_A1', 'Behavioral_A2', 'Predicted_A1', 'Predicted_A2', 'Optimal_A1', 'Optimal_A2']
    df = pd.DataFrame(columns=columns)
    losses_dict = {}
    epoch_num_model_lst = []

    for replication in tqdm(range(num_replications), desc="Replications_M1"):

        # Generate data
        generated_data = generate_data(sample_size, params['setting'],replication_seed=replication)
        # Preprocess data (moved to GPU)
        tuple_train, tuple_Val, all_data = preprocess_data(generated_data, sample_size, params['setting'], replication_seed=replication)

        #  Estimate treatment regime: params['f_model']
        (select2, select1, selects) = adaptive_contrast_tao(all_data, params["contrast"])
        # eval_DTR
        V_replications, df, values_opt = eval_DTR(sample_size, V_replications, replication, None, None, df, params)


    return V_replications, df, values_opt, losses_dict, epoch_num_model_lst




from tqdm.notebook import tqdm


sample_size = 1000  # 500, 1000 are the cases to check
batch_prop = 0.2 #0.07, 0.2
if sample_size < 500:
    batch_prop = 0.5

training_validation_prop = 0.5 #0.95 #0.01

# Prompt user for the number of replications
num_replications = 4

# Prompt user for the setting
setting = 'tao' # 'linear', 'tao', 'scheme_i'

noiseless = True # True False



if setting == 'tao':
    tree_type =  True # True False

# Prompt user for the model type
f_model = 'tao' # (linear, 'tao', 'DQlearning', 'surr_opt'): " tao => adaptive_contrast_tao) # Note for linear linear run separate R code

contrast = 1

surrogate_num = 1 #1- old multiplicative one  2- new one

option_sur = 1 # if surrogate_num = 1 then from 1-5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric


# Constants from scheme_i
C1 = C2 = 3
beta = 1





# #BEST so far for tao's 15000

network_parameters_surogate = {
  'setting': setting,
  'n_epoch': 60, #250
  'num_networks': 2,
  'input_dim_stage1': 2,
  'output_dim_stage1': 1,
  'input_dim_stage2': 5,
  'output_dim_stage2': 1,
  'optimizer_betas': (0.9, 0.999),
  'optimizer_eps': 1e-08,
  'scheduler_step_size': 30,
  'scheduler_gamma': 0.8,
  'hidden_dim_stage1': 10, #20
  'hidden_dim_stage2': 10, #20
  'dropout_rate': 0.0, #0.3, 0.43
  'optimizer_lr': 0.07, # 0.07, 0.007
  'optimizer_weight_decay': 0.001,
  'batch_size': math.ceil(batch_prop*sample_size), #int(0.038*sample_size),
  'f_model': 'surr_opt',
  'option_sur': option_sur, # if surrogate_num = 1 then 5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric
  'contrast': contrast

}







# input_stage2 = [O1, A1, Y1, O2]

if setting =='tao':
    network_parameters_surogate['input_dim_stage1'] = 5
    network_parameters_surogate['input_dim_stage2'] = 7


network_parameters_surogate['option_sur'] = 2 # symmetric
n_epoch = network_parameters_surogate['n_epoch']

# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Lists to store replication results
V_replications_M1_behavioral = []
V_replications_M1_pred = []
V_replications_M1_optimal = []

V_replications = {"V_replications_M1_behavioral": V_replications_M1_behavioral, "V_replications_M1_pred": V_replications_M1_pred, "V_replications_M1_optimal": V_replications_M1_optimal}


print('Setting: ' , setting)
print("f_model: ", f_model)


network_parameters_surogate['f_model'] = 'tao'
V_replications, df, values_opt, losses_dict, epoch_num_model_lst = simulations(surrogate_num,
                                                                           sample_size,
                                                                           num_replications,
                                                                           V_replications,
                                                                           network_parameters_surogate)






In [1]:
import os
import pickle


In [2]:
pwd

'/Users/nilson/Desktop/DTR project/0.DirectSearchApplication'

In [5]:
# folder = "data/None"
folder = "data/20240802130249"

# Define paths to the files
# df_path = os.path.join(folder, 'simulation_data.pkl')
df_path = os.path.join(folder, 'simulation_results.pkl')

# Load DataFrame
# global_df = pd.read_csv(df_path)
with open(df_path, 'rb') as f:
    global_df = pickle.load(f)


In [6]:
global_df

{'{"activation_function": "elu", "batch_size": 8000, "learning_rate": 0.07, "num_layers": 2}': {'DQL': {"Method's Value fn.": 7.881669173638026,
   'Behavioral Value fn.': 3.68256147702535},
  'DS': {"Method's Value fn.": 7.829710245132446,
   'Behavioral Value fn.': 3.68256147702535}},
 '{"activation_function": "elu", "batch_size": 7000, "learning_rate": 0.07, "num_layers": 2}': {'DQL': {"Method's Value fn.": 7.874249408642451,
   'Behavioral Value fn.': 3.6805241107940674},
  'DS': {"Method's Value fn.": 7.81242357691129,
   'Behavioral Value fn.': 3.6805241107940674}}}

In [3]:
def extract_unique_treatment_values(df, columns_to_process):
    unique_values = {}

    for key, cols in columns_to_process.items():
        unique_values[key] = {}
        
        for col in cols:
            all_values = [item for sublist in df[col] for item in sublist]
            unique_values[key][col] = set(all_values)

    return unique_values

In [4]:
# folder = "data/None"
folder = "data/20240710213120"

# Define paths to the files
# df_path = os.path.join(folder, 'simulation_data.pkl')
df_path = os.path.join(folder, 'simulation_data_DS.pkl')

# Load DataFrame
# global_df = pd.read_csv(df_path)
with open(df_path, 'rb') as f:
    global_df = pickle.load(f)

# Extract and process unique values
columns_to_process = {
    'Predicted': ['Predicted_A1', 'Predicted_A2'],
}
unique_values = extract_unique_treatment_values(global_df, columns_to_process)
unique_values

{'Predicted': {'Predicted_A1': {1, 2, 3}, 'Predicted_A2': {1, 2, 3}}}

In [5]:
# unique_values_ = {"key1": "value1", "key2": "value2"}  # Example dictionary
log_message = "\nUnique values:\n" + "\n".join(f"{k}: {v}" for k, v in unique_values.items()) + "\n"
print(log_message)
print(f"\nUnique_values:  {unique_values} ")


Unique values:
Predicted: {'Predicted_A1': {1, 2, 3}, 'Predicted_A2': {1, 2, 3}}


Unique_values:  {'Predicted': {'Predicted_A1': {1, 2, 3}, 'Predicted_A2': {1, 2, 3}}} 


In [6]:
# Define paths to the files for both DQL and DS

results_path = os.path.join(folder, 'simulation_results.pkl')


with open(results_path, 'rb') as f:
    results = pickle.load(f)

In [7]:
results.keys()

dict_keys(['{"activation_function": "relu", "batch_size": 3072, "learning_rate": 0.007, "num_layers": 4}'])

In [8]:
results

{'{"activation_function": "relu", "batch_size": 3072, "learning_rate": 0.007, "num_layers": 4}': {'DQL': {"Method's Value fn.": 7.834487060705821,
   'Behavioral Value fn.': 3.681306759516398},
  'DS': {"Method's Value fn.": 6.1342741052309675,
   'Behavioral Value fn.': 3.681306759516398}}}

In [9]:
results['{"activation_function": "relu", "batch_size": 3072, "learning_rate": 0.007, "num_layers": 4}']["DQL"]

{"Method's Value fn.": 7.834487060705821,
 'Behavioral Value fn.': 3.681306759516398}

In [10]:
results['{"activation_function": "relu", "batch_size": 3072, "learning_rate": 0.007, "num_layers": 4}']["DS"]

{"Method's Value fn.": 6.1342741052309675,
 'Behavioral Value fn.': 3.681306759516398}

In [13]:
for config_key, performance in results.items():
    print(performance)


{'DQL': {"Method's Value fn.": 7.834487060705821, 'Behavioral Value fn.': 3.681306759516398}, 'DS': {"Method's Value fn.": 6.1342741052309675, 'Behavioral Value fn.': 3.681306759516398}}


In [14]:
results

{'{"activation_function": "relu", "batch_size": 3072, "learning_rate": 0.007, "num_layers": 4}': {'DQL': {"Method's Value fn.": 7.834487060705821,
   'Behavioral Value fn.': 3.681306759516398},
  'DS': {"Method's Value fn.": 6.1342741052309675,
   'Behavioral Value fn.': 3.681306759516398}}}

In [15]:
results.keys()

dict_keys(['{"activation_function": "relu", "batch_size": 3072, "learning_rate": 0.007, "num_layers": 4}'])

In [93]:
import torch
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import OneHotEncoder
from sklearn.exceptions import ConvergenceWarning

import warnings



def A_sim(matrix_pi, stage):
    N, K = matrix_pi.shape  # sample size and treatment options
    if N <= 1 or K <= 1:
        logger.error("Sample size or treatment options are insufficient! N: %d, K: %d", N, K)
        raise ValueError("Sample size or treatment options are insufficient!")
    if torch.any(matrix_pi < 0):
        logger.error("Treatment probabilities should not be negative!")
        raise ValueError("Treatment probabilities should not be negative!")

    # Normalize probabilities to add up to 1 and simulate treatment A for each row
    pis = matrix_pi.sum(dim=1, keepdim=True)
    probs = matrix_pi / pis
    A = torch.multinomial(probs, 1).squeeze()

    if stage == 1:
        col_names = ['pi_10', 'pi_11', 'pi_12']
    else:
        col_names = ['pi_20', 'pi_21', 'pi_22']
    
    probs_dict = {name: probs[:, idx] for idx, name in enumerate(col_names)}
    
    return {'A': A, 'probs': probs_dict}

def M_propen(A, Xs, stage):
    """Estimate propensity scores using logistic or multinomial regression."""
    A = np.asarray(A).reshape(-1, 1)
    if A.shape[1] != 1:
        raise ValueError("Cannot handle multiple stages of treatments together!")
    if A.shape[0] != Xs.shape[0]:
        print("A.shape, Xs.shape: ", A.shape, Xs.shape)
        raise ValueError("A and Xs do not match in dimension!")
    if len(np.unique(A)) <= 1:
        raise ValueError("Treatment options are insufficient!")

    # Handle multinomial case using Logistic Regression
    encoder = OneHotEncoder(sparse_output=False)  # Updated parameter name
    A_encoded = encoder.fit_transform(A)
    model = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)

    # Suppressing warnings from the solver, if not converged
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", ConvergenceWarning)
        model.fit(Xs, A.ravel())

    # Predicting probabilities
    s_p = model.predict_proba(Xs)

    if stage == 1:
        col_names = ['pi_10', 'pi_11', 'pi_12']
    else:
        col_names = ['pi_20', 'pi_21', 'pi_22']
    probs_df = pd.DataFrame(s_p, columns=col_names)

    return probs_df




In [102]:
# torch.manual_seed(replication_seed)
sample_size = 10
device = 'cpu'

# Simulate baseline covariates
O1 = torch.randn(5, sample_size, device=device)
Z1 = torch.randn(sample_size, device=device)
Z2 = torch.randn(sample_size, device=device)


# Stage 1 data simulation
x1, x2, x3, x4, x5 = O1[0], O1[1], O1[2], O1[3], O1[4]
pi_10 = torch.ones(sample_size, device=device)
pi_11 = torch.exp(0.5 - 0.5 * x3)
pi_12 = torch.exp(0.5 * x4)
matrix_pi1 = torch.stack((pi_10, pi_11, pi_12), dim=0).t()

result1 = A_sim(matrix_pi1, stage=1)

#     A1, probs1 = result1['A'], result1['probs']
A1, r1 = result1['A'], result1['probs']
# Propensity stage 1
probs1 = M_propen(A1, O1[[2, 3]].t(), stage=1)  # multinomial logistic regression with X3, X4


In [103]:
probs1

Unnamed: 0,pi_10,pi_11,pi_12
0,0.397228,0.494092,0.10868
1,0.333437,0.388866,0.277697
2,0.07877,0.641926,0.279303
3,0.478386,0.410891,0.110723
4,0.077475,0.437941,0.484584
5,0.058862,0.733756,0.207381
6,0.378171,0.525913,0.095916
7,0.447531,0.445077,0.107393
8,0.473132,0.343207,0.183661
9,0.276969,0.578345,0.144686


In [104]:
pd.DataFrame(r1)

Unnamed: 0,pi_10,pi_11,pi_12
0,0.321233,0.455503,0.223264
1,0.233064,0.334597,0.43234
2,0.179841,0.566244,0.253915
3,0.330334,0.404783,0.264883
4,0.147126,0.410865,0.442009
5,0.17503,0.647368,0.177602
6,0.325254,0.479938,0.194808
7,0.329396,0.427222,0.243382
8,0.279626,0.326921,0.393454
9,0.280332,0.494279,0.225388


In [86]:
A1

tensor([2, 2, 0, 1, 1, 0, 2, 2, 2, 0])

In [87]:
O1[[2, 3]]

tensor([[-0.2479, -0.0342, -0.5946,  0.1033,  0.4231, -0.3447, -0.4391, -0.9040,
         -0.1877,  0.3113],
        [-0.9426,  0.0173, -1.1182,  1.5023, -0.8820, -0.3289, -0.8081, -1.2399,
          1.4100,  0.2322]])

In [88]:
O1

tensor([[ 4.5906e-01, -8.7782e-01, -6.0655e-01,  1.6869e-01, -7.3726e-01,
         -1.7103e+00,  1.7288e-01, -6.4159e-01, -4.9704e-01, -1.5568e+00],
        [-1.5049e+00, -8.8236e-01,  1.6431e-03, -7.6861e-01,  3.0450e-01,
          5.0576e-01, -8.9674e-01,  3.1831e-01, -1.5424e+00, -1.5294e+00],
        [-2.4786e-01, -3.4243e-02, -5.9459e-01,  1.0329e-01,  4.2314e-01,
         -3.4472e-01, -4.3906e-01, -9.0396e-01, -1.8769e-01,  3.1130e-01],
        [-9.4263e-01,  1.7326e-02, -1.1182e+00,  1.5023e+00, -8.8204e-01,
         -3.2889e-01, -8.0814e-01, -1.2399e+00,  1.4100e+00,  2.3216e-01],
        [ 5.3532e-01, -8.3709e-02,  1.0224e-02, -1.4863e+00,  1.6502e+00,
          9.7107e-01,  1.0964e-01,  7.8011e-01,  2.1684e+00, -4.5745e-02]])

In [89]:
O1[:,[2,3]].shape

torch.Size([5, 2])

In [90]:
O1[[2, 3]].shape

torch.Size([2, 10])

In [91]:
O1[[2, 3]]

tensor([[-0.2479, -0.0342, -0.5946,  0.1033,  0.4231, -0.3447, -0.4391, -0.9040,
         -0.1877,  0.3113],
        [-0.9426,  0.0173, -1.1182,  1.5023, -0.8820, -0.3289, -0.8081, -1.2399,
          1.4100,  0.2322]])

In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
from scipy.stats import norm
from collections import Counter

import pdb
# # Set the seed for reproducibility
# seed = 12345
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# np.random.seed(seed)

import warnings

# Suppress all warnings
warnings.filterwarnings('ignore')

# pip install rpy2

import rpy2.robjects as ro
from rpy2.robjects import numpy2ri

# Activate automatic conversion of numpy objects to R objects
numpy2ri.activate()




def batches(N, batch_size, seed=0):
    np.random.seed(seed)
    indices = np.arange(N)
    np.random.shuffle(indices)
    for start_idx in range(0, N, batch_size):
        batch_indices = indices[start_idx:start_idx+batch_size]
        yield torch.tensor(batch_indices, dtype=torch.long)


def extract_and_prepare_data(A1_tensor_test, A2_tensor_test, A1, A2, d1_star, d2_star):

    # Helper function to convert tensors to numpy arrays
    def to_numpy(data):
        if isinstance(data, torch.Tensor):
            return data.cpu().numpy() if data.is_cuda else data.numpy()
        elif isinstance(data, np.ndarray):
            return data
        else:
            raise TypeError("The input must be a PyTorch tensor or a NumPy array")

    new_row = {
        'Behavioral_A1': to_numpy(A1_tensor_test).tolist(),
        'Behavioral_A2': to_numpy(A2_tensor_test).tolist(),
        'Predicted_A1': to_numpy(A1).tolist(),
        'Predicted_A2': to_numpy(A2).tolist(),
        'Optimal_A1': to_numpy(d1_star).tolist(),
        'Optimal_A2': to_numpy(d2_star).tolist()
    }

    return new_row





class NNClass(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_networks, dropout_rate):
        super(NNClass, self).__init__()
        self.networks = nn.ModuleList()
        for _ in range(num_networks):
            network = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ELU(alpha=0.4),  # Using ELU instead of ReLU; best result 0.2, 0.13
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ELU(alpha=0.4),
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_dim, output_dim),
                nn.BatchNorm1d(output_dim),
            )
            self.networks.append(network)

    def forward(self, x):
        outputs = []
        for network in self.networks:
            outputs.append(network(x))
        return outputs

    def he_initializer(self):
        for network in self.networks:
            for layer in network:
                if isinstance(layer, nn.Linear):
                    nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                    nn.init.constant_(layer.bias, 0)  # Biases can be initialized to zero

    def reset_weights(self):
        for network in self.networks:
            for layer in network:
                if isinstance(layer, nn.Linear):
                    nn.init.constant_(layer.weight, 0.1) # best 0.1
                    nn.init.constant_(layer.bias, 0.0)







## train V validation loss

def plot_simulation_surLoss_losses_in_grid(selected_indices, losses_dict, cols=3):
    # Calculate the number of rows needed based on the number of selected indices and desired number of columns
    rows = len(selected_indices) // cols + (len(selected_indices) % cols > 0)

    # Create a figure and a set of subplots
    fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 4*rows))  # Adjust figure size as needed
    fig.suptitle(f'Training and Validation Loss for Selected Simulations @ n_epoch = {n_epoch}')

    # Flatten the axes array for easy indexing, in case of a single row or column
    axes = axes.flatten()

    for i, idx in enumerate(selected_indices):
        train_loss, val_loss = losses_dict[idx]

        # Plot on the ith subplot
        axes[i].plot(train_loss, label='Training')
        axes[i].plot(val_loss, label='Validation')
        axes[i].set_title(f'Simulation {idx}')
        axes[i].set_xlabel('Epochs')
        axes[i].set_ylabel('Loss')
        axes[i].legend()

    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust the layout to make room for the suptitle
    plt.show()



def transform_Y(Y1, Y2):

    # Identify the minimum value among Y1 and Y2, only if they are negative
    min_negative_Y = min(min(Y1.min(), 0), min(Y2.min(), 0))

    # If the minimum value is negative, adjust Y1 and Y2 by adding the absolute
    # value of min_negative_Y plus 1 to ensure all values are non-negative
    if min_negative_Y < 0:
        Y1_trans = Y1 - min_negative_Y + 1
        Y2_trans = Y2 - min_negative_Y + 1
    else:
        # If there are no negative values, no adjustment is needed
        Y1_trans = Y1
        Y2_trans = Y2

    return Y1_trans, Y2_trans


def A_sim(matrix_pi, stage):
    N, K = matrix_pi.shape  # sample size and treatment options
    if N <= 1 or K <= 1:
        raise ValueError("Sample size or treatment options are insufficient!")
    if np.min(matrix_pi) < 0:
        raise ValueError("Treatment probabilities should not be negative!")

    # Normalize probabilities to add up to 1 and simulate treatment A for each row
    pis = matrix_pi.sum(axis=1)
    probs = np.divide(matrix_pi, pis[:, np.newaxis])
    A = np.array([np.random.choice(np.arange(K), p=probs[i,]) for i in range(N)])
    if stage == 1:
        col_names = ['pi_10', 'pi_11', 'pi_12']
    else:
        col_names = ['pi_20', 'pi_21', 'pi_22']
    probs_df = pd.DataFrame(probs, columns=col_names)
    return {'A': A, 'probs': probs_df}




def generate_data(sample_size, setting, replication_seed):
    np.random.seed(replication_seed)

    print("DGP: ", setting)


    if setting == 'linear':

        # Generate data
        O1, O2 = np.random.normal(0, 1, size=(sample_size, 2)), np.random.normal(0, 1, sample_size) # Combined Stage 1 Inputs O1 with 2 features


        A1, A2 = np.random.randint(1, 4, size=sample_size), np.random.randint(1, 4, size=sample_size)
        Z1, Z2 = np.random.normal(0, 0.5, sample_size), np.random.normal(0, 0.5, sample_size)

        # Compute Y1 and Y2
        Y1 = 15 + A1 + O1.sum(axis=1) + np.prod(O1, axis=1) + Z1
        Y2 = 15 + O2 + A2 * (1 - O2 + A1 + O1.sum(axis=1)) + Z2

        pi_value = np.full(sample_size, 1 / 3)  # Probability value when there are 3 treatments
        pi_10 = pi_11 = pi_12 = pi_20 = pi_21 = pi_22 = pi_value

        # Compute optimal policy decisions for 'linear', updated for combined O1 tensor
        g1_opt = np.full(sample_size, 3)  # Assuming the optimal policy for stage 1 is to always choose action 3
        g2_opt = np.where( (1 - O2 + g1_opt + O1.sum(axis=1)).astype(float) > 0, 3, 1)



    elif setting == 'scheme_i':
        # Generate data
        O1 = np.random.normal(0, 1, (sample_size, 3))
        Z1, Z2, A1, A2, O2 = [np.random.normal(0, 1, sample_size) for _ in range(2)] + \
                             [np.random.randint(1, 4, sample_size) for _ in range(2)] + \
                             [np.random.normal(0, 1, sample_size)]


        # Compute Y1 using g(O1) and A1
        Y1 = A1 * g(O1) + C1 + Z1

        # Compute Y2 as a weighted sum of f_i(O1, A1) based on the value of A2
        Y2 = sum(f_i(O1, A1, i) * (A2 == i) for i in range(1, 4)) + O2 * beta + C2 + Z2

        # Probabilities for treatments, here assuming it's the same as linear case
        pi_value = np.full(sample_size, 1 / 3)  # Equal probability for 3 treatments
        pi_10 = pi_11 = pi_12 = pi_20 = pi_21 = pi_22 = pi_value


    elif setting == 'tao':


        # Simulate baseline covariates
        N = sample_size
        x1, x2, x3, x4, x5 = np.random.normal(size=(5, N))
        O1 = np.column_stack((x1, x2, x3, x4, x5))

        Z1 = np.random.normal(0, 1, sample_size)
        Z2 = np.random.normal(0, 1, sample_size)

        if noiseless:
            Z1 = 0
            Z2 = 0

        # Stage 1 data simulation
        pi_10 = np.ones(N)
        pi_11 = np.exp(0.5 - 0.5 * x3)
        pi_12 = np.exp(0.5 * x4)
        matrix_pi1 = np.column_stack((pi_10, pi_11, pi_12))
        result1 = A_sim(matrix_pi1, stage=1)
        A1, probs1 = result1['A'], result1['probs']
        A1 +=1


        # Optimal g1.opt corrected
        g1_opt = ((x1 > -1).astype(float) * ((x2 > -0.5).astype(float) + (x2 > 0.5).astype(float))).astype(float) + 1


        # Stage 1 outcome R1

        Y1 = np.exp(1.5 - np.abs(1.5 * x1 + 2) * (A1 - g1_opt)**2) + Z1

        # Stage 2 data simulation
        pi_20 = np.ones(N)
        pi_21 = np.exp(0.2 * Y1 - 1)
        pi_22 = np.exp(0.5 * x4)
        matrix_pi2 = np.column_stack((pi_20, pi_21, pi_22))
        result2 = A_sim(matrix_pi2, stage=2)
        A2, probs2 = result2['A'], result2['probs']
        A2 +=1

        Y1_opt = np.exp(1.5) + Z1

        g2_opt = (x3 > -1).astype(float) * ((Y1_opt > 0.5).astype(float) + (Y1_opt > 3).astype(float)) + 1



        # Stage 2 outcome R2
        Y2 = np.exp(1.26 - np.abs(1.5 * x3 - 2) * (A2 - g2_opt)**2) + Z2


        # Extract propensity for both stages
        pi_10, pi_11, pi_12 = probs1['pi_10'], probs1['pi_11'], probs1['pi_12']
        pi_20, pi_21, pi_22 = probs2['pi_20'], probs2['pi_21'], probs2['pi_22']

        # Dummy O2
        O2 = np.zeros(sample_size)

    Y1_trans, Y2_trans = transform_Y(Y1, Y2)



    return {
        'dgp': setting,
        'O1': torch.tensor(O1, dtype=torch.float32, device=device),
        'A1': torch.tensor(A1, dtype=torch.float32, device=device),
        'Z1': torch.tensor(Z1, dtype=torch.float32, device=device),
        'Y1': torch.tensor(Y1, dtype=torch.float32, device=device),
        'O2': torch.tensor(O2, dtype=torch.float32, device=device),
        'A2': torch.tensor(A2, dtype=torch.float32, device=device),
        'Z2': torch.tensor(Z2, dtype=torch.float32, device=device),
        'Y2': torch.tensor(Y2, dtype=torch.float32, device=device),
        'pi_10': torch.tensor(pi_10, dtype=torch.float32, device=device),
        'pi_11': torch.tensor(pi_11, dtype=torch.float32, device=device),
        'pi_12': torch.tensor(pi_12, dtype=torch.float32, device=device),
        'pi_20': torch.tensor(pi_20, dtype=torch.float32, device=device),
        'pi_21': torch.tensor(pi_21, dtype=torch.float32, device=device),
        'pi_22': torch.tensor(pi_22, dtype=torch.float32, device=device),
        'Y1_trans': torch.tensor(Y1_trans, dtype=torch.float32, device=device),
        'Y2_trans': torch.tensor(Y2_trans, dtype=torch.float32, device=device),
        'g1_opt': torch.tensor(g1_opt, dtype=torch.float32, device=device),
        'g2_opt': torch.tensor(g2_opt, dtype=torch.float32, device=device),
    }


# Preprocess
def preprocess_data(generated_data, sample_size, setting, replication_seed, run='train', training_validation_prop = 0.8):

    # Extract data from the generated_data dictionary
    O1_tensor, O2_tensor = [generated_data[key] for key in ['O1', 'O2']]
    A1_tensor, A2_tensor = [generated_data[key] for key in ['A1', 'A2']]
    g1_opt, g2_opt = [generated_data[key] for key in ['g1_opt', 'g2_opt']]


    if run == 'test':
        Y1_tensor, Y2_tensor = [generated_data[key] for key in ['Y1', 'Y2']]
    else:
        Y1_tensor, Y2_tensor = [generated_data[key] for key in ['Y1_trans', 'Y2_trans']]

    if setting == 'tao':
        Y1_tensor_opt = torch.exp(torch.tensor(1.5)) + generated_data['Z1']
        if tree_type:
            g1_opt = ((O1_tensor[:, 0] > -1).float() * ((O1_tensor[:, 1] > -0.5).float() + (O1_tensor[:, 1] > 0.5).float())) + 1

            g2_opt = ((O1_tensor[:, 2] > -1).float() * ((Y1_tensor_opt > 0.5).float() + (Y1_tensor_opt > 3).float())) + 1
        else:

            g1_opt = ((O1_tensor[:, 0] > -0.5).float() * (1 + (O1_tensor[:, 0] - O1_tensor[:, 1] > 0).float())) + 1

            g2_opt = ((O1_tensor[:, 2] > 0).float() + ((O1_tensor[:, 2] + Y1_tensor_opt > 2.5).float())) + 1

        Y2_tensor = torch.exp(torch.tensor(1.26) - torch.abs(torch.tensor(1.5) * O1_tensor[:, 2] - 2) * (A2_tensor - g2_opt)**2) + generated_data["Z2"]


    # Probabilities
    pi_tensors = [generated_data[key] for key in ['pi_10', 'pi_11', 'pi_12', 'pi_20', 'pi_21', 'pi_22']]
    pi_tensor_stack = torch.stack(pi_tensors)

    # Adjusting A1 and A2 indices
    A1_indices = (A1_tensor - 1).long().unsqueeze(0)  # A1 actions, Subtract 1 to match index values (0, 1, 2)
    A2_indices = (A2_tensor - 1 + 3).long().unsqueeze(0)   # A2 actions, Add +3 to match index values (3, 4, 5) for A2, with added dimension

    # Gathering probabilities based on actions
    # Since we're gathering along the first dimension, we need to use the dim=0 parameter in torch.gather
    P_A1_given_H1_tensor = torch.gather(pi_tensor_stack, dim=0, index=A1_indices).squeeze(0)  # Remove the added dimension after gathering
    P_A2_given_H2_tensor = torch.gather(pi_tensor_stack, dim=0, index=A2_indices).squeeze(0)  # Remove the added dimension after gathering

    # Calculate Ci tensor
    Ci_tensor = (Y1_tensor + Y2_tensor) / (P_A1_given_H1_tensor * P_A2_given_H2_tensor)

    # Input preparation
    input_stage1 = O1_tensor

    if setting == 'tao':
        input_stage2 = torch.cat([O1_tensor, A1_tensor.unsqueeze(1), Y1_tensor.unsqueeze(1)], dim=1)
    else:
        input_stage2 = torch.cat([O1_tensor, A1_tensor.unsqueeze(1), Y1_tensor.unsqueeze(1), O2_tensor.unsqueeze(1)], dim=1)



    if run == 'test':
        return input_stage1, input_stage2, Ci_tensor

    # print("input_stage2: ", input_stage2[:5, :])

    # Splitting data into training and validation sets
    train_size = int(training_validation_prop * sample_size)
    train_tensors = [tensor[:train_size] for tensor in [input_stage1, input_stage2, Ci_tensor, Y1_tensor, Y2_tensor, A1_tensor, A2_tensor]]
    val_tensors = [tensor[train_size:] for tensor in [input_stage1, input_stage2, Ci_tensor, Y1_tensor, Y2_tensor, A1_tensor, A2_tensor]]

    return tuple(train_tensors), tuple(val_tensors), tuple([input_stage1, input_stage2, Ci_tensor, Y1_tensor, Y2_tensor, A1_tensor, A2_tensor, pi_tensor_stack, g1_opt, g2_opt])


def adaptive_contrast_tao(all_data, contrast):

    train_input_stage1, train_input_stage2, train_Ci, train_Y1, train_Y2, train_A1, train_A2, pi_tensor_stack, g1_opt, g2_opt = all_data

    A1 = train_A1.numpy()
    probs1 = pi_tensor_stack.T[:, :3].numpy()

    A2 = train_A2.numpy()
    probs2 = pi_tensor_stack.T[:, 3:].numpy()

    R1 = train_Y1.numpy()
    R2 = train_Y2.numpy()

    g1_opt = g1_opt.numpy()
    g2_opt = g2_opt.numpy()



    if setting == 'tao':

        # Convert tensor to numpy if not already numpy array
        train_input_np = train_input_stage1.numpy()

        x1 = train_input_np[:, 0]
        x2 = train_input_np[:, 1]
        x3 = train_input_np[:, 2]
        x4 = train_input_np[:, 3]
        x5 = train_input_np[:, 4]


        # Load the R script containing the function
        ro.r('source("ACWL_tao.R")')

        # Call the R function
        results = ro.globalenv['train_ACWL'](x1, x2, x3, x4, x5, A1, probs1, A2, probs2, R1, R2, g1_opt, g2_opt, contrast, method= f_model)


    elif setting == 'linear':

        # Convert tensor to numpy if not already numpy array
        train_input_np = train_input_stage2.numpy()

        x1 = train_input_np[:, 0]
        x2 = train_input_np[:, 1]
        O2 = train_input_np[:, 4]


        # Load the R script containing the function
        ro.r('source("ACWL_linear.R")')

        # Call the R function
        results = ro.globalenv['train_ACWL'](x1, x2, O2, A1, probs1, A2, probs2, R1, R2, g1_opt, g2_opt, contrast, method= f_model)





    # Extract results
    select2 = results.rx2('select2')[0]
    select1 = results.rx2('select1')[0]
    selects = results.rx2('selects')[0]

    # TODO: FIX THESE TO GET EXACTLY SAME ACCURACY AS WE GET IN PYTHON
    #print(f"Train: Select1: {select1}, Select2: {select2}, Selects: {selects}")

    return select2, select1, selects



def prepare_stage2_test_input(generated_test_data, A1):
    # Extract tensors from the generated_test_data dictionary
    O1_tensor_test, O2_tensor_test = [generated_test_data[key] for key in ['O1', 'O2']]
    Z1_tensor_test = generated_test_data['Z1']

    if generated_test_data['dgp'] == 'linear':
        # Compute Y1_pred using the given formula, updated for combined O1 tensor
        Y1_pred = 15 + A1 + O1_tensor_test.sum(axis=1) + torch.prod(O1_tensor_test, dim=1) + Z1_tensor_test

    elif generated_test_data['dgp'] == 'scheme_i':
        # Compute Y1_pred for scheme_i using g(O1) and given A1
        Y1_pred = A1 * g(O1_tensor_test) + C1 + Z1_tensor_test

    elif generated_test_data['dgp'] == 'tao':
        # Determine g1_opt based on O1 features for the 'tao' setting
        # Here we apply the conditions for the 'tao' setting to determine g1_opt


        #         if tree_type:
        #             g1_opt_conditions = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5) + (O1_tensor_test[:, 1] > 0.5))) + 1
        #         else:
        #             g1_opt_conditions = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0))) + 1


        # Optimal policy conditions for Stage 1
        if tree_type:
            g1_opt_conditions = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5).float() + (O1_tensor_test[:, 1] > 0.5).float())) + 1
        else:
            g1_opt_conditions = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0).float())) + 1


        # Assuming g1_opt_conditions gives the optimal action, we use it to compute Y1_pred
        if noiseless:
            Z1_tensor_test = 0

        Y1_pred = torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (A1 - g1_opt_conditions)**2) + Z1_tensor_test

    # Form the test input for stage 2 by concatenating the necessary tensors
    if generated_test_data['dgp'] == 'tao':
        test_input_stage2 = torch.cat([O1_tensor_test, A1.unsqueeze(1), Y1_pred.unsqueeze(1)], dim=1)
        #print("Y1_pred [min, max, mean]: ", [torch.min(Y1_pred), torch.max(Y1_pred), torch.mean(Y1_pred)], torch.min(torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (A1 - g1_opt_conditions)**2)),  torch.min(Z1_tensor_test))
        # Calculate the required quantities
        Y1_stats = [torch.min(Y1_pred), torch.max(Y1_pred), torch.mean(Y1_pred)]
        exp_calculation = torch.min(torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (A1 - g1_opt_conditions)**2))
        #Z1_min = torch.min(Z1_tensor_test)

        # Construct the message string
        stats_message = f"Y1_pred [min, max, mean]: {Y1_stats}"

        # Use tqdm.write() to print the stats
        tqdm.write(stats_message)

    else:
        test_input_stage2 = torch.cat([O1_tensor_test, A1.unsqueeze(1), Y1_pred.unsqueeze(1), O2_tensor_test.unsqueeze(1)], dim=1)
        Y1_stats = [torch.min(Y1_pred), torch.max(Y1_pred), torch.mean(Y1_pred)]
        # Construct the message string
        stats_message = f"Y1_pred [min, max, mean]: {Y1_stats}"
        tqdm.write(stats_message)


    return test_input_stage2, Y1_pred



def prepare_Y2_pred(generated_test_data, A1, A2):
    O1_tensor_test = generated_test_data['O1']
    Z1_tensor_test = generated_test_data['Z1']
    Z2_tensor_test = generated_test_data['Z2']  # Assuming Z2 is a key in your dictionary

    if generated_test_data['dgp'] == 'linear':
        O2_tensor_test = generated_test_data['O2']
        # Compute Y2_pred for the 'linear' setting, updated for combined O1 tensor
        Y2_pred = 15 + O2_tensor_test + A2 * (1 - O2_tensor_test + A1 + O1_tensor_test.sum(axis=1)) + Z2_tensor_test

    elif generated_test_data['dgp'] == 'scheme_i':

        # Compute Y2_pred for scheme_i
        Y2_pred = sum(f_i(O1_tensor_test, A1, i) * (A2 == i) for i in range(1, 4)) + \
                  generated_test_data['O2'] * beta + C2 + Z2_tensor_test


    elif generated_test_data['dgp'] == 'tao':
        #         # Determine g2_opt based on O1 features for the 'tao' setting
        #         if tree_type:
        #             g2_opt_conditions = ((O1_tensor_test[:, 2] > -1).float() * ((generated_test_data['Y1'] > 0.5) + (generated_test_data['Y1'] > 3))) + 1
        #         else:
        #             # Assuming R1 is represented by Y1, which is already computed and passed as A1
        #             g2_opt_conditions = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + generated_test_data['Y1'] > 2.5).float())) + 1

        if tree_type:
            #g1_opt_conditions = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5).float() + (O1_tensor_test[:, 1] > 0.5).float())) + 1
            # Y1_pred = torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (d1_star - g1_opt_conditions)**2) #+ Z1_tensor_test
            Y1_pred = torch.exp(torch.tensor(1.5)) + Z1_tensor_test

            g2_opt_conditions = ((O1_tensor_test[:, 2] > -1).float() * ((Y1_pred > 0.5).float() + (Y1_pred > 3).float())) + 1

        else:
            # g1_opt_conditions = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0).float())) + 1
            # Y1_pred = torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (d1_star - g1_opt_conditions)**2) #+ Z1_tensor_test
            Y1_pred = torch.exp(torch.tensor(1.5)) #torch.exp(1.5) #+ Z1_tensor_test

            g2_opt_conditions = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + Y1_pred > 2.5).float())) + 1


        # Assuming g2_opt_conditions gives the optimal action, we use it to compute Y2_pred
        if noiseless:
            Z2_tensor_test = 0

        Y2_pred = torch.exp(1.26 - torch.abs(1.5 * O1_tensor_test[:, 2] - 2) * (A2 - g2_opt_conditions)**2) + Z2_tensor_test

    # print("Y2_pred [min, max, mean]: ", [torch.min(Y2_pred), torch.max(Y2_pred), torch.mean(Y2_pred)] )
    stats_message = f"Y2_pred [min, max, mean]: [{torch.min(Y2_pred)}, {torch.max(Y2_pred)}, {torch.mean(Y2_pred)}]"
    tqdm.write(stats_message)

    return Y2_pred



# def compute_optimal_policy(generated_test_data, A1, A2):
def compute_optimal_policy(generated_test_data):
    O1_tensor_test = generated_test_data['O1']
    Z1_tensor_test = generated_test_data['Z1']

    if generated_test_data['dgp'] == 'linear':
        O2_tensor_test = generated_test_data['O2']
        # Compute optimal policy decisions for 'linear', updated for combined O1 tensor
        d1_star = torch.full_like(Z1_tensor_test, 3, dtype=torch.float)  # Assuming the optimal policy for stage 1 is to always choose action 3
        d2_star = torch.where(1 - O2_tensor_test + d1_star + O1_tensor_test.sum(axis=1) > 0, torch.tensor(3.0), torch.tensor(1.0))


    elif generated_test_data['dgp'] == 'scheme_i':

        # Compute f_i for each action j and for each i
        # and then find the maximum f_i for each j
        f_max_for_each_j = torch.stack([
            torch.stack([f_i(O1_tensor_test, j, i) for i in range(1, 4)], dim=1).max(dim=1)[0]
            for j in range(1, 4)
        ], dim=1)

        # Compute h_j for each action j
        h_j_values = f_max_for_each_j + torch.arange(1, 4, dtype=torch.float32) * g(O1_tensor_test).unsqueeze(1)

        # Determine the action that maximizes h_j for each observation
        d1_star = torch.argmax(h_j_values, dim=1) + 1  # +1 because actions are 1-indexed

        # Compute d2_star, similar to d1_star but based on f_i values directly considering A1
        f_i_values = torch.stack([f_i(O1_tensor_test, generated_test_data['A1'], i) for i in range(1, 4)], dim=1)
        d2_star = torch.argmax(f_i_values, dim=1) + 1  # +1 because actions are 1-indexed

    elif generated_test_data['dgp'] == 'tao':
        # Determine optimal policies based on O1 features for the 'tao' setting
        #         if tree_type:
        #             d1_star = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5) + (O1_tensor_test[:, 1] > 0.5))) + 1
        #             d2_star = ((O1_tensor_test[:, 2] > -1).float() * ((generated_test_data['Y1'] > 0.5) + (generated_test_data['Y1'] > 3))) + 1
        #         else:
        #             d1_star = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0))) + 1
        #             d2_star = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + generated_test_data['Y1'] > 2.5).float())) + 1

        Y1_pred = torch.exp(torch.tensor(1.5)) + Z1_tensor_test
        if tree_type:
            d1_star = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5).float() + (O1_tensor_test[:, 1] > 0.5).float())) + 1

            #g1_opt_conditions = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5).float() + (O1_tensor_test[:, 1] > 0.5).float())) + 1
            # Y1_pred = torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (d1_star - g1_opt_conditions)**2) #+ Z1_tensor_test

            #Y1_pred = torch.full_like(generated_test_data['Y1'], torch.exp(torch.tensor(1.5)))

            d2_star = ((O1_tensor_test[:, 2] > -1).float() * ((Y1_pred > 0.5).float() + (Y1_pred > 3).float())) + 1
        else:

            d1_star = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0).float())) + 1
            # g1_opt_conditions = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0).float())) + 1
            # Y1_pred = torch.exp(1.5 - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (d1_star - g1_opt_conditions)**2) #+ Z1_tensor_test

            d2_star = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + Y1_pred > 2.5).float())) + 1

    return d1_star, d2_star



def calculate_optimal_policy_values(d1_star, d2_star, generated_test_data):

    # Extract necessary tensors from the generated_test_data dictionary
    O1_tensor_test, O2_tensor_test, Z1_tensor_test, Z2_tensor_test = [generated_test_data[key] for key in ['O1', 'O2', 'Z1', 'Z2']]

    # For the 'tao' setting, we don't use O2_tensor_test as O2 is not generated in this setting
    if generated_test_data['dgp'] == 'linear':
        Y1_test_opt = 15 + d1_star + O1_tensor_test.sum(axis=1) + torch.prod(O1_tensor_test, dim=1) + Z1_tensor_test
        Y2_test_opt = 15 + O2_tensor_test + d2_star * (1 - O2_tensor_test + d1_star + O1_tensor_test.sum(axis=1)) + Z2_tensor_test

    elif generated_test_data['dgp'] == 'scheme_i':

        # Compute the optimal Y1_test and Y2_test values using the optimal policies d1_star and d2_star
        # For Y1, use the g function and add C1, Z1
        Y1_test_opt = d1_star * g(O1_tensor_test) + C1 + Z1_tensor_test

        # For Y2, sum the f_i for the d2_star action, add O2 times beta, C2, and Z2
        Y2_test_opt = f_i(O1_tensor_test, d1_star, d2_star) + O2_tensor_test * beta + C2 + Z2_tensor_test


    elif generated_test_data['dgp'] == 'tao':
        # Determine g1_opt and g2_opt based on O1 features for the 'tao' setting
        #         if tree_type:
        #             g1_opt_conditions = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5) + (O1_tensor_test[:, 1] > 0.5))) + 1
        #             g2_opt_conditions = ((O1_tensor_test[:, 2] > -1).float() * ((generated_test_data['Y1'] > 0.5) + (generated_test_data['Y1'] > 3))) + 1
        #         else:
        #             g1_opt_conditions = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0))) + 1
        #             g2_opt_conditions = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + generated_test_data['Y1'] > 2.5).float())) + 1

        if tree_type:
            g1_opt_conditions = ((O1_tensor_test[:, 0] > -1).float() * ((O1_tensor_test[:, 1] > -0.5).float() + (O1_tensor_test[:, 1] > 0.5).float())) + 1
        else:
            g1_opt_conditions = ((O1_tensor_test[:, 0] > -0.5).float() * (1 + (O1_tensor_test[:, 0] - O1_tensor_test[:, 1] > 0).float())) + 1


        # Calculate Y1_test_opt and Y2_test_opt using the determined g1_opt and g2_opt
        Y1_test_opt = torch.exp(torch.tensor(1.5) - torch.abs(1.5 * O1_tensor_test[:, 0] + 2) * (d1_star - g1_opt_conditions)**2) + Z1_tensor_test

        if tree_type:
            g2_opt_conditions = ((O1_tensor_test[:, 2] > -1).float() * ((Y1_test_opt > 0.5).float() + (Y1_test_opt > 3).float())) + 1
        else:
            g2_opt_conditions = ((O1_tensor_test[:, 2] > 0).float() + ((O1_tensor_test[:, 2] + Y1_test_opt > 2.5).float())) + 1

        Y2_test_opt = torch.exp(torch.tensor(1.26) - torch.abs(1.5 * O1_tensor_test[:, 2] - 2) * (d2_star - g2_opt_conditions)**2) + Z2_tensor_test

        # the following is simplified work; above is just for clarity
        # Y1_test_opt = torch.exp(torch.tensor(1.5)) + Z1_tensor_test
        # Y2_test_opt = torch.exp(torch.tensor(1.26)) + Z2_tensor_test
        test1 = torch.abs(1.5 * O1_tensor_test[:, 2] - 2) * (d2_star - g2_opt_conditions)**2

    # Aggregate the values into a tuple
    values_opt = (d1_star, O1_tensor_test, Z1_tensor_test, d2_star, Z2_tensor_test)

    return Y1_test_opt, Y2_test_opt, values_opt



def calculate_policy_values(d1_star, d2_star, generated_test_data, Y1_pred, Y2_pred, V_replications):
    # Optimal policy value calculation
    Y1_test_opt, Y2_test_opt, values_opt = calculate_optimal_policy_values(d1_star, d2_star, generated_test_data)
    V_d1_d2_opt = torch.mean(Y1_test_opt + Y2_test_opt).cpu().item()  # Calculate the mean value and convert to Python scalar
    V_replications["V_replications_M1_optimal"].append(V_d1_d2_opt)  # Append to the list for optimal policy values

    # Behavioral policy value calculation
    V_d1_d2 = torch.mean(generated_test_data['Y1'] + generated_test_data['Y2']).cpu().item()  # Calculate the mean value and convert to Python scalar
    V_replications["V_replications_M1_behavioral"].append(V_d1_d2)  # Append to the list for behavioral policy values

    # Current approach value calculation
    V_replications["V_replications_M1_pred"].append(torch.mean(Y1_pred + Y2_pred).item())  # Append the mean value as a Python scalar to the list for current approach values

    return V_replications, values_opt


def compute_test_outputs(nn, test_input, A_tensor, device, params, is_stage1=True):
    with torch.no_grad():
        if params['f_model'] == "surr_opt":
            # Perform the forward pass
            test_outputs_i = nn(test_input)

            # Directly stack the required outputs and perform computations in a single step
            test_outputs = torch.stack(test_outputs_i[:2], dim=1).squeeze()

            # Compute treatment assignments directly without intermediate variables
            test_outputs = torch.stack([
                torch.zeros_like(test_outputs[:, 0]),
                -test_outputs[:, 0],
                -test_outputs[:, 1]
            ], dim=1)
        else:
            # Modify input for each action and perform a forward pass
            input_tests = [
                torch.cat((test_input, torch.full_like(A_tensor, i).unsqueeze(-1)), dim=1).to(device)
                for i in range(1, 4)  # Assuming there are 3 actions
            ]

            # Forward pass for each modified input and stack the results
            test_outputs = torch.stack([
                nn(input_stage)[0] for input_stage in input_tests
            ], dim=1)

    # Determine the optimal action based on the computed outputs
    optimal_actions = torch.argmax(test_outputs, dim=1) + 1
    return optimal_actions.squeeze().to(device), test_outputs






def eval_DTR(sample_size, V_replications, num_replications, nn_stage1, nn_stage2, df, params):


    # Calculate V using the best model parameters on the test data # Calculate V using the best model parameters on the test data
    generated_test_data = generate_data(sample_size, params['setting'], replication_seed = num_replications)
    # generated_test_data = generate_data_test(sample_size, params['setting'], replication_seed = num_replications)

    # Preprocess data
    test_input_stage1, test_input_stage2, Ci_tensor = preprocess_data(generated_test_data, sample_size, params['setting'], replication_seed=num_replications, run='test')
    # test_input_stage1, test_input_stage2, Ci_tensor = preprocess_data_test(generated_test_data, params['setting'], replication_seed=num_replications, run='test')

    A1_tensor_test, A2_tensor_test = [generated_test_data[key] for key in ['A1', 'A2']]


    # optimal policy
    d1_star, d2_star =  compute_optimal_policy(generated_test_data)


    # Calculate test outputs for all networks in stage 1
    # Perform forward pass

    if params['f_model']!="tao":
        nn_stage1 = initialize_nn(params, 1)
        nn_stage2 = initialize_nn(params, 2)

        # Load the best model parameters for Stage 1 and Stage 2 during validation
        if params['f_model']=="surr_opt":
            nn_stage1.load_state_dict(torch.load(f'best_model_stage_surr_1_{sample_size}.pt'))
            nn_stage2.load_state_dict(torch.load(f'best_model_stage_surr_2_{sample_size}.pt'))
        elif params['f_model']=="DQlearning":
            nn_stage1.load_state_dict(torch.load(f'best_model_stage_Q_1_{sample_size}.pt'))
            nn_stage2.load_state_dict(torch.load(f'best_model_stage_Q_2_{sample_size}.pt'))

        nn_stage1.eval()
        nn_stage2.eval()

        A1, test_outputs_stage1 = compute_test_outputs(nn_stage1, test_input_stage1, A1_tensor_test, device, params, is_stage1=True)
        test_input_stage2, Y1_pred = prepare_stage2_test_input(generated_test_data, A1)

        # Calculate test outputs for all networks in stage 2
        A2, test_outputs_stage2 = compute_test_outputs(nn_stage2, test_input_stage2, A2_tensor_test, device, params, is_stage1=False)
        Y2_pred =  prepare_Y2_pred(generated_test_data, A1, A2)
    else:


        if setting == 'tao':

            test_input_np = test_input_stage1.numpy()
            x1 = test_input_np[:, 0]
            x2 = test_input_np[:, 1]
            x3 = test_input_np[:, 2]
            x4 = test_input_np[:, 3]
            x5 = test_input_np[:, 4]


            # Load the R script containing the function
            ro.r('source("ACWL_tao.R")')

            # Call the R function
            results = ro.globalenv['test_ACWL'](x1, x2, x3, x4, x5, d1_star.numpy(), d2_star.numpy(), noiseless, method= f_model)


        elif setting == 'linear':

            # Convert tensor to numpy if not already numpy array
            test_input_np = test_input_stage2.numpy()

            x1 = test_input_np[:, 0]
            x2 = test_input_np[:, 1]
            O2 = test_input_np[:, 4]


            # Load the R script containing the function
            ro.r('source("ACWL_linear.R")')

            # Call the R function
            results = ro.globalenv['test_ACWL'](x1, x2, O2, d1_star.numpy(), d2_star.numpy(), noiseless, method= f_model)










        # Extract results

        select2_test = results.rx2('select2')[0]
        select1_test = results.rx2('select1')[0]
        selects_test = results.rx2('selects')[0]

        # TODO: FIX THESE TO GET EXACTLY SAME ACCURACY AS WE GET IN PYTHON
        print(f"TEST: Select1: {select1_test}, Select2: {select2_test}, Selects: {selects_test}")


        # Extracting each component of the results and convert them to tensors
        Y1_pred_R = torch.tensor(np.array(results.rx2('R1.a1')), dtype=torch.float32)
        Y2_pred_R = torch.tensor(np.array(results.rx2('R2.a1')), dtype=torch.float32)



        # TODO: FIX THESE TO GET EXACTLY SAME ACCURACY AS WE GET IN PYTHON

        Y1_stats_R = [torch.min(Y1_pred_R), torch.max(Y1_pred_R), torch.mean(Y1_pred_R)]
        message = f"Y1_pred_R [min, max, mean]: {Y1_stats_R}"
        tqdm.write(message)
        message = f"Y2_pred_R [min, max, mean]: [{torch.min(Y2_pred_R)}, {torch.max(Y2_pred_R)}, {torch.mean(Y2_pred_R)}]"
        tqdm.write(message)


        # torch.mean(Y1_pred + Y2_pred): 4.660262107849121
        message = f'torch.mean(Y1_pred_R + Y2_pred_R): {torch.mean(Y1_pred_R + Y2_pred_R)} \n'
        tqdm.write(message)



        A1 = torch.tensor(np.array(results.rx2('g1.a1')), dtype=torch.float32)
        A2 = torch.tensor(np.array(results.rx2('g2.a1')), dtype=torch.float32)

        test_input_stage2, Y1_pred = prepare_stage2_test_input(generated_test_data, A1)
        Y2_pred =  prepare_Y2_pred(generated_test_data, A1, A2)



    # optimal policy
    # d1_star, d2_star =  compute_optimal_policy(generated_test_data, A1, A2)

    # Append to DataFrame
    new_row = {
        'Behavioral_A1': A1_tensor_test.cpu().numpy().tolist(),
        'Behavioral_A2': A2_tensor_test.cpu().numpy().tolist(),
        'Predicted_A1': A1.cpu().numpy().tolist(),
        'Predicted_A2':  A2.cpu().numpy().tolist(),
        'Optimal_A1': d1_star.cpu().numpy().tolist(),
        'Optimal_A2': d2_star.cpu().numpy().tolist()
        }

    # new_row = extract_and_prepare_data(A1_tensor_test, A2_tensor_test, A1, A2, d1_star, d2_star)

    df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)

    message = f'torch.mean(Y1_pred + Y2_pred): {torch.mean(Y1_pred + Y2_pred)} \n\n'
    #message = f'torch.mean(Y1_pred + Y2_pred): {torch.mean(Y1_pred + Y2_pred).item() if isinstance(Y1_pred, torch.Tensor) else np.mean(Y1_pred + Y2_pred)}'
    tqdm.write(message)

    V_replications, values_opt = calculate_policy_values(d1_star, d2_star, generated_test_data, Y1_pred, Y2_pred, V_replications)


    return V_replications, df, values_opt


def simulations(surrogate_num, sample_size, num_replications, V_replications, params):
    columns = ['Behavioral_A1', 'Behavioral_A2', 'Predicted_A1', 'Predicted_A2', 'Optimal_A1', 'Optimal_A2']
    df = pd.DataFrame(columns=columns)
    losses_dict = {}
    epoch_num_model_lst = []

    for replication in tqdm(range(num_replications), desc="Replications_M1"):

        # Generate data
        generated_data = generate_data(sample_size, params['setting'],replication_seed=replication)
        # Preprocess data (moved to GPU)
        tuple_train, tuple_Val, all_data = preprocess_data(generated_data, sample_size, params['setting'], replication_seed=replication)

        #  Estimate treatment regime: params['f_model']
        if params['f_model'] == 'tao':
            (select2, select1, selects) = adaptive_contrast_tao(all_data, params["contrast"])
            # eval_DTR
            V_replications, df, values_opt = eval_DTR(sample_size, V_replications, replication, None, None, df, params)


        elif params['f_model'] == 'DQlearning':
            (nn_stage1, nn_stage2, trn_val_loss_tpl, epoch_num_model_1, epoch_num_model_2) = DQlearning(sample_size, tuple_train, tuple_Val, params)
            epoch_num_model_lst.append([epoch_num_model_1, epoch_num_model_2])
            losses_dict[replication] = trn_val_loss_tpl
            # eval_DTR
            V_replications, df, values_opt = eval_DTR(sample_size, V_replications, replication, nn_stage1, nn_stage2, df, params )

        else:
            # surr_opt
            nn_stage1, nn_stage2, trn_val_loss_tpl, epoch_num_model = surr_opt(sample_size, tuple_train, tuple_Val, surrogate_num, params)
            epoch_num_model_lst.append(epoch_num_model)
            losses_dict[replication] = trn_val_loss_tpl
            # eval_DTR
            V_replications, df, values_opt = eval_DTR(sample_size, V_replications, replication, nn_stage1, nn_stage2, df, params )

    return V_replications, df, values_opt, losses_dict, epoch_num_model_lst




# from tqdm.notebook import tqdm


# sample_size = 1000  # 500, 1000 are the cases to check
# batch_prop = 0.2 #0.07, 0.2
# if sample_size < 500:
#     batch_prop = 0.5

# training_validation_prop = 0.5 #0.95 #0.01


# # Prompt user for the number of replications
# num_replications = 4

# # Prompt user for the setting
# setting = 'tao' # 'linear', 'tao', 'scheme_i'

# noiseless = True # True False



# if setting == 'tao':
#     tree_type =  True # True False

# # Prompt user for the model type
# f_model = 'tao' # (linear, 'tao', 'DQlearning', 'surr_opt'): " tao => adaptive_contrast_tao) # Note for linear linear run separate R code

# contrast = 1

# surrogate_num = 1 #1- old multiplicative one  2- new one

# option_sur = 1 # if surrogate_num = 1 then from 1-5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric


# # Constants from scheme_i
# C1 = C2 = 3
# beta = 1





# # #BEST so far for tao's 15000

# network_parameters_surogate = {
#   'setting': setting,
#   'n_epoch': 60, #250
#   'num_networks': 2,
#   'input_dim_stage1': 2,
#   'output_dim_stage1': 1,
#   'input_dim_stage2': 5,
#   'output_dim_stage2': 1,
#   'optimizer_betas': (0.9, 0.999),
#   'optimizer_eps': 1e-08,
#   'scheduler_step_size': 30,
#   'scheduler_gamma': 0.8,
#   'hidden_dim_stage1': 10, #20
#   'hidden_dim_stage2': 10, #20
#   'dropout_rate': 0.0, #0.3, 0.43
#   'optimizer_lr': 0.07, # 0.07, 0.007
#   'optimizer_weight_decay': 0.001,
#   'batch_size': math.ceil(batch_prop*sample_size), #int(0.038*sample_size),
#   'f_model': 'surr_opt',
#   'option_sur': option_sur, # if surrogate_num = 1 then 5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric
#   'contrast': contrast

# }







# # input_stage2 = [O1, A1, Y1, O2]

# if setting =='tao':
#     network_parameters_surogate['input_dim_stage1'] = 5
#     network_parameters_surogate['input_dim_stage2'] = 7


# network_parameters_surogate['option_sur'] = 2 # symmetric
# n_epoch = network_parameters_surogate['n_epoch']

# # Set the device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Lists to store replication results
# V_replications_M1_behavioral = []
# V_replications_M1_pred = []
# V_replications_M1_optimal = []

# V_replications = {"V_replications_M1_behavioral": V_replications_M1_behavioral, "V_replications_M1_pred": V_replications_M1_pred, "V_replications_M1_optimal": V_replications_M1_optimal}


# print('Setting: ' , setting)
# print("f_model: ", f_model)


# network_parameters_surogate['f_model'] = 'tao'
# V_replications, df, values_opt, losses_dict, epoch_num_model_lst = simulations(surrogate_num,
#                                                                            sample_size,
#                                                                            num_replications,
#                                                                            V_replications,
#                                                                            network_parameters_surogate)







from tqdm.notebook import tqdm


sample_size = 1000  # 500, 1000 are the cases to check
batch_prop = 0.2 #0.07, 0.2
if sample_size < 500:
    batch_prop = 0.5

training_validation_prop = 0.5 #0.95 #0.01


# Prompt user for the number of replications
num_replications = 9

# Prompt user for the setting
setting = 'tao' # 'linear', 'tao', 'scheme_i'

noiseless = True # True False



# Prompt user for the scheme if setting is 'scheme_i'
if setting != 'tao':
    scheme = 1 # scheme number - (1, 2, or 3)

# If the setting is 'tao', ask for tree_type

if setting == 'tao':
    tree_type =  True # True False

# Prompt user for the model type
f_model = 'tao' # (linear, 'tao', 'DQlearning', 'surr_opt'): " tao => adaptive_contrast_tao) # Note for linear linear run separate R code

contrast = 1

surrogate_num = 1 #1- old multiplicative one  2- new one

option_sur = 1 # if surrogate_num = 1 then from 1-5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric


# Constants from scheme_i
C1 = C2 = 3
beta = 1






# network_parameters_surogate = {
#   'setting': setting,
#   'n_epoch': 220,
#   'num_networks': 2,
#   'input_dim_stage1': 2,
#   'output_dim_stage1': 1,
#   'input_dim_stage2': 5,
#   'output_dim_stage2': 1,
#   'optimizer_betas': (0.9, 0.999),
#   'optimizer_eps': 1e-08,
#   'scheduler_step_size': 30,
#   'scheduler_gamma': 0.8,
#   'hidden_dim_stage1': 10, #20
#   'hidden_dim_stage2': 10, #20
#   'dropout_rate': 0.0, #0.3, 0.43
#   'optimizer_lr': 0.1, # 0.07
#   'optimizer_weight_decay': 0.001,
#   'batch_size': math.ceil(batch_prop*sample_size), #int(0.038*sample_size),
#   'f_model': 'surr_opt',
#   'option_sur': option_sur, # if surrogate_num = 1 then 5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric
#   'contrast': contrast

# }


# #BEST so far for tao's 15000

network_parameters_surogate = {
  'setting': setting,
  'n_epoch': 60, #250
  'num_networks': 2,
  'input_dim_stage1': 2,
  'output_dim_stage1': 1,
  'input_dim_stage2': 5,
  'output_dim_stage2': 1,
  'optimizer_betas': (0.9, 0.999),
  'optimizer_eps': 1e-08,
  'scheduler_step_size': 30,
  'scheduler_gamma': 0.8,
  'hidden_dim_stage1': 10, #20
  'hidden_dim_stage2': 10, #20
  'dropout_rate': 0.0, #0.3, 0.43
  'optimizer_lr': 0.07, # 0.07, 0.007
  'optimizer_weight_decay': 0.001,
  'batch_size': math.ceil(batch_prop*sample_size), #int(0.038*sample_size),
  'f_model': 'surr_opt',
  'option_sur': option_sur, # if surrogate_num = 1 then 5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric
  'contrast': contrast

}


# # to ovefitting model
# network_parameters_surogate = {
#   'setting': setting,
#   'n_epoch': 250,
#   'num_networks': 2,
#   'input_dim_stage1': 2,
#   'output_dim_stage1': 1,
#   'input_dim_stage2': 5,
#   'output_dim_stage2': 1,
#   'optimizer_betas': (0.9, 0.999),
#   'optimizer_eps': 1e-08,
#   'scheduler_step_size': 30,
#   'scheduler_gamma': 0.8,
#   'hidden_dim_stage1': 5,
#   'hidden_dim_stage2': 5,
#   'dropout_rate': 0.00, #0.3
#   'optimizer_lr': 0.01,
#   'optimizer_weight_decay': 0.001,
#   'batch_size': math.ceil(batch_prop*sample_size), #int(0.038*sample_size),
#   'f_model': 'surr_opt',
#   'option_sur': 1 # if surrogate_num = 1 then 5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric

# }




# network_parameters_surogate = {
#   'setting': setting,
#   'n_epoch': 250,
#   'num_networks': 2,
#   'input_dim_stage1': 2,
#   'output_dim_stage1': 1,
#   'input_dim_stage2': 5,
#   'output_dim_stage2': 1,
#   'optimizer_betas': (0.9, 0.999),
#   'optimizer_eps': 1e-08,
#   'scheduler_step_size': 30,
#   'scheduler_gamma': 0.8,
#   'hidden_dim_stage1': 20,
#   'hidden_dim_stage2': 20,
#   'dropout_rate': 0.43, #0.3, 0.43
#   'optimizer_lr': 0.01,
#   'optimizer_weight_decay': 0.001,
#   'batch_size': math.ceil(batch_prop*sample_size), #int(0.038*sample_size),
#   'f_model': 'surr_opt',
#   'option_sur': option_sur, # if surrogate_num = 1 then 5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric
#   'contrast': contrast

# }



# # #BEST so far for linear's 15000

# network_parameters_surogate = {
#   'setting': setting,
#   'n_epoch': 120, #60, 120
#   'num_networks': 2,
#   'input_dim_stage1': 2,
#   'output_dim_stage1': 1,
#   'input_dim_stage2': 5,
#   'output_dim_stage2': 1,
#   'optimizer_betas': (0.9, 0.999),
#   'optimizer_eps': 1e-08,
#   'scheduler_step_size': 30,
#   'scheduler_gamma': 0.8,
#   'hidden_dim_stage1': 40, #3,
#   'hidden_dim_stage2': 40, #5,80 increasing to 80 made overall good value func. stablity
#   'dropout_rate': 0.5, # 0.43, 0.5
#   'optimizer_lr': 0.1, # 0.006
#   'optimizer_weight_decay': 0.03, #0.003
#   'batch_size': math.ceil(batch_prop*sample_size), #int(0.038*sample_size),
#   'f_model': 'surr_opt',
#   'option_sur': 1, # if surrogate_num = 1 then from 1-5 options, if surrogate_num = 2 then 1-> assymetric, 2 -> symmetric
#   'contrast': contrast
# }

network_parameters_qlearning = {
  'setting': setting,
  'n_epoch': 240,
  'num_networks': 1,
  'input_dim_stage1': 6,
  'output_dim_stage1': 1,
  'input_dim_stage2': 8,
  'output_dim_stage2': 1,
  'optimizer_betas': (0.9, 0.999),
  'optimizer_eps': 1e-08,
  'scheduler_step_size': 30,
  'scheduler_gamma': 0.8,
  'hidden_dim_stage1': 5,
  'hidden_dim_stage2': 5,
  'dropout_rate': 0,
  'optimizer_lr': 0.07,
  'optimizer_weight_decay': 0.03,
  'batch_size': math.ceil(batch_prop*sample_size),
  'f_model': "DQlearning"
}

# network_parameters_qlearning = {
#   'setting': setting,
#   'n_epoch': 120,
#   'num_networks': 1,
#   'input_dim_stage1': 3,
#   'output_dim_stage1': 1,
#   'input_dim_stage2': 6,
#   'output_dim_stage2': 1,
#   'optimizer_betas': (0.9, 0.999),
#   'optimizer_eps': 1e-08,
#   'scheduler_step_size': 30,
#   'scheduler_gamma': 0.8,
#   'hidden_dim_stage1': 40,
#   'hidden_dim_stage2': 40,
#   'dropout_rate': 0.43,
#   'optimizer_lr': 0.01,
#   'optimizer_weight_decay': 0.03,
#   'batch_size': math.ceil(batch_prop*sample_size),
#   'f_model': "DQlearning"
# }


# standard q learning params

network_parameters_qlearning = {
  'setting': setting,
  'n_epoch': 240,
  'num_networks': 1,
  'input_dim_stage1': 3,
  'output_dim_stage1': 1,
  'input_dim_stage2': 6,
  'output_dim_stage2': 1,
  'optimizer_betas': (0.9, 0.999),
  'optimizer_eps': 1e-08,
  'scheduler_step_size': 30,
  'scheduler_gamma': 0.8,
  'hidden_dim_stage1': 5,
  'hidden_dim_stage2': 5,
  'dropout_rate': 0,
  'optimizer_lr': 0.07,
  'optimizer_weight_decay': 0.03,
  'batch_size': math.ceil(batch_prop*sample_size),
  'f_model': "DQlearning"
}



# input_stage2 = [O1, A1, Y1, O2]

if setting =='tao':
    network_parameters_surogate['input_dim_stage1'] = 5
    network_parameters_surogate['input_dim_stage2'] = 7
    network_parameters_qlearning['input_dim_stage1'] = 6
    network_parameters_qlearning['input_dim_stage2'] = 8


if setting=='scheme_i':
    network_parameters_surogate['input_dim_stage1'] = 3
    network_parameters_surogate['input_dim_stage2'] = 6
    network_parameters_qlearning['input_dim_stage1'] = 4
    network_parameters_qlearning['input_dim_stage2'] = 7

    # Define g as a lambda function
    g = lambda O1: 1.5 * (O1[:, 2] > 0)

    # Define f_i lambda functions for each scheme
    if scheme == 1:
        f_i = lambda O1, A1, i: 1.5 * A1 * O1[:, i-1] + A1 * (O1[:, i-1] ** 2) / 2
    elif scheme == 2:
        f_i =  lambda O1, A1, i: 1.5 * (O1[:, i-1] > 0).float() * A1 + A1 * (O1[:, i-1] ** 2) / 8
    elif scheme == 3:
        f_i = lambda O1, A1, i: 1.5 * (np.sin(O1[:, i-1]) > 0) * A1 + A1 * (O1[:, i-1] ** 2) / 8
    else:
        raise ValueError("Invalid scheme number. Please choose 1, 2, or 3.")


network_parameters_surogate['option_sur'] = 2 # symmetric
n_epoch = network_parameters_qlearning['n_epoch']

# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Lists to store replication results
V_replications_M1_behavioral = []
V_replications_M1_pred = []
V_replications_M1_optimal = []

V_replications = {"V_replications_M1_behavioral": V_replications_M1_behavioral, "V_replications_M1_pred": V_replications_M1_pred, "V_replications_M1_optimal": V_replications_M1_optimal}


print('Setting: ' , setting)
print("f_model: ", f_model)

if f_model == "DQlearning":
    V_replications, df, values_opt, losses_dict, epoch_num_model_lst = simulations(surrogate_num,
                                                                                   sample_size,
                                                                                   num_replications,
                                                                                   V_replications,
                                                                                   network_parameters_qlearning)
elif f_model == "surr_opt":
    V_replications, df, values_opt, losses_dict, epoch_num_model_lst = simulations(surrogate_num,
                                                                               sample_size,
                                                                               num_replications,
                                                                               V_replications,
                                                                               network_parameters_surogate)

else:
    network_parameters_surogate['f_model'] = 'tao'
    V_replications, df, values_opt, losses_dict, epoch_num_model_lst = simulations(surrogate_num,
                                                                               sample_size,
                                                                               num_replications,
                                                                               V_replications,
                                                                               network_parameters_surogate)





Setting:  tao
f_model:  tao


Replications_M1:   0%|          | 0/9 [00:00<?, ?it/s]

R[write to console]: Loading required package: nnet



DGP:  tao
Train model:  tao 


R[write to console]: 



RRuntimeError: 

In [None]:
V_replications["V_replications_M1_pred"]

In [None]:
# Calculate the average V and standard deviation for test data
V_replications_M1_pred = V_replications["V_replications_M1_pred"]

avg_V_test_M1 = np.mean(V_replications_M1_pred)
std_dev_test_M1 = np.std(V_replications_M1_pred)
print(f"Average V for Test Data, M1: {np.mean(avg_V_test_M1)}")
print(f"Standard Deviation for Test Data, M1: {std_dev_test_M1}\n\n")

plt.plot(range(1, num_replications + 1), V_replications_M1_pred, 'o-')
plt.xlabel('Replications_M1: 500')
plt.ylabel('V Value')
plt.title(f'V values for {num_replications} Test Replications (Test Sample Size: {sample_size})')
plt.show()


In [None]:
# Running totals for overall accuracy
correct_behavioral_A1 = 0
correct_behavioral_A2 = 0
correct_predicted_A1 = 0
correct_predicted_A2 = 0
total_A1 = 0
total_A2 = 0

# Initialize lists to hold per-simulation accuracies
accuracies = {
    'Accuracy_A1': [],
    'Accuracy_A2': []
}

# Iterate over each row in the DataFrame for row-wise accuracy
for index, row in df.iterrows():
    behavioral_A1 = row['Behavioral_A1']
    behavioral_A2 = row['Behavioral_A2']
    predicted_A1 = row['Predicted_A1']
    predicted_A2 = row['Predicted_A2']
    optimal_A1 = row['Optimal_A1']
    optimal_A2 = row['Optimal_A2']

    # Calculate row-wise accuracy for A1 and A2
    row_correct_behavioral_A1 = sum(a == p
                                    for a, p in zip(behavioral_A1, optimal_A1))
    row_correct_behavioral_A2 = sum(a == p
                                    for a, p in zip(behavioral_A2, optimal_A2))
    row_correct_predicted_A1 = sum(o == p
                                   for o, p in zip(optimal_A1, predicted_A1))
    row_correct_predicted_A2 = sum(o == p
                                   for o, p in zip(optimal_A2, predicted_A2))

    # Store per-simulation accuracies
    accuracies['Accuracy_A1'].append(row_correct_predicted_A1/len(predicted_A1))
    accuracies['Accuracy_A2'].append(row_correct_predicted_A2/len(predicted_A2))


    # Update running totals for overall accuracy
    correct_behavioral_A1 += row_correct_behavioral_A1
    correct_behavioral_A2 += row_correct_behavioral_A2
    correct_predicted_A1 += row_correct_predicted_A1
    correct_predicted_A2 += row_correct_predicted_A2
    total_A1 += len(predicted_A1)
    total_A2 += len(predicted_A2)

# Convert accuracies dictionary to DataFrame
accuracy_df = pd.DataFrame(accuracies)

# Calculate overall accuracies
overall_accuracy_behavioral_A1 = correct_behavioral_A1 / total_A1
overall_accuracy_behavioral_A2 = correct_behavioral_A2 / total_A2
overall_accuracy_predicted_A1 = correct_predicted_A1 / total_A1
overall_accuracy_predicted_A2 = correct_predicted_A2 / total_A2

# Print the overall accuracies
print("Overall Accuracy for Behavioral A1:", overall_accuracy_behavioral_A1)
print("Overall Accuracy for Behavioral A2:", overall_accuracy_behavioral_A2)
print("\n")
print("Overall Accuracy for predicted A1:", overall_accuracy_predicted_A1)
print("Overall Accuracy for predicted A2:", overall_accuracy_predicted_A2)

accuracy_df["Value function"] = V_replications["V_replications_M1_pred"]
accuracy_df["Optimal Value function"] = V_replications["V_replications_M1_optimal"]

# Print the DataFrame with per-simulation accuracies
print("\n ", accuracy_df)

