# Piecewise-Linear Regression

### `SMC` approach

In [1]:
# import libraries
import import_ipynb
import toolbox_sccf as sccf
import toolbox_SMC_backend as smc

## basic imports 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
import warnings
warnings.filterwarnings('ignore')
from IPython.display import clear_output
import seaborn as sns
import cvxpy as cp
import time
clear_output()
np.random.seed(120824)

In [2]:
## problem dimensions

w = int(2)
B1 = int(5) 
B2 = int(5) 
N = int(5e3)
L = 100.0 # X search space dim

radius_data = 10 

## data representation
mode = 'full' # others include 'regular','truncated'
stand = False # others include False 

### synthetic data generation

In [3]:
from scipy.stats import ortho_group

def gt(omegas):
    return np.array([1/10*omega[-1]**3 + 1*np.sin(sum(omega)) + 4+omega[0]*omega[1]-2/10*np.sum(omega)**2 for omega in omegas]) 

In [4]:
'''
generate "num_points" random points in "dimension" that have uniform
 probability over the unit ball scaled by "radius" (length of points
 are in range [0, "radius"]).
'''

# https://stackoverflow.com/questions/54544971/how-to-generate-uniform-random-points-inside-d-dimension-ball-sphere 

def random_ball(num_points, dimension, radius=1):
    random_directions = np.random.normal(size=(dimension,num_points))
    random_directions /= np.linalg.norm(random_directions, axis=0)
    random_radii = np.random.random(num_points) ** (1/dimension)
    return radius * (random_directions * random_radii).T


omegas = random_ball(N,w,radius=radius_data)
taus = gt(omegas)

In [5]:
omega1 = np.linspace(-radius_data, radius_data, int(2e2))
omega2 = np.linspace(-radius_data, radius_data, int(2e2))

OMEGA1, OMEGA2 = np.meshgrid(omega1, omega2)   
list_omegas = []
for o1 in omega1:
    for o2 in omega2:
        list_omegas.append([o1,o2])

list_omegas = np.array(list_omegas)

### utils

In [6]:
'''
utils (global)
'''    

def extended_omegas(omegas,augmented=False):
    if augmented:
        N,d = omegas.shape
        embedding = np.zeros((N,int(d+d*(d+1)/2)))
        for s,samp in enumerate(omegas):
            new_omega = list(samp)
            buf = np.outer(samp,samp)
            for ofd in range(d):
                new_omega += list(np.diagonal(buf,ofd))
            embedding[s] = np.array(new_omega)
    else:
        embedding = omegas.copy()
    return np.hstack((embedding,np.ones((len(omegas),1))))

def rescale(data):
    num_samples,num_features = data.shape
    new_data = data.copy()
    means,stds = [],[]
    for colid,feat in enumerate(data.T):
        mfeat,stdfeat = np.mean(feat),np.std(feat)
        means.append(mfeat)
        stds.append(stdfeat)
        if stdfeat>0:
            new_data[:,colid] = (feat-mfeat)/stdfeat
    return new_data,means,stds

def coordinates(sigma_,B1_=B1,B2_=B2):
    assert sigma_>=0 and sigma_<=B1_*B2_-1, 'value range error (0 -> H-1)'
    e2 = sigma_%B2_
    e1 = (sigma_-(e2-1))/B2_
    return (int(e1),int(e2))

def selector(coordinates,B1_=B1,B2_=B2):
    return coordinates[1]+coordinates[0]*B2_

def recast(x,w_ext,B1_=B1,B2_=B2):
    assert len(x)==(B1_+B2_)*w_ext, 'dimension error'
    mat = x.reshape((B1_+B2_,w_ext))
    return mat[:B1_],mat[B1_:]

def mix(mat1,mat2):
    assert mat1.shape[1]==mat2.shape[1],'dimension error'
    buf = []
    for lm1 in mat1:
        for lm2 in mat2:
            buf.append(list(lm1+lm2))
    return np.array(buf)

def unfold(mat1,mat2):
    return (np.vstack((mat1,mat2))).flatten()

In [7]:
if mode=='regular':
    ext_omegas = extended_omegas(omegas)
    ext_list_omegas = extended_omegas(list_omegas)
else:
    ext_omegas_full = extended_omegas(omegas,True)
    ext_list_omegas_full = extended_omegas(list_omegas,True)
    if mode=='full':
        ext_omegas = ext_omegas_full.copy()
        ext_list_omegas = ext_list_omegas_full.copy()
    else:
        max_dim = ext_omegas_full.shape[1]-1
        select = np.concatenate((np.random.choice(np.arange(max_dim),replace=False,size=int(max_dim*2/3)),[max_dim]))
        ext_omegas = ext_omegas_full[:,select]
        ext_list_omegas = ext_list_omegas_full[:,select]

if stand:
    polished_omegas,means,stds = rescale(ext_omegas)
else:
    polished_omegas,means,stds = ext_omegas.copy(),np.zeros(len(ext_omegas)),np.zeros(len(ext_omegas))
    
polished_list_omegas = []
for idcol,col in enumerate(ext_list_omegas.T):
    if stds[idcol]>0:
        polished_list_omegas.append((col-means[idcol])/stds[idcol])
    else:
        polished_list_omegas.append(col)
polished_list_omegas = np.array(polished_list_omegas).T

        
w_ext = len(polished_omegas[0])

In [8]:
N_train = min(len(omegas),int(1.5e3))
id_train = np.random.choice(np.arange(len(omegas)),replace=False,size=N_train)

omegas_train = omegas[id_train]
polished_omegas_train = polished_omegas[id_train]
taus_train = taus[id_train]

In [9]:
from sklearn.cluster import KMeans
kmeans1 = KMeans(n_clusters=int(B1),n_init=20).fit(omegas_train)
kmeans2 = KMeans(n_clusters=int(B2),n_init=20).fit(omegas_train)

In [10]:
'''
functions
'''

def loss(x,target,data,B1_=B1,B2_=B2):
    mat1,mat2 = recast(x,data.shape[1],B1_,B2_)
    val1,val2 = mat1@data.T,mat2@data.T
    mval1,mval2 = np.max(val1,0),np.max(val2,0)
    preds = mval1-mval2
    return np.mean(np.abs(target-preds))

def h_vals(x,target,data,B1_=B1,B2_=B2):
    mat1,mat2 = recast(x,data.shape[1],B1_,B2_)
    val1,val2 = mat1@data.T,mat2@data.T
    mval1,mval2 = np.max(val1,0),np.max(val2,0)
    h_bar = np.maximum(target+mval2,mval1)+np.maximum(-target+mval1,mval2)
    mix_mat = mix(mat1,mat2)
    return np.outer(h_bar,np.ones(len(mix_mat)))-(mix_mat@data.T).T

In [11]:
"""
recovering of init affectation
"""
init_partition = []
for e1,e2 in zip(kmeans1.labels_,kmeans2.labels_):
    init_partition.append((e1,e2))

In [12]:
init_partition

[(0, 0),
 (3, 3),
 (1, 1),
 (4, 2),
 (3, 3),
 (4, 2),
 (2, 4),
 (3, 3),
 (2, 4),
 (0, 0),
 (1, 1),
 (4, 2),
 (1, 1),
 (0, 0),
 (1, 1),
 (3, 3),
 (4, 2),
 (4, 2),
 (0, 0),
 (2, 4),
 (3, 3),
 (4, 2),
 (0, 0),
 (0, 0),
 (2, 4),
 (2, 4),
 (4, 2),
 (3, 3),
 (2, 4),
 (0, 0),
 (2, 4),
 (4, 2),
 (0, 0),
 (0, 0),
 (1, 1),
 (3, 3),
 (0, 0),
 (3, 3),
 (4, 2),
 (3, 3),
 (0, 0),
 (2, 4),
 (1, 1),
 (2, 4),
 (4, 2),
 (3, 3),
 (0, 0),
 (3, 3),
 (0, 0),
 (4, 2),
 (1, 1),
 (0, 0),
 (3, 3),
 (3, 3),
 (3, 3),
 (1, 1),
 (1, 1),
 (0, 0),
 (2, 4),
 (0, 0),
 (4, 2),
 (0, 0),
 (4, 2),
 (3, 3),
 (1, 1),
 (2, 4),
 (3, 3),
 (2, 4),
 (4, 2),
 (0, 0),
 (2, 4),
 (1, 1),
 (3, 3),
 (1, 1),
 (3, 3),
 (0, 0),
 (0, 0),
 (2, 4),
 (2, 4),
 (4, 2),
 (4, 2),
 (0, 0),
 (1, 1),
 (2, 4),
 (2, 4),
 (3, 3),
 (1, 1),
 (1, 1),
 (1, 1),
 (0, 0),
 (0, 0),
 (0, 0),
 (1, 1),
 (1, 1),
 (1, 1),
 (1, 1),
 (3, 3),
 (0, 0),
 (0, 0),
 (2, 4),
 (3, 3),
 (1, 1),
 (4, 2),
 (1, 1),
 (3, 3),
 (2, 4),
 (2, 4),
 (4, 2),
 (0, 0),
 (3, 3),
 (2, 4),
 

In [13]:
def part2weights(partition,B1_=B1,B2_=B2):
    N_loc = len(partition)
    buf = np.zeros((N_loc,B1_*B2_))
    for idtup,tup in enumerate(partition):
        buf[idtup,selector(tup)] += 1
    return [buf]

### package instance

In [14]:
# param
margin = .0

# variables
mat1_cvx = cp.Variable((B1,w_ext))
mat2_cvx = cp.Variable((B2,w_ext))

# constraints
list_cstr_cvx = [cp.norm(mat1_cvx[e1][:-1],'inf')<=L for e1 in range(B1)] + [cp.norm(mat2_cvx[e2][:-1],'inf')<=L for e2 in range(B2)]

# prior-knowledge encoding (-> symmetry breaking)
if B1>1:
    for e1 in range(B1-1):
        list_cstr_cvx += [cp.sum(mat1_cvx[e1])+margin<=cp.sum(mat1_cvx[e1+1])]
if B2>1:
    for e2 in range(B2-1):
        list_cstr_cvx += [cp.sum(mat2_cvx[e2])+margin<=cp.sum(mat2_cvx[e2+1])]

# objective function implementation
data = polished_omegas_train.copy()
target = taus_train.copy()
N_t = len(data)

val1_cvx,val2_cvx = mat1_cvx@data.T,mat2_cvx@data.T
mval1_cvx,mval2_cvx = cp.max(val1_cvx,0),cp.max(val2_cvx,0)
h_bar_cvx = cp.maximum(target+mval2_cvx,mval1_cvx)+cp.maximum(-target+mval1_cvx,mval2_cvx)
main_term = 1/N_t * cp.sum(h_bar_cvx)
list_expr = []
for l in range(B1*B2):
    e1_sel,e2_sel = coordinates(l)
    list_expr.append(-1/N_t * data@(mat1_cvx[e1_sel]+mat2_cvx[e2_sel]))

objective_smc = smc.SumMinExpr(list_min_exprs=[smc.MinExpr(list_expr)],main_fun=main_term)

prob_smc = smc.Problem(objective_smc,list_cstr_cvx)

In [15]:
'''NEW -> parametric speed-up'''

param_mat1_cvx = cp.Parameter((B1,w_ext))
param_mat2_cvx = cp.Parameter((B2,w_ext))

def w2p(weights):
    weight = weights[0].copy() # there should be a single element in weights; weight of size Nt,l=B1*B2
    param1 = np.zeros((B1,w_ext))
    param2 = np.zeros((B2,w_ext))
    for l in range(B1*B2):
        e1_sel,e2_sel = coordinates(l)
        vec_shift = weight[:,l]@data
        param1[e1_sel]+=vec_shift
        param2[e2_sel]+=vec_shift
    return [param1,param2]

objective_param_speed_smc = main_term-(1/N_t)*(cp.sum(cp.multiply(param_mat1_cvx,mat1_cvx))\
                                               +cp.sum(cp.multiply(param_mat2_cvx,mat2_cvx)))

In [16]:
'''NEW -> parametric sped-up smc.Problem'''
prob_smc_speed = smc.Problem(objective_smc,list_cstr_cvx,custom_param_expand=[objective_param_speed_smc,[param_mat1_cvx,param_mat2_cvx],w2p])

In [17]:
wsw = part2weights(init_partition)

In [18]:
'''
# way too slow ! ;) 
tic = time.time()
prob_smc.solve(method='am',maxIters=int(100),verb_=True,extra_verb_=False,tol=1e-9)
toc = time.time()
print(' ')
print('solved in '+str(np.round(toc-tic,4))+' [s]')
'''

"\n# way too slow ! ;) \ntic = time.time()\nprob_smc.solve(method='am',maxIters=int(100),verb_=True,extra_verb_=False,tol=1e-9)\ntoc = time.time()\nprint(' ')\nprint('solved in '+str(np.round(toc-tic,4))+' [s]')\n"

In [19]:
rinit = prob_smc_speed.weights_setup('random')

In [20]:
tic = time.time()
prob_smc_speed.solve(method='vandessel',min_decr=1e-5,maxIters=int(150),verb_=True,extra_verb_=False,tol=1e-9,warm_start_weights=rinit.copy())
toc = time.time()
print(' ')
print('solved in '+str(np.round(toc-tic,4))+' [s]')

iter. 0001 | Fval. 9.4377e+00 | BICval. 9.4377e+00 | decr.  inf | temp. -1.0000e+00
iter. 0002 | Fval. 1.5086e+00 | BICval. 2.2170e+00 | decr. 7.2207e+00 | temp. -inf
iter. 0003 | Fval. 9.6329e-01 | BICval. 1.1502e+00 | decr. 1.0669e+00 | temp. -3.2768e+04
iter. 0004 | Fval. 8.3896e-01 | BICval. 8.8049e-01 | decr. 2.6969e-01 | temp. -2.1845e+04
iter. 0005 | Fval. 7.8423e-01 | BICval. 8.1857e-01 | decr. 6.1916e-02 | temp. -1.4564e+04
iter. 0006 | Fval. 7.4093e-01 | BICval. 7.6924e-01 | decr. 4.9331e-02 | temp. -9.7090e+03
iter. 0007 | Fval. 7.1755e-01 | BICval. 7.4582e-01 | decr. 2.3420e-02 | temp. -6.4727e+03
iter. 0008 | Fval. 7.0797e-01 | BICval. 7.2275e-01 | decr. 2.3067e-02 | temp. -8.6303e+03
iter. 0009 | Fval. 7.0386e-01 | BICval. 7.1219e-01 | decr. 1.0560e-02 | temp. -1.1507e+04
iter. 0010 | Fval. 7.0028e-01 | BICval. 7.0581e-01 | decr. 6.3801e-03 | temp. -1.5343e+04
iter. 0011 | Fval. 6.9732e-01 | BICval. 7.0116e-01 | decr. 4.6576e-03 | temp. -2.0457e+04
iter. 0012 | Fval. 6.94

In [21]:
tic = time.time()
prob_smc_speed.solve(method='vandessel',min_decr=1e-5,maxIters=int(150),verb_=True,extra_verb_=False,tol=1e-9,warm_start_weights=wsw.copy())
toc = time.time()
print(' ')
print('solved in '+str(np.round(toc-tic,4))+' [s]')

iter. 0001 | Fval. 1.2713e+00 | BICval. 1.4440e+00 | decr.  inf | temp. -1.0000e+00
iter. 0002 | Fval. 1.1846e+00 | BICval. 1.2972e+00 | decr. 1.4681e-01 | temp. -1.0240e+03
iter. 0003 | Fval. 1.0998e+00 | BICval. 1.1893e+00 | decr. 1.0792e-01 | temp. -1.3653e+03
iter. 0004 | Fval. 1.0500e+00 | BICval. 1.1880e+00 | decr. 1.3236e-03 | temp. -9.1022e+02
iter. 0005 | Fval. 1.0358e+00 | BICval. 1.1286e+00 | decr. 5.9385e-02 | temp. -1.2136e+03
iter. 0006 | Fval. 1.0319e+00 | BICval. 1.0948e+00 | decr. 3.3831e-02 | temp. -1.6182e+03
iter. 0007 | Fval. 1.0275e+00 | BICval. 1.0701e+00 | decr. 2.4615e-02 | temp. -2.1576e+03
iter. 0008 | Fval. 1.0232e+00 | BICval. 1.0523e+00 | decr. 1.7826e-02 | temp. -2.8768e+03
iter. 0009 | Fval. 1.0198e+00 | BICval. 1.0395e+00 | decr. 1.2818e-02 | temp. -3.8357e+03
iter. 0010 | Fval. 1.0166e+00 | BICval. 1.0299e+00 | decr. 9.5880e-03 | temp. -5.1142e+03
iter. 0011 | Fval. 1.0108e+00 | BICval. 1.0209e+00 | decr. 8.9708e-03 | temp. -6.8190e+03
iter. 0012 | Fva

In [22]:
mat1_vds,mat2_vds = mat1_cvx.value,mat2_cvx.value

In [23]:
tic = time.time()
prob_smc_speed.solve(method='am',maxIters=int(100),verb_=True,extra_verb_=False,tol=1e-9,warm_start_weights=rinit.copy())
toc = time.time()
print(' ')
print('solved in '+str(np.round(toc-tic,4))+' [s]')

iter. 0001 | Fval. 9.4377e+00 | BICval. 9.4377e+00 | 
iter. 0002 | Fval. 1.5086e+00 | BICval. 2.2170e+00 | 
iter. 0003 | Fval. 9.7093e-01 | BICval. 1.1454e+00 | 
iter. 0004 | Fval. 8.4496e-01 | BICval. 8.8385e-01 | 
iter. 0005 | Fval. 7.8641e-01 | BICval. 8.1668e-01 | 
iter. 0006 | Fval. 7.4019e-01 | BICval. 7.5624e-01 | 
iter. 0007 | Fval. 7.1762e-01 | BICval. 7.2526e-01 | 
iter. 0008 | Fval. 7.0968e-01 | BICval. 7.1287e-01 | 
iter. 0009 | Fval. 7.0546e-01 | BICval. 7.0699e-01 | 
iter. 0010 | Fval. 7.0279e-01 | BICval. 7.0410e-01 | 
iter. 0011 | Fval. 7.0089e-01 | BICval. 7.0187e-01 | 
iter. 0012 | Fval. 6.9965e-01 | BICval. 7.0007e-01 | 
iter. 0013 | Fval. 6.9782e-01 | BICval. 6.9874e-01 | 
iter. 0014 | Fval. 6.9444e-01 | BICval. 6.9616e-01 | 
iter. 0015 | Fval. 6.9079e-01 | BICval. 6.9275e-01 | 
iter. 0016 | Fval. 6.8819e-01 | BICval. 6.8953e-01 | 
iter. 0017 | Fval. 6.8755e-01 | BICval. 6.8766e-01 | 
iter. 0018 | Fval. 6.8747e-01 | BICval. 6.8747e-01 | 
iter. 0019 | Fval. 6.8746e-0

In [24]:
tic = time.time()
prob_smc_speed.solve(method='am',maxIters=int(100),verb_=True,extra_verb_=False,tol=1e-9,warm_start_weights=wsw.copy())
toc = time.time()
print(' ')
print('solved in '+str(np.round(toc-tic,4))+' [s]')

iter. 0001 | Fval. 1.2713e+00 | BICval. 1.4440e+00 | 
iter. 0002 | Fval. 1.1730e+00 | BICval. 1.2351e+00 | 
iter. 0003 | Fval. 1.0879e+00 | BICval. 1.1249e+00 | 
iter. 0004 | Fval. 1.0303e+00 | BICval. 1.0529e+00 | 
iter. 0005 | Fval. 1.0052e+00 | BICval. 1.0163e+00 | 
iter. 0006 | Fval. 9.9415e-01 | BICval. 9.9492e-01 | 
iter. 0007 | Fval. 9.9205e-01 | BICval. 9.9229e-01 | 
iter. 0008 | Fval. 9.9204e-01 | BICval. 9.9204e-01 | 
iter. 0009 | Fval. 9.9203e-01 | BICval. 9.9203e-01 | 
iter. 0010 | Fval. 9.9203e-01 | BICval. 9.9203e-01 | 
-> terminated (stopping condition satisfied)
 
solved in 5.7329 [s]


In [25]:
mat1_am,mat2_am = mat1_cvx.value,mat2_cvx.value

In [None]:
tic = time.time()
# bug with smc_speed
prob_smc_speed.solve(method='boyd',maxIters=int(150),verb_=True,extra_verb_=False,warm_start_weights=rinit.copy(),tol=0)
toc = time.time()
print(' ')
print('solved in '+str(np.round(toc-tic,4))+' [s]')

iter. 0001 | Fval. 9.4377e+00 | BICval. 9.4377e+00 
iter. 0002 | Fval. 9.4377e+00 | BICval. 9.4377e+00 
iter. 0003 | Fval. 9.4377e+00 | BICval. 9.4377e+00 
iter. 0004 | Fval. 3.2602e+00 | BICval. 8.3556e+00 
iter. 0005 | Fval. 1.8382e+00 | BICval. 5.4447e+00 
iter. 0006 | Fval. 1.3503e+00 | BICval. 2.2979e+00 
iter. 0007 | Fval. 1.2275e+00 | BICval. 1.7815e+00 
iter. 0008 | Fval. 1.1407e+00 | BICval. 1.5358e+00 
iter. 0009 | Fval. 1.0940e+00 | BICval. 1.3799e+00 
iter. 0010 | Fval. 1.0416e+00 | BICval. 1.2679e+00 
iter. 0011 | Fval. 1.0053e+00 | BICval. 1.1759e+00 
iter. 0012 | Fval. 9.6552e-01 | BICval. 1.1018e+00 
iter. 0013 | Fval. 9.3028e-01 | BICval. 1.0504e+00 
iter. 0014 | Fval. 9.0102e-01 | BICval. 1.0065e+00 
iter. 0015 | Fval. 8.8725e-01 | BICval. 9.7545e-01 
iter. 0016 | Fval. 8.6908e-01 | BICval. 9.4835e-01 
iter. 0017 | Fval. 8.5591e-01 | BICval. 9.2496e-01 
iter. 0018 | Fval. 8.3825e-01 | BICval. 9.0400e-01 
iter. 0019 | Fval. 8.2866e-01 | BICval. 8.8268e-01 
iter. 0020 |

In [None]:
tic = time.time()
# bug with smc_speed
prob_smc_speed.solve(method='boyd',maxIters=int(150),verb_=True,extra_verb_=False,warm_start_weights=wsw.copy(),tol=1e-9)
toc = time.time()
print(' ')
print('solved in '+str(np.round(toc-tic,4))+' [s]')

In [None]:
mat1_boyd,mat2_boyd = mat1_cvx.value,mat2_cvx.value

In [None]:
loss(unfold(mat1_vds,mat2_vds),taus_train,polished_omegas_train) # training loss

In [None]:
'''
discrete-neighbourhood utils
'''
def active_span(list_of_vals,tol=5e-2):
    
    # parameter
    MIN_SLACK = 1e-2
    
    # assertion
    assert tol>=0 and tol<=1, 'tol should be a float in [0,1]'
    
    # main code
    selectable = []
    for raw_vals in list_of_vals:
        vals = np.array(raw_vals)
        min_val,max_val = min(vals),max(vals)
        selectable.append(list(np.where((vals-min_val)/max(MIN_SLACK,max_val-min_val)<=tol)[0]))
    
    return selectable

'''
others...
'''

import random
import itertools


def random_selection(list_of_indices,num=int(1)):
    
    # assertion 
    assert num>=1, 'num should be an integer'
    
    local_num = int(num)
    
    # main code
    if local_num==1:
        return [random.choice(indices) for indices in list_of_indices]
    else:
        buf = []
        for _ in range(local_num):
            buf.append([random.choice(indices) for indices in list_of_indices])
        return buf
    
def reachable_size(list_of_indices):
    return np.prod([len(indices) for indices in list_of_indices])
    
def possible_selections(list_of_indices,max_num=int(1e3)):
    counter = 0
    buf = []
    for elem in itertools.product(*list_of_indices):
        buf.append(elem)
        counter += 1
        if counter>=max_num:
            return buf
    return buf

In [None]:
if w==2:
    
    from numpy import ma
    from matplotlib import cm, ticker
    

    tau_disp = gt(np.array(list_omegas)) 
    tau_pred = np.array([max(mat1_vds@omega)-max(mat2_vds@omega) for omega in polished_list_omegas])
    TAU = tau_disp.reshape((len(omega1),len(omega2)))
    TAU_pred = tau_pred.reshape((len(omega1),len(omega2)))
    
    plt.rcParams['figure.figsize'] = [7, 7]
    fig,ax = plt.subplots()
    plt.grid()
    plt.title('sampling centers for $w = $'+str(w))
    ax.set_aspect('equal', adjustable='box')
    plt.xlabel('$\\omega_1$')
    plt.ylabel('$\\omega_2$')
    cs = ax.contourf(OMEGA1, OMEGA2, TAU, cmap=cm.PuBu_r)
    plt.scatter(omegas[:,0],omegas[:,1],s=taus-np.min(taus)+1,color='red',label='$\{\\omega^{(s)}\}_{s\in[N]}$')
    centers1,centers2 = kmeans1.cluster_centers_,kmeans2.cluster_centers_
    plt.scatter(centers1[:,0],centers1[:,1],label='split - 1',color='orange')
    plt.scatter(centers2[:,0],centers2[:,1],label='split - 2',color='purple')
    cbar = fig.colorbar(cs)
    plt.legend()
    plt.show()
    
    from matplotlib.ticker import LinearLocator
    fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
    plt.rcParams['figure.figsize'] = [7, 7]

    # Plot the surface.
    surf = ax.plot_surface(OMEGA1, OMEGA2, TAU_pred, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)

    plt.show()

In [None]:
if w==2:
    
    fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
    plt.rcParams['figure.figsize'] = [7, 7]

    # Plot the surface.
    surf = ax.plot_surface(OMEGA1, OMEGA2, TAU, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)

    plt.show()

In [None]:
x_vds = unfold(mat1_vds,mat2_vds)
print(' mat 1')
print(' ')
print(mat1_vds)
print(' ')
print(' --- ')
print(' ')
print(' mat 2 ')
print(' ')
print(mat2_vds)

In [None]:
loss(x_vds,target=taus,data=polished_omegas) # full loss