In [57]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

import scipy

import sys
from pathlib import Path
sys.path.insert(0, str(Path().resolve().parents[1]))
from model import left_right_task as lrt, network_model, util, plot_style

# Load the fiducial network
Wji, pset, amp, dur, l_kernel, r_kernel = util.load_fiducial_network(True)

In [66]:
# calculate means of Wji, ignoring the diagonal
Wji_means = [np.mean(w[np.eye(w.shape[0], dtype=bool) == False]) for w in Wji]
# calculate stds of Wji, ignoring the diagonal
Wji_stds = [np.std(w[np.eye(w.shape[0], dtype=bool) == False]) for w in Wji]

# print the means and stds
print("Wji means:", Wji_means)
print("Wji stds:", Wji_stds)
print('CVs', Wji_stds / np.array(Wji_means))

Wji means: [0.07055, 0.04329000000000001, 0.0037600000000000003, 0.00266]
Wji stds: [0.001534536221080462, 0.0012065114114338733, 9.510277250660907e-05, 7.013053548784869e-05]
CVs [0.02175104 0.02787044 0.02529329 0.02636486]


In [25]:
np.array([np.std(w) for w in Wji]) / Wji_means

array([-2.15788606, -2.04256772,  1.99686518,  1.99939067])

In [16]:
print(np.logspace(-2, np.log10(2), 5))

[0.01       0.03760603 0.14142136 0.53182959 2.        ]


In [None]:
def new_make_Wji(rng:np.random.default_rng, numPairs:int, mean:float, std:float,
                 mean_tol = None, std_tol =None, timeout=np.inf, timeout_limit=10,
                 verbose=False):
    
    assert mean != 0, "Mean cannot be zero"
    assert std  >= 0, "Standard deviation must be non-negative"
    assert mean_tol >= 0, "Mean tolerance must be non-negative"
    assert std_tol >= 0, "Standard deviation tolerance must be non-negative"

    if mean_tol is None:
        mean_tol = 0.05 * mean
    if std_tol is None:
        std_tol = 0.05 * std

    if std == 0:
        Wji = np.ones((numPairs, numPairs)) * mean
        np.fill_diagonal(Wji, 0)
        return Wji

    sign = -1 if mean < 0 else 1
    abs_mean = np.abs(mean)

    mu_log = np.log(abs_mean**2 / np.sqrt(std**2 + abs_mean**2))
    sigma_log = np.sqrt(np.log(1 + (std**2 / abs_mean**2)))
    Wji = _log_normal_helper(rng, numPairs, mu_log, sigma_log, abs_mean, mean_tol, timeout)
    # now enforce std
    get_whole_std = lambda Wji: np.std(Wji[np.eye(Wji.shape[0], dtype=bool) == False])
    timeouts = 0
    tries = 0
    closest = np.inf
    while np.abs(get_whole_std(Wji) - std) > std_tol:
        try:
            std_boost = tries * 1e-6 * std # small boost to sigma to help convergence
            # helps a lot for higher stds, where we need 
            # to get samples from the tail to match the mean and std simultaneously

            # recompute log-normal parameters
            mu_log = np.log(abs_mean**2 / np.sqrt((std + std_boost)**2 + abs_mean**2))
            sigma_log = np.sqrt(np.log(1 + ((std + std_boost)**2 / abs_mean**2)))
            Wji = _log_normal_helper(rng, numPairs, mu_log, sigma_log, abs_mean, mean_tol, timeout)
            tries += 1
            current_std = get_whole_std(Wji)
            if np.abs(current_std - std) < closest:
                closest = np.abs(current_std - std)
            if verbose and tries%1000==0: print(f"Retry {tries}: closest = {closest:.3f}, target = {std:.3f} ", end='\r')
        except TimeoutError:
            timeouts += 1
            if timeouts > timeout_limit:
                raise TimeoutError(f"Timeout count exceeded while generating Wji with mean {mean} and std {std}")
            continue
    if verbose: print()
    return Wji * sign

def _log_normal_helper(rng, numPairs, mu, sigma, mean, mean_tol, timeout=np.inf):
    """Generates an array where each row has the desired mean"""
    Wji = np.zeros((numPairs, numPairs))
    for i in range(numPairs):
        start_time = time.time()
        curr_time = start_time
        values = rng.lognormal(mu, sigma, numPairs-1)
        current_mean = np.mean(values)
        while abs(current_mean - mean) > mean_tol and \
                abs(curr_time - start_time) < timeout:
            values = rng.lognormal(mu, sigma, numPairs-1)
            current_mean = np.mean(values)
            curr_time = time.time()
        if np.abs(current_mean - mean) <= mean_tol:
            Wji[i] = np.insert(values, i, 0)  # Insert zero on the diagonal
        if abs(curr_time - start_time) >= timeout:
            raise TimeoutError(f"Timeout exceeded while generating Wji with mean {mean} and std {sigma}")
    return Wji


In [137]:
# test new_make_Wji
for i in np.logspace(-1, np.log10(1),5):
    print(i)
    rng = np.random.default_rng()
    Wji_test = new_make_Wji(rng, 5, 1, i, mean_tol=0.05, std_tol=0.05*i, verbose=i>1)
    print("Mean of Wji_test:", np.mean(Wji_test[np.eye(Wji_test.shape[0], dtype=bool) == False]))
    print("Std of Wji_test:", np.std(Wji_test[np.eye(Wji_test.shape[0], dtype=bool) == False]))

0.1
Mean of Wji_test: 1.010392567814888
Std of Wji_test: 0.10357593514682097
0.1778279410038923
Mean of Wji_test: 1.0111319585402656
Std of Wji_test: 0.1746121856133018
0.31622776601683794
Mean of Wji_test: 0.9983498109952766
Std of Wji_test: 0.32306818976836166
0.5623413251903491
Mean of Wji_test: 0.9938524277732469
Std of Wji_test: 0.5436597410182112
1.0
Mean of Wji_test: 0.9865949584538551
Std of Wji_test: 0.9652266707000106
