# Scaled Robust Sigmoid normalization (SRS)

In this script, we normalize expression data using Ben Fulcher's scaled robust sigmoid normalization (SRS). We load data from all donors in the Allen Human Brain Atlas and normalize them via SRS.


In [1]:
""" Initially, import needed libraries and define some enviromnent variables. """

import pandas as pd
import numpy as np
import math
import seaborn as sns
import matplotlib.pyplot as plt

import sys
import os
sys.path.insert(0, os.path.abspath('/home/mike/projects/PyGEST'))

# Explicitly specify directories up front.
base_dir = '.'
ge_dir = '/data'

# Get Call data from all subjects and combine it all.
subs = ['H03511009', 'H03511012', 'H03511015', 'H03511016', 'H03512001', 'H03512002', ]

sub_samples = {}
srs_list = []


In [34]:
import pygest as ge
data = ge.Data(ge_dir)

2019-12-05 19:35:20 [INFO] | PyGEST has initialized logging, and is running on host 'cardano'
2019-12-05 19:35:20 [INFO] | Found 9 donors in /data/sourcedata/participants.tsv


In [2]:
""" Define a useful formatter for time-stamping. """

import datetime

def now_string():
    return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

In [3]:
call_list = []
expr_list = []
for sub in subs:
    # Load data for each sub
    print("Reading {}'s data.".format(sub))
    annot = pd.read_csv(os.path.join(ge_dir, "sourcedata/sub-{}/expr/SampleAnnot.csv".format(sub)))
    calls = pd.read_csv(os.path.join(ge_dir, "sourcedata/sub-{}/expr/PACall.csv".format(sub)),
                        header=None, index_col=0)
    expr = pd.read_csv(os.path.join(ge_dir, "sourcedata/sub-{}/expr/MicroarrayExpression.csv".format(sub)),
                       header=None, index_col=0)
    
    # Store a reference to each dataframe
    sub_samples[sub] = list(annot['well_id'])
    print(annot.groupby('slab_type')['slab_type'].count())
    calls.columns = sub_samples[sub]
    call_list.append(calls)
    # call_list.append(calls.loc[:, annot[annot['slab_type'] == 'CX']['well_id']])
    expr.columns = sub_samples[sub]
    print("  {:,} probes by {:,} samples".format(expr.shape[0], expr.shape[1]))
    expr_list.append(expr)
    
    # Determine quartile values, intra-quartile range
    ravelled = expr.values.ravel()
    q75, q25 = np.percentile(ravelled, [75, 25])
    print("  {} expression has mean of {:0.3f}, sd of {:0.3f}, iqr of ({:0.2f} to {:0.2f}) = {:0.3f}".format(
        sub, np.mean(ravelled), np.std(ravelled), q25, q75, q75 - q25
    ))


Reading H03511009's data.
slab_type
BS     26
CB     42
CX    295
Name: slab_type, dtype: int64
  58,692 probes by 363 samples
  H03511009 expression has mean of 5.234, sd of 2.960, iqr of (2.40 to 7.32) = 4.924
Reading H03511012's data.
slab_type
BS     80
CB     48
CX    401
Name: slab_type, dtype: int64
  58,692 probes by 529 samples
  H03511012 expression has mean of 5.248, sd of 3.049, iqr of (2.43 to 7.44) = 5.008
Reading H03511015's data.
slab_type
BS     79
CB     62
CX    329
Name: slab_type, dtype: int64
  58,692 probes by 470 samples
  H03511015 expression has mean of 5.237, sd of 3.157, iqr of (2.38 to 7.49) = 5.108
Reading H03511016's data.
slab_type
BS     59
CB     80
CX    362
Name: slab_type, dtype: int64
  58,692 probes by 501 samples
  H03511016 expression has mean of 5.228, sd of 3.109, iqr of (2.56 to 7.42) = 4.863
Reading H03512001's data.
slab_type
BS    154
CB     53
CX    739
Name: slab_type, dtype: int64
  58,692 probes by 946 samples
  H03512001 expression ha

In [4]:
""" Concat call and expr data together into a single dataframe for each data type. """

calls = pd.concat(call_list, axis=1)
call_stats = pd.DataFrame(index=calls.index, columns=['num_called', ],
    data = calls.apply(sum, axis=1),
)
call_stats['pct_called'] = call_stats['num_called'] / len(calls.columns)

expression = pd.concat(expr_list, axis=1)

print("Calls is {:,} x {:,} 'CX' samples. Expression is all {:,} x {:,}".format(
    calls.shape[0], calls.shape[1], expression.shape[0], expression.shape[1],
))

Calls is 58,692 x 3,702 'CX' samples. Expression is all 58,692 x 3,702


In [5]:
""" Report on distribution of expression values. """

ravelled = expression.loc[:, calls.columns].values.ravel()
q75, q25 = np.percentile(ravelled, [75, 25])
print("Cortical expression has mean of {:0.3f}, sd of {:0.3f}, iqr of ({:0.2f} to {:0.2f}) = {:0.3f}".format(
    np.mean(ravelled), np.std(ravelled), q25, q75, q75 - q25
))

ravelled = expression.values.ravel()
q75, q25 = np.percentile(ravelled, [75, 25])
print("Overall expression has mean of {:0.3f}, sd of {:0.3f}, iqr of ({:0.2f} to {:0.2f}) = {:0.3f}".format(
    np.mean(ravelled), np.std(ravelled), q25, q75, q75 - q25
))


Cortical expression has mean of 5.232, sd of 3.009, iqr of (2.47 to 7.35) = 4.883
Overall expression has mean of 5.232, sd of 3.009, iqr of (2.47 to 7.35) = 4.883


In [6]:
""" Define some useful functions for later. """

def easy_chars(desc, x):
    """ Return a string describing the distribution of x. """
    
    q75, q25 = np.percentile(x, [75, 25])
    return "{}: min {:0.3f} - 25% @ {:0.3f} - mean {:0.3f}, median {:0.3f} - 75% @ {:0.3f} - max {:0.3f}".format(
        desc, min(x), q25, np.mean(x), np.median(x), q75, max(x)
    )

def dist_trio_plot(raw, transformed, scaled):
    """ Return a figure and three-axes tuple of distplots for three matrices """
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 6))
    fig.tight_layout()

    sns.distplot(raw, ax=ax1)
    ax1.set_ylabel("Raw")
    ax1.set_title(easy_chars("Raw", raw))
    
    sns.distplot(transformed, ax=ax2)
    ax2.set_ylabel("RS")
    ax2.set_title(easy_chars("RS", transformed))

    sns.distplot(scaled, ax=ax3)
    ax3.set_ylabel("SRS")
    ax3.set_title(easy_chars("SRS", scaled))
    
    return fig, (ax1, ax2, ax3)


def srs_reg_plot(raw, transformed, scaled, max_points=2**16):
    """ Return a figure with transformed and scaled regressed against raw. """
    
    # Max out at 'max_points' points per vector. It takes way too long with tens of millions of points to plot.
    smallest_length = min(len(raw), len(transformed), len(scaled))
    idx = np.random.choice(range(smallest_length), min(max_points, smallest_length))
    
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    
    sns.regplot(raw[idx], srs_ravelled[idx], color='gray', scatter_kws={'s': 1}, ax=ax)
    sns.regplot(raw[idx], scaled_srs_ravelled[idx], color='blue', scatter_kws={'s': 1}, ax=ax)
    
    return fig, ax
                           

In [7]:
# Now, lets adjust everything
sns.set_style("white")
sns.despine(left=True)
srs_list = []
for i, sub in enumerate(subs):
    print("Starting SRS-transform of {} @ {}".format(sub, now_string()))

    # Pre-compute some necessary stats about our current dataset
    ravelled = expr_list[i].values.ravel()
    min_x = min(ravelled)
    max_x = max(ravelled)
    med_x = np.median(ravelled)
    q75, q25 = np.percentile(ravelled, [75, 25])
    iqr = q75 - q25
    
    # The Scalable Robust Sigmoid normalization
    # Transform it first
    srs_expr = expr_list[i].applymap(lambda x: (1 / (1 + math.exp((-1.35 * (x - med_x)) / iqr))))
    srs_ravelled = srs_expr.values.ravel()
    min_srs = min(srs_ravelled)
    max_srs = max(srs_ravelled)
    
    # Then scale it to fit between 0 and 1
    srs_expr = srs_expr.applymap(lambda x: (x - min_srs) / (max_srs - min_srs))
    srs_list.append(srs_expr)
    scaled_srs_ravelled = srs_expr.values.ravel()
    
    print("  Originally, {} has [{:,} x {:,}] values ({:,} edges). SRS has [{:,} x {:,}] ({:,})".format(
        sub,
        expr_list[i].shape[0], expr_list[i].shape[1], len(ravelled),
        srs_expr.shape[0], srs_expr.shape[1], len(scaled_srs_ravelled),
    ))
    
    # Plot the distributions for comparisons
    fig, axes = dist_trio_plot(ravelled, srs_ravelled, scaled_srs_ravelled)
    print("  " + easy_chars("Raw", ravelled))
    print("  " + easy_chars(" RS", srs_ravelled))
    print("  " + easy_chars("SRS", scaled_srs_ravelled))
    fig.suptitle(sub)
    fig.savefig(os.path.join(".", "{}_dists.png".format(sub)))
    plt.close(fig)
    print("      {}: wrote {}_dists.png".format(now_string(), sub))
    
    # Plot the raw vs srs data, but only a sample of a million points each (out of tens of millions)
    fig, ax = srs_reg_plot(ravelled, srs_ravelled, scaled_srs_ravelled)
    fig.suptitle(sub)
    fig.savefig(os.path.join(".", "{}_regress.png".format(sub)))
    plt.close(fig)
    print("      {}: wrote {}_regress.png".format(now_string(), sub))

    # Save the data for later use
    # srs_expr.to_pickle(os.path.join(ge_dir, "cache/{}-exprsrs.df".format(sub)))
    # srs_expr.to_csv(os.path.join(ge_dir, "sourcedata/sub-{}/expr/ExprSRS.csv".format(sub)), header=False)

expression_srs = pd.concat(srs_list, axis=1)
expression_srs.to_pickle(os.path.join(ge_dir, "cache/expression-srs.df"))
print("Overall, EXPR has [{:,} x {:,}] values ({:,} edges). SRS has [{:,} x {:,}] ({:,})".format(
    expression.shape[0], expression.shape[1], len(expression.values.ravel()),
    expression_srs.shape[0], expression_srs.shape[1], len(expression_srs.values.ravel()),
))


Starting SRS-transform of H03511009 @ 2019-12-05 14:48:21
  Originally, H03511009 has [58,692 x 363] values (21,305,196 edges). SRS has [58,692 x 363] (21,305,196)
  Raw: min 1.408 - 25% @ 2.401 - mean 5.234, median 4.986 - 75% @ 7.325 - max 18.586
   RS: min 0.273 - 25% @ 0.330 - mean 0.511, median 0.500 - 75% @ 0.655 - max 0.977
  SRS: min 0.000 - 25% @ 0.081 - mean 0.338, median 0.323 - 75% @ 0.543 - max 1.000
      2019-12-05 14:49:29: wrote H03511009_dists.png
      2019-12-05 14:49:37: wrote H03511009_regress.png
Starting SRS-transform of H03511012 @ 2019-12-05 14:49:37
  Originally, H03511012 has [58,692 x 529] values (31,048,068 edges). SRS has [58,692 x 529] (31,048,068)
  Raw: min 1.101 - 25% @ 2.429 - mean 5.248, median 5.103 - 75% @ 7.438 - max 18.134
   RS: min 0.254 - 25% @ 0.327 - mean 0.505, median 0.500 - 75% @ 0.652 - max 0.971
  SRS: min 0.000 - 25% @ 0.102 - mean 0.350, median 0.343 - 75% @ 0.556 - max 1.000
      2019-12-05 14:51:13: wrote H03511012_dists.png
     

<Figure size 432x288 with 0 Axes>

In [8]:
for i, x in enumerate(srs_list):
    print("{}. {}".format(i, x.shape))

0. (58692, 363)
1. (58692, 529)
2. (58692, 470)
3. (58692, 501)
4. (58692, 946)
5. (58692, 893)


In [9]:
print("Overall, EXPR has [{:,} x {:,}] values ({:,} edges). SRS has [{:,} x {:,}] ({:,})".format(
    expression.shape[0], expression.shape[1], len(expression.values.ravel()),
    expression_srs.shape[0], expression_srs.shape[1], len(expression_srs.values.ravel()),
))


Overall, EXPR has [58,692 x 3,702] values (217,277,784 edges). SRS has [58,692 x 3,702] (217,277,784)


In [10]:
print(expression.shape)
expression

(58692, 3702)


Unnamed: 0_level_0,11281,11305,11289,11335,11319,11263,11326,11318,11310,11302,...,7082,7090,7098,7011,7009,7017,7025,7033,7041,7057
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1058685,3.501077,4.154173,4.103324,3.962034,3.607175,3.397071,4.238925,3.220545,3.347516,4.094927,...,2.938440,2.912970,2.062344,3.655761,2.715837,4.274668,3.759001,4.903928,4.258150,4.079339
1058684,1.757962,1.775427,1.858134,2.655145,1.896092,1.865290,1.659177,1.867601,1.843686,1.618922,...,1.336530,1.338208,1.317621,1.577733,1.677115,1.279049,1.279049,1.279049,1.279049,1.410152
1058683,1.832017,1.975964,2.030700,1.862528,2.190454,2.049008,1.812967,2.142825,2.101910,1.892518,...,1.408085,1.430103,1.416651,1.691274,1.794972,1.338397,1.279049,1.279049,1.279049,1.561965
1058682,4.682878,5.480787,5.136348,5.477912,5.413002,5.105746,4.917350,5.037718,4.954490,5.064677,...,5.375030,5.398929,4.305211,4.235528,4.032939,4.471348,4.053165,4.226554,4.287815,4.940261
1058681,6.198028,6.838466,6.646096,6.745124,6.366537,6.382501,6.277766,6.840333,7.023353,6.435251,...,6.660846,6.435792,6.539584,6.131654,6.471048,6.130729,6.492402,5.759537,5.760315,6.530532
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1071207,5.655763,5.823679,5.300496,5.352596,5.223130,5.374505,5.507863,5.577939,5.118923,5.437017,...,4.938741,5.338355,5.238756,4.274668,4.756083,5.166071,5.537803,5.302679,4.821446,5.491417
1071208,6.084677,5.616215,5.807900,6.074535,4.197354,5.639088,6.039145,5.714244,4.597356,5.314180,...,6.517703,5.998020,7.099504,6.539101,6.328269,6.361749,6.912770,6.973036,6.913132,6.656428
1071209,1.825304,2.078818,2.543738,2.614573,2.089824,1.770973,2.365111,2.042799,1.917748,1.612720,...,1.279049,1.448611,1.279049,1.541985,1.696660,1.557243,1.279049,1.279049,2.076494,1.361727
1071210,5.072246,5.317859,4.854512,4.704538,4.577459,5.316336,5.184707,5.546673,5.204318,5.271943,...,5.757539,2.485872,6.081343,3.880230,5.429549,5.313957,7.088099,7.256872,7.133733,7.068572


In [11]:
print(expression_srs.shape)
expression_srs

(58692, 3702)


Unnamed: 0_level_0,11281,11305,11289,11335,11319,11263,11326,11318,11310,11302,...,7082,7090,7098,7011,7009,7017,7025,7033,7041,7057
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1058685,0.180336,0.242325,0.237441,0.223914,0.190280,0.170643,0.250483,0.154330,0.166046,0.236635,...,0.137263,0.134981,0.061698,0.203297,0.117475,0.262387,0.213042,0.323578,0.260792,0.243579
1058684,0.027639,0.029047,0.035754,0.103407,0.038853,0.036337,0.019726,0.036526,0.034578,0.016527,...,0.004322,0.004449,0.002896,0.022820,0.030609,0.000000,0.000000,0.000000,0.000000,0.009906
1058683,0.033630,0.045416,0.049946,0.036112,0.063317,0.051468,0.032084,0.059308,0.055879,0.038561,...,0.009749,0.011429,0.010402,0.031727,0.039971,0.004463,0.000000,0.000000,0.000000,0.021593
1058682,0.293484,0.371100,0.337628,0.370822,0.364525,0.334648,0.316302,0.328024,0.319919,0.330649,...,0.369564,0.371892,0.265338,0.258609,0.239130,0.281435,0.241069,0.257743,0.263657,0.327126
1058681,0.439950,0.499614,0.481926,0.491059,0.455853,0.457352,0.447491,0.499784,0.516397,0.462299,...,0.492435,0.471430,0.481152,0.442635,0.474739,0.442547,0.476740,0.406896,0.406971,0.480307
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1071207,0.388032,0.404211,0.353597,0.358660,0.346073,0.360788,0.373724,0.380509,0.335931,0.366856,...,0.326978,0.365990,0.356275,0.262387,0.309149,0.349180,0.385404,0.362511,0.315526,0.380895
1071208,0.429184,0.384211,0.402694,0.428218,0.246479,0.386421,0.424845,0.393675,0.285173,0.354927,...,0.479107,0.429857,0.532491,0.481107,0.461300,0.464460,0.515594,0.521074,0.515627,0.492026
1071209,0.033084,0.053951,0.093641,0.099839,0.054869,0.028688,0.078182,0.050951,0.040627,0.016036,...,0.000000,0.012845,0.000000,0.020042,0.032152,0.021226,0.000000,0.000000,0.062866,0.006227
1071210,0.331386,0.355285,0.310183,0.295590,0.283241,0.355137,0.342335,0.377483,0.344243,0.350821,...,0.406703,0.097436,0.437833,0.224548,0.374874,0.363611,0.531466,0.546538,0.535562,0.529709


In [13]:
""" Compare entirety of raw expression to SRS-normalized values. """

print("Starting regression plot of all data at {}".format(now_string()))
n = min(len(expression.values.ravel()), 2**21)
idx = np.random.choice(range(len(expression.values.ravel())), n)
x_vals = expression.values.ravel()[idx]

fig, ax = plt.subplots(1, 1, figsize=(6, 6))
sns.regplot(x=x_vals, y=expression_srs.values.ravel()[idx], scatter_kws={'s': 1}, ax=ax)
ax.set_xlabel("Expression")
ax.set_ylabel("SRS-Normalized")
fig.suptitle("All Expression")
fig.savefig(os.path.join(".", "raw_vs_srs.png"))
plt.close(fig)
print("      wrote raw_vs_srs.png at {}".format(now_string()))


Starting regression plot of all data at 2019-12-05 16:10:20
      wrote raw_vs_srs.png at 2019-12-05 16:14:10


# Adapting split halves

Each split half is loaded directly from file rather than extracted from full expression dataframe. To use split halves, we'll open each one, figure out its contents, and make a comparable file using SRS-normalized data.

In [64]:
""" Functions copied from PyGEST to split the same way, but with pre-determined split lists. """

import csv
from pygest.rawdata.glasser import glasser_parcel_map

def average_expr_per_parcel(wellid_expression, parcel_map):
    """ Average expression values over all wellids in each parcel_map-defined parcel. """

    parcels = pd.DataFrame(
        data={'parcel': [parcel_map[x] for x in wellid_expression.columns]},
        index=wellid_expression.columns
    )
    parcel_means = {}
    for parcel in sorted(list(set(parcel_map.values()))):
        parcel_idx = parcels[parcels['parcel'] == parcel].index
        if len(parcel_idx) > 0:
            parcel_means[parcel] = wellid_expression.loc[:, parcel_idx].mean(axis=1)

    return pd.DataFrame(data=parcel_means)


def split_file_name(d, ext):
    """ Return the filename for a split-half dataframe.
        :param dict d: A dictionary containing the parts of the split-half filename.
        :param str ext: 'csv' for list of wellids, 'df' for dataframes
    """
    if ext == 'df':
        return "parcelby-{parby}_splitby-{splby}.df".format_map(d)
    elif ext == 'csv':
        return "{parby}s_splitby-{splby}.csv".format_map(d)
    else:
        raise KeyError("Split file names only handle 'csv' or 'df' files.")

        
def write_subset(phase, splitby, seed, df_wellid, df_parcel):
    """ Write out three files for samples, expression, and each parcellated expression. """

    d = {'splby': splitby, 'phase': phase, 'seed': seed}
    base_path = os.path.join(
        ge_dir, 'splits', "sub-all_hem-A_samp-glasser_prob-fornito",
        "batch-{}{:05}".format(phase, seed),
    )
    os.makedirs(os.path.abspath(base_path), exist_ok=True)

    def write_a_split(df, parcelby):
        df.to_pickle(os.path.join(base_path, "parcelby-{}_splitby-{}.srs.df".format(parcelby, splitby)))
        print("  final {}-split {} set is {} probes x {} {}s, {} {} wellids.".format(
            splitby, phase, df.shape[0], df.shape[1], parcelby,
            'from' if splitby == 'wellid' else 'comprising', len(df.columns)
        ))

    write_a_split(df_wellid, 'wellid')
    write_a_split(df_parcel, 'glasser')


In [84]:
""" Cycle through split halves. """

import os
import pandas as pd

fornito_probes = sorted(list(data.probes("fornito").index))

for split_seed in range(200, 216):
    for phase in ["train", "test", ]:
        split_path = os.path.join(ge_dir, "splits", "sub-all_hem-A_samp-glasser_prob-fornito", "batch-{}00{}".format(phase, split_seed))
        for split in ["wellid", "glasser", ]:
            print("Split {}, {} phase, by {}".format(split_seed, phase, split))
            wellid_file = "wellids_splitby-{}.csv".format(split)
            wellids = sorted(list(pd.read_csv(os.path.join(split_path, wellid_file), header=None).values.ravel()))
            print("  {:,} wellids".format(len(wellids)))
            expr = expression_srs.loc[fornito_probes, wellids]
            expr.index.rename("probe_id", inplace=True)
            expr_parcellated = average_expr_per_parcel(expr, glasser_parcel_map)
            write_subset(phase, split, split_seed, expr, expr_parcellated)


Split 200, train phase, by wellid
  640 wellids
  final wellid-split train set is 15745 probes x 640 wellids, from 640 wellids.
  final wellid-split train set is 15745 probes x 162 glassers, from 162 wellids.
Split 200, train phase, by glasser
  724 wellids
  final glasser-split train set is 15745 probes x 724 wellids, comprising 724 wellids.
  final glasser-split train set is 15745 probes x 88 glassers, comprising 88 wellids.
Split 200, test phase, by wellid
  640 wellids
  final wellid-split test set is 15745 probes x 640 wellids, from 640 wellids.
  final wellid-split test set is 15745 probes x 160 glassers, from 160 wellids.
Split 200, test phase, by glasser
  556 wellids
  final glasser-split test set is 15745 probes x 556 wellids, comprising 556 wellids.
  final glasser-split test set is 15745 probes x 89 glassers, comprising 89 wellids.
Split 201, train phase, by wellid
  640 wellids
  final wellid-split train set is 15745 probes x 640 wellids, from 640 wellids.
  final wellid-s

In [85]:
""" Plot a few of the files, old vs new, to ensure new files appear as expected. """

import pickle

dfa = None
dfb = None

test_path = "/data/splits/sub-all_hem-A_samp-glasser_prob-fornito/batch-train00207"
for p in ["wellid", "glasser" ]:
    for s in ["wellid", "glasser", ]:
        with open(os.path.join(test_path, "parcelby-{}_splitby-{}.df".format(p, s)), "rb") as f:
            dfa = pickle.load(f)
            dfa = dfa.sort_index()
            dfa = dfa.loc[:, sorted(list(dfa.columns))]
        with open(os.path.join(test_path, "parcelby-{}_splitby-{}.srs.df".format(p, s)), "rb") as f:
            dfb = pickle.load(f)
            dfb = dfb.sort_index()
            dfb = dfb.loc[:, sorted(list(dfb.columns))]
        idx = np.random.choice(range(min(len(dfa), len(dfb))), min(len(dfa), len(dfb), 2**20))
        fig, ax = plt.subplots()
        sns.regplot(dfa.values.ravel()[idx], dfb.values.ravel()[idx], ax=ax)
        fig.savefig(os.path.join(test_path, "p-{}_s-{}.png".format(p, s)))
        plt.close(fig)
