<a href="https://colab.research.google.com/github/jfpva/cmr_strain_synchrony_analysis/blob/main/cmr_strain_synchrony_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CMR Strain Synchrony Analysis 

Perform sychrony analysis on regional strain curves in scientific reports from Circle cvi42.

Based on methods described in 
> Balasubramanian S, Harrild DM, Kerur B, Marcus E, del Nido P, Geva T, Powell AJ. Impact of surgical pulmonary valve replacement on ventricular strain and synchrony in patients with repaired tetralogy of Fallot: a cardiovascular magnetic resonance feature tracking study. J Cardiovasc Magn Reson 2018 201. 2018;20(1):1–11. DOI: [10.1186/S12968-018-0460-0](https://doi.org/10.1186/S12968-018-0460-0)

Code repository: [github.com/jfpva/cmr_strain_synchrony_analysis](https://github.com/jfpva/cmr_strain_synchrony_analysis)

## Upload Data

***user input required***

In [None]:
# based on example in https://colab.research.google.com/notebooks/io.ipynb#scrollTo=vz-jH8T_Uk2c

from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('Uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

## Import Dependencies

In [None]:
import numpy as np
import csv
import codecs

## Define Local Functions

### Parse Data

In [None]:
def parse_data(filename,tablename,ventriclename='Left Ventricle'):
  """Parse regional strain data in Circle cvi42 scietific report.

    Parameters:
      filename (str): name of .txt file with source data
      tablename (str): name of table of regional strain values to parse
      ventriclename (str): either 'Left Ventricle' or 'Right Ventricle'

    Returns:
      data (dict): dictionary of parsed data with following keys
        'Filename' (str): name of source data file
        'Description' (str): name of table of regional strain
        'Time' (list of int:): timing of strain curves
        'Strain' (list of list of floats): strain curves for each regional segment
  """

  data = {}

  # open file, read in and clean up text
  doc = codecs.open(filename, 'rU', 'latin1')
  d = csv.reader((x.replace('\x00', '') for x in doc), delimiter="\t")
  rawvals_list = []
  for row in d:
    rawvals_list.append(row)
  rawvals_list = [x for x in rawvals_list if x != []]

  # filename of source data
  data['Filename'] = filename

  # search through raw data and extra table of interest
  for i in range(len(rawvals_list)):
    if rawvals_list[i][0].startswith(ventriclename) and len(rawvals_list[i])>1:
      if rawvals_list[i][1].startswith(tablename):
        # description of table
        data['Description'] = rawvals_list[i][0]+rawvals_list[i][1]
        # timing data, in milliseconds
        j=i+2
        time=[]
        for k in range(1,len(rawvals_list[j])):
          time.append(int(rawvals_list[j][k]))
        data['Time'] = time
        # strain data, as percent change
        j=i+3
        cols=len(rawvals_list[j])-1;
        rows=16
        strain = [[0 for i in range(cols)] for j in range(rows)]
        for j in range(i+3,i+19):
          segmentNo=int(rawvals_list[j][0])
          for k in range(1,len(rawvals_list[j])):
            strain[segmentNo-1][k-1]=float(rawvals_list[j][k])
        data['Strain'] = strain

  return data

### Print Strain Table 

In [None]:
def print_table(title,time,strain):
  """Print formated table of regional strain data.

      Parameters:
        title (str): description of data
        time (list of int): timing of strain curves
        strain (list of list of float): strain curves for each regional segment

      Returns:
        none
  """
  print(title)
  print('')

  print("{:<8}".format('time (ms)'),end=' ')
  for x in range(len(time)):
    print("{:>7}".format(time[x]),end='  ')
  print('')
  print('')

  segmentnum=0
  for r in range(len(strain)):
    segmentnum+=1
    print("{:<3}".format('seg'),end=' ')
    print("{:>2}".format(segmentnum),end='    ')  
    for c in range(len(strain[r])):
      print("{:>7.3f}".format(strain[r][c]),end='  ')
    print('')
  
  print('')

  return

### Temporally Align Strain Curves

In [None]:
def temporal_align(strain):
  """Temporally align regional strain data so that zero strain values occur at 
     time zero.

      Parameters:
        strain (list of list of float): strain curves for each regional segment

      Returns:
        strain (list of list of float): temporally aligned strain curves
  """

  from statistics import mode

  offsets = []
  for segment in range(len(strain)):
    i = [i for i,x in enumerate(strain[segment]) if x==0]  # find indices of zero values
    if bool(i):
      offsets.append(i[0])  # use index of first zero value only

  shift = -mode(offsets)  # shift by most common zero value index

  for segment in range(len(strain)):  
    strain[segment] = np.ndarray.tolist(np.roll(strain[segment],shift))  # apply shift
  
  return strain

### Find Peak Strain Time

In [None]:
def time_to_peak(time,strain):
  """Find time that peak strain occurs.

      Parameters:
        time (list of int): timing of strain curve
        strain (list of float): strain curve

      Returns:
        tp (float): time of peak strain 
        ip (float): index of peak strain

        tp and ip will be assing nan if strain contains missing values
  """
  if np.any(np.isnan(strain)):
    ip = float('nan')
    tp = float('nan')
  else:
    absstrain = [abs(ele) for ele in strain]
    ip = absstrain.index(max(absstrain))  # index of peak strain
    tp = time[ip]                         # time at peak strain
  return tp, ip

### Calculate Normalized Cross-Correlation

In [None]:
def normalized_cross_correlation(t,s1,s2,rr):
  """Calculate normalized cross-correlation between two strain curves.

      Parameters:
        t (list of int): timing of strain curves
        s1 (list of float): strain curve
        s2 (list of float): strain curve
        rr (int): R-R interval

        t, s1 and s2 are assumed to be the same length

      Returns:
        ncc (list of float): normalized cross-correlation values
        lag (list of int): offset applied to s2 
        
        each element in lag corresponds to elements in ncc
  """
  # normalize signals
  ns1 = (s1 - np.mean(s1)) / (np.std(s1) * len(s1))
  ns2 = (s2 - np.mean(s2)) / (np.std(s2))
  # initalize lists of return values
  ncc = []  # normalized cross-correlation
  lag = []  # offset lag applied to s2
  # calculate correlation for each lag
  for shift in range(-int(np.floor((len(t)-1)/2)),int(np.ceil((len(t)-1)/2))+1):
    if shift >= 0:
      tshift = t[shift]
    else:
      tshift = t[shift] - rr  # shifts < 0 are equivalent to negative time shifts
    lag.append(tshift)
    c = np.correlate(ns1,np.roll(ns2,shift),'valid')
    ncc.append(c[0])
  return ncc, lag

### Calculate Cross-Correlation Delay

In [None]:
def cross_correlation_delay(t,s1,s2,rr): 
  """Calculate cross-correlation delay between two strain curves.

     Cross-correlation delay is the lag corresponding to largest normalized 
     cross-correlation value.

      Parameters:
        t (list of int): timing of strain curves
        s1 (list of float): strain curve
        s2 (list of float): strain curve
        rr (int): R-R interval

        t, s1 and s2 are assumed to be the same length

      Returns:
        ccd (float): cross-correlation delay time
        ncc (list of float): normalized cross-correlation values
        lag (list of int): offset applied to s2 
        
        each element in lag corresponds to elements in ncc
  """
  ncc, lag = normalized_cross_correlation(t,s1,s2,rr)
  if np.any(np.isnan(ncc)):
    ccd = float('nan')
  else:
    ccd = lag[ncc.index(max(ncc))]  # TODO: verify if this should be ccd = lag[ncc.index(max(np.abs(ncc)))]
  return ccd, ncc, lag

### Plot Strain Curves

In [None]:
def plot_strain(data):
  """Generate figure showing strain curves.

      Parameters:
        data (dict): dictionary of data with following keys
          'Filename' (str): name of source data file
          'Description' (str): name of table of regional strain
          'Time' (list of int:): timing of strain curves
          'Strain' (list of list of floats): strain curves for each regional segment
          'Peak Strain Time' (list of int): time that peak strain occurs for each segement
      
      Returns:
        none
  """

  import matplotlib.pyplot as plt

  nseg = len(data['Strain'])

  ncol = 6; # six segments at level basal and mid ventricule 
  nrow = int(np.ceil(nseg/ncol));

  fig, axs = plt.subplots(nrow, ncol, figsize=(12, 8), sharex=True, sharey=True)

  for seg in range(nseg):
    row = int(np.floor(seg/ncol))
    col = seg%ncol
    if not np.isnan(data['Peak Strain Time'][seg]):
      axs[row,col].scatter(data['Peak Strain Time'][seg],data['Strain'][seg][data['Peak Strain Index'][seg]],label='Peak Strain',marker="o",color='red')
    axs[row,col].plot(data['Time'],data['Strain'][seg],label=('Strain'))
    axs[row,col].set_title('seg '+str(seg+1))
    if row == nrow-1:
      axs[row,col].set_xlabel('time (ms)')
    if col == 0:
      axs[row,col].set_ylabel('strain (%)')
    if row == 0 and col == 0:
      axs[row,col].legend()
    axs[row,col].grid(True)

  fig.suptitle(data['Filename']+'\n'+data['Description'],y=1.05)

  fig.tight_layout()
  plt.show()

  return

### Print Results

In [None]:
def print_results(data):
  """Print results of synchrony analysis.

      Parameters:
        data (dict): dictionary of data with following keys
          'Filename' (str): name of source data file
          'Description' (str): name of table of regional strain
          'Time' (list of int:): timing of strain curves
          'Strain' (list of list of floats): strain curves for each regional segment
          'Peak Strain Time' (list of int): time that peak strain occurs for each segement
          'Maximum Difference in Peak Strain Times' (int): as described
          'Standard Deviation of Peak Strain Times' (float): as described
          'Cross-Correlation Delay Pairs' (list of str): description of pairs of opposing segments
          'Cross-Correlation Delays' (list of float): cross-correlation delay for pairs of opposing segments
          'Maximum Absolute Cross-Correlation Delay' (float): as described

      Returns:
        none
  """
  print(data['Description'])
  print('')

  print("{:<21}".format('Segment Number'),end='  ')
  for r in range(len(data['Peak Strain Time'])):
    print("{:>5}".format(r+1),end=' ')
  print('')
  print("{:<21}".format('Peak Strain Time (ms)'),end='  ')
  for r in range(len(data['Peak Strain Time'])):
    print("{:>5}".format(data['Peak Strain Time'][r]),end=' ')
  print('')
  print('')

  print("{:<28}".format('Segment Pairs'),end='  ')
  for r in range(len(data['Cross-Correlation Delay Pairs'])):
    print("{:>8}".format(data['Cross-Correlation Delay Pairs'][r]),end=' ')
  print('')
  print("{:<28}".format('Cross-Correlation Delay (ms)'),end='  ')
  for r in range(len(data['Cross-Correlation Delays'])):
    print("{:>8}".format(data['Cross-Correlation Delays'][r]),end=' ')
  print('')
  print('')

  print("{:<40}".format('Maximum Difference in Peak Strain Times'),end=' = ')
  print("{:<}".format(data['Maximum Difference in Peak Strain Times']),end=' ms')
  print('')
  print("{:<40}".format('Standard Deviation of Peak Strain Times'),end=' = ')
  print("{:<.2f}".format(data['Standard Deviation of Peak Strain Times']),end=' ms')
  print('')
  print("{:<40}".format('Maximum Absolute Cross-Correlation Delay'),end=' = ')
  print("{:<}".format(data['Maximum Absolute Cross-Correlation Delay']),end=' ms')
  print('')

  return

## Unit Tests for Cross-Correlation Functions

In [None]:
import unittest

class TestNormalizedCrossCorrelation(unittest.TestCase):

  def test_autocorrelation(self):
    '''Test case function for autocorrelation'''
    t  = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
    rr = 21
    s  = [0,1,2,5,9,14,23,31,37,39,36,26,11,11,8,3,1,1,2,2,1]
    ncc, lag = normalized_cross_correlation(t,s,s,rr)
    result = ncc[lag.index(lag==0)]
    expected = 1
    self.assertEqual(result, expected)

  def test_normalizedcorrelation(self):
    '''Test case function for normalized correlation'''
    t  = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
    rr = 21
    s1 = [0,1,2,5,9,14,23,31,37,39,36,26,11,11,8,3,1,1,2,2,1]
    s2 = [x*2 for x in s1]
    ncc, lag = normalized_cross_correlation(t,s1,s2,rr)
    result = ncc[lag.index(lag==0)]
    expected = 1
    self.assertEqual(result, expected)

  def test_lag(self):
    '''Test case function for lag'''
    t  = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
    rr = 21
    s1 = [0,1,2,5,9,14,23,31,37,39,36,26,11,11,8,3,1,1,2,2,1]
    s2 = [1,0,1,2,5,9,14,23,31,37,39,36,26,11,11,8,3,1,1,2,2]
    ncc, lag = normalized_cross_correlation(t,s1,s2,rr)
    result = ncc[lag.index(-1)]
    expected = 1
    self.assertEqual(result, expected)

unittest.main(argv=[''], verbosity=2, exit=False)

## Parse Data

In [None]:
data = []  # list of dictionaries 

ventriclename = 'Left Ventricle'

for filename in list(uploaded.keys()):

  print(filename)
  print('')

  tablename = 'AHA Diagram Data - 2D Short Axis Results -  Radial Strain (%)'
  data.append( parse_data(filename,tablename,ventriclename) )
  print_table(data[-1]['Description'],data[-1]['Time'],data[-1]['Strain'])
  
  print('--------------------------------------------------------------------------------')
  print('')

  print(filename)
  print('')

  tablename = 'AHA Diagram Data - 2D Short Axis Results -  Circumferential Strain (%)'
  data.append( parse_data(filename,tablename,ventriclename) )
  print_table(data[-1]['Description'],data[-1]['Time'],data[-1]['Strain'])
  
  print('--------------------------------------------------------------------------------')
  print('')

  print(filename)
  print('')

  tablename = 'AHA Diagram Data - 2D Long Axis Results -  Radial Strain (%)'
  data.append( parse_data(filename,tablename,ventriclename) )
  print_table(data[-1]['Description'],data[-1]['Time'],data[-1]['Strain'])

  print('--------------------------------------------------------------------------------')
  print('')

  print(filename)
  print('')

  tablename = 'AHA Diagram Data - 2D Long Axis Results -  Longitudinal Strain (%)'
  data.append( parse_data(filename,tablename,ventriclename) )
  print_table(data[-1]['Description'],data[-1]['Time'],data[-1]['Strain'])

  if filename != list(uploaded.keys())[-1]:
    print('--------------------------------------------------------------------------------')
    print('')

## Process Data



Temporally align strain curves

In [None]:
for i in range(len(data)):
  data[i]['Strain']=temporal_align(data[i]['Strain'])

Calculate peak strain time for each strain curve

In [None]:
for i in range(len(data)):
  tpeak = []
  indpeak = []
  for seg in range(len(data[i]['Strain'])):
    t, ind = time_to_peak(data[i]['Time'],data[i]['Strain'][seg])
    tpeak.append(t)
    indpeak.append(ind)
  data[i]['Peak Strain Time'] = tpeak
  data[i]['Peak Strain Index'] = indpeak

Calculate synchrony metrics

1.   maximum difference in time-to-peak strain among any two segments
2.   standard deviation of the time-to-peak strain values for all segments 
3.   max cross-correlation delay between pairs of opposing segments
        - basal ventricle: 1v4, 2v5, 3v6
        - mid ventricle: 7v10, 8v11, 9v12
        - apical ventricle: 13v15, 14v16
 

In [None]:
for i in range(len(data)):

  # Clear Variables
  t, dt, rr, ccd01v04, ccd02v05, ccd03v05, ccd07v10, ccd08v11, ccd09v12, ccd13v15, ccd14v16 = (None,)*11

  # Timing Values
  t  = data[i]['Time']
  dt = np.mean(np.diff(t))  # mean time interval between cardiac phases
  rr = int(np.round(t[-1]+dt))  # estimated R-R interval
  data[i]['R-R Interval'] = rr 

  # Calculate Maximum Difference in Peak Strain Times
  data[i]['Maximum Difference in Peak Strain Times'] = np.nanmax(data[i]['Peak Strain Time']) - np.nanmin(data[i]['Peak Strain Time'])

  # Calculate Standard Deviation of Peak Strain Times
  data[i]['Standard Deviation of Peak Strain Times'] = np.nanstd(data[i]['Peak Strain Time'])

  # Calculate Cross-Correlation Delay for Basal LV Opposing Pairs
  ccd01v04, ncc01v04, lag01v04 = cross_correlation_delay(t,data[i]['Strain'][0],data[i]['Strain'][3],rr)
  ccd02v05, ncc02v05, lag02v05 = cross_correlation_delay(t,data[i]['Strain'][1],data[i]['Strain'][4],rr)
  ccd03v05, ncc03v06, lag03v06 = cross_correlation_delay(t,data[i]['Strain'][2],data[i]['Strain'][5],rr)

  # Calculate Cross-Correlation Delay for Mid LV Opposing Pairs
  ccd07v10, ncc07v10, lag07v10 = cross_correlation_delay(t,data[i]['Strain'][6],data[i]['Strain'][9],rr)
  ccd08v11, ncc08v11, lag08v11 = cross_correlation_delay(t,data[i]['Strain'][7],data[i]['Strain'][10],rr)
  ccd09v12, ncc09v12, lag09v12 = cross_correlation_delay(t,data[i]['Strain'][8],data[i]['Strain'][11],rr)

  # Calculate Cross-Correlation Delay for Apical LV Opposing Pairs
  ccd13v15, ncc13v15, lag13v15 = cross_correlation_delay(t,data[i]['Strain'][12],data[i]['Strain'][14],rr)
  ccd14v16, ncc14v16, lag14v16 = cross_correlation_delay(t,data[i]['Strain'][13],data[i]['Strain'][15],rr)

  # Calculate Calculate Maximum Cross-Correlation Delay
  data[i]['Cross-Correlation Delays'] = [ccd01v04, ccd02v05, ccd03v05, ccd07v10, ccd08v11, ccd09v12, ccd13v15, ccd14v16]
  data[i]['Cross-Correlation Delay Pairs'] = ['01v04', '02v05', '03v05', '07v10', '08v11', '09v12', '13v15', '14v16']
  if np.all(np.isnan(data[i]['Cross-Correlation Delays'])):
    data[i]['Maximum Absolute Cross-Correlation Delay'] = float('nan')
  else:
    data[i]['Maximum Absolute Cross-Correlation Delay'] = np.nanmax(np.abs(data[i]['Cross-Correlation Delays']))

## Summarize Data and Results

In [None]:
for i in range(len(data)):
  
  print(data[i]['Filename'])
  print('')

  print(data[i]['Description'])
  print('')

  plot_strain(data[i])
  print('')
  
  print_results(data[i])
  print('')

  if i < len(data)-1:
    print('--------------------------------------------------------------------------------')
    print('')