In [14]:
import torch
import json
from tqdm import tqdm
from itertools import product as p
import numpy as np
import dask.array as da

In [6]:
class A:
    pass
self = A()

In [7]:
self.num_neurons_d = {} # {fname, int}
self.representations_d = {} # {fname, tensor}
f1, f2, f3 = "foo", "bar", "baz"
representation_files = [f1, f2, f3]

In [8]:
# initialize `num_neurons_d`, `representations_d` with fake data
n1, n2, n3 = 100, 80, 70
nword = 10_000
t1 = torch.randn(nword, n1)
t2 = torch.randn(nword, n2)
t3 = torch.randn(nword, n3)
self.num_neurons_d[f1] = n1
self.num_neurons_d[f2] = n2
self.num_neurons_d[f3] = n3
self.representations_d[f1] = t1
self.representations_d[f2] = t2
self.representations_d[f3] = t3

In [12]:
# extra params
self.limit = None
self.device = torch.device('cpu')
self.dask_chunk_size = 5_000

# compute_correlations

In [10]:
def center_gram(G):
    means = G.mean(0)
    means -= means.mean() / 2
    return G - means[None, :] - means[:, None]

def gram_rbf(X, threshold=1.0):
    if type(X) == torch.Tensor:
        dot_products = X @ X.t()
        sq_norms = dot_products.diag()
        sq_distances = -2*dot_products + sq_norms[:,None] + sq_norms[None,:]
        sq_median_distance = sq_distances.median()
        return torch.exp(-sq_distances / (2*threshold**2 * sq_median_distance))
    elif type(X) == da.Array:
        dot_products = X @ X.T
        sq_norms = da.diag(dot_products)
        sq_distances = -2*dot_products + sq_norms[:,None] + sq_norms[None,:]
        sq_median_distance = da.percentile(sq_distances.ravel(), 50)
        return da.exp((-sq_distances / (2*threshold**2 * sq_median_distance)))
    else:
        raise ValueError

In [15]:
# Set `daskp`
daskp = True if self.device == torch.device('cpu') else False

# Set `self.similarities`
# {network: {other: rbfcka_similarity}}
self.similarities = {network: {} for network in self.representations_d}
for network, other_network in tqdm(p(self.representations_d,
                                     self.representations_d),
                                   desc='rbfcka',
                                   total=len(self.representations_d)**2):

    if network == other_network:
        continue

    if other_network in self.similarities[network]: 
        continue

    if daskp:
        c = self.dask_chunk_size
        X = da.from_array(np.asarray(self.representations_d[network]), chunks=(c, c))
        Y = da.from_array(np.asarray(self.representations_d[other_network]), chunks=(c, c))

        Gx = center_gram(gram_rbf(X))
        Gy = center_gram(gram_rbf(Y))

        scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
        norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
        norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))
        
        sim = (scaled_hsic / (norm_gx*norm_gy)).compute()
    else:
        device = self.device
        X = self.representations_d[network].to(device)
        Y = self.representations_d[other_network].to(device)

        Gx = center_gram(gram_rbf(X))
        Gy = center_gram(gram_rbf(Y))

        scaled_hsic = torch.dot(Gx.view(-1), Gy.view(-1)).cpu().item()
        norm_gx = torch.norm(Gx, p="fro").cpu().item()
        norm_gy = torch.norm(Gy, p="fro").cpu().item()

        sim = scaled_hsic / (norm_gx*norm_gy)
        
    self.similarities[network][other_network] = sim
    self.similarities[other_network][network] = sim

  result = function(*args, **kwargs)
  result = function(*args, **kwargs)
  result = function(*args, **kwargs)
rbfcka: 100%|██████████████████████████████████| 9/9 [00:11<00:00,  1.26s/it]


# Dask

In [3]:
import dask.array as da

In [2]:
from dask.distributed import Client, progress
client = Client(processes=False, threads_per_worker=4,
                n_workers=1, memory_limit='10GB', local_dir="/data/sls/temp/johnmwu/dask-worker-space")
client

0,1
Client  Scheduler: inproc://128.30.34.149/21271/6  Dashboard: http://localhost:43134/status,Cluster  Workers: 1  Cores: 4  Memory: 10.00 GB


In [28]:
client.scheduler_info()

{'type': 'Scheduler',
 'id': 'Scheduler-de457f70-9d74-4ef1-8df2-e26696a7461d',
 'address': 'inproc://128.30.34.149/13269/73',
 'services': {'dashboard': 43111},
 'workers': {'inproc://128.30.34.149/13269/74': {'type': 'Worker',
   'id': 0,
   'host': '128.30.34.149',
   'resources': {},
   'local_directory': '/data/sls/temp/johsdfnmwu/worker-pdlocskn',
   'name': 0,
   'nthreads': 4,
   'memory_limit': 10000000000,
   'last_seen': 1563909624.7895453,
   'services': {},
   'metrics': {'cpu': 35.7,
    'memory': 277630976,
    'time': 1563909624.2893617,
    'read_bytes': 64331.1177675439,
    'write_bytes': 1055677.628679231,
    'num_fds': 165,
    'executing': 0,
    'in_memory': 0,
    'ready': 0,
    'in_flight': 0,
    'bandwidth': 100000000},
   'nanny': None}}}

In [11]:
dir(client)

['__aenter__',
 '__aexit__',
 '__await__',
 '__class__',
 '__del__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__enter__',
 '__eq__',
 '__exit__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_asynchronous',
 '_cancel',
 '_close',
 '_connecting_to_scheduler',
 '_dec_ref',
 '_deserializers',
 '_ensure_connected',
 '_expand_key',
 '_expand_resources',
 '_expand_retries',
 '_gather',
 '_gather_future',
 '_gather_keys',
 '_gather_remote',
 '_gather_semaphore',
 '_get_dataset',
 '_get_futures_error',
 '_get_task_stream',
 '_graph_to_futures',
 '_handle_cancelled_key',
 '_handle_error',
 '_handle_key_in_memory',
 '_handle_lost_data',
 '_handle_report',
 '_handle_restart',
 '_handle_retried_key',
 '_handle_scheduler_coroutine',
 '_hand

In [14]:
client.set_metadata?

In [13]:
dir(client.scheduler)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__enter__',
 '__eq__',
 '__exit__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'addr',
 'address',
 'close_rpc',
 'deserializers',
 'pool',
 'serializers']

In [5]:
Client?

In [8]:
import dask

In [9]:
dask.config.config

{'temporary-directory': None,
 'array': {'svg': {'size': 120},
  'chunk-size': '128MiB',
  'rechunk-threshold': 4},
 'distributed': {'version': 2,
  'scheduler': {'allowed-failures': 3,
   'bandwidth': 100000000,
   'blocked-handlers': [],
   'default-data-size': 1000,
   'events-cleanup-delay': '1h',
   'idle-timeout': None,
   'transition-log-length': 100000,
   'work-stealing': True,
   'worker-ttl': None,
   'preload': [],
   'preload-argv': [],
   'dashboard': {'status': {'task-stream-length': 1000},
    'tasks': {'task-stream-length': 100000},
    'tls': {'ca-file': None, 'key': None, 'cert': None}}},
  'worker': {'blocked-handlers': [],
   'multiprocessing-method': 'forkserver',
   'use-file-locking': True,
   'connections': {'outgoing': 50, 'incoming': 10},
   'preload': [],
   'preload-argv': [],
   'daemon': True,
   'profile': {'interval': '10ms', 'cycle': '1000ms', 'low-level': False},
   'memory': {'target': 0.6, 'spill': 0.7, 'pause': 0.8, 'terminate': 0.95}},
  'client':

In [37]:
network = f1
other_network = f2
limit = self.limit

In [38]:
X = da.from_array(np.array(self.representations_d[network][:limit]), chunks=(1000, 1000))
Y = da.from_array(np.array(self.representations_d[other_network][:limit]), chunks=(1000, 1000))

In [39]:
Gx = center_gram(gram_rbf(X))
Gy = center_gram(gram_rbf(Y))

  concatenate=True,


In [40]:
scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))

In [41]:
%time sim = (scaled_hsic / (norm_gx*norm_gy)).compute() 

  result = function(*args, **kwargs)


KeyboardInterrupt: 

In [20]:
# %time sim = (scaled_hsic / (norm_gx*norm_gy)).compute() 
# 39:25 using 1000 1000 chunks

  result = function(*args, **kwargs)


CPU times: user 11min 7s, sys: 28min 17s, total: 39min 25s
Wall time: 26min 46s










































In [9]:
X = da.from_array(np.asarray(self.representations_d[network][:limit]), chunks=(5001, 5001))
Y = da.from_array(np.asarray(self.representations_d[other_network][:limit]), chunks=(5001, 5001))

In [13]:
Gx = center_gram(gram_rbf(X))
Gy = center_gram(gram_rbf(Y))

In [14]:
scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))

In [9]:
X = da.from_array(np.asarray(self.representations_d[network][:limit]), chunks=(5001, 5001))
Y = da.from_array(np.asarray(self.representations_d[other_network][:limit]), chunks=(5001, 5001))

Gx = center_gram(gram_rbf(X))
Gy = center_gram(gram_rbf(Y))

scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))

# Full functions

In [19]:
def compute_correlations(self):
    def center_gram(G):
        means = G.mean(0)
        means -= means.mean() / 2
        return G - means[None, :] - means[:, None]

    def gram_rbf(X, threshold=1.0):
        if type(X) == torch.Tensor:
            dot_products = X @ X.t()
            sq_norms = dot_products.diag()
            sq_distances = -2*dot_products + sq_norms[:,None] + sq_norms[None,:]
            sq_median_distance = sq_distances.median()
            return torch.exp(-sq_distances / (2*threshold**2 * sq_median_distance))
        elif type(X) == da.Array:
            dot_products = X @ X.T
            sq_norms = da.diag(dot_products)
            sq_distances = -2*dot_products + sq_norms[:,None] + sq_norms[None,:]
            sq_median_distance = da.percentile(sq_distances.ravel(), 50)
            return da.exp((-sq_distances / (2*threshold**2 * sq_median_distance)))
        else:
            raise ValueError
    
    # Set `daskp`
    daskp = True if self.device == torch.device('cpu') else False

    # Set `self.similarities`
    # {network: {other: rbfcka_similarity}}
    self.similarities = {network: {} for network in self.representations_d}
    for network, other_network in tqdm(p(self.representations_d,
                                         self.representations_d),
                                       desc='rbfcka',
                                       total=len(self.representations_d)**2):

        if network == other_network:
            continue

        if other_network in self.similarities[network]: 
            continue

        if daskp:
            c = self.dask_chunk_size
            X = da.from_array(np.asarray(self.representations_d[network]), chunks=(c, c))
            Y = da.from_array(np.asarray(self.representations_d[other_network]), chunks=(c, c))

            Gx = center_gram(gram_rbf(X))
            Gy = center_gram(gram_rbf(Y))

            scaled_hsic = da.dot(Gx.ravel(), Gy.ravel())
            norm_gx = da.sqrt(da.dot(Gx.ravel(), Gx.ravel()))
            norm_gy = da.sqrt(da.dot(Gy.ravel(), Gy.ravel()))

            sim = (scaled_hsic / (norm_gx*norm_gy)).compute()
        else:
            device = self.device
            X = self.representations_d[network].to(device)
            Y = self.representations_d[other_network].to(device)

            Gx = center_gram(gram_rbf(X))
            Gy = center_gram(gram_rbf(Y))

            scaled_hsic = torch.dot(Gx.view(-1), Gy.view(-1)).cpu().item()
            norm_gx = torch.norm(Gx, p="fro").cpu().item()
            norm_gy = torch.norm(Gy, p="fro").cpu().item()

            sim = scaled_hsic / (norm_gx*norm_gy)

        self.similarities[network][other_network] = sim
        self.similarities[other_network][network] = sim

In [20]:
compute_correlations(self)

rbfcka:   0%|                                          | 0/9 [00:00<?, ?it/s]

True


  result = function(*args, **kwargs)
  result = function(*args, **kwargs)
  result = function(*args, **kwargs)
rbfcka: 100%|██████████████████████████████████| 9/9 [00:11<00:00,  1.27s/it]
