In [21]:
import numpy as np
from scipy.sparse import dia_matrix

def construct_multidiagonal_matrix(v1, v2, C):
    """
    Construct a multidiagonal sparse matrix where M[i,j] = abs(v1[i] - v2[j]) if abs(i-j) < C.
    
    Parameters:
    v1, v2 (array-like): Input vectors (must be same length)
    C (int): Bandwidth parameter controlling how many diagonals are non-zero
    
    Returns:
    dia_matrix: The constructed sparse matrix in DIAgonal format
    """
    n = len(v1)
    assert len(v2) == n, "v1 and v2 must have the same length"
    
    offsets = np.arange(-C + 1, C)  # Diagonals from -C+1 to C-1
    
    data = []
    for k in offsets:
        # For diagonal k, compute valid indices
        i_min = max(0, -k)
        i_max = min(n, n - k)
        diag_length = i_max - i_min
        
        # Compute the values for this diagonal
        diag_values = np.abs(v1[i_min:i_max] - v2[i_min + k:i_max + k])
        
        # Pad to ensure the diagonal has length `n`
        if k < 0:
            # Lower diagonal: pad at the end
            diag_values_padded = np.pad(diag_values, (0, n - diag_length), mode='constant')
        else:
            # Upper diagonal: pad at the beginning
            diag_values_padded = np.pad(diag_values, (n - diag_length, 0), mode='constant')
        
        data.append(diag_values_padded)
    
    data = np.array(data)
    return dia_matrix((data, offsets), shape=(n, n))

# Test with your example
v1 = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
v2 = np.array([1.5, 2.5, 3.5, 4.5, 5.5, 6.5])
C = 2

M_sparse = construct_multidiagonal_matrix(v1, v2, C)

print("DIA Matrix data (diagonals):")
print(M_sparse.data)

print("\nDIA Matrix offsets:")
print(M_sparse.offsets)

print("\nDense matrix:")
print(M_sparse.toarray())

DIA Matrix data (diagonals):
[[0.5 0.5 0.5 0.5 0.5 0. ]
 [0.5 0.5 0.5 0.5 0.5 0.5]
 [0.  1.5 1.5 1.5 1.5 1.5]]

DIA Matrix offsets:
[-1  0  1]

Dense matrix:
[[0.5 1.5 0.  0.  0.  0. ]
 [0.5 0.5 1.5 0.  0.  0. ]
 [0.  0.5 0.5 1.5 0.  0. ]
 [0.  0.  0.5 0.5 1.5 0. ]
 [0.  0.  0.  0.5 0.5 1.5]
 [0.  0.  0.  0.  0.5 0.5]]


In [38]:
import importlib
import ot_sparse
importlib.reload(ot_sparse) 
# print(dir(ot_sparse))
v1 = np.random.rand(1000)
v2 = np.random.rand(1000)
C = 5  # Bandwidth
M_sparse = construct_multidiagonal_matrix(v1, v2, C)

# Uniform marginals
a = np.ones(1000) / 1000
b = np.ones(1000) / 1000

# Parameters
reg = 0
reg_m1 = 0.1
reg_m2 = 0.1

# Initial guess (sparse representation)
G0_flat = np.concatenate([np.ones(1000 - abs(k)) * 0.1 for k in M_sparse.offsets])

# Get the loss function
_func = ot_sparse.get_loss_unbalanced_sparse(a, b, None, M_sparse, reg, reg_m1, reg_m2)

print(_func(G0_flat))

(np.float64(477.2115842324547), array([0.22865932, 0.11946734, 0.65157539, ..., 0.23079947, 0.87111928,
       0.76936016], shape=(8980,)))
