In [1]:
import numpy as np
# from scipy import stats, linalg
from folktables import ACSDataSource, ACSMobility
from matplotlib import pyplot as plt
import torch
from scipy import stats
import pandas as pd

In [2]:
p = 20
r = 3

In [3]:
%load_ext memory_profiler

In [4]:
def m_dist2(x1, x2, K):
    return (x2 - x1).T @ K @ (x2 - x1)

def generate_synthetic_data(n, r, p, data, S=None):
    X = data.sample(n, axis=0)
    X = X.loc[:, X.var(axis=0) > 0]
    X = X.sample(p, axis=1)
    X = np.array(stats.zscore(np.array(X), axis=0))
   
    Astar = np.random.normal(0, 1, size=(p, r)) / np.sqrt(p)
    Kstar = Astar @ Astar.T

    if not S:
        Sbar = []
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    if i != j and j != k and k != i:
                        Sbar.append((i, j, k))
        Sbar = np.array(Sbar)
    
        choices = np.random.choice([False, True], size=len(Sbar), replace=True, p=[0.95, 0.05])
        S = Sbar[choices, :]

    M = []
    for t in S:
        i, j, k = t
        Mt = np.outer(X[i], X[k]) + np.outer(X[k], X[i]) - np.outer(X[i], X[j]) - np.outer(X[j], X[i]) + np.outer(X[j], X[j]) - np.outer(X[k], X[k])
        M.append(Mt)
    M = np.array(M)

    probs = 1 / (1 + np.exp(-np.einsum('ijj->i', np.einsum('ijk,kl->ijl', M,  Kstar))))

    y = 2 * np.array(probs > np.random.random(np.shape(probs)), dtype=int) - 1
    
    # y = []
    # for t in S:
    #     i, j, k = t
    #     prob = 1 / (1 + np.exp(m_dist2(X[i], X[k], Kstar) - m_dist2(X[i], X[j], Kstar)))
    #     yt = np.random.choice([1, -1], size=1, p=[prob, 1 - prob])
    #     y.append(yt)
    # y = np.array(y)
    
    return X, S, y, Astar, Kstar

def initialization(n, p, S, X, y):
    transition_matrices = [np.zeros((n - 1, n - 1)) for _ in range(n)]
    for t in range(len(S)):
        def gap(l, i):
            if l < i:
                return l
            else:
                return l - 1
        i, j, k = S[t]
        i = i.get()
        j = j.get()
        k = k.get()
        transition_matrices[i][gap(j, i), gap(k, i)] = 0.99 if y[t] == 1 else 0.01
        transition_matrices[i][gap(k, i), gap(j, i)] = 0.99 if y[t] == -1 else 0.01

    for i in range(n):
        d = np.max(np.sum(transition_matrices[i], axis=1))
        transition_matrices[i] = transition_matrices[i] / d

        self_loops = np.diag(1 - np.sum(transition_matrices[i], axis=1))

        transition_matrices[i] += self_loops

    dists_from_i = []
    for i in range(n):
        eigenvalues, eigenvectors = np.linalg.eigh(transition_matrices[i])
        leading_index = np.where(np.isclose(eigenvalues, 1))[0]
        leading_eigenvector = np.abs(eigenvectors[:, leading_index].real)
        dists_wo_i = np.log(leading_eigenvector)
        if not np.all(np.isfinite(dists_wo_i)):
            print(i)
        dists = np.zeros(n)
        np.put(dists, list(range(i)) + list(range(i+1, n)), dists_wo_i)
        dists_from_i.append(dists)
    D = np.stack(dists_from_i)

    J = np.identity(n) - (np.ones((n, 1)) @ np.ones((1, n))) / n
    H = - J @ D @ J / 2
    Xprime = J @ X
    
    XtHX = Xprime.T @ H @ Xprime / (n**2)
    Sigma = Xprime.T @ Xprime / n
    eigenvalues, eigenvectors = lobpcg(A=XtHX, B=Sigma, X=np.random.normal(size=(p,r)))
    U = eigenvectors[:, np.argsort(eigenvalues)[-r:]]
    Lambda = np.sort(eigenvalues)[-r:]
    AAt = U @ np.diag(Lambda) @ U.T
    
    V, s, W = np.linalg.svd(AAt)
    
    top_r_indicies = np.argsort(s)[:r]
    V_r = V[:, top_r_indicies].reshape((p, r))
    s_r = s[top_r_indicies]
    Ahat = V_r @ np.sqrt(np.diag(s_r))

    return Ahat

def L(A, triplets): 
    yMtAs = torch.einsum('bij,jk->bik', triplets, A)
    yMtAATs = torch.einsum('bij,jk->bik', yMtAs, torch.transpose(A, 0, 1))
    TryMtAATs = torch.einsum('bii->b', yMtAATs)
    losses = torch.log(1 + torch.exp(-TryMtAATs))
    
    return torch.mean(losses)

In [34]:
n = 400
p = 20
r = 3

In [35]:
%memit synthetic_data = pd.DataFrame(np.random.normal(size=(1000, 50)))

peak memory: 1310.99 MiB, increment: 0.00 MiB


In [36]:
%memit X = synthetic_data.sample(1000, axis=0)
%memit X = X.loc[:, X.var(axis=0) > 0]
%memit X = X.sample(p, axis=1)
%memit X = np.array(stats.zscore(np.array(X), axis=0))

peak memory: 1310.99 MiB, increment: 0.00 MiB
peak memory: 1310.99 MiB, increment: 0.00 MiB
peak memory: 1310.99 MiB, increment: 0.00 MiB
peak memory: 1310.99 MiB, increment: 0.00 MiB


In [37]:
%memit Astar = np.random.normal(0, 1, size=(p, r)) / np.sqrt(p)
%memit Kstar = Astar @ Astar.T

peak memory: 1310.99 MiB, increment: 0.00 MiB
peak memory: 1311.00 MiB, increment: 0.00 MiB


In [38]:
def generate_sbar():
    S = []
    for i in range(n):
        for j in range(n):
            for k in range(n):
                if i != j and j != k and k != i:
                    S.append((i, j, k))
    S = np.array(S)
    choices = np.random.choice([False, True], size=len(S), replace=True, p=[0.95, 0.05])
    S = S[choices, :]
    return np.array(S)

%memit S = generate_sbar()


peak memory: 8336.18 MiB, increment: 7025.18 MiB


peak memory: 1034.69 MiB, increment: 105.43 MiB
peak memory: 929.26 MiB, increment: 0.00 MiB


In [None]:
M = []
for t in S:
    i, j, k = t
    Mt = np.outer(X[i], X[k]) + np.outer(X[k], X[i]) - np.outer(X[i], X[j]) - np.outer(X[j], X[i]) + np.outer(X[j], X[j]) - np.outer(X[k], X[k])
    M.append(Mt)
M = np.array(M)

probs = 1 / (1 + np.exp(-np.einsum('ijj->i', np.einsum('ijk,kl->ijl', M,  Kstar))))

y = 2 * np.array(probs > np.random.random(np.shape(probs)), dtype=int) - 1

In [8]:
r = 3
p = 20
%memit X, S, y, Astar, Kstar = generate_synthetic_data(10, r, p, synthetic_data)

peak memory: 7841.45 MiB, increment: 0.19 MiB


In [15]:
for n in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]: # [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]:
    

    
    print("Generating synthetic data...")   
    X, S, y, Astar, Kstar = generate_synthetic_data(n, r, p, synthetic_data)
    print(np.shape(X))
    
    np.save("Astar.npy", Astar.get())
    
    print("Initializing...")
    A0 = initialization(n, p, S, y)
    
    print(n)
    print(np.linalg.norm(Astar - A0))

Generating synthetic data...
(100, 20)
Initializing...
100
1.8842892682261403
Generating synthetic data...


KeyboardInterrupt: 

In [82]:
len(S)

48251