In [24]:
import importlib
from functools import partial

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from decision_learning.data.shortest_path_grid import genData
from decision_learning.modeling.loss import SPOPlus
from decision_learning.modeling.models import LinearRegression
from decision_learning.modeling.train import train, calc_test_regret

In [23]:
# Only for on the fly reloading of changing modules
# import decision_learning.data.shortest_path_grid
# importlib.reload(decision_learning.data.shortest_path_grid)
# import decision_learning.modeling.val_metrics
# importlib.reload(decision_learning.modeling.val_metrics)

# import decision_learning.modeling.models
# importlib.reload(decision_learning.modeling.models)
# from decision_learning.modeling.models import LinearRegression
# import decision_learning.modeling.loss
# importlib.reload(decision_learning.modeling.loss)
# import decision_learning.modeling.train
# importlib.reload(decision_learning.modeling.train)
# from decision_learning.modeling.train import train, calc_test_regret

<module 'decision_learning.modeling.train' from '/home1/yongpeng/decision-focused-learning/src/decision_learning/modeling/train.py'>

# Optimization Solver

In [3]:
def shortest_path_solver(costs, size, sens = 1e-4):
    # Forward Pass
    starting_ind = 0
    starting_ind_c = 0
    samples = costs.shape[0]
    V_arr = torch.zeros(samples, size ** 2)
    for i in range(0, 2 * (size - 1)):
        num_nodes = min(i + 1, 9 - i)
        num_nodes_next = min(i + 2, 9 - i - 1)
        num_arcs = 2 * (max(num_nodes, num_nodes_next) - 1)
        V_1 = V_arr[:, starting_ind:starting_ind + num_nodes]
        layer_costs = costs[:, starting_ind_c:starting_ind_c + num_arcs]
        l_costs = layer_costs[:, 0::2]
        r_costs = layer_costs[:, 1::2]
        next_V_val_l = torch.ones(samples, num_nodes_next) * float('inf')
        next_V_val_r = torch.ones(samples, num_nodes_next) * float('inf')
        if num_nodes_next > num_nodes:
            next_V_val_l[:, :num_nodes_next - 1] = V_1 + l_costs
            next_V_val_r[:, 1:num_nodes_next] = V_1 + r_costs
        else:
            next_V_val_l = V_1[:, :num_nodes_next] + l_costs
            next_V_val_r = V_1[:, 1:num_nodes_next + 1] + r_costs
        next_V_val = torch.minimum(next_V_val_l, next_V_val_r)
        V_arr[:, starting_ind + num_nodes:starting_ind + num_nodes + num_nodes_next] = next_V_val

        starting_ind += num_nodes
        starting_ind_c += num_arcs

    # Backward Pass
    starting_ind = size ** 2
    starting_ind_c = costs.shape[1]
    prev_act = torch.ones(samples, 1)
    sol = torch.zeros(costs.shape)
    for i in range(2 * (size - 1), 0, -1):
        num_nodes = min(i + 1, 9 - i)
        num_nodes_next = min(i, 9 - i + 1)
        V_1 = V_arr[:, starting_ind - num_nodes:starting_ind]
        V_2 = V_arr[:, starting_ind - num_nodes - num_nodes_next:starting_ind - num_nodes]

        num_arcs = 2 * (max(num_nodes, num_nodes_next) - 1)
        layer_costs = costs[:, starting_ind_c - num_arcs: starting_ind_c]

        if num_nodes < num_nodes_next:
            l_cs_res = ((V_2[:, :num_nodes_next - 1] - V_1 + layer_costs[:, ::2]) < sens) * prev_act
            r_cs_res = ((V_2[:, 1:num_nodes_next] - V_1 + layer_costs[:, 1::2]) < sens) * prev_act
            prev_act = torch.zeros(V_2.shape)
            prev_act[:, :num_nodes_next - 1] += l_cs_res
            prev_act[:, 1:num_nodes_next] += r_cs_res
        else:
            l_cs_res = ((V_2 - V_1[:, :num_nodes - 1] + layer_costs[:, ::2]) < sens) * prev_act[:, :num_nodes - 1]
            r_cs_res = ((V_2 - V_1[:, 1:num_nodes] + layer_costs[:, 1::2]) < sens) * prev_act[:, 1:num_nodes]
            prev_act = torch.zeros(V_2.shape)
            prev_act += l_cs_res
            prev_act += r_cs_res
        cs = torch.zeros(layer_costs.shape)
        cs[:, ::2] = l_cs_res
        cs[:, 1::2] = r_cs_res
        sol[:, starting_ind_c - num_arcs: starting_ind_c] = cs

        starting_ind = starting_ind - num_nodes
        starting_ind_c = starting_ind_c - num_arcs
    # Dimension (samples, num edges)
    obj = torch.sum(sol * costs, axis=1)
    # Dimension (samples, 1)
    return sol, obj.reshape(-1,1)

# Initialize Data

data parameters

In [46]:
indices_arr = torch.randperm(100000)
indices_arr_test = torch.randperm(100000)

sim = 0
n_arr = [200, 400, 800, 1600]
ep_arr = ['unif', 'normal']
trials = 100

exp_arr = []
for n in n_arr:
    for ep in ep_arr:
        for t in range(trials):
            exp_arr.append([n, ep, t])


In [80]:
exp = [1600, 'unif', 10] #exp_arr[sim]
print(exp)

[1600, 'unif', 10]


In [81]:
ep_type = exp[1]
trial = exp[2]

# generate data
grid = (5, 5)  # grid size
num_data = exp[0]  # number of training data
num_feat = 5  # size of feature
deg = 6  # polynomial degree
e = .3  # noise width

# path planting for shortest path example
planted_good_pwl_params = {'slope0':0, 
                    'int0':2,
                    'slope1':0, 
                    'int1':2}
planted_bad_pwl_params = {'slope0':4, 
                    'int0':0,
                    'slope1':0, 
                    'int1':2.2}

plant_edge = True

In [82]:
generated_data = genData(num_data=num_data+200,
        num_features=num_feat, 
        grid=grid, 
        deg=deg, 
        noise_type=ep_type,
        noise_width=e,
        seed=indices_arr[trial],     
        plant_edges=plant_edge,
        planted_good_pwl_params=planted_good_pwl_params,
        planted_bad_pwl_params=planted_bad_pwl_params)
sol, obj = shortest_path_solver(costs=generated_data['cost'], size=5)
final_data = {'X':generated_data['feat'],
              'true_cost':generated_data['cost'],
              'true_sol':sol,
              'true_obj':obj}

# test data
generated_data_test = genData(num_data=10000,
        num_features=num_feat, 
        grid=grid, 
        deg=deg, 
        noise_type=ep_type,
        noise_width=e,
        seed=indices_arr[trial],     
        plant_edges=plant_edge,
        planted_good_pwl_params=planted_good_pwl_params,
        planted_bad_pwl_params=planted_bad_pwl_params)
sol_test, obj_test = shortest_path_solver(costs=generated_data_test['cost'], size=5)
final_data_test = {'X':generated_data_test['feat'],
              'true_cost':generated_data_test['cost'],
              'true_sol':sol_test,
              'true_obj':obj_test}

2024-10-30 16:10:50,544 - decision_learning.data.shortest_path_grid - DEBUG - good_bad_edges: [ 1  4  9 16 24 31 36 39  0  3  8 15 23 30 35 38], remain_edges: [ 2  5  6  7 10 11 12 13 14 17 18 19 20 21 22 25 26 27 28 29 32 33 34 37]
2024-10-30 16:10:50,544 - decision_learning.data.shortest_path_grid - DEBUG - good_bad_edges: [ 1  4  9 16 24 31 36 39  0  3  8 15 23 30 35 38], remain_edges: [ 2  5  6  7 10 11 12 13 14 17 18 19 20 21 22 25 26 27 28 29 32 33 34 37]
2024-10-30 16:10:50,544 - decision_learning.data.shortest_path_grid - DEBUG - chg_pt: 0.0
2024-10-30 16:10:50,544 - decision_learning.data.shortest_path_grid - DEBUG - chg_pt: 0.0
2024-10-30 16:10:50,545 - decision_learning.data.shortest_path_grid - DEBUG - chg_pt: 0.55
2024-10-30 16:10:50,545 - decision_learning.data.shortest_path_grid - DEBUG - chg_pt: 0.55
2024-10-30 16:10:50,578 - decision_learning.data.shortest_path_grid - DEBUG - good_bad_edges: [ 1  4  9 16 24 31 36 39  0  3  8 15 23 30 35 38], remain_edges: [ 2  5  6  7 

Split data into train and validation set

In [83]:
# Splitting each input in the same way
train_dict = {}
val_dict = {}

for key, value in final_data.items():
    train_data, val_data = train_test_split(value, test_size=0.2, random_state=42)
    train_dict[key] = train_data
    val_dict[key] = val_data


# double checking the splitting is done correctly
# Splitting indices
indices = np.arange(len(next(iter(final_data.values()))))
train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42)

for key in final_data.keys():
    train_same = (final_data[key][train_indices] == train_dict[key]).all()
    val_same = (final_data[key][val_indices] == val_dict[key]).all()
    print(f'{key}: {train_same}, {val_same}')

X: True, True
true_cost: True, True
true_sol: True, True
true_obj: True, True


# Training Loop

Initialize Inputs to Trainining Loop

In [87]:
# Prediction Model
pred_model = LinearRegression(input_dim=train_dict['X'].shape[1],
                 output_dim=train_dict['true_cost'].shape[1])

# optimization solver
optmodel = partial(shortest_path_solver,size=5)

# loss function
loss_fn = SPOPlus(optmodel=optmodel)

# training, validation data
train_data_dict = train_dict
val_data_dict = val_dict

In [90]:
metrics, trained_model = train(pred_model=pred_model,
                optmodel=optmodel,
                loss_fn=loss_fn,
                train_data_dict=train_data_dict,
                val_data_dict=val_data_dict,
                num_epochs=100,
                lr=0.1,
                scheduler_params={'step_size': 10, 'gamma': 0.1},
                minimization=True)

100%|██████████| 45/45 [00:01<00:00, 30.58it/s]
100%|██████████| 12/12 [00:00<00:00, 91.10it/s]
2024-10-30 16:11:55,914 - decision_learning.modeling.train - INFO - epoch: 0, train_loss: 4.721681075625949, val_metric: 0.20636834088639172
2024-10-30 16:11:55,914 - decision_learning.modeling.train - INFO - epoch: 0, train_loss: 4.721681075625949, val_metric: 0.20636834088639172
100%|██████████| 45/45 [00:01<00:00, 30.77it/s]
100%|██████████| 12/12 [00:00<00:00, 92.64it/s]
2024-10-30 16:11:57,510 - decision_learning.modeling.train - INFO - epoch: 1, train_loss: 4.434849818547566, val_metric: 0.18795803610406175
2024-10-30 16:11:57,510 - decision_learning.modeling.train - INFO - epoch: 1, train_loss: 4.434849818547566, val_metric: 0.18795803610406175
100%|██████████| 45/45 [00:00<00:00, 587.72it/s]
100%|██████████| 12/12 [00:00<00:00, 3805.51it/s]
2024-10-30 16:11:57,593 - decision_learning.modeling.train - INFO - epoch: 2, train_loss: 4.436901330947876, val_metric: 0.1462575909642184
2024-

In [91]:
test_regret = calc_test_regret(pred_model=pred_model,
                               test_data_dict=final_data_test,
                               optmodel=optmodel)
print(test_regret)

0.06736003958783911
