In [1]:
from collections import namedtuple
import numpy as np


In [2]:
Var = namedtuple('Var', ['name', 'shape'])
Subproblem = namedtuple('Subproblem', ['name', 'inputs', 'outputs', 'function'])

In [3]:
x1 = Var('x1', (1,))
x2 = Var('x2', (10, 4))
x3 = Var('x3', (1, 1))
x4 = Var('x4', (5,))
f1 = Subproblem( 'f1', (x3,x2), (x1,), lambda x3, x2: x2*x3)
f2 = Subproblem( 'f2', (x1,x3), (x4,), lambda x1, x3: np.array([x1-1,x1*x3,x1-x3,x3+1,x1+2*x3]))

In [4]:
def get_index_ranges(variables):
    index_ranges = []
    start_idx = 0
    for var in variables:
        size = np.prod(var.shape)
        end_idx = start_idx + size
        index_ranges.append((start_idx, end_idx))
        start_idx = end_idx
    return index_ranges

def subset_index_ranges(all_variables, selected_subset):
    all_index_ranges = get_index_ranges(all_variables)
    subset_index_ranges = [all_index_ranges[all_variables.index(var)] for var in selected_subset]
    return subset_index_ranges

def select_subset(flat_vector, subset_index_ranges):
    result = []
    for start_idx, end_idx in subset_index_ranges:
        result.append(flat_vector[start_idx:end_idx])
    return np.concatenate(result)

In [34]:
all_variables = [x1, x2, x3, x4]

projection = {
    f1: subset_index_ranges(all_variables, f1.inputs),
    f2: subset_index_ranges(all_variables, f2.inputs),
}

In [35]:
projection[f1]

[(41, 42), (1, 41)]

In [7]:
x0 = np.random.rand(sum(np.prod(v.shape) for v in all_variables))

In [22]:
def get_precomputed_info(variables):
    precomputed = []
    start_idx = 0
    for var in variables:
        size = np.prod(var.shape)
        end_idx = start_idx + size
        precomputed.append((start_idx, end_idx, var.shape))
        start_idx = end_idx
    return precomputed
    
def split_vector(flat_vector, precomputed_info):
    split_arrays = []
    for start_idx, end_idx, shape in precomputed_info:
        split_array = flat_vector[start_idx:end_idx].reshape(shape)
        split_arrays.append(split_array)
    return split_arrays

def get_precomputed_indices(all_variables, selected_subset):
    all_info = get_precomputed_info(all_variables)
    selected_indices = []
    for var, (start_idx, end_idx, shape) in zip(all_variables, all_info):
        if var in selected_subset:
            selected_indices.append((start_idx, end_idx, shape))
    return selected_indices

def set_subset(flat_vector, precomputed_indices, input_arrays):
    modified_vector = flat_vector.copy()
    for (start_idx, end_idx, shape), input_array in zip(precomputed_indices, input_arrays):
        if np.isscalar(input_array) and np.prod(shape) == 1:
            modified_vector[start_idx:end_idx] = input_array
        elif not np.isscalar(input_array) and input_array.shape == shape:
            modified_vector[start_idx:end_idx] = input_array.flatten()
    return modified_vector

In [36]:
split_precomputed = {
    f1: get_precomputed_info(f1.inputs),
    f2: get_precomputed_info(f2.inputs),
}

In [37]:
x0new = set_subset(x0, get_precomputed_indices(all_variables, [x3]), [1000])

In [38]:
f1in =select_subset(x0new, projection[f1])

In [39]:
split_vector(f1in, split_precomputed[f1])

[array([[1000.]]),
 array([[0.34662245, 0.89445433, 0.13437457, 0.91916477],
        [0.69804307, 0.20646939, 0.55856588, 0.35467334],
        [0.3438828 , 0.8616317 , 0.2082174 , 0.66286191],
        [0.8373436 , 0.46778506, 0.64034099, 0.79067727],
        [0.05517509, 0.00981833, 0.58492999, 0.24797695],
        [0.26053618, 0.37102363, 0.77855374, 0.62454899],
        [0.70186809, 0.03846233, 0.40762003, 0.2791195 ],
        [0.52012894, 0.14159655, 0.48445472, 0.73137283],
        [0.7294046 , 0.07803234, 0.14590698, 0.94180338],
        [0.33029304, 0.33554624, 0.52750539, 0.31489753]])]

In [42]:
def flatten_output(f):
    def wrapped(x):
        result = f.function(*split_vector(select_subset(x, projection[f]), split_precomputed[f]))
        flattened_result = np.concatenate([x.flatten() if isinstance(x, np.ndarray) else np.array([x]) for x in result])
        return flattened_result
    return wrapped

In [46]:
newf1 = flatten_output(f1)
newf2 = flatten_output(f2)

In [47]:
newf2(x0new)

array([-9.47593515e-01,  5.24064848e+01, -9.99947594e+02,  1.00100000e+03,
        2.00005241e+03])