In [37]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np

In [38]:
class A:
    pass
self = A()

In [39]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [40]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 1000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [41]:
# extra params
self.limit = None
self.device = torch.device('cpu')

# compute_correlations

In [11]:
def center_gram(G):
    means = G.mean(0)
    means -= means.mean() / 2
    return G - means[None, :] - means[:, None]

def gram_rbf(X, threshold=1.0):
    if type(X) == torch.Tensor:
        dot_products = X @ X.t()
        sq_norms = dot_products.diag()
        sq_distances = -2*dot_products + sq_norms[:,None] + sq_norms[None,:]
        sq_median_distance = sq_distances.median()
        return torch.exp(-sq_distances / (2*threshold**2 * sq_median_distance))
    elif type(X) == da.Array:
        dot_products = X @ X.T
        sq_norms = da.diag(dot_products)
        sq_distances = -2*dot_products + sq_norms[:,None] + sq_norms[None,:]
        sq_median_distance = da.percentile(sq_distances.ravel(), 50)
        return da.exp((-sq_distances / (2*threshold**2 * sq_median_distance)))
    else:
        raise ValueError

In [42]:
# Set `limit`
n_words = next(iter(self.representations_d.values())).size()[0]
if type(self.limit) == float:
    limit = int(n_words * self.limit)
elif type(self.limit) == int:
    limit = self.limit
else:
    limit = self.limit

# Set `daskp`
daskp = True if self.device == torch.device('cpu') else False

# Set `self.similarities`
# {network: {other: rbfcka_similarity}}
self.similarities = {network: {} for network in self.representations_d}
for network, other_network in tqdm(p(self.representations_d,
                                     self.representations_d),
                                   desc='rbfcka',
                                   total=len(self.representations_d)**2):

    if network == other_network:
        continue

    if other_network in self.similarities[network]: 
        continue

    if daskp:
        c = self.dask_chunk_size
        X = da.from_array(np.asarray(self.representations_d[network][:limit]), chunks=(c, c))
        Y = da.from_array(np.asarray(self.representations_d[other_network][:limit]), chunks=(c, c))

        Gx = center_gram(gram_rbf(X))
        Gy = center_gram(gram_rbf(Y))

        scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
        norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
        norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))
        
        sim = (scaled_hsic / (norm_gx*norm_gy)).compute()
    else:
        device = self.device
        X = self.representations_d[network][:limit].to(device)
        Y = self.representations_d[other_network][:limit].to(device)

        # TO DO: random subset of data using limit?
        Gx = center_gram(gram_rbf(X))
        Gy = center_gram(gram_rbf(Y))

        scaled_hsic = torch.dot(Gx.view(-1), Gy.view(-1)).cpu().item()
        norm_gx = torch.norm(Gx, p="fro").cpu().item()
        norm_gy = torch.norm(Gy, p="fro").cpu().item()

        sim = scaled_hsic / (norm_gx*norm_gy)
        
    self.similarities[network][other_network] = sim
    self.similarities[other_network][network] = sim

# Dask

In [6]:
import dask.array as da

In [7]:
from dask.distributed import Client, progress
client = Client(processes=False, threads_per_worker=4,
                n_workers=1, memory_limit='10GB')
client

Port 8787 is already in use. 
Perhaps you already have a cluster running?
Hosting the diagnostics dashboard on a random port instead.


0,1
Client  Scheduler: inproc://128.30.34.149/26967/1  Dashboard: http://localhost:40189/status,Cluster  Workers: 1  Cores: 4  Memory: 10.00 GB


In [25]:
import dask

In [8]:
network = f1
other_network = f2
limit = self.limit

In [13]:
X = da.from_array(np.array(self.representations_d[network][:limit]), chunks=(1000, 1000))
Y = da.from_array(np.array(self.representations_d[other_network][:limit]), chunks=(1000, 1000))

In [13]:
Gx = center_gram(gram_rbf(X))
Gy = center_gram(gram_rbf(Y))

In [14]:
scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))

In [20]:
# %time sim = (scaled_hsic / (norm_gx*norm_gy)).compute() 
# 39:25 using 1000 1000 chunks

  result = function(*args, **kwargs)


CPU times: user 11min 7s, sys: 28min 17s, total: 39min 25s
Wall time: 26min 46s










































In [9]:
X = da.from_array(np.asarray(self.representations_d[network][:limit]), chunks=(5001, 5001))
Y = da.from_array(np.asarray(self.representations_d[other_network][:limit]), chunks=(5001, 5001))

In [13]:
Gx = center_gram(gram_rbf(X))
Gy = center_gram(gram_rbf(Y))

In [14]:
scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))

In [9]:
X = da.from_array(np.asarray(self.representations_d[network][:limit]), chunks=(5001, 5001))
Y = da.from_array(np.asarray(self.representations_d[other_network][:limit]), chunks=(5001, 5001))

Gx = center_gram(gram_rbf(X))
Gy = center_gram(gram_rbf(Y))

scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))

# Full functions

In [49]:
def compute_correlations(self):
    def center_gram(G):
        means = G.mean(0)
        means -= means.mean() / 2
        return G - means[None, :] - means[:, None]

    def gram_rbf(X, threshold=1.0):
        if type(X) == torch.Tensor:
            dot_products = X @ X.t()
            sq_norms = dot_products.diag()
            sq_distances = -2*dot_products + sq_norms[:,None] + sq_norms[None,:]
            sq_median_distance = sq_distances.median()
            return torch.exp(-sq_distances / (2*threshold**2 * sq_median_distance))
        elif type(X) == da.Array:
            dot_products = X @ X.T
            sq_norms = da.diag(dot_products)
            sq_distances = -2*dot_products + sq_norms[:,None] + sq_norms[None,:]
            sq_median_distance = da.percentile(sq_distances.ravel(), 50)
            return da.exp((-sq_distances / (2*threshold**2 * sq_median_distance)))
        else:
            raise ValueError


    # Set `daskp`
    # Logic could become more complex
    daskp = True if self.device == torch.device('cpu') else False

    # Set `self.similarities`
    # {network: {other: rbfcka_similarity}}
    self.similarities = {network: {} for network in self.representations_d}
    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d),
                                       desc='rbfcka',
                                       total=len(self.representations_d)**2):

        if network == other_network:
            continue

        if other_network in self.similarities[network]: 
            continue

        if daskp:
            c = self.dask_chunk_size
            X = da.from_array(np.asarray(self.representations_d[network][:limit]), chunks=(c, c))
            Y = da.from_array(np.asarray(self.representations_d[other_network][:limit]), chunks=(c, c))

            Gx = center_gram(gram_rbf(X))
            Gy = center_gram(gram_rbf(Y))

            scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
            norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
            norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))

            sim = (scaled_hsic / (norm_gx*norm_gy)).compute()
        else:
            device = self.device
            X = self.representations_d[network][:limit].to(device)
            Y = self.representations_d[other_network][:limit].to(device)

            # TO DO: random subset of data using limit?
            Gx = center_gram(gram_rbf(X))
            Gy = center_gram(gram_rbf(Y))

            scaled_hsic = torch.dot(Gx.view(-1), Gy.view(-1)).cpu().item()
            norm_gx = torch.norm(Gx, p="fro").cpu().item()
            norm_gy = torch.norm(Gy, p="fro").cpu().item()

            sim = scaled_hsic / (norm_gx*norm_gy)

        self.similarities[network][other_network] = sim
        self.similarities[other_network][network] = sim