In [1]:
import os
import time
from brainscore_vision import score
from brainscore_vision.benchmarks.majajhong2015 import *
from brainscore_vision.benchmarks.majajhong2015.benchmark import load_assembly
from brainscore_vision import benchmark_registry

from brainscore_vision.metrics.internal_consistency.ceiling import (
    _SplitHalvesConsistency, SpearmanBrownCorrection, XarrayDefaults, CrossValidationSingle
)

benchmark_registry

  class Score(DataAssembly):


{'MajajHong2015.V4-pls': <function brainscore_vision.benchmarks.majajhong2015.benchmark.DicarloMajajHong2015V4PLS()>,
 'MajajHong2015.IT-pls': <function brainscore_vision.benchmarks.majajhong2015.benchmark.DicarloMajajHong2015ITPLS()>,
 'MajajHong2015public.V4-pls': <function brainscore_vision.benchmarks.majajhong2015.benchmark.MajajHongV4PublicBenchmark()>,
 'MajajHong2015public.IT-pls': <function brainscore_vision.benchmarks.majajhong2015.benchmark.MajajHongITPublicBenchmark()>,
 'MajajHong2015public.V4-temporal-pls': <function brainscore_vision.benchmarks.majajhong2015.<lambda>()>,
 'MajajHong2015public.IT-temporal-pls': <function brainscore_vision.benchmarks.majajhong2015.<lambda>()>}

In [2]:
import cupy as cp
from scipy.stats import t as t_dist
from pdb import set_trace

def pearsonr_cupy(x, y):
    # Ensure inputs are CuPy arrays
    x = cp.asarray(x.data)
    y = cp.asarray(y.data)
    
    # Check that the lengths match
    if x.size != y.size:
        raise ValueError("Input arrays must have the same length.")

    # Number of observations
    n = x.size

    # Compute means
    mean_x = cp.mean(x)
    mean_y = cp.mean(y)

    # Compute deviations from the mean
    x_dev = x - mean_x
    y_dev = y - mean_y

    # Compute numerator (covariance) and denominator (standard deviations)
    cov = cp.sum(x_dev * y_dev)
    std_x = cp.sqrt(cp.sum(x_dev ** 2))
    std_y = cp.sqrt(cp.sum(y_dev ** 2))

    # Pearson correlation coefficient
    r = cov / (std_x * std_y)

    # Compute the t-statistic
    df = n - 2  # degrees of freedom
    t_stat = r * cp.sqrt(df / ((1 - r ** 2) + 1e-9))  # Add small value to avoid division by zero

    # Use SciPy to compute p-value based on the t-distribution CDF
    p_value = 2 * t_dist.cdf(-abs(cp.asnumpy(t_stat)), df)  # Convert to NumPy for p-value computation

    return cp.asnumpy(r), p_value

In [3]:
import scipy
import numpy as np 
import pandas as pd

from brainscore_core.metrics import Metric, Score
from brainio.assemblies import walk_coords, DataAssembly
from brainscore_vision.metric_helpers.xarray_utils import XarrayCorrelation

from brainscore_vision.metrics.internal_consistency.ceiling import (
    _SplitHalvesConsistency, SpearmanBrownCorrection, XarrayDefaults, CrossValidationSingle
)

class PearsonCorrelation:
    """
    Computes the Pearson r between two halves of an assembly.
    """

    def __init__(self, stimulus_coord=XarrayDefaults.stimulus_coord,
                 neuroid_dim=XarrayDefaults.neuroid_dim, neuroid_coord=XarrayDefaults.neuroid_coord):
        # correlation = scipy.stats.pearsonr
        correlation = lambda x,y: (cp.corrcoef(x.data, y.data)[0,1].item(), None)
        self._correlation = XarrayCorrelation(correlation, correlation_coord=stimulus_coord,
                                              neuroid_coord=neuroid_coord)
        self._neuroid_dim = neuroid_dim

    def __call__(self, half1, half2):
        return self._correlation(half1, half2)

    def aggregate(self, scores):
        return scores.median(dim=self._neuroid_dim)
    
class SplitHalfWrapper:
        def __init__(self, split_coord, consistency_metric: Metric, correction):
            self._split_coord = split_coord
            self._consistency_metric = consistency_metric
            self._correction = correction

        def __call__(self, half1, half2):
            half1, half2 = self._average_repetitions(half1), self._average_repetitions(half2)
            consistency = self._consistency_metric(half1, half2)
            corrected_consistency = self._correction(consistency, n=2)
            return corrected_consistency

        def _average_repetitions(self, assembly):
            repetition_dims = assembly[self._split_coord].dims
            nonrepetition_coords = [coord for coord, dims, values in walk_coords(assembly)
                                    if dims == repetition_dims and coord != self._split_coord]
            average = assembly.multi_groupby(nonrepetition_coords).mean(dim=repetition_dims)
            return average
        
def _average_repetitions(assembly, split_coord):
    repetition_dims = assembly[split_coord].dims
    nonrepetition_coords = [coord for coord, dims, values in walk_coords(assembly)
                            if dims == repetition_dims and coord != split_coord]
    average = assembly.multi_groupby(nonrepetition_coords).mean(dim=repetition_dims)
    return average        

In [4]:
from scipy.stats import pearsonr

def xarray_correlation(prediction, target, neuroid_coord, correlation_coord, correlation=pearsonr):
    # align
    prediction = prediction.sortby([correlation_coord, neuroid_coord])
    target = target.sortby([correlation_coord, neuroid_coord])
    
    # Assert alignment
    assert np.array(prediction[correlation_coord].values == target[correlation_coord].values).all()
    assert np.array(prediction[neuroid_coord].values == target[neuroid_coord].values).all()

    # Compute correlation per neuroid
    neuroid_dims = target[neuroid_coord].dims
    assert len(neuroid_dims) == 1
    correlations = []
    
    for i, coord_value in enumerate(target[neuroid_coord].values):
        target_neuroids = target.isel(**{neuroid_dims[0]: i})  # `isel` is about 10x faster than `sel`
        prediction_neuroids = prediction.isel(**{neuroid_dims[0]: i})
        r, p = correlation(target_neuroids, prediction_neuroids)
        correlations.append(r)

    # Package the result
    result = Score(correlations,
                   coords={coord: (dims, values)
                           for coord, dims, values in walk_coords(target) if dims == neuroid_dims},
                   dims=neuroid_dims)
    
    return result

def xarray_correlation_cp(prediction, target, neuroid_coord, correlation_coord, correlation=pearsonr):
    
    # align
    prediction = prediction.sortby([correlation_coord, neuroid_coord])
    target = target.sortby([correlation_coord, neuroid_coord])
    
    # Assert alignment
    assert np.array(prediction[correlation_coord].values == target[correlation_coord].values).all()
    assert np.array(prediction[neuroid_coord].values == target[neuroid_coord].values).all()

    # Compute correlation per neuroid
    neuroid_dims = target[neuroid_coord].dims
    assert len(neuroid_dims) == 1
    correlations = []
    
    for i, coord_value in enumerate(target[neuroid_coord].values):
        target_neuroids = target.isel(**{neuroid_dims[0]: i})  # `isel` is about 10x faster than `sel`
        prediction_neuroids = prediction.isel(**{neuroid_dims[0]: i})
        r, p = correlation(target_neuroids.data, prediction_neuroids.data)
        correlations.append(r)
        
    # Package the result
    result = Score(correlations,
                   coords={coord: (dims, values)
                           for coord, dims, values in walk_coords(target) if dims == neuroid_dims},
                   dims=neuroid_dims)
    
    return result

In [5]:
# load a subset of the data
region = "IT"
access = "public"
time_interval = None
assembly_repetition = load_assembly(average_repetitions=False, 
                                    region=region, 
                                    access=access, time_interval=time_interval)

# Looks like 148480 trials acros 168 neural sites
assembly_repetition.shape

(148480, 168)

In [6]:
import cupy as cp 
import xarray as xr

In [7]:
# Convert the underlying data to CuPy
data_cupy = assembly_repetition.map_blocks(cp.asarray)
data_cupy.shape

(148480, 168)

In [8]:
data_cupy.device

<CUDA Device 0>

In [9]:
print(f"DataArray type after conversion: {type(data_cupy.data)}")

DataArray type after conversion: <class 'cupy.cuda.memory.MemoryPointer'>


In [10]:
# Convert the underlying data to CuPy while keeping the Xarray structure
data_cupy = assembly_repetition.copy(deep=True)  # Optional: Copy to preserve the original
data_cupy.data = cp.asarray(data_cupy.data)  # Convert to CuPy

# Now, data_cupy is still an Xarray DataArray but using CuPy arrays under the hood
print(f"DataArray type after conversion: {type(data_cupy.data)}")
print(f"Is it still an Xarray DataArray? {isinstance(data_cupy, xr.DataArray)}")

DataArray type after conversion: <class 'cupy.ndarray'>
Is it still an Xarray DataArray? True


In [11]:
data_cupy

In [12]:
split_coord='repetition'
stimulus_coord = 'stimulus_id'
neuroid_dim = 'neuroid'
neuroid_coord = 'neuroid_id'
cross_validation_kwargs=None

consistency_metric = PearsonCorrelation(stimulus_coord=stimulus_coord, 
                                        neuroid_dim=neuroid_dim,
                                        neuroid_coord=neuroid_coord)
aggregate = consistency_metric.aggregate
    
correction = SpearmanBrownCorrection()
_consistency = SplitHalfWrapper(split_coord=split_coord,
                                consistency_metric=consistency_metric, correction=correction)
_aggregate = aggregate
cross_validation_defaults = dict(train_size=0.5, split_coord=split_coord,
                                 stratification_coord=None, unique_split_values=True)
cross_validation_kwargs = {**cross_validation_defaults, **(cross_validation_kwargs or {})}
_cross_validation = CrossValidationSingle(**cross_validation_kwargs)

In [13]:
# assembly_repetition.values, data_cupy.data.device

In [14]:
# result = _cross_validation(assembly_repetition, apply=_consistency, aggregate=_aggregate)
# result

In [15]:
result = _cross_validation(data_cupy, apply=_consistency, aggregate=_aggregate)
result

cross-validation: 100%|██████████| 10/10 [00:37<00:00,  3.72s/it]


In [24]:
# let's manually step through the CrossValidationSingle step
cross_validation_values, splits = _cross_validation._split.build_splits(data_cupy)
cross_validation_values.shape, splits

((51,),
 [(array([ 4, 34, 14, 31, 10, 28, 45, 35, 18, 20, 25,  6,  7, 48,  1, 16,  0,
          15,  5, 11,  9,  8, 12, 43, 37]),
   array([22,  2, 49, 26, 33, 44, 30, 50, 32, 27,  3, 29, 47, 41, 39, 21, 40,
          38, 19, 24, 13, 42, 17, 46, 36, 23])),
  (array([ 5, 27, 40, 37, 36,  3,  9, 22, 39, 47,  7, 23, 34, 25,  4,  0, 15,
          18, 41, 45, 16, 26, 12, 43, 49]),
   array([29, 20, 32, 28, 14, 17, 35, 11,  1, 13, 46, 33, 21, 24, 48, 31, 30,
          50, 42, 10,  2, 44,  8,  6, 19, 38])),
  (array([35, 17, 28, 23, 10, 41, 32, 18,  4, 42, 30, 20, 46, 43,  8, 13,  7,
          11, 49, 27, 12,  2, 50, 21,  6]),
   array([34,  3, 45,  5, 33, 40, 31, 39, 14, 44, 48, 29, 16, 26, 15, 22, 47,
          19,  9,  1, 38,  0, 37, 36, 24, 25])),
  (array([ 9, 30, 15, 27, 24, 34,  6, 43, 45, 14, 40, 19, 41, 21, 25, 32, 10,
          49, 33, 13, 37, 22, 48, 20, 42]),
   array([44, 12,  1, 35, 26, 11, 38, 36, 23,  3, 50,  0, 29, 16, 46, 18, 28,
          17, 31,  8, 47,  4,  5, 39,  7,  2]

In [127]:
from tqdm import tqdm
from brainio.transform import subset
from brainscore_vision.benchmark_helpers.neural_common import average_repetition
from brainscore_vision.metrics import Score
from scipy.stats import pearsonr

assembly = data_cupy
assembly = assembly_repetition
split_scores = []
for split_iterator, (train_indices, test_indices) \
        in tqdm(enumerate(splits), total=len(splits), desc='cross-validation'):
    train_values, test_values = cross_validation_values[train_indices], cross_validation_values[test_indices]
    train = subset(assembly, train_values, dims_must_match=False)
    #train_mask = np.isin(assembly.repetition.values, train_values)
    #train = assembly.values[train_mask,:]

    test = subset(assembly, test_values, dims_must_match=False)
    #test_mask = np.isin(assembly.repetition.values, test_values)
    #test = assembly.values[test_mask,:]
    break
#     split_score = _consistency(train, test)
#     # half1 = _average_repetitions(train, split_coord)
#     # half2 = _average_repetitions(test, split_coord)
    
#     split_score = split_score.expand_dims('split')
#     split_score['split'] = [split_iterator]
#     split_scores.append(split_score)
# split_scores = Score.merge(*split_scores)

cross-validation:   0%|          | 0/10 [00:00<?, ?it/s]


In [128]:
half1 = _average_repetitions(train, split_coord)
half2 = _average_repetitions(test, split_coord)

In [40]:
consistency_metric = PearsonCorrelation(stimulus_coord=stimulus_coord, 
                                        neuroid_dim=neuroid_dim,
                                        neuroid_coord=neuroid_coord)

In [None]:
import time



In [129]:
start = time.time()
result = xarray_correlation(half1, half2, neuroid_coord=neuroid_coord, correlation_coord=stimulus_coord, 
                            correlation = pearsonr )
dur = time.time() - start
print(dur)
result

0.41633081436157227


In [130]:
.4 / .15

2.666666666666667

In [126]:
start = time.time()
result = xarray_correlation_cp(half1, half2, neuroid_coord=neuroid_coord, correlation_coord=stimulus_coord, 
                               correlation = lambda x,y: (cp.corrcoef(x, y)[0,1].item(), None) )
dur = time.time() - start
print(dur)
result

0.15602660179138184


In [103]:
# consistency_metric(half1, half2)

In [123]:
cp.corrcoef(half1.data[:,0], half2.data[:,0])[0,1].item()

0.5560557118368319

In [92]:
half1.data.transpose().shape, half2.data.transpose().shape

((168, 3200), (168, 3200))

In [87]:
half1.data.transpose()

array([[-0.13490714,  0.1940739 ,  0.15228647, ...,  0.4535077 ,
         0.2627261 , -0.5260907 ],
       [-0.43063596,  0.31669247, -0.66462725, ...,  0.366535  ,
        -0.07539596, -0.106332  ],
       [-0.45261225,  0.5510013 , -0.35180023, ...,  0.612802  ,
        -0.6328119 , -0.42061844],
       ...,
       [ 0.240606  ,  0.39146066,  0.5492965 , ...,  0.41977802,
         0.11960585,  0.11060322],
       [ 0.13254511, -0.25623715,  0.54298735, ...,  0.50503   ,
         0.3500974 , -0.24538483],
       [ 0.17550433,  0.22261089,  0.01614855, ...,  0.31118324,
        -0.04281639, -0.15389812]], dtype=float32)