In [None]:
# See E. Peterfreund, O. Lindenbaum, F. Dietrich, T. Bertalan, M. Gavish, I.G. Kevrekidis and R.R. Coifman,
# "LOCA: LOcal Conformal Autoencoder for standardized data coordinates",
# https://www.pnas.org/doi/full/10.1073/pnas.2014627117
#
#
# -----------------------------------------------------------------------------
# Author: Erez Peterfreund , Ofir Lindenbaum
#         erezpeter@gmail.com  , ofir.lindenbaum@yale.edu , 2020
# 
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------

In [2]:
# import necessary packages 
import sys
import os
from utils import *
from Loca import *
from mpl_toolkits import mplot3d
import scipy

import matplotlib.pyplot as plt

# Generate data and divide into train and validation

In [3]:
import math, random
# An implementation of the fibonnachi sphere algorithm that samples uniformly a sphere. Here we allow to define the minimal 
# angle of the height cooordinate of the sphere.
#
# Input:
#       amount_samples - integer.
#       minAngle - a float between 0 and pi. The minimal height angle of the produced points.
# 
# Output:
#       points - a amount_samples-by-3 array. The array includes the points on the sphere.
def fibonacci_sphere(amount_samples=1, minAngle=0):
        
    points = np.zeros((amount_samples,3))
    rnd = random.random() * amount_samples
    
    offset = 2./amount_samples *(1-minAngle)
    increment = math.pi * (3. - math.sqrt(5.)) #


    for i in range(amount_samples):
        z = ((i * offset) - 1) + (offset / 2);
        r = math.sqrt(1 - pow(z,2))

        phi = ((i + rnd) % amount_samples) * increment

        x = math.cos(phi) * r
        y = math.sin(phi) * r



        points[i,:]=np.array([x,y,z])
        

    return points

In [4]:
# The  function gets as a single datapoint on the 3-dimensional sphere and generates a burst around this point. i.e. 
# A. The function generates a white Gasussian isotropic noise in the polar space around (0,pi/2) with fixed radius 1, 
#     and transform the representation into Cartesian representation.
# B. The function rotates the burst so that it will be centered around the given datapoint
#
# Input:
#           datapoint- d- dimensional vector. The point around which the burst will be generated
#           M - integer. The amount of points in the burst
#           n_std- float. The standard deviation that defines the burst in polar space.  

def generate_burst(datapoint,M,n_std):
    U, D,rotationMatT  = np.linalg.svd(datapoint.reshape((1,-1)), full_matrices = True)
    rotationMat=U *rotationMatT.T
    angleNoise = np.random.randn(M,2)*n_std
    angleNoise[:,1]+= np.pi/2
    
    burst= np.zeros((M,3))
    burst_angle= np.zeros((M,2))
    
    burst[:,0] = np.cos(angleNoise[:,0])* np.sin(angleNoise[:,1])
    burst[:,1] = np.sin(angleNoise[:,0])*np.sin(angleNoise[:,1])
    burst[:,2] = np.cos(angleNoise[:,1])
    burst = np.dot(burst, rotationMat)

    burst_angle[:,0] = np.arctan2(burst[:,1], burst[:,0])
    burst_angle[:,1] = np.arccos(burst[:,2])
    return burst, burst_angle

# The non linear function that is applied to the latent data.
#
# Input:
#          data- n-by-2 or n-by-m-by-2 tensor. The dataset input that contatins n or n*m datapoints in dimension 2.
# 
# Output:
#          new_data- n-by-2 or n-by-m-by-2 tensor (According to data's structure)
def non_linear_function(data):
    new_data=np.zeros(data.shape[:-1]+(2,))

    if len(data.shape)==2:
        new_data[:,0]= -2*data[:,0]/ (data[:,2]-1)
        new_data[:,1]= -2*data[:,1]/ (data[:,2]-1)
    else:
        new_data[:,:,0]= -2*data[:,:,0]/ (data[:,:,2]-1)
        new_data[:,:,1]= -2*data[:,:,1]/ (data[:,:,2]-1)
    
    return new_data


N=800
M=400
cloud_std=0.01


# Define the grid of points on a sphere in the latent space (denoted by x), 
# and using a polar coordinates system (denoted by theta)
x= fibonacci_sphere(N,minAngle=0. )
x_inter=x[np.arccos(x[:,2])>np.pi*5/6, :]


x= x[np.arccos(x[:,2])>np.pi*1/3, :]
x= x[np.arccos(x[:,2])<np.pi*5/6, :]
N=np.shape(x)[0]
N_inter= np.shape(x_inter)[0]

thetas = np.zeros((N,2))
thetas[:,0] = np.arctan2(x[:,1], x[:,0])
thetas[:,1] = np.arccos(x[:,2])

thetas_inter = np.zeros((N_inter,2))
thetas_inter[:,0] = np.arctan2(x_inter[:,1], x_inter[:,0])
thetas_inter[:,1] = np.arccos(x_inter[:,2])


# Define the bursts in the latent space
additionalDataX= np.zeros((N,M,3))
additionalDataThetas= np.zeros((N,M,2))
additionalDataX_inter= np.zeros((N_inter,M,3))
additionalDataThetas_inter= np.zeros((N_inter,M,2))


for i in range(N):
    additionalDataX[i,:,:], additionalDataThetas[i,:,:]= generate_burst(x[i,:], M,cloud_std)

for i in range(N_inter):
    additionalDataX_inter[i,:,:], additionalDataThetas_inter[i,:,:]= generate_burst(x_inter[i,:], M,cloud_std)
    

# Define the non linear transformation of the bursts
y= non_linear_function(x)
additionalDataY= non_linear_function(additionalDataX)

y_inter= non_linear_function(x_inter)
additionalDataY_inter= non_linear_function(additionalDataX_inter)

In [5]:
%matplotlib notebook

fig = plt.figure()
ax = fig.add_subplot(121)
ax.scatter(additionalDataY[:,:10,0],additionalDataY[:,:10,1],c= np.log(np.sum(additionalDataY[:,:10,:]**2,axis=2)))
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(additionalDataX[:,:10,0].reshape((-1,1)), additionalDataX[:,:10,1].reshape((-1,1)),additionalDataX[:,:10,2].reshape((-1,1)), c= np.log(np.sum(additionalDataY[:,:10,:]**2,axis=2)).reshape((-1)))


<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7f16e5491828>

In [6]:
# Divide into train and validation
indexes = np.random.permutation(N)
indexes_train, indexes_val = indexes[:N*9//10], indexes[N*9//10:]

data_train = additionalDataY[indexes_train,:,:]
data_val = additionalDataY[indexes_val,:,:]

# Define and train the neural net

In [10]:
amount_epochs = 10000
params = {}
params['clouds_var']= cloud_std**2

params['activation_enc']='tanh' # The activation function defined in the encoder
#Options: 'relu'- Relu,   'l_relu'- Leaky Relu,    'sigmoid'-sigmoid,   'tanh'- tanh, 'none'- none
params['activation_dec']='tanh' # The activation function defined in the decoder

params['encoder_layers']=[2,100,100,3,3] # The amount of neurons in each layer of the encoder 
params['decoder_layers']=[2,100,100,2,2] # The amount of neurons in each layer of the decoder 

model = Loca(**params) 


In [11]:
data_train.shape

(491, 400, 2)

In [None]:
batch_size=55
lrs= [1e-3,3e-4,1e-4,3e-5,1e-5]


model = Loca(**params) 

x_dist = scipy.spatial.distance.pdist(x)
x_dist_inter = scipy.spatial.distance.pdist(x_inter)
for lr in lrs:
    print(lr)
    model.train( data_train,amount_epochs,lr=lr, batch_size=batch_size, data_val=data_val, evaluate_every=100,verbose=True)
    

0.001
Epoch: 0100 Train : white= 2.63895 rec=2.23882      Val: : white= 2.63438 rec=2.03211
99
Epoch: 0200 Train : white= 2.51675 rec=2.00643      Val: : white= 2.45218 rec=1.80882
199
Epoch: 0300 Train : white= 2.31947 rec=1.44029      Val: : white= 2.33012 rec=1.31225
299
Epoch: 0400 Train : white= 2.04769 rec=0.55908      Val: : white= 2.02441 rec=0.54105
399
Epoch: 0500 Train : white= 1.71824 rec=0.39295      Val: : white= 1.69516 rec=0.38051
499
Epoch: 0600 Train : white= 1.66393 rec=0.30861      Val: : white= 1.61887 rec=0.29707
599
Epoch: 0700 Train : white= 1.63652 rec=0.24232      Val: : white= 1.56252 rec=0.22904
699
Epoch: 0800 Train : white= 1.61854 rec=0.20168      Val: : white= 1.52964 rec=0.18822
799
Epoch: 0900 Train : white= 1.61255 rec=0.17898      Val: : white= 1.51818 rec=0.16597
899
Epoch: 1000 Train : white= 1.60652 rec=0.15227      Val: : white= 1.52511 rec=0.13724
999
Epoch: 1100 Train : white= 1.60553 rec=0.12758      Val: : white= 1.52514 rec=0.11367
1099
Epoc

In [None]:
embedding, recon = model.test(additionalDataY)

# Fig 6b

In [None]:
def set_axes_radius(ax, data):
    origin = np.mean(data,axis=0,keepdims =True)
    radius= np.max(np.abs(data-origin))
    ax.set_xlim3d([origin[0,0] - radius, origin[0,0] + radius])
    ax.set_ylim3d([origin[0,1] - radius, origin[0,1] + radius])
    ax.set_zlim3d([origin[0,2] - radius, origin[0,2] + radius])


In [None]:
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
f= plt.figure(figsize=(10,10))
ax = f.add_subplot(111,projection='3d' )

plt.locator_params(nbins=4)


ax.scatter(additionalDataX[:,:,0], additionalDataX[:,:,1],\
           additionalDataX[:,:,2], c= 1-additionalDataThetas[:,:,1].flatten(),s=30, \
           alpha=.8,cmap=plt.cm.get_cmap('RdYlBu'),\
          vmin=np.min(1-additionalDataThetas[:,:,1]), vmax=np.max(1-additionalDataThetas[:,:,1]))

ax.set_xlabel('$x[1]$', labelpad=25, fontsize=35)
ax.set_ylabel('$x[2]$', labelpad=25, fontsize=35)
ax.set_zlabel('$x[3]$', labelpad=25,fontsize=35)

plt.tick_params(labelsize=25)
ax.set_aspect('equal')

set_axes_radius(ax, additionalDataX[:,0,:])


ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False


ax.view_init(elev=15., azim=45)


plt.tight_layout() 


# Fig 6c

In [None]:
%matplotlib notebook

f,ax = plt.subplots(1,1,figsize=(10,10))
plt.locator_params(nbins=6)

plt.tick_params(labelsize=25)
plt.axis('equal')
ax.scatter(additionalDataY[:,:,0],additionalDataY[:,:,1],\
           c= 1-additionalDataThetas[:,:,1],alpha=0.8,s=60,\
           cmap=plt.cm.get_cmap('RdYlBu'),\
          vmin=np.min(1-additionalDataThetas[:,:,1]), vmax=np.max(1-additionalDataThetas[:,:,1]))
ax.set_xlabel('$y[1]$',fontsize=35)
ax.set_ylabel('$y[2]$',fontsize=35)

plt.tight_layout()


# Fig 6d

In [None]:
loca_embedding,_ = model.test(additionalDataY)
loca_embedding_inter,_ = model.test(additionalDataY_inter)


%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
f= plt.figure(figsize=(10,10))
ax = f.add_subplot(111,projection='3d' )

plt.locator_params(nbins=4)

ax.set_aspect('equal')


R, bias = calibrate_data_b(additionalDataX[:,0,:], loca_embedding[:,0,:])
Loca_calibrated= np.matmul(loca_embedding, R)+bias
Loca_inter_calibrated= np.matmul(loca_embedding_inter, R)+bias

#Loca_calibrated= calibrate_data(loca_embedding.reshape((-1,3)),additionalDataX.reshape((-1,3)), scaling=False )
#Loca_inter_calibrated= calibrate_data2_based_calibration_data(loca_embedding_inter.reshape((-1,3)),loca_embedding.reshape((-1,3)),\
#                                      additionalDataX.reshape((-1,3)), scaling=False )


vmin = np.min(1-np.append(additionalDataThetas[:,:,1],additionalDataThetas_inter[:,:,1],axis=0 ))
vmax = np.max(1-np.append(additionalDataThetas[:,:,1],additionalDataThetas_inter[:,:,1],axis=0 ))

ax.scatter(Loca_calibrated[:,:,0], Loca_calibrated[:,:,1],\
           Loca_calibrated[:,:,2], c= 1-additionalDataThetas[:,:,1].reshape((-1)),s=30, \
           alpha=.8,cmap=plt.cm.get_cmap('RdYlBu'),\
          vmin=vmin, vmax=vmax)

ax.scatter(Loca_inter_calibrated[:,:,0], Loca_inter_calibrated[:,:,1],\
           Loca_inter_calibrated[:,:,2], c= 1-additionalDataThetas_inter[:,:,1].reshape((-1)),s=30, \
           alpha=.8,cmap=plt.cm.get_cmap('RdYlBu'),\
          vmin=vmin, vmax=vmax)


ax.set_xlabel(r'${\rho}[1]$', labelpad=25, fontsize=35)
ax.set_ylabel(r'${\rho}[2]$', labelpad=25, fontsize=35)
ax.set_zlabel(r'${\rho}[3]$', labelpad=25,fontsize=35)

plt.tick_params(labelsize=25)
ax.set_aspect('equal')

set_axes_radius(ax, Loca_calibrated[:,0,:])


ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

ax.view_init(elev=15., azim=45)

plt.tight_layout() 


# Fig. S2a

In [None]:
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
f= plt.figure(figsize=(10,10))
ax = f.add_subplot(111,projection='3d' )

plt.locator_params(nbins=4)

ax.set_aspect('equal')
#plt.tight_layout()

ax.scatter(x[:,0], x[:,1],\
           x[:,2], c= 1-thetas[:,1].flatten(),s=120, \
           alpha=.8,cmap=plt.cm.get_cmap('RdYlBu'),edgecolor='black',\
          vmin=np.min(1-additionalDataThetas[:,:,1]), vmax=np.max(1-additionalDataThetas[:,:,1]))

ax.set_xlabel('$x[1]$', labelpad=25, fontsize=35)
ax.set_ylabel('$x[2]$', labelpad=25, fontsize=35)
ax.set_zlabel('$x[3]$', labelpad=25,fontsize=35)

plt.tick_params(labelsize=25)
ax.set_aspect('equal')

set_axes_radius(ax, additionalDataX[:,0,:])


ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

ax.view_init(elev=15., azim=45)

plt.tight_layout() 


# Fig S2b

In [None]:
euclid_dist=scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(y[:,:]))**2

P_dmaps=dm_from_dist(euclid_dist,fac=1)
dm=np.linalg.svd(P_dmaps)
rep_DM=dm[1]*dm[0]/(np.matmul(dm[0][:,0].reshape(-1,1),np.ones((1,dm[0].shape[0])) ))


DM= rep_DM[:,1:4]

In [None]:
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
f= plt.figure(figsize=(10,10))
ax = f.add_subplot(111,projection='3d' )

plt.locator_params(nbins=4)

ax.set_aspect('equal')


DM_R,DM_bias, DM_multFactor= calibrate_data_b(DM,x,scaling=True )
DM_calibrated= DM_multFactor*np.matmul(DM,DM_R)+DM_bias

print('mult factor= '+str(DM_multFactor))

ax.scatter(DM_calibrated[:,0], DM_calibrated[:,1],\
           DM_calibrated[:,2], c= 1-thetas[:,1],s=120, \
           alpha=.8,cmap=plt.cm.get_cmap('RdYlBu'),edgecolor='black',\
          vmin=np.min(1-additionalDataThetas[:,:,1]), vmax=np.max(1-additionalDataThetas[:,:,1]))

ax.set_xlabel(r'$\widetilde{\psi}[1]$', labelpad=25, fontsize=35)
ax.set_ylabel(r'$\widetilde{\psi}[2]$', labelpad=25, fontsize=35)
ax.set_zlabel(r'$\widetilde{\psi}[3]$', labelpad=25,fontsize=35)

plt.tick_params(labelsize=25)
ax.set_aspect('equal')

set_axes_radius(ax, DM_calibrated)


ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

ax.view_init(elev=15., azim=45)

plt.tight_layout() 

# Fig S2c

In [None]:
epsilon = 1e-12

mahal_dist=np.zeros((N,N))
for i in range(N):
    cov_i=np.cov(additionalDataY[i,:,:].T) +np.eye(2)*epsilon        
    inv_conv=np.linalg.pinv(cov_i)
    mahal_dist[i,:]=scipy.spatial.distance.cdist(y[[i],:], \
                                                 y, 'mahalanobis',VI = inv_conv)**2
    
mahal_dist= (mahal_dist+ mahal_dist.T)/2

p_mahal=dm_from_dist(mahal_dist,fac=50)
dm_mahal=np.linalg.svd(p_mahal)


A_DM=dm_mahal[1]*dm_mahal[0]/(np.matmul(dm_mahal[0][:,0].reshape(-1,1),np.ones((1,dm_mahal[0].shape[0])) ))

A_DM= A_DM[:,1:4]

In [None]:
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
f= plt.figure(figsize=(10,10))
ax = f.add_subplot(111,projection='3d' )

plt.locator_params(nbins=4)

ax.set_aspect('equal')

ADM_R,ADM_bias, A_DM_factor= calibrate_data_b(A_DM,x,scaling=True )
ADM_calibrated= A_DM_factor*np.matmul(A_DM,ADM_R)+ADM_bias


print('factor is: '+str(A_DM_factor))

ax.scatter(ADM_calibrated[:,0], ADM_calibrated[:,1],\
           ADM_calibrated[:,2], c= 1-thetas[:,1],s=120, \
           alpha=.8,cmap=plt.cm.get_cmap('RdYlBu'),edgecolor='black',\
          vmin=np.min(1-additionalDataThetas[:,:,1]), vmax=np.max(1-additionalDataThetas[:,:,1]))

ax.set_xlabel(r'$\widetilde{\phi}[1]$', labelpad=25, fontsize=35)
ax.set_ylabel(r'$\widetilde{\phi}[2]$', labelpad=25, fontsize=35)
ax.set_zlabel(r'$\widetilde{\phi}[3]$', labelpad=25,fontsize=35)

plt.tick_params(labelsize=25)
ax.set_aspect('equal')

set_axes_radius(ax, ADM_calibrated)


ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

ax.view_init(elev=15., azim=45)

plt.tight_layout() 


# Fig S2d

In [None]:
loca_embedding,loca_recon = model.test(y)

%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
f= plt.figure(figsize=(10,10))
ax = f.add_subplot(111,projection='3d' )

plt.locator_params(nbins=4)

ax.set_aspect('equal')

Loca_calibrated= np.matmul(loca_embedding, R)+bias


ax.scatter(Loca_calibrated[:,0], Loca_calibrated[:,1],\
           Loca_calibrated[:,2], c= 1-thetas[:,1],s=120, \
           alpha=.8,cmap=plt.cm.get_cmap('RdYlBu'),edgecolor='black',\
          vmin=np.min(1-additionalDataThetas[:,:,1]), vmax=np.max(1-additionalDataThetas[:,:,1]))

ax.set_xlabel(r'${\rho}[1]$', labelpad=25, fontsize=35)
ax.set_ylabel(r'${\rho}[2]$', labelpad=25, fontsize=35)
ax.set_zlabel(r'${\rho}[3]$', labelpad=25,fontsize=35)

plt.tick_params(labelsize=25)
ax.set_aspect('equal')

set_axes_radius(ax, Loca_calibrated)


ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

ax.view_init(elev=15., azim=45)

plt.tight_layout() 


In [None]:
x_dist=scipy.spatial.distance.pdist(x)


DM_dist= scipy.spatial.distance.pdist(DM)
A_DM_dist= scipy.spatial.distance.pdist(A_DM)

loca_embedding,loca_recon = model.test(y)
loca_dist_frame=  scipy.spatial.distance.pdist(loca_embedding)


loca_embedding_inter,_ = model.test(y_inter)
x_dist_inter = scipy.spatial.distance.pdist(x_inter)
loca_dist_inter=  scipy.spatial.distance.pdist(loca_embedding_inter)



scatterSize= 60

f,ax = plt.subplots(1,1,figsize=(10,10))
plt.scatter(x_dist,(DM_dist*get_calibrate_factor_for_distsA(DM_dist,x_dist)),s=scatterSize,alpha=0.4, label='DM')

plt.scatter(x_dist,(A_DM_dist*get_calibrate_factor_for_distsA(A_DM_dist,x_dist)),s=scatterSize,alpha=0.4, label='A-DM')

plt.scatter(x_dist,loca_dist_frame,s=scatterSize,alpha=0.4, label='LOCA frame')

plt.scatter(x_dist_inter,loca_dist_inter,s=scatterSize,alpha=0.4, label='LOCA interpolation')

plt.plot([np.min(x_dist),np.max(x_dist)],[np.min(x_dist),np.max(x_dist)],'--',linewidth=7, c='black')

plt.xlabel('Dist. in $X$',fontsize=35)
plt.ylabel('Dist. in Emb. Space',fontsize=35)
plt.tick_params(labelsize=25)


leg = plt.legend(fontsize=25,markerscale=2.,)
for lh in leg.legendHandles: 
    lh.set_alpha(1)

plt.tight_layout()



In [None]:
print(np.mean((x_dist-(DM_dist*get_calibrate_factor_for_distsA(DM_dist,x_dist)))**2))
print(np.mean((x_dist-(A_DM_dist*get_calibrate_factor_for_distsA(A_DM_dist,x_dist)))**2))
print(np.mean((x_dist-loca_dist_frame*get_calibrate_factor_for_distsA(loca_dist_frame,x_dist))**2))
print(np.mean((x_dist_inter-loca_dist_inter)**2))