In [41]:
import pandas as pd

In [42]:
from scipy import stats
import numpy as np

def calculate_d_prime(hits, misses, false_alarms, correct_rejections):
    """
    Calculate d-prime (d') from signal detection theory measures.
    
    Parameters:
    -----------
    hits: int
        Number of correct hits
    misses: int
        Number of misses
    false_alarms: int
        Number of false alarms
    correct_rejections: int
        Number of correct rejections
        
    Returns:
    --------
    float: d-prime value (or NaN if calculation is impossible)
    """
    # Check if we have enough trials
    n_signal = hits + misses
    n_noise = false_alarms + correct_rejections
    
    if n_signal == 0 or n_noise == 0:
        return np.nan
    
    # Calculate hit rate and false alarm rate
    hit_rate = hits / n_signal
    false_alarm_rate = false_alarms / n_noise
    
    # Apply corrections for extreme values (0 or 1)
    if hit_rate == 1:
        hit_rate = 1 - 1/(2*n_signal)
    if hit_rate == 0:
        hit_rate = 1/(2*n_signal)
        
    if false_alarm_rate == 1:
        false_alarm_rate = 1 - 1/(2*n_noise)
    if false_alarm_rate == 0:
        false_alarm_rate = 1/(2*n_noise)
    
    # Convert to z-scores
    z_hit = stats.norm.ppf(hit_rate)
    z_fa = stats.norm.ppf(false_alarm_rate)
    
    # Calculate d-prime
    d_prime = z_hit - z_fa
    
    return d_prime

In [43]:
alltasks = pd.read_csv('/BICNAS2/tuominen/ANM2_SCZ/code/wmtask_allsubjects.csv', index_col=0)
len(alltasks.Subject.unique())

56

In [44]:
# Aggregate data by Subject, run, and BlockCondition
# Sum up hits, misses, false alarms, and correct rejections for each condition
aggregated = alltasks.groupby(['Subject', 'run', 'BlockCondition', 'group']).agg({
    'n_hit': 'sum',
    'n_miss': 'sum',
    'n_fa': 'sum',
    'n_cr': 'sum'
}).reset_index()

# Calculate d-prime for the aggregated data
aggregated['d_prime'] = aggregated.apply(
    lambda row: calculate_d_prime(
        hits=row['n_hit'], 
        misses=row['n_miss'], 
        false_alarms=row['n_fa'], 
        correct_rejections=row['n_cr']
    ), 
    axis=1
)

In [45]:
avg_agg = aggregated.groupby(['Subject','BlockCondition']).mean(numeric_only=True).reset_index()

In [46]:
twoback_avg_agg = avg_agg[avg_agg['BlockCondition'] == '2-back']

In [47]:
twoback_avg_agg.Subject.unique()

array([  2,   4,   5,   6,   8,  10,  13,  17,  20,  21,  22,  23,  24,
        25,  26,  27,  29,  30,  32,  33,  36,  40,  41,  42,  43,  44,
        46,  48,  49,  51,  54,  56,  57,  61, 502, 504, 507, 508, 512,
       513, 515, 517, 523, 527, 529, 530, 531, 532, 533, 534, 535, 537,
       538, 539, 543, 545])

In [48]:
twoback_avg_agg.to_csv('/BICNAS2/tuominen/ANM2_SCZ/code/2back_dprime_allsubjects.csv')