In [3]:
from pulse2percept.models import BiphasicAxonMapModel, AxonMapModel, BiphasicAxonMapSpatial
from pulse2percept.model_selection import ParticleSwarmOptimizer
from pulse2percept.implants import ArgusII
import pulse2percept as p2p
import shapes

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime
import math
from tqdm.notebook import tqdm

In [4]:
mse_params = ['moments_central']
mse_params = ['area', "major_axis_length", "minor_axis_length", 'orientation']
threshold = 1 / np.sqrt(np.exp(1))
loss_fn  = 'r2'
scale_features = False
verbose = False
search_params = {
        'rho': (10, 1000),
        'axlambda': (10, 2500)
    }
n_epochs = 20


In [5]:
# Two functions to alternate between
def update_implant(subject, current_pos, model, df, max_iter=5):
    data = df[df['subject'] == subject]
    implant, _ = shapes.model_from_params(shapes.subject_params[subject], offset=current_pos, biphasic=(type(model) == BiphasicAxonMapModel))
    if type(model) == BiphasicAxonMapModel: 
        estimator = shapes.BiphasicAxonMapEstimator(implant=implant, model=model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
    else:
        estimator = shapes.AxonMapEstimator(implant=implant, model=model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
        
    
    y_image = data['image']
    y = estimator.compute_moments(y_image)
    data[estimator._mse_params] = y
    y_averaged = data.groupby(['amp1', 'freq', 'pdur', 'electrode1'])[estimator._mse_params].mean()
    x = data[['amp1', 'freq', 'pdur', 'electrode1']].drop_duplicates()
    initial_score = estimator.score(x, y_averaged)
    print("starting score: {}, rho:{}, axlambda:{}".format(initial_score, model.rho, model.axlambda))
    

    offsets = [np.array([current_pos[0] + dx, current_pos[1] + dy, current_pos[2] + drot]) 
                              for dx in [-50, -15, -5, 0, 5, 15, 50]
                              for dy in [-50, -15, -5, 0, 5, 15, 50]
                              for drot in [-2, -0.75, -0.25, 0, 0.25, 0.75, 2]]
    # offsets = [np.array([current_pos[0] + dx, current_pos[1] + dy, current_pos[2] + drot]) 
    #                           for dx in [-1, 1]
    #                           for dy in [-1, 1]
    #                           for drot in [-1, 1]]  
    
    def compute_score(offset, model, mse_params, _mse_params, threshold, loss_fn, scale_features, yshape, verbose, x, y_averaged, subject):
        # get a new model and implant
        # print("-", end="")
        implant, new_model = shapes.model_from_params(shapes.subject_params[subject], offset=offset, biphasic=(type(model) == BiphasicAxonMapModel))
        new_model.xystep = 0.5
        new_model.ignore_pickle = True
        for param in search_params.keys():
            setattr(new_model, param, getattr(model, param))
        if type(model) == BiphasicAxonMapModel: 
            estimator = shapes.BiphasicAxonMapEstimator(implant=implant, model=new_model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
        else:
            estimator = shapes.AxonMapEstimator(implant=implant, model=new_model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
        
        estimator._mse_params = _mse_params
        estimator.yshape = yshape
        score = estimator.score(x, y_averaged)
        print(".", end="")
        return score
    
    # extremely messy but neccesary to reconstruct the estimator to avoid mutating the same object in different threads
    scores = p2p.utils.parfor(compute_score, offsets, func_args=[model, mse_params, estimator._mse_params, threshold, 
                                                                 loss_fn, scale_features, estimator.yshape, verbose, x, y_averaged, subject])
    
    print()
    best_score = 99999999999
    best_offset = None
    best_implant = None
    for offset, score in zip(offsets, scores):
        if score < best_score:
            implant, _ = shapes.model_from_params(shapes.subject_params[subject], offset=offset)
            estimator.implant = implant
            best_score = score
            best_offset = offset
            best_implant = implant
    if best_score < initial_score:
        print("Better model found: {}".format(best_offset))
        return best_score, best_offset, best_implant
    print("Couldn't find a better offset")
    implant, _ = shapes.model_from_params(shapes.subject_params[subject], offset=current_pos)
    return initial_score, current_pos, implant
    

def update_params(implant, subject, model, df, max_iter=5): # we can just store 
    # dont search entire range, just area right around each param
    search_params = {
        'rho': (10, 1000),
        'axlambda': (10, 2500)}
    search_params = {param : (max(getattr(model, param) - 30, 10), getattr(model, param) + 30) for param in search_params.keys()}

    data = df[df['subject'] == subject]
    initial_vals = {param : getattr(model, param) for param in search_params.keys()}
    
    if type(model) == BiphasicAxonMapModel: 
        estimator = shapes.BiphasicAxonMapEstimator(implant=implant, model=model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
        initial_vals['a5'] = estimator.a5
        initial_vals['a6'] = estimator.a6
        estimator.fit_size_model(data['amp1'], data['image'])
    else:
        estimator = shapes.AxonMapEstimator(implant=implant, model=model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
    
    y_image = data['image']
    y = estimator.compute_moments(y_image)
    data[estimator._mse_params] = y
    y_averaged = data.groupby(['amp1', 'freq', 'pdur', 'electrode1'])[estimator._mse_params].mean()
    x = data[['amp1', 'freq', 'pdur', 'electrode1']].drop_duplicates()
    initial_score = estimator.score(x, y_averaged)
    print("starting score: {}".format(initial_score))
    
    for iteration in range(max_iter):
        
        opt = ParticleSwarmOptimizer(estimator, search_params, max_iter=75, swarm_size=40 * len(search_params), has_loss_function=True)
        opt.fit(x, y_averaged)
        score = estimator.score(x, y_averaged)
        if verbose:
            print("Iteration {} score: {}".format(iteration, score))
        if score < initial_score:
            # rho and lambda have already been changed in the model
            new_vals = {param: getattr(model, param) for param in initial_vals.keys()}
            print("Better model found: {}".format(new_vals))
            return score, new_vals, model
    print("Couldn't find a better model")
    for param, val in initial_vals.items():
        setattr(model, param, val)
    return initial_score, initial_vals, model

In [6]:
df = shapes.load_shapes("../data/shapes.h5", implant="ArgusII", stim_class='SingleElectrode')

In [None]:
subjects = ['12-005', '51-009', '52-001']
implants = [shapes.model_from_params(shapes.subject_params[s], biphasic=True)[0] for s in subjects]
models = [shapes.model_from_params(shapes.subject_params[s], biphasic=True)[1] for s in subjects]
models[0].rho = 130
models[0].axlambda = 500
models[1].rho = 29
models[1].axlambda = 750
models[2].rho = 200
models[2].axlambda = 1500
for m in models:
    m.xystep=0.25

out_csv = "../results/implant_location/" + str(datetime.now().strftime("%m%d_%H%M")) + "_" + str(mse_params) + "_" + str(threshold) + "_"  + ".csv"
out_df = pd.DataFrame(columns=['subject', 'epoch', 'score', 'rho', 'axlambda', 'dx', 'dy', 'drot', 'final'])
for subject, starting_implant, model in zip(subjects, implants, models):
    print("Subject: {}".format(subject))
    model.xystep = 0.5
    row = [subject, 0, 999999, model.rho, model.axlambda, 0, 0, 0, False]
    out_df.loc[len(out_df)] = row
    out_df.to_csv(out_csv)
    
    offset = np.array([0, 0, 0])
    for epoch in tqdm(range(n_epochs)):
        # first update the implant
        
        print("Updating Implant Location")
        score, new_offset, implant = update_implant(subject, offset, model, df)
        if np.all(new_offset == offset):
            print("Converged\n")
            row = [subject, epoch, score, model.rho, model.axlambda, offset[0], offset[1], offset[2], True]
            out_df.loc[len(out_df)] = row
            out_df.to_csv(out_csv)
            break
        print("New offset: {}, score:{}".format(new_offset, score))
        offset = new_offset
        
        print("Updating model params")
        score, params, model = update_params(implant, subject, model, df)
        print("New params: {}, score:{}".format(params, score))
        print()
        
        row = [subject, epoch, score, params['rho'], params['axlambda'], offset[0], offset[1], offset[2], False]
        out_df.loc[len(out_df)] = row
        out_df.to_csv(out_csv)

Subject: 12-005


  0%|          | 0/20 [00:00<?, ?it/s]

Updating Implant Location
starting score: 20.759077789678237, rho:130, axlambda:500


## Offsets

In [17]:
import os
import pandas as pd
frames = [pd.read_csv(os.path.join('../results/implant_location', file)) for file in os.listdir("../results/implant_location") if ".csv" in file]
results = pd.concat([f[f['final'] == True] for f in frames])
results[['subject', 'rho', 'axlambda', 'dx', 'dy', 'drot', 'score']].reset_index(drop=True)

Unnamed: 0,subject,rho,axlambda,dx,dy,drot,score
0,12-005,167.82224,651.671555,120.0,45.0,0.75,6.717848
1,52-001,196.603664,1931.91022,-20.0,5.0,2.75,6.6861
2,51-009,49.040328,1180.477531,0.0,-15.0,-2.0,12.260011
