# Tutorial for the core API

In [12]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from copy import deepcopy

## Setup parameters

In [104]:
from rashomon import hasse

There are 3 features
- Feature 1 takes on four values, {0, 1, 2, 3}
- Feature 2 takes on three values, {0, 1, 2}
- Feature 3 takes on three values, {0, 1, 2}

In [14]:
M = 3
R = np.array([4, 3, 3])

First, we find all the profiles corresponding to this setup. For the profiles, only the number of features matters.

In [109]:
num_profiles = 2**M
profiles, profile_map = hasse.enumerate_profiles(M)

print("Profiles")
print(profiles)

print("\nMap from each profile tuple to its index in `profiles` list")
print(profile_map)

Profiles
[(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)]

Map from each profile tuple to its index in `profiles` list
{(0, 0, 0): 0, (0, 0, 1): 1, (0, 1, 0): 2, (0, 1, 1): 3, (1, 0, 0): 4, (1, 0, 1): 5, (1, 1, 0): 6, (1, 1, 1): 7}


Next, we find all the possible feature combinations (i.e., policies) in our example.

In [None]:
all_policies = hasse.enumerate_policies(M, R)
num_policies = len(all_policies)

print(f"All {num_policies} policies")
print(all_policies)

All 36 policies
[(np.int64(0), np.int64(0), np.int64(0)), (np.int64(0), np.int64(0), np.int64(1)), (np.int64(0), np.int64(0), np.int64(2)), (np.int64(0), np.int64(1), np.int64(0)), (np.int64(0), np.int64(1), np.int64(1)), (np.int64(0), np.int64(1), np.int64(2)), (np.int64(0), np.int64(2), np.int64(0)), (np.int64(0), np.int64(2), np.int64(1)), (np.int64(0), np.int64(2), np.int64(2)), (np.int64(1), np.int64(0), np.int64(0)), (np.int64(1), np.int64(0), np.int64(1)), (np.int64(1), np.int64(0), np.int64(2)), (np.int64(1), np.int64(1), np.int64(0)), (np.int64(1), np.int64(1), np.int64(1)), (np.int64(1), np.int64(1), np.int64(2)), (np.int64(1), np.int64(2), np.int64(0)), (np.int64(1), np.int64(2), np.int64(1)), (np.int64(1), np.int64(2), np.int64(2)), (np.int64(2), np.int64(0), np.int64(0)), (np.int64(2), np.int64(0), np.int64(1)), (np.int64(2), np.int64(0), np.int64(2)), (np.int64(2), np.int64(1), np.int64(0)), (np.int64(2), np.int64(1), np.int64(1)), (np.int64(2), np.int64(1), np.int64(2)),

## Partition matrix

In [17]:
from rashomon import extract_pools

We will look only at the (1, 1, 1) profile for the purpose of illustration.

In [18]:
policies_111 = [x for x in all_policies if x[0] > 0 and x[1] > 0 and x[2] > 0]
policies_111

[(np.int64(1), np.int64(1), np.int64(1)),
 (np.int64(1), np.int64(1), np.int64(2)),
 (np.int64(1), np.int64(2), np.int64(1)),
 (np.int64(1), np.int64(2), np.int64(2)),
 (np.int64(2), np.int64(1), np.int64(1)),
 (np.int64(2), np.int64(1), np.int64(2)),
 (np.int64(2), np.int64(2), np.int64(1)),
 (np.int64(2), np.int64(2), np.int64(2)),
 (np.int64(3), np.int64(1), np.int64(1)),
 (np.int64(3), np.int64(1), np.int64(2)),
 (np.int64(3), np.int64(2), np.int64(1)),
 (np.int64(3), np.int64(2), np.int64(2))]

Let us say that the partition is as follows:
- $\pi_1$ = {(1, 1, 1), (1, 2, 1)}
- $\pi_2$ = {(1, 1, 2), (1, 2, 2)}
- $\pi_3$ = {(2, 1, 1), (2, 2, 1), (3, 1, 1), (3, 2, 1)}
- $\pi_4$ = {(2, 1, 2), (2, 2, 2), (3, 1, 2), (3, 2, 2)}

This corresponds to the following $\Sigma$ matrix. The `np.inf` implies that that feature does not take those factor levels.

In [19]:
sigma_111 = np.array([[0, 1],
                  [1, np.inf],
                  [0, np.inf]])
sigma_111

array([[ 0.,  1.],
       [ 1., inf],
       [ 0., inf]])

This is how we extract the pools from the matrix.

In [20]:
pi_pools, pi_policies = extract_pools.extract_pools(policies_111, sigma_111)

`pi_pools` is a dictionary that maps each pool index to a list of _indices_ of feature combinations in that pool

In [21]:
pi_pools

{0: [0, 2], 1: [1, 3], 2: [4, 8, 10, 6], 3: [5, 9, 11, 7]}

`pi_policies` is a dictionary that maps each feature combination (through its index) to the index of the pool it belongs to

In [22]:
pi_policies

{0: 0, 2: 0, 1: 1, 3: 1, 4: 2, 8: 2, 10: 2, 6: 2, 5: 3, 9: 3, 11: 3, 7: 3}

`extract_pools` also has an optional argument `lattice_edges` where you provide the edges in the Hasse. If you call `extract_pools` on the same Hasse very often, it is more efficient to pre-compute the lattice edges once and pass in this argument

In [23]:
hasse_edges = extract_pools.lattice_edges(policies_111)

pi_pools, pi_policies = extract_pools.extract_pools(policies_111, sigma_111, lattice_edges=hasse_edges)

## Generate data

Since there are 4 pools, we only need to select 4 distributions for the outcome. For simplicity, say the outcomes come from $N(\mu_{\pi}, \sigma_{\pi}^2)$ with the following parameters for each pool

In [24]:
mu_111 = np.array([0, 2, 4, -2])
var_111 = np.array([1, 1, 1, 1])

Fix 50 samples per feature and generate the data

In [25]:
np.random.seed(3)

num_samples_per_feature = 50
num_data = len(policies_111) * num_samples_per_feature

X = np.zeros(shape=(num_data, M))
D = np.zeros(shape=(num_data, 1), dtype='int_')
y = np.zeros(shape=(num_data, 1))

idx_ctr = 0
for k, feature in enumerate(policies_111):
    # policy_idx = [i for i, x in enumerate(all_policies) if x == policy]

    pool_id = pi_policies[k]
    mu_i = mu_111[pool_id]
    var_i = var_111[pool_id]
    y_i = np.random.normal(mu_i, var_i, size=(num_samples_per_feature, 1))

    start_idx = idx_ctr * num_samples_per_feature
    end_idx = (idx_ctr + 1) * num_samples_per_feature

    X[start_idx:end_idx, ] = feature
    D[start_idx:end_idx, ] = k
    y[start_idx:end_idx, ] = y_i

    idx_ctr += 1

`X` is the feature matrix.

`D` tells us the feature indices i.e., `D[i, 0]` is the feature index of `X[i, ]`.

`y` is the outcome vector

In [26]:
print(X[:10,])

print(D[:10])

print(y[:10])

[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
[[0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]]
[[ 1.78862847]
 [ 0.43650985]
 [ 0.09649747]
 [-1.8634927 ]
 [-0.2773882 ]
 [-0.35475898]
 [-0.08274148]
 [-0.62700068]
 [-0.04381817]
 [-0.47721803]]


We can calculate the mean outcome of each feature through the following object. The first column contains the sums of outcomes for each feature. The second column is the count. So dividing the first column by the second lets us find the average. We keep the sums and counts separately for internal computation purposes

In [27]:
from rashomon import loss

policy_means_111 = loss.compute_policy_means(D, y, len(policies_111))
print(policy_means_111)

[[ -14.696341     50.        ]
 [ 103.83263356   50.        ]
 [   1.45837656   50.        ]
 [  95.24720387   50.        ]
 [ 221.50472005   50.        ]
 [ -92.80830557   50.        ]
 [ 189.57330566   50.        ]
 [-107.31751574   50.        ]
 [ 209.98734846   50.        ]
 [-103.75982477   50.        ]
 [ 200.56828936   50.        ]
 [ -99.03931665   50.        ]]


## Finding the Rashomon set for profile (1, 1, 1)

In [28]:
from rashomon import aggregate

%load_ext autoreload
%autoreload 2

Let us set the maximum number of pools to be $H = \infty$ and the Rashomon threshold to be $\theta = 8$ and regularization $\lambda = 1$

In [29]:
H = np.inf
theta = 8
lamb = 1

RPS_111 = aggregate.RAggregate_profile(M, R, H, D, y, theta, profile=(1, 1, 1), reg=lamb)

The output of `RAggregate_profile` is an object of type `RashomonSet`. Here are some useful things we can do with this. Observe that the true partition `sigma_111` is the second partition in the RPS and has the least loss.

In [30]:
# Show the partition matrix for each member of the RPS
for sig in RPS_111.sigma:
    print(sig)

# Count the number of pools in each Rashomon partition
print(RPS_111.pools)

[]


In [31]:
# Calculate the loss
# This needs to be done only when calling `RAggregate_profile`
# When calling the main function `RAggregate`, loss is automatically calculated
RPS_111.calculate_loss(D, y, policies_111, policy_means_111, reg=lamb)

# Print the loss for each member in the RPS
print(RPS_111.loss)

[]


Additionally, there is an internal function that manually checks every single partition to see if it belongs to the RPS.

In [32]:
RPS_111_brute_force = aggregate._brute_RAggregate_profile(M, R, H, D, y, theta, profile=(1, 1, 1), reg=lamb)

# Verify that the brute force computation matches the branch-and-bound algorithm
RPS_111_brute_force.P_hash == RPS_111.P_hash

True

## For all profiles

Fix the partition matrices and outcome parameters for all other profiles

In [33]:
# Profile (0, 0, 0)
sigma_000 = None
mu_000 = np.array([0])
var_000 = np.array([1])

# Profile (0, 0, 1)
sigma_001 = np.array([[1]])
mu_001 = np.array([-2])
var_001 = np.array([1])

# Profile (0, 1, 0)
sigma_010 = np.array([[1]])
mu_010 = np.array([1])
var_010 = np.array([1])

# Profile (0, 1, 1)
sigma_011 = np.array([[1], [0]])
mu_011 = np.array([1, -2])
var_011 = np.array([1, 1])

# Profile (1, 0, 0)
sigma_100 = np.array([[0, 1]])
mu_100 = np.array([0, 2])
var_100 = np.array([1, 1])

# Profile (1, 0, 1)
sigma_101 = np.array([[0, 1], [0, np.inf]])
mu_101 = np.array([0, 2, 1, -2])
var_101 = np.array([1, 1, 1, 1])

# Profile (1, 1, 0)
sigma_110 = np.array([[0, 1], [1, np.inf]])
mu_110 = np.array([0, -2])
var_110 = np.array([1, 1])

sigma = [sigma_000, sigma_001, sigma_010, sigma_011, sigma_100, sigma_101, sigma_110, sigma_111]
mu = [mu_000, mu_001, mu_010, mu_011, mu_100, mu_101, mu_110, mu_111]
var = [var_000, var_001, var_010, var_011, var_100, var_101, var_110, var_111]

### Find all pools for each profile

This code block does what we did previously for a single profile. Since `extract_pools` only works for indexing within a Hasse, we need to carefully map the universal indexing of features across all profiles to its corresponding index within the profile that it belongs to. This is why this code chunk appears more complicated than it actually is.

In [34]:
policies_profiles = {}
policies_profiles_masked = {}
policies_ids_profiles = {}
pi_policies = {}
pi_pools = {}
for k, profile in enumerate(profiles):

    policies_temp = [(i, x) for i, x in enumerate(all_policies) if hasse.policy_to_profile(x) == profile]
    unzipped_temp = list(zip(*policies_temp))
    policies_ids_k = list(unzipped_temp[0])
    policies_k = list(unzipped_temp[1])
    policies_profiles[k] = deepcopy(policies_k)
    policies_ids_profiles[k] = policies_ids_k

    profile_mask = list(map(bool, profile))

    # Mask the empty arms
    for idx, pol in enumerate(policies_k):
        policies_k[idx] = tuple([pol[i] for i in range(M) if profile_mask[i]])
    policies_profiles_masked[k] = policies_k

    if np.sum(profile) > 0:
        pi_pools_k, pi_policies_k = extract_pools.extract_pools(policies_k, sigma[k])
        if len(pi_pools_k.keys()) != mu[k].shape[0]:
            print(f"Profile {k}. Expected {len(pi_pools_k.keys())} pools. Received {mu[k].shape[0]} means.")
        pi_policies[k] = pi_policies_k
        # pi_pools_k has indicies that match with policies_profiles[k]
        # Need to map those indices back to all_policies
        pi_pools[k] = {}
        for x, y in pi_pools_k.items():
            y_full = [policies_profiles[k][i] for i in y]
            y_agg = [all_policies.index(i) for i in y_full]
            pi_pools[k][x] = y_agg
    else:
        pi_policies[k] = {0: 0}
        pi_pools[k] = {0: [0]}

### Generate data

Again, this repeats what we did for a single profile for all profiles.

In [35]:
def generate_data(mu, var, n_per_pol, all_policies, pi_policies, M):
    num_data = num_policies * n_per_pol
    X = np.zeros(shape=(num_data, M))
    D = np.zeros(shape=(num_data, 1), dtype='int_')
    y = np.zeros(shape=(num_data, 1))

    idx_ctr = 0
    for k, profile in enumerate(profiles):
        policies_k = policies_profiles[k]

        for idx, policy in enumerate(policies_k):
            policy_idx = [i for i, x in enumerate(all_policies) if x == policy]

            pool_id = pi_policies[k][idx]
            mu_i = mu[k][pool_id]
            var_i = var[k][pool_id]
            y_i = np.random.normal(mu_i, var_i, size=(n_per_pol, 1))

            start_idx = idx_ctr * n_per_pol
            end_idx = (idx_ctr + 1) * n_per_pol

            X[start_idx:end_idx, ] = policy
            D[start_idx:end_idx, ] = policy_idx[0]
            y[start_idx:end_idx, ] = y_i

            idx_ctr += 1

    return X, D, y

In [36]:
num_samples_per_feature = 50000

X, D, y = generate_data(mu, var, num_samples_per_feature, all_policies, pi_policies, M)
policy_means = loss.compute_policy_means(D, y, num_policies)

### Finding the Rashomon Set

In [46]:
H = np.inf
theta = 13
lamb = 1

R_set, R_profiles = aggregate.RAggregate(M, R, H, D, y, theta, reg=lamb, verbose=True)

(0, 0, 0) 12.030462425103009
Profile (0, 0, 0) has 1 objects in Rashomon set
(0, 0, 1) 12.057744867578162
Adaptive
Profile (0, 0, 1) took 0.039968013763427734 s adaptively
Profile (0, 0, 1) has 2 objects in Rashomon set
(0, 1, 0) 12.0582758451157
Adaptive
Profile (0, 1, 0) took 0.04004096984863281 s adaptively
Profile (0, 1, 0) has 2 objects in Rashomon set
(0, 1, 1) 12.114337623086165
Adaptive
Profile (0, 1, 1) took 0.21412301063537598 s adaptively
Profile (0, 1, 1) has 4 objects in Rashomon set
(1, 0, 0) 12.085844227048296
Adaptive
Profile (1, 0, 0) took 0.1397690773010254 s adaptively
Profile (1, 0, 0) has 4 objects in Rashomon set
(1, 0, 1) 12.168943514770383
Adaptive
Profile (1, 0, 1) took 0.7037448883056641 s adaptively
Profile (1, 0, 1) has 8 objects in Rashomon set
(1, 1, 0) 12.168701235927644
Adaptive
Profile (1, 1, 0) took 0.7059371471405029 s adaptively
Profile (1, 1, 0) has 8 objects in Rashomon set
(1, 1, 1) 12.335497837770925
Adaptive
Profile (1, 1, 1) took 3.606073141098

The output of `RAggregate` is different from that of `RAggregate_profile`. For starters, the output is a tuple.

The first item `R_set` is a list. Each item in `R_set` is a list itself. The length of this list is the number of profiles. Each item in `R_set[i]` gives an index for a partition of that profile. So `R_set[i][k]` is the partition of the k-th profile in the i-th Rashomon partition in the RPS.

The second item `R_profiles` is a list whose length is the number of profiles. Each item is the `RashomonSet` object that we saw earlier. The indices in `R_set` correspond to the partitions in `R_profiles`. So the actual partition of `R_set[i][k]` is retrieved by accessing `R_profiles[k].sigma[R_set[i][k]]`.

In [67]:
R_profiles[7]

[array([[ 1.,  1.],
       [ 1., inf],
       [ 1., inf]]), array([[ 1.,  1.],
       [ 1., inf],
       [ 0., inf]]), array([[ 0.,  1.],
       [ 1., inf],
       [ 0., inf]]), array([[ 0.,  1.],
       [ 1., inf],
       [ 1., inf]]), array([[ 1.,  0.],
       [ 1., inf],
       [ 1., inf]]), array([[ 1.,  1.],
       [ 0., inf],
       [ 1., inf]])]

Now, let us see how to access these.

In [49]:
i = 3

RPS_partitions_i = R_set[i]

total_loss = 0
total_pools = 0
for k, profile in enumerate(profiles):
    print("Profile", profile)

    R_partition_i_k = R_profiles[k].sigma[RPS_partitions_i[k]]
    print("Partition")
    print(R_partition_i_k)

    # Notice that unlike the per-profile case, the loss of this partition is already pre-computed
    loss_i_k = R_profiles[k].loss[RPS_partitions_i[k]]
    print(f"Loss = {loss_i_k}")
    
    pools_i_k = R_profiles[k].pools[RPS_partitions_i[k]]
    print(f"Number of pools = {pools_i_k}")

    total_loss += loss_i_k
    total_pools += pools_i_k

    print("---")

print(f"Total loss = {total_loss}, Total number of pools = {total_pools}")

Profile (0, 0, 0)
Partition
None
Loss = 1.0276327713315392
Number of pools = 1
---
Profile (0, 0, 1)
Partition
[[1.]]
Loss = 1.0549152190530173
Number of pools = 1.0
---
Profile (0, 1, 0)
Partition
[[1.]]
Loss = 1.0554461937366306
Number of pools = 1.0
---
Profile (0, 1, 1)
Partition
[[1.]
 [1.]]
Loss = 1.3612916033997204
Number of pools = 1.0
---
Profile (1, 0, 0)
Partition
[[1. 1.]]
Loss = 1.157395257884124
Number of pools = 1.0
---
Profile (1, 0, 1)
Partition
[[ 1.  1.]
 [ 1. inf]]
Loss = 1.5540328652364819
Number of pools = 1.0
---
Profile (1, 1, 0)
Partition
[[ 1.  1.]
 [ 1. inf]]
Loss = 1.3142552442601412
Number of pools = 1.0
---
Profile (1, 1, 1)
Partition
[[ 0.  1.]
 [ 1. inf]
 [ 1. inf]]
Loss = 4.444784097194952
Number of pools = 2.0
---
Total loss = 12.969753252096606, Total number of pools = 9.0


By default `RAggregate` uses only one process. But we can parallelize finding Rashomon sets for each profile by changing the `num_workers` argument

In [50]:
import time

In [51]:
# num_workers = 1

start = time.time()
R_set1, R_profiles1 = aggregate.RAggregate(M, R, H, D, y, theta, reg=lamb, num_workers=1)
end = time.time()
elapsed = end - start

print(f"With 1 worker, RAggregate took {elapsed} s.")

With 1 worker, RAggregate took 9.680739164352417 s.


In [52]:
# num_workers = 2

start = time.time()
R_set2, R_profiles2 = aggregate.RAggregate(M, R, H, D, y, theta, reg=lamb, num_workers=2)
end = time.time()
elapsed = end - start

print(f"With 2 workers, RAggregate took {elapsed} s.")

Process SpawnPoolWorker-7:
Process SpawnPoolWorker-8:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/fanyiyang/miniconda3/envs/scGPT/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/fanyiyang/miniconda3/envs/scGPT/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/fanyiyang/miniconda3/envs/scGPT/lib/python3.11/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File "/Users/fanyiyang/miniconda3/envs/scGPT/lib/python3.11/multiprocessing/pool.py", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/fanyiyang/Desktop/Research/RHS-X/understanding RPS paper revised/rashomon-partition-sets/Code/rashomon/aggregate/raggregate.py", line 154, in parallel_worker_RAggregat_pr

KeyboardInterrupt: 

In [None]:
# Check whether the results are the same
print(R_set1 == R_set2)

True


The difference of 1 second seems negligible but the gains will be more substantial when there are more features.

In [320]:
R_profiles[0].P_qe

[None]

### First step: find neighbourhood of a partition

In [258]:
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, List, Tuple, Optional, Sequence
import numpy as np
import math, random

# --- bring your repo primitives (edit path as needed) ---
# compute_Q(D, y, sigma, policies, policy_means, reg=1, normalize=0, lattice_edges=None) -> float
from rashomon.loss import compute_Q
from rashomon.aggregate.raggregate import RAggregate          # (variant A)
from rashomon import hasse

###
# 1) State + anchor materialization
###

class ProfilePart:
    cov_ids: List[int]        # active covariate indices for THIS profile
    B: np.ndarray  

    def __init__(self, cov_ids, B):
        self.cov_ids = cov_ids
        self.B = B

State = List[ProfilePart]  # length P; each item is an (M,R) int matrix

def build_anchor_states(R_set, R_profiles, M: int, R: int) -> List[State]:
    """
    R_set[g]: list[int] of length P (which candidate per profile)
    R_profiles[p][i]: (M,R) matrix for profile p, candidate i
    Returns: list of states; each state is [Sigma_0, ..., Sigma_{P-1}], Sigma_p ∈ ℤ^{M×R}.
    """
    P = len(R_profiles)
    anchors: List[State] = []
    profiles, profile_map = hasse.enumerate_profiles(int(math.log2(P)))
    for idx_vec in R_set:
        if len(idx_vec) != P:
            raise ValueError(f"R_set entry length {len(idx_vec)} != P={P}")
        state = []
        for p, sel in enumerate(idx_vec):
            cand = R_profiles[p].sigma[int(sel)]
            state.append(ProfilePart(cov_ids=profiles[p], B=cand))
        anchors.append(state)
        # # dedup by bytes of concatenated matrices
        # key = b"".join(S.tobytes() for S in state)
        # if key not in seen:
        #     anchors.append([S.copy() for S in state])
        #     seen.add(key)
    return anchors

# ---------- One-boundary shifts (UBS) inside a profile ----------
# def _row_segments_from_boundaries(b: np.ndarray):
#     R = len(b)
#     start = 0
#     end = R-1
#     segs=[]
#     for i in range(R):
#         if b[i] == 1: 
#             segs.append((start, i))
#             start = i+1
#             if (start == R): 
#                 segs.append((R,R))
#                 return segs
#         if i == R-1:
#             segs.append((start, R))
#             return segs
        

def _unit_shift_row_neighbors(b: np.ndarray) -> List[np.ndarray]:
    """
    Given a boundary row b (0/1, length R), return all neighbors formed by
    moving a single 1 one position left or right IF the destination is 0.
    No min-segment logic; exactly the 'check-adjacent-zero then move' rule.
    """
    b = np.asarray(b, dtype=int).ravel()
    R = b.size
    neigh = []
    ones = np.where(b == 1)[0]
    for j in ones:
        # move left
        if j - 1 >= 0 and b[j - 1] == 0:
            nb = b.copy()
            nb[j] = 0
            nb[j - 1] = 1
            neigh.append(nb)
        # move right
        if j + 1 < R and b[j + 1] == 0:
            nb = b.copy()
            nb[j] = 0
            nb[j + 1] = 1
            neigh.append(nb)
    # dedup in case two different moves produce same result (unlikely here)
    uniq = {}
    for r in neigh:
        uniq[r.tobytes()] = r
    return list(uniq.values())

def profile_part_neighbors_ubs(part: ProfilePart, min_len=1) -> List[ProfilePart]:
    C, R = part.B.shape
    out=[]
    for r in range(C):
        b = part.B[r,:]
        while math.isinf(b[-1]): # get rid of inf to compute boundary
            b = b[:-1]
        if len(b)<=1: continue
        
        for nb_row in _unit_shift_row_neighbors(b):
            if nb_row is None: continue
            nbB = part.B.copy()
            while (len(nbB[r,:]) > len(nb_row)): # pad inf back to keep format
                nb_row.append(float('inf'))
            nbB[r,:] = nb_row
            out.append(ProfilePart(cov_ids=list(part.cov_ids), B=nbB))
    return out

def state_neighbors_ubs(state: List[ProfilePart], min_len=1) -> List[State]: # state: list of matrix in the order of profiles
    neigh=[]
    y = state
        
    for p, profile_state in enumerate(state):
        if p == 0:
            continue
        part = profile_state # the profile and partition we are focusing on
        for nb in profile_part_neighbors_ubs(part, min_len=min_len): # enumerate all neighbor of this profile and partition
            y_new = y
            y_new[p]=nb
            neigh.append(y_new)
    return neigh

In [316]:
anchor = build_anchor_states(R_set, R_profiles, M, R)

state_neighbors_ubs(anchor[2])[0][7].B
R_set

[array([0, 0, 0, 0, 0, 0, 0, 0]),
 array([0, 0, 0, 0, 0, 0, 0, 1]),
 array([0, 0, 0, 0, 0, 0, 0, 2]),
 array([0, 0, 0, 0, 0, 0, 0, 3]),
 array([0, 0, 0, 0, 0, 0, 0, 4]),
 array([0, 0, 0, 0, 0, 0, 0, 5]),
 array([0, 0, 0, 0, 0, 0, 1, 0]),
 array([0, 0, 0, 0, 0, 0, 1, 1]),
 array([0, 0, 0, 0, 0, 0, 2, 0]),
 array([0, 0, 0, 0, 0, 0, 3, 0]),
 array([0, 0, 0, 0, 0, 1, 0, 0]),
 array([0, 0, 0, 0, 0, 1, 0, 1]),
 array([0, 0, 0, 0, 0, 2, 0, 0]),
 array([0, 0, 0, 0, 0, 2, 0, 1]),
 array([0, 0, 0, 0, 0, 3, 0, 0]),
 array([0, 0, 0, 0, 1, 0, 0, 0]),
 array([0, 0, 0, 0, 1, 0, 0, 1]),
 array([0, 0, 0, 0, 2, 0, 0, 0]),
 array([0, 0, 0, 1, 0, 0, 0, 0]),
 array([0, 0, 0, 1, 0, 0, 0, 1]),
 array([0, 0, 0, 2, 0, 0, 0, 0]),
 array([0, 0, 1, 0, 0, 0, 0, 0]),
 array([0, 1, 0, 0, 0, 0, 0, 0])]

### Second step: AIS, proposal

In [None]:
# ---------- q0, MH, AIS on this state space ----------
def states_equal(a: State, b: State) -> bool:
    if len(a)!=len(b): return False
    for ap,bp in zip(a,b):
        if ap.cov_ids != bp.cov_ids: return False
        if not np.array_equal(ap.B, bp.B): return False
    return True

def Kh_logpdf_state_stay_or_step(x: State, a: State, p_stay=0.25, min_len=1):
    if states_equal(x,a): return math.log(p_stay)
    N = state_neighbors_ubs(a, min_len=min_len)
    for n in N:
        if states_equal(x,n):
            return math.log(1.0-p_stay) - math.log(len(N))
    return float("-inf")

def Kh_sample_state_stay_or_step(a: State, rng: np.random.Generator, p_stay=0.25, min_len=1):
    if rng.random() < p_stay:  # stay
        return [ProfilePart(cov_ids=p.cov_ids, B=(p.B.copy() if p.B is not None else None)) for p in a]
    N = state_neighbors_ubs(a, min_len=min_len)
    if not N:
        return [ProfilePart(cov_ids=p.cov_ids, B=(p.B.copy() if p.B is not None else None)) for p in a]
    return N[rng.integers(0,len(N))]

from dataclasses import dataclass

@dataclass
class AnchorMixture:
    anchors: List[State]
    log_alpha: List[float]
    p_stay: float = 0.25
    min_len: int = 1
    def __post_init__(self):
        m = max(self.log_alpha) if self.log_alpha else 0.0
        w = [math.exp(z-m) for z in self.log_alpha] or [1.0]
        Z = sum(w) or 1.0
        self.alpha = [wi/Z for wi in w]
    def log_q0(self, x: State) -> float:
        terms=[]
        for a,A in zip(self.alpha, self.anchors):
            if a==0: continue
            lk = Kh_logpdf_state_stay_or_step(x, A, p_stay=self.p_stay, min_len=self.min_len)
            if lk==float("-inf"): continue
            terms.append(math.log(a)+lk)
        if not terms: return float("-inf")
        m=max(terms); return m+math.log(sum(math.exp(z-m) for z in terms))
    def sample_q0(self, rng: np.random.Generator) -> State:
        idx = rng.choice(len(self.anchors), p=np.array(self.alpha))
        return Kh_sample_state_stay_or_step(self.anchors[idx], rng, p_stay=self.p_stay, min_len=self.min_len)

def make_ladder(K:int, gamma:float=4.0):
    if K<2: return np.array([1.0], float)
    g=np.linspace(0.0,1.0,K); return g**gamma

@dataclass
class AISConfig:
    n_paths:int=600
    n_levels:int=40
    moves_per_level:int=12
    min_len:int=1
    seed:Optional[int]=2

@dataclass
class AISOutput:
    terminals: List[State]
    logw: np.ndarray
    normw: np.ndarray
    ladder: np.ndarray

def mh_step_state_uniform_neighbors(x: State, t: float, log_q0, score_s, min_len=1) -> State:
    N_cur = state_neighbors_ubs(x, min_len=min_len)
    if not N_cur: return x
    prop = random.choice(N_cur) # random proposal from the neighbourhood
    N_prop = state_neighbors_ubs(prop, min_len=min_len) # neighbourhood of the proposed state
    def logpi_t(z: State): # target distribution at this specifc t
        lq0 = log_q0(z)
        if lq0==float("-inf"): return float("-inf")
        return (1.0-t)*lq0 + t*score_s(z) # formula for target dist at ladder step t
    lcur, lprop = logpi_t(x), logpi_t(prop) # pi_t of x and proposed 
    if lprop == float("-inf"): return x
    logr = (lprop-lcur) + math.log(max(1,len(N_cur))) - math.log(max(1,len(N_prop))) # acceptance ratio
    if math.log(random.random()+1e-300) < min(0.0, logr): return prop
    return x

def run_ais_state(anchors: List[State], score_s: Callable[[State], float], cfg: AISConfig=AISConfig()) -> AISOutput:
    if cfg.seed is not None:
        np.random.seed(cfg.seed); random.seed(cfg.seed)
    rng=np.random.default_rng(cfg.seed)
    log_alpha=[score_s(A) for A in anchors]
    mix=AnchorMixture(anchors, log_alpha, p_stay=0.25, min_len=cfg.min_len)
    ladder = make_ladder(cfg.n_levels)
    terminals=[]; logw=np.zeros(cfg.n_paths,float)
    for p in range(cfg.n_paths):
        x = mix.sample_q0(rng); lw=0.0 # start with initial sample from anchor
        for t_prev, t_cur in zip(ladder[:-1], ladder[1:]):
            lq0_x = mix.log_q0(x); s_x = score_s(x) # calculate relevant posterior for specific x
            if lq0_x==float("-inf"):
                x=mix.sample_q0(rng); lq0_x=mix.log_q0(x); s_x=score_s(x)
            lw += (t_cur - t_prev) * (s_x - lq0_x)
            for _ in range(cfg.moves_per_level):
                x = mh_step_state_uniform_neighbors(x, t_cur, mix.log_q0, score_s, min_len=cfg.min_len)
        terminals.append([ProfilePart(cov_ids=pp.cov_ids, B=(pp.B.copy() if pp.B is not None else None)) for pp in x])
        logw[p]=lw
    m=float(np.max(logw)); w=np.exp(logw-m); return AISOutput(terminals, logw, w/w.sum(), ladder)


In [337]:
import numpy as np
from typing import Any, Sequence, Optional, Union
from rashomon.loss import compute_Q as _compute_Q

# ---------- 1) Q when sigma is None: one-pool predictor (mean of y_k) ----------
def _compute_Q_none_mean(D_k: np.ndarray, y_k: np.ndarray, *, reg: float, normalize: int) -> float:
    yv = y_k.ravel()
    if yv.size:
        mu = float(np.mean(yv))
        mse = float(np.mean((yv - mu) ** 2))
        if normalize:
            mse = mse * yv.size / float(normalize)
    else:
        mse = 0.0
    h = 1  # one pool
    return mse + reg * h

# ---------- 2) Global loss: sum per-profile losses; DO NOT touch non-None sigma ----------
RLike = Union[int, Sequence[int], np.ndarray]

def global_loss_raw(
    state: Sequence[Any],          # per profile: entry has .B OR is the sigma object itself OR None
    D: np.ndarray,                 # shape (N,1): [policy_id]
    y: np.ndarray,                 # shape (N,) or (N,1)
    policies: Sequence[Any],       # length P
    policy_means: np.ndarray,      # shape (P,2)
    reg: float = 1.0,
    normalize: int = 0,
    lattice_edges=None,
) -> float:
    total = 0.0
    P = len(state)
    for k in range(P):
        sigma_k = getattr(state[k], "B", state[k])  # DO NOT sanitize; pass through as-is
        mask = (D[:, 0].astype(int) == int(k))
        if not np.any(mask):
            # No rows for this profile -> treat as one pool (sigma=None logic)
            Q_k = _compute_Q_none_mean(D_k=D[:0], y_k=y[:0], reg=reg, normalize=normalize)
            total += Q_k
            continue

        D_k = D[mask].copy()
        y_k = y[mask].copy()
        D_k[:, 0] = 0  # localize to single-policy view

        policies_k = [policies[k]]
        pm_k = np.asarray(policy_means[k]).reshape(1, 2)

        if sigma_k is None:
            Q_k = _compute_Q_none_mean(D_k=D_k, y_k=y_k, reg=reg, normalize=normalize)
        else:
            # Pass sigma_k EXACTLY as provided (0/1/inf allowed)
            Q_k = float(_compute_Q(
                D=D_k, y=y_k, sigma=sigma_k,
                policies=policies_k, policy_means=pm_k,
                reg=reg, normalize=normalize, lattice_edges=lattice_edges
            ))
        total += Q_k
    return total

# ---------- 3) AIS score: exp(-beta * global_loss_raw(state)) ----------
def make_score_s_expneg_raw(
    *,
    D,
    y,
    policies,
    policy_means,
    reg: float = 1.0,
    normalize: int = 0,
    lattice_edges=None,
    beta: float = 1.0,                 # set 1.0 for exp(-loss)
    prior_logprob=lambda state: 0.0,   # optional log-prior
):
    def score_s(state):
        Q = global_loss_raw(
            state=state,
            D=D, y=y,
            policies=policies,
            policy_means=policy_means,
            reg=reg, normalize=normalize,
            lattice_edges=lattice_edges,
        )
        return float(np.exp(prior_logprob(state) - beta * Q))
    return score_s


In [341]:
# anchors: list[ProfilePart] per profile (your state format)
anchors = build_anchor_states(R_set, R_profiles, M, R)

# build the scorer for AIS
score_s = make_score_s_expneg_raw(
    D=D,
    y=y,
    policies=all_policies,
    policy_means=policy_means,
    reg=lamb,
    lattice_edges=None,   # or your edges
    beta=1.0,             # exp(-loss)
    prior_logprob=lambda state: 0.0
)

# then pass score_s into your AIS routine
# ais_out = run_ais_state(anchors, score_s=score_s, cfg=cfg)


cfg = AISConfig(n_paths=600, n_levels=40, moves_per_level=12, min_len=1, seed=3)
ais_out = run_ais_state(anchors, score_s=score_s, cfg=cfg)


In [347]:
max(ais_out.normw)

np.float64(0.002010050345678765)

In [349]:
import pickle
import gzip

def save_ais_pickle(path: str, particles: List[List[ProfilePart]], weights, meta: dict | None = None):
    """
    path: e.g. 'ais_run.pkl' or 'ais_run.pkl.gz' (gzip if endswith .gz)
    particles: list of states; each state = list[ProfilePart]
    weights: array-like of floats (same length as particles)
    meta: optional dict with run info (seed, ladder, etc.)
    """
    payload = {
        "particles": particles,
        "weights": np.asarray(weights, float),
        "meta": meta or {}
    }
    opener = gzip.open if path.endswith(".gz") else open
    with opener(path, "wb") as f:
        # Use highest protocol for speed/size
        pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_ais_pickle(path: str):
    """Returns (particles, weights, meta)."""
    opener = gzip.open if path.endswith(".gz") else open
    with opener(path, "rb") as f:
        obj = pickle.load(f)
    return obj["particles"], obj["weights"], obj.get("meta", {})

In [355]:
save_ais_pickle("ais_particles.pkl.gz", ais_out.terminals, ais_out.normw,
                meta={"run_id":"first_trial","n_paths":600,"n_levels":40,"moves_per_level":12,"reg":1,"seed":3,"min_length":1})

particles2, weights2, meta2 = load_ais_pickle("ais_particles.pkl.gz")


In [356]:
meta2

{'run_id': 'first_trial',
 'n_paths': 600,
 'n_levels': 40,
 'moves_per_level': 12,
 'reg': 1,
 'seed': 3,
 'min_length': 1}