In [None]:
cd /home/yuchen/pulse2percept

In [None]:
import pulse2percept as p2p
from pulse2percept.implants import ArgusII
import shapes
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import string
import skimage
from skimage.measure import label, regionprops, regionprops_table
from pulse2percept.viz import plot_argus_phosphenes
import math
from statistics import mean
from sklearn.linear_model import LinearRegression

In [None]:
data = shapes.load_shapes("/home/yuchen/shapes/data/shapes.h5", subjects=['12-005','51-009','52-001'],stim_class=None)
data = data[data['electrode2'] != str()]
data = data[data['stim_class']=='MultiElectrode'].reset_index(drop=True)

In [None]:
data['stim_class'].unique()

In [None]:
count_0 = 0
count_1 = 0
test = []
lst = []
for i in range(len(data)):
    if data.electrode2[i] == str():
        lst.append('SingleElectrode')
    else:
        lst.append(data.stim_class[i])
data['stim_class'] = lst

cate = data[['subject','amp1','freq','electrode1','electrode2','stim_class']].drop_duplicates().reset_index(drop=True)

img = []
for i in range(len(cate)):
    df = data[(data['subject'] == cate['subject'][i]) & 
              (data['amp1'] == cate['amp1'][i]) & 
              (data['electrode1'] == cate['electrode1'][i]) &
              (data['electrode2'] == cate['electrode2'][i]) & 
              (data['stim_class'] == cate['stim_class'][i])
             ].reset_index(drop=True)
    if df.empty:
        img.append(np.zeros((384, 512)))
    else:
        if df.electrode2[0] == str():  # single electrode 
            pass
        elif df.electrode2[0] != str():  # double electrode
            phosphene1 = np.mean([p2p.utils.images.center_image(skimage.measure.label(image) == 1) for image in df.image], axis=0)
            phosphene2 = np.mean([p2p.utils.images.center_image(skimage.measure.label(image) == 2) for image in df.image], axis=0)
            phosphene1_centroid = np.mean([regionprops(skimage.measure.label(image>0))[0].centroid for image in df.image], axis=0)
            if sum(sum(phosphene2)) != 0:
                count = (0,0)
                phosphene2_centroid = 0
                for image in df.image:
                    if len(regionprops(skimage.measure.label(image>0))) >= 2:
                        count = (count[0] + regionprops(skimage.measure.label(image>0))[1].centroid[0],count[1] + regionprops(skimage.measure.label(image>0))[1].centroid[1]) 
                        phosphene2_centroid += 1
                phosphene2_centroid = (count[0] / phosphene2_centroid,count[1] / phosphene2_centroid)
            y_center, x_center = phosphene1.shape[0] / 2, phosphene1.shape[1] / 2
            phosphene1 = p2p.utils.images.shift_image(phosphene1, phosphene1_centroid[1] - x_center, phosphene1_centroid[0] - y_center)
            if sum(sum(phosphene2)) != 0:
                phosphene2 = p2p.utils.images.shift_image(phosphene2, phosphene2_centroid[1] - x_center, phosphene2_centroid[0] - y_center)
                image_temp = phosphene1 + phosphene2
            else:
                image_temp = phosphene1
            test.append(image_temp)
            img.append(image_temp)
            
cate['avg_image'] = img

In [None]:
cate_temp = cate[cate['subject'] == '12-005']

In [None]:
distance = [0,0,0,2,7,7,0,0,5,0,
                0,0,4,11,13,15,15,3,7,4,
                0,0,15,16,17,17,19,16,9,4,
                0,0,16,19,15,17,22,25,13,10,
                0,0,8,15,14,13,17,23,14,2,
                0,0,10,15,13,12,9,11,5,-1]

threshold = [20,24,52,81,145,129,153,97,81,65,
                 24,40,65,125,233,169,161,169,65,24,
                 36,0,145,218,177,145,117,97,0,36,
                 20,0,169,185,169,145,97,81,81,40,
                 28,36,93,169,169,169,129,93,36,44,
                 0,32,73,93,125,129,85,48,56,36]

thickness = [19,18,23,19,17,18,18,24,20,22,
             22,24,19,15,16,14,14,20,24,21,
             24,26,16,13,16,16,16,15,19,23,
             23,27,15,17,17,17,12,9,18,19,
             25,28,22,17,19,20,15,9,18,24,
             23,27,18,17,16,21,22,22,26,-1]


lst_d = []
lst_e = []
lst_s = []
subject_list = ['12-005']
data = data[data.subject == "S2"].reset_index(drop=True)

for subj in range(len(subject_list)):
    s2 = shapes.subject_params[subject_list[subj]]
    implant,model = shapes.model_from_params(s2)
    for i in string.ascii_uppercase[0:6]: 
        for j in range(1,11):
            electrode = i + str(j)
            lst_d.append(math.sqrt(implant[electrode].x**2 +implant[electrode].y**2 ))
            lst_e.append(electrode)
            lst_s.append('S' + str(subj+2))
            
df = pd.DataFrame({'electrode':lst_e,
                   'subject':lst_s,
                   'distance':distance,
                   'threshold':threshold,
                  'distance_to_fovea':lst_d })

In [None]:
cate_temp[:10]

In [None]:
lst_dti = []
lst_th = []
lst_dtf = []

for i in range(len(cate_temp)):
    e1 = cate_temp['electrode1'][i]
    e2 = cate_temp['electrode2'][i]    
    lst_dti.append(mean([df[df['electrode'] == e1]['distance'].tolist()[0], 
                         df[df['electrode'] == e2]['distance'].tolist()[0]]))
    lst_th.append(mean([df[df['electrode'] == e1]['threshold'].tolist()[0], 
                         df[df['electrode'] == e2]['threshold'].tolist()[0]]))
    lst_dtf.append(mean([df[df['electrode'] == e1]['distance_to_fovea'].tolist()[0], 
                         df[df['electrode'] == e2]['distance_to_fovea'].tolist()[0]]))
    
cate_temp['lst_dti'] = lst_dti
cate_temp['lst_dtf'] = lst_dtf

In [None]:
cate_temp.groupby(by=['electrode1','electrode2','amp1']).count()

In [None]:
cate_temp_temp = cate_temp.copy()
cate_temp_temp = cate_temp_temp.sort_values(by=['electrode1','electrode2','amp1']).reset_index(drop=True)
for i in range(len(cate_temp_temp)):
    plt.figure()
    plt.title(str(cate_temp_temp['electrode1'][i])+' '+str(cate_temp_temp['electrode2'][i])+' '+str(cate_temp_temp['amp1'][i]))
    plt.imshow(cate_temp_temp['avg_image'][i],cmap='gray')

In [None]:
cate_temp_temp = cate_temp[(cate_temp['amp1'] == 2)& (cate_temp['lst_dtf']>=1650) & (cate_temp['lst_dtf']<=3250)]

cate_temp_temp = cate_temp_temp.sort_values(by=['amp1','lst_dti']).reset_index(drop=True)
for i in range(len(cate_temp_temp)):
    plt.figure()
    plt.title(str(cate_temp_temp['electrode1'][i])+' '+str(cate_temp_temp['electrode2'][i])+
              ' '+str(cate_temp_temp['amp1'][i]) + ' ' + str(cate_temp_temp['lst_dti'][i]))
    plt.imshow(cate_temp_temp['avg_image'][i],cmap='gray')

In [None]:
cate_temp = cate[cate['subject']=='51-009'].reset_index(drop=True)

subject_list = ['51-009']
# data = data[data.subject == "S3"].reset_index(drop=True)
lst_d = []
lst_e = []
lst_s = []
for subj in range(len(subject_list)):
    s2 = shapes.subject_params[subject_list[subj]]
    implant,model = shapes.model_from_params(s2)
    for i in string.ascii_uppercase[0:6]: 
        for j in range(1,11):
            electrode = i + str(j)
            lst_d.append(math.sqrt(implant[electrode].x**2 +implant[electrode].y**2 ))
            lst_e.append(electrode)
            lst_s.append('S3')
df = pd.DataFrame({'electrode':lst_e,
                   'subject':lst_s,
                  'distance_to_fovea':lst_d })

lst_dtf = []

for i in range(len(cate_temp)):
    e1 = cate_temp['electrode1'][i]
    e2 = cate_temp['electrode2'][i]    
    lst_dtf.append(mean([df[df['electrode'] == e1]['distance_to_fovea'].tolist()[0], 
                         df[df['electrode'] == e2]['distance_to_fovea'].tolist()[0]]))
    
cate_temp['lst_dtf'] = lst_dtf

In [None]:
cate_temp_temp = cate_temp[(cate_temp['amp1'] == 2)]

cate_temp_temp = cate_temp_temp.sort_values(by=['amp1','lst_dtf']).reset_index(drop=True)
for i in range(len(cate_temp_temp)):
    plt.figure()
    plt.title(str(cate_temp_temp['electrode1'][i])+' '+str(cate_temp_temp['electrode2'][i])+
              ' '+str(cate_temp_temp['amp1'][i]) + ' ' + str(round(cate_temp_temp['lst_dtf'][i],3)))
    plt.imshow(cate_temp_temp['avg_image'][i],cmap='gray')

In [None]:
data = shapes.load_shapes("/home/yuchen/shapes/data/shapes.h5", subjects=['12-005','51-009','52-001'],stim_class=None)
data = data[data['electrode2'] == str()].reset_index(drop=True)

xrange = []
yrange = []
for i in data.groupby(['subject']).count().reset_index().amp1:
    print(i)
xrange.extend([(-30,30)]*1032) 
xrange.extend([(-32.5,32.5)]*819) 
xrange.extend([(-32,32)]*875) 
yrange.extend([(-22.5,22.5)] * 1032)
yrange.extend([(-24.4,24.4)] * 819)
yrange.extend([(-24,24)] * 875)

data['xrange'] = xrange
data['yrange'] = yrange
data = data.replace({'12-005':'S2', '51-009':'S3', '52-001':'S4'})
data = data.rename(columns={"electrode1": "electrode",'amp1':'amp'})
data = data[['subject','freq','electrode','amp','xrange','yrange','image']]

distance = [0,0,0,2,7,7,0,0,5,0,
                0,0,4,11,13,15,15,3,7,4,
                0,0,15,16,17,17,19,16,9,4,
                0,0,16,19,15,17,22,25,13,10,
                0,0,8,15,14,13,17,23,14,2,
                0,0,10,15,13,12,9,11,5,-1]

threshold = [20,24,52,81,145,129,153,97,81,65,
                 24,40,65,125,233,169,161,169,65,24,
                 36,0,145,218,177,145,117,97,0,36,
                 20,0,169,185,169,145,97,81,81,40,
                 28,36,93,169,169,169,129,93,36,44,
                 0,32,73,93,125,129,85,48,56,36]

thickness = [19,18,23,19,17,18,18,24,20,22,
             22,24,19,15,16,14,14,20,24,21,
             24,26,16,13,16,16,16,15,19,23,
             23,27,15,17,17,17,12,9,18,19,
             25,28,22,17,19,20,15,9,18,24,
             23,27,18,17,16,21,22,22,26,-1]


lst_d = []
lst_e = []
lst_s = []
subject_list = ['12-005']
data = data[data.subject == "S2"].reset_index(drop=True)

for subj in range(len(subject_list)):
    s2 = shapes.subject_params[subject_list[subj]]
    implant,model = shapes.model_from_params(s2)
    for i in string.ascii_uppercase[0:6]: 
        for j in range(1,11):
            electrode = i + str(j)
            lst_d.append(math.sqrt(implant[electrode].x**2 +implant[electrode].y**2 ))
            lst_e.append(electrode)
            lst_s.append('S' + str(subj+2))
            
df = pd.DataFrame({'electrode':lst_e,
                   'subject':lst_s,
                   'distance':distance,
                   'threshold':threshold,
                  'distance_to_fovea':lst_d })

df = df.merge(data, how = 'inner', on = ['electrode','subject'])
argus = ArgusII(x=-1896, y =-542, rot=-44, eye='RE')

In [None]:
fig, axes = plt.subplots(ncols=3, nrows = 2, figsize=(34, 16))

for i in range(2):
    for j in range(3):
        axes[i,j].set_xticks([])
        axes[i,j].set_yticks([])

df_temp = df[(df.freq==20) & (df.amp==1.5) ]
plot_argus_phosphenes(df_temp, argus,scale=1,ax = axes[0,0])
df_temp = df[(df.freq==20) & (df.amp==2) ]
plot_argus_phosphenes(df_temp, argus,scale=1,ax = axes[0,1])
df_temp = df[(df.freq==20) & (df.amp==5) ]
plot_argus_phosphenes(df_temp, argus,scale=1,ax = axes[0,2])

axes[0,0].set_title('Amplitude = 1.5xTh',size=22)
axes[0,1].set_title('Amplitude = 2xTh',size=22)
axes[0,2].set_title('Amplitude = 5xTh',size=22)

df_temp = df[(df.freq==6) & (df.amp==1.5)]
plot_argus_phosphenes(df_temp, argus,scale=1,ax = axes[1,0])
df_temp = df[(df.freq==40) & (df.amp==1.5)]
plot_argus_phosphenes(df_temp, argus,scale=1,ax = axes[1,1])
df_temp = df[(df.freq==60) & (df.amp==1.5)]
plot_argus_phosphenes(df_temp, argus,scale=1,ax = axes[1,2])

axes[1,0].set_title('Frequency = 6Hz',size=22)
axes[1,1].set_title('Frequency = 40Hz',size=22)
axes[1,2].set_title('Frequency = 60Hz',size=22)


In [None]:
fig.savefig('/home/yuchen/paper/6a_new. S12-005 Drawings.pdf', transparent=True)

In [None]:
data = shapes.load_shapes("/home/yuchen/shapes/data/shapes.h5", subjects=['12-005','51-009','52-001'],stim_class=None)
data = data[data['electrode2'] == str()].reset_index(drop=True)

xrange = []
yrange = []
for i in data.groupby(['subject']).count().reset_index().amp1:
    print(i)
xrange.extend([(-30,30)]*1032) 
xrange.extend([(-32.5,32.5)]*819) 
xrange.extend([(-32,32)]*875) 
yrange.extend([(-22.5,22.5)] * 1032)
yrange.extend([(-24.4,24.4)] * 819)
yrange.extend([(-24,24)] * 875)

data['xrange'] = xrange
data['yrange'] = yrange
data = data.replace({'12-005':'S2', '51-009':'S3', '52-001':'S4'})
data = data.rename(columns={"electrode1": "electrode",'amp1':'amp'})
data = data[['subject','freq','electrode','amp','xrange','yrange','image']]

lst_d = []
lst_e = []
lst_s = []
distance = []
subject_list = ['51-009']
data = data[data.subject == "S3"].reset_index(drop=True)

for subj in range(len(subject_list)):
    s2 = shapes.subject_params[subject_list[subj]]
    implant,model = shapes.model_from_params(s2)
    for i in string.ascii_uppercase[0:6]: 
        for j in range(1,11):
            electrode = i + str(j)
            distance.append(0)
            lst_d.append(math.sqrt(implant[electrode].x**2 +implant[electrode].y**2 ))
            lst_e.append(electrode)
            lst_s.append('S3')
    if subj == '12-005':
        electrode = lst_e[:-1]
    elif subj == '52-001':
        electrode = ['F1','F2','F4','F5','F6','F7','F8','F9','F10',
               'E1','E2','E3','E4','E5','E6','E7','E8','E9','E10', 
               'D1','D2','D3','D4','D5','D6','D7','D8','D9','D10', 
               'C3','C4','C5','C6','C7','C8','C9','C10', 
               'B1','B3','B4','B5','B6','B7','B8','B9','B10', 
               'A1','A2','A3','A4','A5','A6','A7','A8','A9','A10']
    else:
        electrode = ['F1','F2','F3','F4','F5','F6','F7','F8','F9','F10',
               'E1','E2','E3','E4','E5','E6','E7','E8','E9','E10', 
               'D1','D2','D3','D4','D5','D6','D7','D8','D9','D10', 
               'C2','C3','C4','C5','C6','C7','C8','C9','C10', 
               'B3','B4','B5','B6','B7','B8','B9','B10', 
               'A3','A4','A5','A6','A7','A8','A9','A10']
        
            
df = pd.DataFrame({'electrode':lst_e,
                   'subject':lst_s,
                   'distance':distance,
                  'distance_to_fovea':lst_d })

df = df[df['electrode'].isin(electrode)]
df = df.merge(data, how = 'inner', on = ['electrode','subject'])
argus = ArgusII(x=-1203, y =280, rot=-35, eye='RE')

In [None]:
df['freq'].unique().tolist(),df['amp'].unique().tolist()

In [None]:
fig, axes = plt.subplots(ncols=3, nrows = 2, figsize=(34, 16))

for i in range(2):
    for j in range(3):
        axes[i,j].set_xticks([])
        axes[i,j].set_yticks([])

df_temp = df[(df.freq==20) & (df.amp==1.5) ]
plot_argus_phosphenes(df_temp, argus,scale=1.5,ax = axes[0,0])
df_temp = df[(df.freq==20) & (df.amp==2) ]
plot_argus_phosphenes(df_temp, argus,scale=1.5,ax = axes[0,1])
df_temp = df[(df.freq==20) & (df.amp==5) ]
plot_argus_phosphenes(df_temp, argus,scale=1.5,ax = axes[0,2])

axes[0,0].set_title('Amplitude = 1.5xTh',size=22)
axes[0,1].set_title('Amplitude = 2xTh',size=22)
axes[0,2].set_title('Amplitude = 5xTh',size=22)

df_temp = df[(df.freq==6) & (df.amp==1.5)]
plot_argus_phosphenes(df_temp, argus,scale=1.5,ax = axes[1,0])
df_temp = df[(df.freq==40) & (df.amp==1.5)]
plot_argus_phosphenes(df_temp, argus,scale=1.5,ax = axes[1,1])
df_temp = df[(df.freq==60) & (df.amp==1.5)]
plot_argus_phosphenes(df_temp, argus,scale=1.5,ax = axes[1,2])

axes[1,0].set_title('Frequency = 6Hz',size=22)
axes[1,1].set_title('Frequency = 40Hz',size=22)
axes[1,2].set_title('Frequency = 60Hz',size=22)

In [None]:
fig.savefig('/home/yuchen/paper/6b_new. S51-009 Drawings.pdf', transparent=True)

In [None]:
data = shapes.load_shapes("/home/yuchen/shapes/data/shapes.h5", subjects=['12-005','51-009','52-001'],stim_class=None)
data = data[data['electrode2'] == str()].reset_index(drop=True)

xrange = []
yrange = []
for i in data.groupby(['subject']).count().reset_index().amp1:
    print(i)
xrange.extend([(-30,30)]*1032) 
xrange.extend([(-32.5,32.5)]*819) 
xrange.extend([(-32,32)]*875) 
yrange.extend([(-22.5,22.5)] * 1032)
yrange.extend([(-24.4,24.4)] * 819)
yrange.extend([(-24,24)] * 875)

data['xrange'] = xrange
data['yrange'] = yrange
data = data.replace({'12-005':'S2', '51-009':'S3', '52-001':'S4'})
data = data.rename(columns={"electrode1": "electrode",'amp1':'amp'})
data = data[['subject','freq','electrode','amp','xrange','yrange','image']]

lst_d = []
lst_e = []
lst_s = []
distance = []
subject_list = ['52-001']
data = data[data.subject == "S4"].reset_index(drop=True)

for subj in range(len(subject_list)):
    s2 = shapes.subject_params[subject_list[subj]]
    implant,model = shapes.model_from_params(s2)
    for i in string.ascii_uppercase[0:6]: 
        for j in range(1,11):
            electrode = i + str(j)
            distance.append(0)
            lst_d.append(math.sqrt(implant[electrode].x**2 +implant[electrode].y**2 ))
            lst_e.append(electrode)
            lst_s.append('S4')
    if subj == '12-005':
        electrode = lst_e[:-1]
    elif subj == '52-001':
        electrode = ['F1','F2','F4','F5','F6','F7','F8','F9','F10',
               'E1','E2','E3','E4','E5','E6','E7','E8','E9','E10', 
               'D1','D2','D3','D4','D5','D6','D7','D8','D9','D10', 
               'C3','C4','C5','C6','C7','C8','C9','C10', 
               'B1','B3','B4','B5','B6','B7','B8','B9','B10', 
               'A1','A2','A3','A4','A5','A6','A7','A8','A9','A10']
    else:
        electrode = ['F1','F2','F3','F4','F5','F6','F7','F8','F9','F10',
               'E1','E2','E3','E4','E5','E6','E7','E8','E9','E10', 
               'D1','D2','D3','D4','D5','D6','D7','D8','D9','D10', 
               'C2','C3','C4','C5','C6','C7','C8','C9','C10', 
               'B3','B4','B5','B6','B7','B8','B9','B10', 
               'A3','A4','A5','A6','A7','A8','A9','A10']
        
            
df = pd.DataFrame({'electrode':lst_e,
                   'subject':lst_s,
                   'distance':distance,
                  'distance_to_fovea':lst_d })

df = df[df['electrode'].isin(electrode)]
df = df.merge(data, how = 'inner', on = ['electrode','subject'])
argus = ArgusII(x=-1945, y =469, rot=-34, eye='RE')

In [None]:
df['freq'].unique().tolist(),df['amp'].unique().tolist()

In [None]:
df[df['freq']==40.0].amp.unique()

In [None]:
fig, axes = plt.subplots(ncols=3, nrows = 2, figsize=(34, 16))

for i in range(2):
    for j in range(3):
        axes[i,j].set_xticks([])
        axes[i,j].set_yticks([])

df_temp = df[(df.freq==20) & (df.amp==1.5) ]
plot_argus_phosphenes(df_temp, argus,scale=0.7,ax = axes[0,0])
df_temp = df[(df.freq==20) & (df.amp==2) ]
plot_argus_phosphenes(df_temp, argus,scale=0.7,ax = axes[0,1])
df_temp = df[(df.freq==20) & (df.amp==5) ]
plot_argus_phosphenes(df_temp, argus,scale=0.7,ax = axes[0,2])

axes[0,0].set_title('Amplitude = 1.5xTh',size=22)
axes[0,1].set_title('Amplitude = 2xTh',size=22)
axes[0,2].set_title('Amplitude = 5xTh',size=22)

df_temp = df[(df.freq==6) & (df.amp==1.25)]
plot_argus_phosphenes(df_temp, argus,scale=0.7,ax = axes[1,0])
df_temp = df[(df.freq==40) & (df.amp==1.25)]
plot_argus_phosphenes(df_temp, argus,scale=0.7,ax = axes[1,1])
df_temp = df[(df.freq==60) & (df.amp==1.25)]
plot_argus_phosphenes(df_temp, argus,scale=0.7,ax = axes[1,2])

axes[1,0].set_title('Frequency = 6Hz',size=22)
axes[1,1].set_title('Frequency = 40Hz',size=22)
axes[1,2].set_title('Frequency = 60Hz',size=22)

In [None]:
fig.savefig('/home/yuchen/paper/6c_new. S52-001 Drawings.pdf', transparent=True)