In [1]:
import matplotlib
matplotlib.use('Agg')
import os
import datetime
import numpy as np
import dill as pickle
import random
import sys
np.random.seed(0)
random.seed(0)
now = datetime.datetime.now().strftime("%B_%d_%Y_%H_%M_%S")
workingdirectory = os.popen('git rev-parse --show-toplevel').read()[:-1]
sys.path.append(workingdirectory)
os.chdir(workingdirectory)
#print(os.getcwd())
from codes.experimentclasses.SwissRoll49 import SwissRoll49
from codes.otherfunctions.multirun import get_coeffs_reps
from codes.otherfunctions.multirun import get_grads_reps_noshape_tangent
from codes.otherfunctions.multiplot import plot_betas_customcolors2
import copy
from codes.geometer.RiemannianManifold import RiemannianManifold
from codes.geometer.ShapeSpace import ShapeSpace
from codes.geometer.TangentBundle import TangentBundle

#set parameters
n = 100000 #number of data points to simulate
nsel = 25 #number of points to analyze with lasso
itermax = 1000 #maximum iterations per lasso run
tol = 1e-10 #convergence criteria for lasso
#lambdas = np.asarray([5,10,15,20,25,50,75,100], dtype = np.float16)#lambda values for lasso
#lambdas = np.asarray([0,2.95339658e-06, 5.90679317e-06, 8.86018975e-06, 1.18135863e-05,
#       1.47669829e-05, 2.95339658e-05, 4.43009487e-05, 5.90679317e-05])
#lambdas = np.asarray([0,.0001,.001,.01,.1,1,10], dtype = np.float64)
lambdas = np.asarray(np.hstack([np.asarray([0]),np.logspace(-3,1,10)]), dtype = np.float16)
n_neighbors = 1000 #number of neighbors in megaman
n_components = 2 #number of embedding dimensions (diffusion maps)
#10000 points and .17 dt is best so far
#100000 and .05 is better
#thats with 50 neighbors
#diffusion_time = 0.05 #diffusion time controls gaussian kernel radius per gradients paper
#diffusion_time = 0.1
dim = 2
cores = 16

/Users/samsonkoelle/Downloads/manigrad-100818/mani-samk-gradients


In [2]:
experiment = SwissRoll49(xvar = 0.1,cores = cores, noise = False)
experiment.M = experiment.generate_data(n = n,theta = np.pi/4) #if noise == False then noise parameters are overriden
experiment.q = n_components
radius = .5
folder = workingdirectory + '/Figures/swissroll/' + now
def compute_nbr_wts_fast(A, sample, radius):
    Ps = list()
    nbrs = list()
    for ii in range(len(sample)):
        dists_sq = np.linalg.norm(experiment.M.data - experiment.M.data[sample[ii]], axis = 1)**2
        dists_sc = dists_sq / radius
        nbrs.append(np.where(dists_sc < 3*radius)[0])
        dists_sc[np.where(dists_sc <  3*radius)[0]] = 0
        w = np.exp(-dists_sc)
        #w = np.array(A[sample[ii],:].todense()).flatten()
        p = w / np.sum(w)
        #nbrs.append(np.where(p > 0)[0])
        Ps.append(p[nbrs[ii]])
    return(Ps, nbrs)

def get_wlpca_tangent_sel_fast(M, selectedpoints,  radius, dim = None):

    n = M.data.shape[0]
    nsel = len(selectedpoints)
    data = M.data
    A = M.data
    (PS, nbrs) = compute_nbr_wts_fast(A, selectedpoints,radius)
    d = M.data.shape[1]
    print(nsel,d,dim)
    tangent_bases = np.zeros((nsel, d, dim))
    for i in range(nsel):
        # print(i)
        p = PS[i]
        nbr = nbrs[i]
        Z = (data[nbr, :] - np.dot(p, data[nbr, :])) * p[:, np.newaxis]
        sig = np.dot(Z.transpose(), Z)
        e_vals, e_vecs = np.linalg.eigh(sig)
        j = e_vals.argsort()[::-1]  # Returns indices that will sort array from greatest to least.
        e_vec = e_vecs[:, j]
        e_vec = e_vec[:, :dim]
        tangent_bases[i, :, :] = e_vec
    return (tangent_bases)
def get_grads_reps_noshape_tangent_fast(experiment, nreps, nsel, cores,radius, dim):

    experiments = {}
    dim = experiment.dim
    for i in range(nreps):
        experiments[i] = copy.copy(experiment)
        experiments[i].M.selected_points = np.random.choice(list(range(experiment.n)), nsel, replace=False)
        experiments[i].selected_points = experiments[i].M.selected_points
        tangent_bases = get_wlpca_tangent_sel_fast(experiments[i].M, experiments[i].M.selected_points, radius, dim)
        subM = RiemannianManifold(experiments[i].M.data[experiments[i].M.selected_points], dim)
        subM.tb = TangentBundle(subM, tangent_bases)
        experiments[i].df_M = np.asarray([np.identity(dim) for i in range(nsel)])
        experiments[i].dg_x = experiments[i].get_dx_g_full(experiments[i].M.data[experiments[i].M.selected_points])
        experiments[i].dg_x_norm = experiments[i].normalize(experiments[i].dg_x)
        experiments[i].dg_M = experiments[i].project(tangent_bases, experiments[i].dg_x_norm)
    return(experiments)

experiment.M.selected_points = np.random.choice(list(range(n)),nsel,replace = False)
experiment.selected_points = experiment.M.selected_points
nreps = 5
#import pickle
#with open('ethanolsavegeom1.pkl', 'wb') as output:
#    pickle.dump(experiment.N, output, pickle.HIGHEST_PROTOCOL)
print('pregrad',datetime.datetime.now().strftime("%B_%d_%Y_%H_%M_%S"))
import matplotlib.pyplot as plt
#plt.scatter(experiment.N.data[:,0], experiment.N.data[:,1])
#folder = workingdirectory + '/Figures/swissroll/' + now
os.mkdir(folder)
#experiment.N.plot([0,1],list(range(n)), experiment.ts,.25,1,folder + '/c1', False)
#experiment.N.plot([0,1],list(range(n)), experiment.ys,.25,1,folder + '/c0', False)
#plt.scatter(experiment.N.data[:,0], experiment.N.data[:,1], c = experiment.ts)
#plt.savefig('/Users/samsonkoelle/Desktop/swizz' + str(now))
experiments = get_grads_reps_noshape_tangent_fast(experiment, nreps, nsel,cores, radius, dim)


def get_betas_spam2(xs, ys, groups, lambdas, n, q, itermax, tol):
    # n = xs.shape[0]
    p = len(np.unique(groups))
    lambdas = np.asarray(lambdas, dtype=np.float64)
    yadd = np.expand_dims(ys, 1)
    groups = np.asarray(groups, dtype=np.int32) + 1
    W0 = np.zeros((xs.shape[1], yadd.shape[1]), dtype=np.float32)
    Xsam = np.asfortranarray(xs, dtype=np.float32)
    Ysam = np.asfortranarray(yadd, dtype=np.float32)
    coeffs = np.zeros((len(lambdas), q, n, p))
    for i in range(len(lambdas)):
        # alpha = spams.fistaFlat(Xsam,Dsam2,alpha0sam,ind_groupsam,lambda1 = lambdas[i],mode = mode,itermax = itermax,tol = tol,numThreads = numThreads, regul = "group-lasso-l2")
        # spams.fistaFlat(Y,X,W0,TRUE,numThreads = 1,verbose = TRUE,lambda1 = 0.05, it0 = 10, max_it = 200,L0 = 0.1, tol = 1e-3, intercept = FALSE,pos = FALSE,compute_gram = TRUE, loss = 'square',regul = 'l1')
        output = spams.fistaFlat(Ysam, Xsam, W0, True, groups=groups, numThreads=-1, verbose=True,
                                     lambda1=lambdas[i], it0=100, max_it=itermax, L0=0.5, tol=tol, intercept=False,
                                     pos=False, compute_gram=True, loss='square', regul='group-lasso-l2', ista=False,
                                     subgrad=False, a=0.1, b=1000)
        coeffs[i, :, :, :] = np.reshape(output[0], (q, n, p))
        # print(output[1])
    return (coeffs)

def get_coeffs(experiment, lambdas, itermax, nsel, tol):
    experiment.xtrain, experiment.groups = experiment.construct_X(experiment.dg_M)
    experiment.ytrain = experiment.construct_Y(experiment.df_M, list(range(nsel)))
    experiment.coeffs = get_betas_spam2(experiment.xtrain, experiment.ytrain, experiment.groups, lambdas,
                                        nsel, experiment.dim, itermax, tol)
    return (experiment)


def get_coeffs_parallel(experiments, nreps, lambdas, itermax, nsel, tol, cores):
    p = Pool(cores)
    results = p.map(
        lambda i: get_coeffs(experiment=experiments[i], lambdas=lambdas, itermax=itermax, nsel=nsel, tol=tol),
        range(nreps))
    output = {}
    for i in range(nreps):
        output[i] = results[i]
    return (output)


from pathos.multiprocessing import ProcessingPool as Pool
import spams




savename = 'swissroll_020220_highp01'
savefolder = 'swissroll'
experiments = get_coeffs_parallel(experiments, nreps, lambdas, itermax, nsel, tol, cores)

savename = 'swissroll_020520_samosa_highp01'
savefolder = 'swissroll'
#experiments = get_coeffs_parallel(experiments, nreps, lambdas, itermax, nsel, tol, cores)
with open(workingdirectory + '/untracked_data/embeddings/' + savefolder + '/' + savename + 's.pkl',
        'wb') as output:
    pickle.dump(experiments, output, pickle.HIGHEST_PROTOCOL)

xaxis = np.asarray(lambdas, dtype = np.float64) * np.sqrt(nsel * 2)
title ='Swiss Roll'
gnames = np.asarray(list(range(experiment.p)), dtype = str)
#folder = workingdirectory + '/Figures/swissroll/' + now
#os.mkdir(folder)
filename = folder + '/betas_highp1'
print('preplot',datetime.datetime.now().strftime("%B_%d_%Y_%H_%M_%S"))
colors = np.hstack([np.repeat('red',2), np.repeat('black',49)])
legtitle = 'Function type'
color_labels = np.asarray(['Manifold coordinates','Ambient coordinates'])
#plot_betas(experiments, xaxis, title,filename, gnames,nsel)
plot_betas_customcolors2(experiments, xaxis, title,filename, gnames,nsel,colors, legtitle, color_labels)

pregrad February_07_2020_00_41_19
25 49 2
25 49 2
25 49 2
25 49 2
25 49 2
preplot February_07_2020_00_42_27
[[8.76537795 0.56263888 0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]
 [7.51974411 0.25164269 0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]
 [7.35013041 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]
 [8.2510707  0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]
 [8.35967352 0.55642025 0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]]
[[ 51.58526701 130.78036368 133.96728319 130.4658862  120.29814998
   80.77819061   0.           0.           0.           0.
    0.        ]
 [ 53.79755142 143.80611863 143.15789686 136.6310268  127.28842139
   85.74787818   0.           0.           0.           0.
    0.        ]
 [ 50.71329555 13

[[15.04294585 17.01481474 12.43580324  9.0905618   0.          0.
   0.          0.          0.          0.          0.        ]
 [11.51309042  2.0265239   0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [13.38835931 23.40505354 25.43096349 20.85454571  5.46313465  0.
   0.          0.          0.          0.          0.        ]
 [12.52847266 14.2995749   9.58186849  1.83142936  0.          0.
   0.          0.          0.          0.          0.        ]
 [13.44141733 12.52828867 12.31101424  8.71687686  0.          0.
   0.          0.          0.          0.          0.        ]]
[[11.04288189  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 9.13187226  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [13.22822716  0.08545035  0.          0.          0.          0.
   0.          0.          0. 

[[10.45252326  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 7.44583433  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 7.54980538  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 6.62764808  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 7.03481671  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]]
[[ 9.82911202  1.83225581  0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 9.688356    0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [11.85758863  0.          0.          0.          0.          0.
   0.          0.          0. 

[[ 9.50617242  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [10.34050883  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 8.53435146  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 7.93362581  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 8.27628494  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]]
[[17.66032296 18.37316789  5.0904224   0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [12.94158436  0.02508583  0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [17.87199134  0.82278412  0.          0.          0.          0.
   0.          0.          0. 

[[ 8.96952117 27.03201608 32.1874161  30.32656614 25.87256063 14.83883217
   0.          0.          0.          0.          0.        ]
 [ 7.76102893 28.08751163 29.22303982 25.99271889 18.52287552 12.03052612
   0.          0.          0.          0.          0.        ]
 [10.02724711 38.99835992 48.61725279 43.41476011 34.90247456 17.92237481
   0.          0.          0.          0.          0.        ]
 [ 8.50483563 32.02951388 37.37145492 34.44470478 26.09617487 16.28332082
   0.          0.          0.          0.          0.        ]
 [ 8.16615197 19.57738831 22.01697722 20.35512527 16.76826754 14.17218713
   0.          0.          0.          0.          0.        ]]
[[44.32929209 40.15571577 37.63189235 35.66844979 27.65640739  4.50464259
   0.          0.          0.          0.          0.        ]
 [22.59302468 25.056708   25.5433401  22.244236   11.30241843  0.
   0.          0.          0.          0.          0.        ]
 [40.14911016 40.08126549 38.70754355 35.8502825

[[17.58436328 22.50646776 15.3809207  11.97403327  0.          0.
   0.          0.          0.          0.          0.        ]
 [14.261535    4.2267824   0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [18.77102132 38.99167679 41.13031552 33.87998684  9.89930571  0.
   0.          0.          0.          0.          0.        ]
 [17.25378261 21.01526745 14.37902417  2.73554972  0.          0.
   0.          0.          0.          0.          0.        ]
 [20.33903862 18.34429782 17.33482094 12.74400867  0.          0.
   0.          0.          0.          0.          0.        ]]
[[ 7.78111231  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 8.50396127  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [14.51160582  0.12873832  0.          0.          0.          0.
   0.          0.          0. 

[[5.11611428 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]
 [6.61874792 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]
 [5.23492304 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]
 [8.41691087 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]
 [6.80263666 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.        ]]
[[12.09997887  4.752642    0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 4.82072883  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [11.21781937  0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [ 7.91609839  0.     

[[1.47129161e+01 1.32347896e+01 3.21141344e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [1.31425394e+01 1.26775578e-02 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [1.32583399e+01 4.82483951e-01 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [1.13857199e+01 1.06639256e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [1.35910437e+01 3.06517421e-01 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00]]
[[16.31457378  1.62955624  0.          0.          0.          0.
   0.          0.          0.          0.          0.        ]
 [19.27637