In [1]:
%%capture
from lmi import lmi
import time
import os

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
from tqdm.notebook import tqdm
import torch
from scipy.stats import kendalltau

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".10"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
from bmi.estimators import MINEEstimator as MINE
from bmi.estimators import InfoNCEEstimator as InfoNCE

torch.manual_seed(2121)
np.random.seed(2121)

In [2]:
###############################
#### experiment parameters ####
###############################


# 1000 estimates per estimator per experiment
# 12k total

intrinsic = 1
ambient_dimensions = np.linspace(1, 50, 5, dtype=int)
sample_numbers = np.linspace(10**3, 10**4, 10, dtype=int)
true_mis = [1]*10

In [3]:
def generate_gaussian_dataset(ambient, intrinsic, nuisance, antidiag, samples=10**3):
    """
    
    """
    
    assert intrinsic+nuisance <= ambient, "Dimensionality not adding up"
    
    X_nuisance = np.random.normal(size=(samples, nuisance))
    Y_nuisance = np.random.normal(size=(samples, nuisance))


    cov = np.array([[6, antidiag], [antidiag, 3.5]])
    
    pts = [np.random.multivariate_normal([0, 0], cov, 
                                         size=samples) for i in range(intrinsic)]
    
    pts = np.hstack(pts)
    
    # one copy of the intrinsic dimensions
    Xs = pts[:, [i for i in range(0, 2*intrinsic, 2)]]
    Ys = pts[:, [i for i in range(1, 2*intrinsic, 2)]]
    
    # then randomly sample them to make up the rest of the dimensions
    X_redundant = pts[:, np.random.choice(range(0, 2*intrinsic, 2), 
                                   size = (ambient-(intrinsic+nuisance)))]
    
    Y_redundant = pts[:, np.random.choice(range(1, 2*intrinsic, 2), 
                               size = (ambient-(intrinsic+nuisance)))]
    
    Xs = np.hstack((Xs, X_redundant, X_nuisance))
    Ys = np.hstack((Ys, Y_redundant, Y_nuisance))
    
    return Xs, Ys

def mi_from_rho(rho, intrinsic):
    return -0.5*np.log2((1-(rho/(np.sqrt(6*3.5)))**2))*intrinsic

def rho_from_mi(mi, intrinsic):
    return np.sqrt(6*3.5) * np.sqrt(1 - 2**(-2*mi/intrinsic))

# type 2: dependence structure requires 1 dimension

In [4]:
res_d = {
    "Estimator" : [],
    "Sample number" : [],
    "True MI" : [],
    "Estimate" : [],
    "Ambient dimensions" : [],
    "Intrinsic dimensions" : []
}


for true_mi in tqdm(true_mis, position=3, desc='anti_diag'):
    
    for s in tqdm(sample_numbers, desc='sample number', leave=False):

        for ambient in tqdm(ambient_dimensions, desc='dimensions', leave=False):
    
            antidiag = rho_from_mi(true_mi, intrinsic) 

            Xs, Ys = generate_gaussian_dataset(ambient, intrinsic, 
                                               (ambient-intrinsic) // 2, 
                                               antidiag,
                                               samples=s)
            # rescale
            Xs = np.nan_to_num((Xs - Xs.mean(axis=0)) / Xs.std(axis=0))
            Ys = np.nan_to_num((Ys - Ys.mean(axis=0)) / Ys.std(axis=0))
            
            batch_size = 256 # default
            if len(Xs) < 512: # adjustment for small sample size
                batch_size = 32
                
            infonce = InfoNCE(verbose=False, batch_size=batch_size)
            infonce_mi = infonce.estimate(Xs, Ys)/np.log(2)

            mine = MINE(verbose=False, batch_size=batch_size)
            mine_mi = mine.estimate(Xs, Ys)/np.log(2)

            ksg_mi = np.mean(lmi.ksg.mi(Xs, Ys))

            if ambient < 50:
                Xs = np.tile(Xs, (1, 1 + 50//ambient))
                Ys = np.tile(Ys, (1, 1 + 50//ambient))
            simi_mi = np.nanmean(lmi.lmi(Xs, Ys)[0])

            res_d['Estimator'].append('InfoNCE')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(infonce_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)

            res_d['Estimator'].append('LMI')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(simi_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)
            
            res_d['Estimator'].append('MINE')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(mine_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)
            
            res_d['Estimator'].append('KSG')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(ksg_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)

anti_diag:   0%|          | 0/10 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

In [5]:
res_df = pd.DataFrame(res_d)
res_df.to_csv('../results/B_Gaussian_sample_complexity_low_rank.csv')

# type 1: no underlying low dimensional structure

In [6]:
res_d = {
    "Estimator" : [],
    "Sample number" : [],
    "True MI" : [],
    "Estimate" : [],
    "Ambient dimensions" : [],
    "Intrinsic dimensions" : []
}

for true_mi in tqdm(true_mis, position=3, desc='anti_diag'):
    
    for s in tqdm(sample_numbers, desc='sample number', leave=False):

        for ambient in tqdm(ambient_dimensions, desc='dimensions', leave=False):
            
            ######################################
            #### no low-dimensional structure ####
            ######################################
            
            intrinsic = ambient ##################            
            
            ######################################
    
            antidiag = rho_from_mi(true_mi, intrinsic) 

            Xs, Ys = generate_gaussian_dataset(ambient, intrinsic, 
                                               (ambient-intrinsic) // 2, 
                                               antidiag,
                                               samples=s)
            
            # rescale
            Xs = np.nan_to_num((Xs - Xs.mean(axis=0)) / Xs.std(axis=0))
            Ys = np.nan_to_num((Ys - Ys.mean(axis=0)) / Ys.std(axis=0))
            
            batch_size = 256 # default
            if len(Xs) < 512: # adjustment for small sample size
                batch_size = 32

            infonce = InfoNCE(verbose=False, batch_size=batch_size)
            infonce_mi = infonce.estimate(Xs, Ys)/np.log(2)

            mine = MINE(verbose=False, batch_size=batch_size)
            mine_mi = mine.estimate(Xs, Ys)/np.log(2)

            ksg_mi = np.mean(lmi.ksg.mi(Xs, Ys))

            if ambient < 50:
                Xs = np.tile(Xs, (1, 1 + 50//ambient))
                Ys = np.tile(Ys, (1, 1 + 50//ambient))
            simi_mi = np.nanmean(lmi.lmi(Xs, Ys)[0])

            res_d['Estimator'].append('InfoNCE')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(infonce_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)

            res_d['Estimator'].append('LMI ')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(simi_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)
            
            res_d['Estimator'].append('MINE')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(mine_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)
            
            res_d['Estimator'].append('KSG')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(ksg_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)

anti_diag:   0%|          | 0/10 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

In [7]:
res_df = pd.DataFrame(res_d)
res_df.to_csv('../results/B_Gaussian_sample_complexity_high_rank.csv')

# type 3: 10% intrinsic

In [8]:
res_d = {
    "Estimator" : [],
    "Sample number" : [],
    "True MI" : [],
    "Estimate" : [],
    "Ambient dimensions" : [],
    "Intrinsic dimensions" : []
}

for true_mi in tqdm(true_mis, position=3, desc='anti_diag'):
    
    for s in tqdm(sample_numbers, desc='sample number', leave=False):

        for ambient in tqdm(ambient_dimensions, desc='dimensions', leave=False):
            
            ######################################
            #### .1 low-dimensional structure ####
            ######################################
            
            intrinsic = max(1, int(0.1*ambient)) ##################            
            
            ######################################
    
            antidiag = rho_from_mi(true_mi, intrinsic) 

            Xs, Ys = generate_gaussian_dataset(ambient, intrinsic, 
                                               (ambient-intrinsic) // 2, 
                                               antidiag,
                                               samples=s)
            
            # rescale
            Xs = np.nan_to_num((Xs - Xs.mean(axis=0)) / Xs.std(axis=0))
            Ys = np.nan_to_num((Ys - Ys.mean(axis=0)) / Ys.std(axis=0))
            
            batch_size = 256 # default
            if len(Xs) < 512: # adjustment for small sample size
                batch_size = 32

            infonce = InfoNCE(verbose=False, batch_size=batch_size)
            infonce_mi = infonce.estimate(Xs, Ys)/np.log(2)

            mine = MINE(verbose=False, batch_size=batch_size)
            mine_mi = mine.estimate(Xs, Ys)/np.log(2)

            ksg_mi = np.mean(lmi.ksg.mi(Xs, Ys))

            if ambient < 50:
                Xs = np.tile(Xs, (1, 1 + 50//ambient))
                Ys = np.tile(Ys, (1, 1 + 50//ambient))
            simi_mi = np.nanmean(lmi.lmi(Xs, Ys)[0])

            res_d['Estimator'].append('InfoNCE')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(infonce_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)

            res_d['Estimator'].append('LMI')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(simi_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)
            
            res_d['Estimator'].append('MINE')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(mine_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)
            
            res_d['Estimator'].append('KSG')
            res_d['Sample number'].append(s)
            res_d['True MI'].append(true_mi)
            res_d['Estimate'].append(ksg_mi)
            res_d['Ambient dimensions'].append(ambient)
            res_d['Intrinsic dimensions'].append(intrinsic)

anti_diag:   0%|          | 0/10 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

sample number:   0%|          | 0/10 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

dimensions:   0%|          | 0/5 [00:00<?, ?it/s]

In [9]:
res_df = pd.DataFrame(res_d)
res_df.to_csv('../results/B_Gaussian_sample_complexity_medium_rank.csv')
# res_df