In [None]:
import torch, itertools, random
from typing import List, Tuple


def compute_v(num_cri:int, num_interval:int, cri_crisp:dict):

    value_cri = [torch.as_tensor(list(vals), dtype=torch.float64) for vals in cri_crisp.values()]
    assert len(value_cri) == num_cri, f"num_cri mismatch: got {len(value_cri)} vs {num_cri}"

    num_alter = value_cri[0].numel()
    for arr in value_cri[1:]:
        assert arr.numel() == num_alter, "Inconsistent number of alternatives across criteria"

    v_j = []
    eps = 1e-12
    for i in range(num_cri):
        vals = value_cri[i]
        v_j.append([])
        vmin = torch.min(vals).item()
        vmax = torch.max(vals).item()
        interval_size = (vmax - vmin) / num_interval if num_interval > 0 else 0.0

        x_j = [vmin + k * interval_size for k in range(num_interval + 1)]

        for _, g_qi in enumerate(vals.tolist()):
            row = []
            if interval_size <= eps:

                row = [1.0] * num_interval
            else:
                for t in range(1, num_interval + 1):
                    if g_qi > x_j[t]:
                        row.append(1.0)
                    elif x_j[t - 1] <= g_qi <= x_j[t]:
                        denom = x_j[t] - x_j[t - 1]
                        row.append(1.0 if denom <= eps else (g_qi - x_j[t - 1]) / denom)
                    else:
                        row.append(0.0)
            v_j[i].append(row)

    return torch.tensor(v_j, dtype=torch.float64)  # [num_cri, num_alter, num_interval]


def V_INT(v: torch.Tensor):

    num_cri, num_alter, num_intervals = v.shape
    V_plus, V_minus, V_int = [], [], []

    for a in range(num_alter):
        V_plus.append([])
        V_minus.append([])
        c = v[:, a, :].reshape(-1)  
        for i in range(num_cri):
            via = v[i, a, :]  
            for j in range(i + 1, num_cri):
                vja = v[j, a, :]
                b = (via[:, None] * vja[None, :]).reshape(-1)  
                V_plus[a].append(b)
                V_minus[a].append(b.clone())  

        out = c
        for k in range(len(V_plus[a])):
            out = torch.cat((out, V_plus[a][k], V_minus[a][k]), dim=0)
        V_int.append(out)

    return V_int, V_plus, V_minus



def interaction(u_int_new: torch.Tensor, num_cri:int, num_interval:int):

    device = u_int_new.device
    dtype = u_int_new.dtype

    numbers = list(range(1, num_cri + 1))
    random.shuffle(numbers)
    pairs = []
    for i in range(0, len(numbers) - 1, 2):
        pairs.append(sorted([numbers[i], numbers[i + 1]]))

    if len(pairs) > 0:
        z = torch.randint(0, 3, (len(pairs),), device=device)

    head = num_cri * num_interval
    u_int_new[head:] = torch.zeros_like(u_int_new[head:])

    nu = list(range(1, num_cri + 1))
    pairs_dup = list(itertools.combinations(nu, 2))
    block_len = num_interval * num_interval

    for idx, pair in enumerate(pairs):
        i1 = pair[0] - 1 
        j1 = pair[1] - 1
        pos = pairs_dup.index((pair[0], pair[1]))

        start = head + pos * (block_len * 2)
        mid   = start + block_len
        end   = start + block_len * 2

        upper = []
        for k in range(num_interval):
            for l in range(num_interval):
                uik = u_int_new[i1 * num_interval + k]
                ujl = u_int_new[j1 * num_interval + l]
                upper.append(torch.minimum(uik, ujl))
        upper = torch.stack(upper).to(dtype=dtype, device=device)

        rand_block = torch.rand(block_len, dtype=dtype, device=device) * upper

        tag = z[idx].item() if len(pairs) > 0 else 0
        if tag == 0:
            continue
        elif tag == 1: 
            u_int_new[start:mid] = rand_block
            u_int_new[mid:end] =  torch.zeros_like(rand_block)
        else:           
            u_int_new[start:mid] = torch.zeros_like(rand_block)
            u_int_new[mid:end]   = -rand_block

    s = u_int_new.sum()
    u_int_new = u_int_new / (s + torch.as_tensor(1e-12, dtype=dtype, device=device))
    return u_int_new


def _loglike_logistic(V_list: List[torch.Tensor], u: torch.Tensor,
                      Q: List[Tuple[int,int]], omega: float):

    dots = torch.stack([torch.dot(Va.to(u.dtype), u) for Va in V_list])  # 先批量算内积
    loglike = torch.tensor(0.0, dtype=u.dtype, device=u.device)
    for (a, b) in Q:
        x = omega * (dots[a] - dots[b])
        loglike = loglike + torch.nn.functional.logsigmoid(x)  # log σ(x)
    return loglike


def calculate_acceptance_probability(u_old: torch.Tensor, u_new: torch.Tensor,
                                     Q: List[Tuple[int,int]], alpha: torch.Tensor,
                                     V_list: List[torch.Tensor], omega: float,
                                     q_alpha: torch.Tensor):

    dtype = torch.float64
    device = u_old.device
    u_old = u_old.to(dtype)
    u_new = u_new.to(dtype)
    alpha = alpha.to(dtype).to(device)
    q_alpha = q_alpha.to(dtype).to(device)

    ll_new = _loglike_logistic(V_list, u_new, Q, omega)
    ll_old = _loglike_logistic(V_list, u_old, Q, omega)

    dir_prior = torch.distributions.Dirichlet(alpha)
    dir_prop  = torch.distributions.Dirichlet(q_alpha)

    log_ratio = (ll_new - ll_old) \
                + (dir_prior.log_prob(u_new) - dir_prior.log_prob(u_old)) \
                + (dir_prop.log_prob(u_old) - dir_prop.log_prob(u_new))

    r = torch.exp(torch.clamp(log_ratio, max=50.0))
    r = torch.minimum(r, torch.tensor(1.0, dtype=dtype, device=device))
    return r


# ===============================
def metropolis_hastings(M:int, N:int, num_chains:int, Q:List[Tuple[int,int]], data:dict,
                        num_cri:int, num_alter:int, num_interval:int, omega:float,
                        alpha_val:float=1.0):

    # 维度
    num_pairs = num_cri * (num_cri - 1) // 2
    num_para = num_cri * num_interval + num_interval * num_interval * num_pairs * 2

    alpha   = torch.full((num_para,), float(alpha_val), dtype=torch.float64)
    q_alpha = torch.ones_like(alpha)  


    v = compute_v(num_cri, num_interval, data)     # [num_cri, num_alter, num_interval]
    V_int, _, _ = V_INT(v)                         # list(len=num_alter)
    assert len(V_int) == num_alter and V_int[0].numel() == num_para

    all_samples = []
    all_V_int = V_int  
    all_U = []

    for _ in range(num_chains):
        u = torch.distributions.Dirichlet(alpha).sample()
        u = interaction(u, num_cri, num_interval)

        samples, U = [], []
        for t in range(M + N):
            u_new = torch.distributions.Dirichlet(q_alpha).sample()
            u_new = interaction(u_new, num_cri, num_interval)

            r = calculate_acceptance_probability(u, u_new, Q, alpha, V_int, omega, q_alpha)
            if torch.rand(()) < r.item():
                u = u_new

            if t >= M:  
                samples.append(u.clone())
                u64 = u.to(torch.float64)
                U_INT = [torch.dot(u64, V_int[i].to(u64.dtype)) for i in range(num_alter)]
                U.append(U_INT)

        all_samples.append(samples)
        all_U.append(U)

    return all_samples, all_U, all_V_int
