In [1]:
import numpy as np
import scipy as sp
import torch as th

import os, pdb, sys, json, glob, tqdm
import pandas as pd
from sklearn.manifold import MDS as skmds

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')

sns.set_context('notebook')

%matplotlib notebook
plt.rcParams['figure.figsize'] = [5,5]
plt.rcParams['figure.dpi'] = 200
# %config InlineBackend.figure_format ='retina'

dev = 'cpu'

from utils import *

In [6]:
def embed_bhat(dd, fn='', ss=1, ne=3, load_w=False):
    dc = dd[['seed', 'widen', 'numc', 't', 'err', 'verr', 'favg', 'vfavg']]
    th.save(dc, 'didx_%s.p'%fn)
    x = np.array([dd.iloc[i]['yh'][::ss].float().numpy() for i in range(len(dd))])
    n = x.shape[0]

    if not os.path.isfile('w_%s.p'%fn):
        a = np.sqrt(np.exp(x))
        a = np.moveaxis(a, 0, 1)
        w = np.zeros((n,n))
        nc = 50 if n < 4000 else 100
        print('chunks: ', nc)
        for aa in tqdm.tqdm(np.split(a, nc)):
            w += np.log(np.einsum('kil,kjl->kij', aa, aa, optimize=True)).sum(0)
        w = -w

        del a
        l = np.eye(w.shape[0]) - 1.0/w.shape[0]
        w = l @ w @ l
        print('Saving w')
        th.save(w, 'w_%s.p'%fn)
    else:
        print('Found: w_%s.p'%fn)

    w = th.load('w_%s.p'%fn)
    print('Projecting')
    e1,v1 = sp.linalg.eigh(w, driver='evx', check_finite=False,
                        subset_by_index=[n-(ne+1),n-1])
    e2,v2 = sp.linalg.eigh(w, driver='evx', check_finite=False,
                        subset_by_index=[0,(ne+1)])
    e = np.concatenate((e1,e2))
    v = np.concatenate((v1,v2), axis=1)

    ii = np.argsort(np.abs(e))[::-1]
    e,v = e[ii], v[:,ii]
    xp = v*np.sqrt(np.abs(e))
    r = dict(xp=xp,w=w,e=e,v=v)
    th.save(r, 'r_%s.p'%fn)
    return

def embed_skl(dd, fn='', ss=1, ne=3, load_w=False):
    fn += '_skl'
    dc = dd[['seed', 'widen', 'numc', 't', 'err', 'verr', 'favg', 'vfavg']]
    th.save(dc, 'didx_%s.p'%fn)
    x = np.array([dd.iloc[i]['yh'][::ss].float().numpy() for i in range(len(dd))])
    n = x.shape[0]

    if not os.path.isfile('w_%s.p'%fn):
        a = np.exp(x)
        loga = x
        a = np.moveaxis(a, 0, 1)
        loga = np.moveaxis(loga, 0, 1)
        
        da = a[:,:,None,:] - a[:,None,:,:]
        loga = loga[:,:,None,:] - loga[:,None,:,:]
        w = np.zeros((n,n))
        nc = 50 if n < 4000 else 100
        print('chunks: ', nc)
        for daa,dlogaa in tqdm.tqdm(zip(np.split(da, nc), np.split(dloga, nc))):
            w += np.einsum('lijk->ij', daa*dlogaa)

        del a
        l = np.eye(w.shape[0]) - 1.0/w.shape[0]
        w = l @ w @ l
        print('Saving w')
        th.save(w, 'w_%s.p'%fn)
    else:
        print('Found: w_%s.p'%fn)

    w = th.load('w_%s.p'%fn)
    print('Projecting')
    e1,v1 = sp.linalg.eigh(w, driver='evx', check_finite=False,
                        subset_by_index=[n-(ne+1),n-1])
    e2,v2 = sp.linalg.eigh(w, driver='evx', check_finite=False,
                        subset_by_index=[0,(ne+1)])
    e = np.concatenate((e1,e2))
    v = np.concatenate((v1,v2), axis=1)

    ii = np.argsort(np.abs(e))[::-1]
    e,v = e[ii], v[:,ii]
    xp = v*np.sqrt(np.abs(e))
    r = dict(xp=xp,w=w,e=e,v=v)
    th.save(r, 'r_%s.p'%fn)
    return

In [7]:
# r = embed_bhat(d, fn='wnc_28_44_48', ss=10)
# d = th.load('d_2_2.p')
r = embed_skl(d, fn='wnc_44', ss=10)

Found: w_wnc_44_skl.p
Projecting


In [90]:
a = np.random.randn(100,8,10)**2
loga = np.log(a)

In [66]:
a = np.arange(5,9+5).reshape((3,3))
loga = np.log(a)

In [94]:
da = a[:,:,None,:]-a[:,None,:,:]
dloga = loga[:,:,None,:]-loga[:,None,:,:]

In [82]:
from scipy.spatial.distance import cdist
print('einsum\n', np.einsum('ijk->ij', da**2))
print('cdist\n', cdist(a,a,'sqeuclidean'))

einsum
 [[  0  27 108]
 [ 27   0  27]
 [108  27   0]]
cdist
 [[  0.  27. 108.]
 [ 27.   0.  27.]
 [108.  27.   0.]]


In [95]:
print('einsum\n', np.einsum('lijk->ij', da*dloga))

einsum
 [[   0.         3587.02592105 3978.31833713 3879.01073538 3697.90901108
  4038.56747022 3928.95397121 3983.73882022]
 [3587.02592105    0.         4075.74406537 4390.97779987 3890.27212038
  4697.32611235 4322.26641331 4168.79262413]
 [3978.31833713 4075.74406537    0.         3823.19059786 3952.23474922
  4302.71221748 3807.72074051 4071.9233486 ]
 [3879.01073538 4390.97779987 3823.19059786    0.         4460.13317407
  4322.45120402 3896.79093572 4174.8749563 ]
 [3697.90901108 3890.27212038 3952.23474922 4460.13317407    0.
  4198.38016764 3963.17386218 3880.06412094]
 [4038.56747022 4697.32611235 4302.71221748 4322.45120402 4198.38016764
     0.         4304.21161252 3965.88824807]
 [3928.95397121 4322.26641331 3807.72074051 3896.79093572 3963.17386218
  4304.21161252    0.         4258.00085438]
 [3983.73882022 4168.79262413 4071.9233486  4174.8749563  3880.06412094
  3965.88824807 4258.00085438    0.        ]]


In [99]:
a = np.array([1,2,3])
b = np.array([1,2,3])