In [1]:
import itertools

import numpy as np
import pandas as pd
import xarray as xr

from scipy.stats import zscore, pearsonr
from scipy.spatial.distance import cdist, pdist, squareform, cosine as cosdist
cossim = lambda x, y: 1 - cosdist(x, y)
from scipy.sparse import dok_matrix, coo_matrix, csr_matrix
from scipy.optimize import differential_evolution, minimize_scalar

from sklearn.base import clone
from sklearn.utils import check_random_state, Bunch
from sklearn.model_selection import cross_val_score, KFold
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA

from gemmr.generative_model import GEMMR, JointCovarianceModelCCA
from gemmr.estimators import SVDCCA, SVDPLS
# from gemmr.estimators.r_estimators import SparseCCA
from gemmr.estimators.helpers import pearson_transform_scorer
from gemmr.sample_analysis.macros import analyze_subsampled_and_resampled
from gemmr.model_selection import max_min_detector

import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')
plt.rcParams['figure.figsize'] = [1.7, 1.7]
plt.rcParams['figure.dpi'] = 300
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Helvetica']
plt.rcParams['font.size'] = 8
plt.rcParams['axes.titlesize'] = 9
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['legend.fontsize'] = 7.5
plt.rcParams['xtick.labelsize'] = 7.5
plt.rcParams['ytick.labelsize'] = 7.5
plt.rcParams['figure.titlesize'] = 10
plt.rcParams['legend.frameon'] = False
plt.rcParams['legend.handlelength'] = 1.
plt.rcParams['legend.handletextpad'] = .5
plt.rcParams['legend.borderaxespad'] = .25

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

from tqdm.notebook import tqdm, trange

from gemmr.estimators.exotic import RegularizedCCA, FSCCA, BsFSCCA, SparsePCA_CCA
from gemmr.estimators.combination import XProjectionCCA, IterativeCCA, C3A, IntegrativeCCA, MoreXCCA, XYProjectionCCA, LinRegFillInCCA
from gemmr.estimators.annotated import AnnotatedMultiviewEstimator
from gemmr.generative_model.other import PopulationCCA, SubPopulationCCA
from gemmr.generative_model.integrative import RandomFtrIntegrativeGenerativeModel, ZeroFtrCorrIntegrativeGenerativeModel, RandomJointCovarianceModel, DualGEMMR
from gemmr.sample_analysis.annotated import analyze_model, analyze_model_i, analyze_model_parameters

In [2]:
common_truth = SubPopulationCCA(normalize_weights=True)

In [5]:
ialgs = [
    AnnotatedMultiviewEstimator('cca', SVDCCA(), SubPopulationCCA()),
    AnnotatedMultiviewEstimator('c3a', C3A(), C3A()),
    #AnnotatedMultiviewEstimator('c3a_weighted', C3A( =True), C3A())
]

In [None]:
ds = analyze_model_parameters(
    algs=ialgs,
    GM=RandomJointCovarianceModel,
    params=dict(
        px=[16,], #dimensions for data set x
        pya=[16,], #dimensions for data set y_a
        pyb=[16], #dimensions for data set y_b
        ax=[-1], #
        aya=[-1],
        ayb=[-1],
        #ax=[-1, 0],
        exa_mix=[.9], #similarity of weight vectors for subject set 1 x X, subject set 2 x X, used in generative model
        mix_component=[-1,],
        n_pc_skip=[0],
        random_state=np.arange(10), #set random seed for generating data matrices
        n_components_a=[15],
    ),
    fixed_params=dict(rxa=.3, rxb=.3), #rxa is true correlation for subjects in set 1, rxb is corresponding for set 2
    common_truth=common_truth, #algorithm to use as ground truth to compare C3A to 
    n_per_ftrs=[4, 16, 64, 256],#, 1024, 4096*4],#, 8, 16, 32, 64, 128],#, 128, 256, 2048], // number of subjects per feature (X + Ya) in subject set 1
    n_per_ftr2s=[32],#, 1024, 4096*4],# 8, 16, 32, 64, 128], // number of subjects per feature (X+ Yb) in data set 2
    n_rep=2, #number of times to generate dataset and run CCA - 10 -100 times
)
ds_mean = ds.sel(mode=0).mean('random_state').mean('rep').mean('px').mean('pya').mean('pyb').mean('ax').mean('aya').mean('ayb').mean('n_pc_skip').mean('n_components_a')

In [7]:
ds2 = analyze_model_parameters(
    algs=ialgs,
    GM=RandomJointCovarianceModel,
    params=dict(
        px=[16,], #dimensions for data set x
        pya=[16,], #dimensions for data set y_a
        pyb=[16], #dimensions for data set y_b
        ax=[-1], #
        aya=[-1],
        ayb=[-1],
        #ax=[-1, 0],
        exa_mix=[.9], #similarity of weight vectors for subject set 1 x X, subject set 2 x X, used in generative model
        mix_component=[-1,],
        n_pc_skip=[0],
        random_state=np.arange(10), #set random seed for generating data matrices
        n_components_a=[15],
    ),
    fixed_params=dict(rxa=.3, rxb=.3), #rxa is true correlation for subjects in set 1, rxb is corresponding for set 2
    common_truth=common_truth, #algorithm to use as ground truth to compare C3A to 
    n_per_ftrs=[16],#, 1024, 4096*4],#, 8, 16, 32, 64, 128],#, 128, 256, 2048], // number of subjects per feature (X + Ya) in subject set 1
    n_per_ftr2s=[16],#, 1024, 4096*4],# 8, 16, 32, 64, 128], // number of subjects per feature (X+ Yb) in data set 2
    n_rep=1, #number of times to generate dataset and run CCA - 10 -100 times
)

n_components_a:   0%|          | 0/1 [00:00<?, ?it/s]
[A

[A[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





[A[A[A[A[A[A






[A[A[A[A[A[A[A







[A[A[A[A[A[A[A[A








[A[A[A[A[A[A[A[A[A

loss: 15.898991500663637
loss: 3.021983831822557
loss: 15.900376714592614
loss: 2.819853307847261
loss: 1.5124454190311272
loss: -0.24137455403055683
loss: 3.5606044156332626
loss: -0.17013931437185661
loss: -0.4885612079115001
loss: 15.903448453500685
loss: 8.725123986404931
loss: 0.4359446158211504
loss: 2.7728632050376176
loss: -0.061986422044599955
loss: 15.891715401880695
loss: 3.771052234961569
loss: -0.22793575119506224
loss: -0.08611966300395747
loss: -0.25913885961829863
loss: 14.441287045871013
loss: -0.08859888143853722
loss: -0.08390131808994991
loss: -0.27996809089355956
loss: 15.903669490752952
loss: -0.10682057182265793
loss: 4.793529967678175
loss: -0.08705700312869075
loss: -0.13444101335468372
loss: -0.09815896779797333
loss: -0.5697660602075116
loss: 10.03934463728398
loss: -0.5475599882692145
loss: -0.26316568274151025
loss: -0.08729201561784271
loss: -0.1652289572337274
loss: 4.738241115176826
loss: 15.881291653004448
loss: 11.496167660644824
loss: 11.9270617356238











[A[A[A[A[A[A[A[A[A[A










[A[A[A[A[A[A[A[A[A[A[A











[A[A[A[A[A[A[A[A[A[A[A[A

loss: 0.8972959917603512
loss: 2.946142089850206
loss: -0.20010044456963735
loss: -0.2038991108963829
loss: 9.769356145397985
loss: -0.3053587298406953
loss: 6.840233554352504
loss: 1.0412286180865027
loss: -0.515087588061142
loss: -0.23017476983724566
loss: 17.717149414059254
loss: 8.768925734092521
loss: 1.0625698532022776
loss: -0.6009965943527464
loss: 8.423552588329178
loss: 10.387365912847242
loss: 8.597542535738695
loss: 17.013122062018365
loss: 0.519266491512491
loss: 9.779618124140217
loss: -0.7168994916693247
loss: 1.0377311750800589
loss: 7.720096907033888
loss: -0.3677219140038156
loss: 5.709973997033087
loss: 1.9874415821020135
loss: -0.31630386248351955
loss: 6.829984953832848
loss: -0.3731233556003255
loss: -0.2920300172169756
loss: 11.7727867608768
loss: 0.6928242092503106
loss: 8.630177362473262
loss: -0.5736665281979321
loss: 0.6404619827093543
loss: -0.4114835371691839
loss: 9.287017798666943
loss: 17.736979933731437
loss: 1.8199796553268404
loss: 11.964160921542415














[A[A[A[A[A[A[A[A[A[A[A[A











[A[A[A[A[A[A[A[A[A[A[A[A










[A[A[A[A[A[A[A[A[A[A[A










[A[A[A[A[A[A[A[A[A[A[A









[A[A[A[A[A[A[A[A[A[A









[A[A[A[A[A[A[A[A[A[A








[A[A[A[A[A[A[A[A[A








[A[A[A[A[A[A[A[A[A







[A[A[A[A[A[A[A[A







[A[A[A[A[A[A[A[A






[A[A[A[A[A[A[A






[A[A[A[A[A[A[A





[A[A[A[A[A[A





[A[A[A[A[A[A




[A[A[A[A[A




[A[A[A[A[A



[A[A[A[A



[A[A[A[A


[A[A[A


[A[A[A

[A[A

[A[A
[A
[A
[A

[A[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





[A[A[A[A[A[A






[A[A[A[A[A[A[A







[A[A[A[A[A[A[A[A








[A[A[A[A[A[A[A[A[A

loss: -0.46926696941336177
loss: -0.09469959730854353
loss: -0.09659732813774528
loss: -0.08780807102686164
loss: -0.2480382875884095
loss: -0.10379844045114144
loss: 13.73793794831596
loss: 9.246821814321075
loss: 0.6902061707929705
loss: -0.5624250747633373
loss: 13.052165367599097
loss: 10.451005781637818
loss: 12.236815185274086
loss: -0.489540508765855
loss: -0.10403868586134948
loss: 0.09430799600869755
loss: 6.219445608975182
loss: 15.616663128314746
loss: 8.351672386828282
loss: 1.2769213181705046
loss: -0.46686980251430577
loss: 0.08477248456476569
loss: 6.988603976843471
loss: 14.199623504343492
loss: -0.5682993112683136
loss: 0.07338519842338459
loss: 10.354658705305912
loss: -0.09187001943376656
loss: 0.5140991269180775
loss: -0.22558680470726059
loss: 15.852952379209448
loss: 10.376224195066847
loss: -0.0679914160242427
loss: 3.5004276809262027
loss: -0.1682973048564452
loss: 0.7461971460404525
loss: 0.7833680801826344
loss: -0.3102690050619461
loss: -0.08926416037640104
l










[A[A[A[A[A[A[A[A[A








                                          

loss: -0.5684750581731841
loss: -0.5859666723861167
loss: -0.5683767908060888
loss: -0.5858967245007829
loss: -0.585922687620525
loss: -0.5858400568474574
loss: -0.5857931717886103
loss: -0.5687420196199433
loss: -0.5859629228734329
loss: -0.5859704940862155
loss: -0.5858073602735837
loss: -0.5858739788526456
loss: -0.5687301380235896
loss: -0.5859619166114293
loss: -0.5859704960021804
loss: -0.5858109720757246
loss: -0.5678194477490993
loss: -0.585910011497147
loss: -0.5859704184690598
loss: -0.5859703697142906
loss: -0.5859705445654451
loss: -0.5859704665043101
loss: -0.5859704970188614
loss: -0.5859512256431421
loss: -0.5859705229799201
loss: -0.5859567177137657
loss: -0.5859705130781434
loss: -0.5859705449508702
loss: -0.5685339004257426
loss: -0.5859602180182945
loss: -0.5859174521818248
loss: -0.5859698606622122
loss: -0.5859105233245061
loss: -0.5859705517106506
loss: -0.5681140395536203
loss: -0.5859678782663134
loss: -0.5859705220176561
loss: -0.5859703797106615
loss: -0.58597

[A[A[A[A[A[A[A[A[A






[A[A[A[A[A[A[A





[A[A[A[A[A[A




[A[A[A[A[A



[A[A[A[A


[A[A[A

[A[A
                                                     

KeyboardInterrupt: 

In [23]:
ds_mean = ds2.sel(mode=0).mean('random_state').mean('rep').mean('px').mean('pya').mean('pyb').mean('ax').mean('aya').mean('ayb').mean('n_pc_skip').mean('n_components_a')

In [24]:
ds_mean