In [1]:
from pulse2percept.models import BiphasicAxonMapModel, AxonMapModel, BiphasicAxonMapSpatial
from pulse2percept.model_selection import ParticleSwarmOptimizer
from pulse2percept.implants import ArgusII
import shapes

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime
import math

In [2]:
mse_params = ['moments_central']
mse_params = ['area', 'eccentricity', 'orientation']
threshold = 1 / np.sqrt(np.exp(1))
loss_fn  = 'r2'
scale_features = False
verbose = True
search_params = {
        'rho': (10, 1000),
        'axlambda': (10, 2500)
    }
n_epochs = 20


In [3]:
# Two functions to alternate between
def update_implant(implant, subject, current_pos, model, df, max_iter=5):
    data = df[df['subject'] == subject]
    if type(model) == BiphasicAxonMapModel: 
        estimator = shapes.BiphasicAxonMapEstimator(implant=implant, model=model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
        estimator.fit_size_model(data['amp1'], data['image'])
    else:
        estimator = shapes.AxonMapEstimator(implant=implant, model=model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
        
    
    y_image = data['image']
    y = estimator.compute_moments(y_image)
    data[estimator._mse_params] = y
    y_averaged = data.groupby(['amp1', 'freq', 'pdur', 'electrode1'])[estimator._mse_params].mean()
    x = data[['amp1', 'freq', 'pdur', 'electrode1']].drop_duplicates()
    initial_score = estimator.score(x, y_averaged)
    

    offsets = [np.array([current_pos[0] + dx, current_pos[1] + dy, current_pos[2] + drot]) 
                              for dx in [-25, -10, -5, -2, 2, 5, 10, 25]
                              for dy in [-25, -10, -5, -2, 2, 5, 10, 25]
                              for drot in [-5,  -3, -1, 1, 3, 5]]
    # offsets = [np.array([current_pos[0] + dx, current_pos[1] + dy, current_pos[2] + drot]) 
    #                           for dx in [-1, 1]
    #                           for dy in [-1, 1]
    #                           for drot in [-1, 1]]
    best_score = 99999999999
    best_offset = None
    best_implant = None
    for offset in offsets:
        implant, _ = shapes.model_from_params(shapes.subject_params[subject], offset=offset)
        estimator.implant = implant
        score = estimator.score(x, y_averaged)
        if score < best_score:
            best_score = score
            best_offset = offset
            best_implant = implant
    if best_score < initial_score:
        # rho and lambda have already been changed in the model
        if verbose:
            print("Better model found: {}".format(best_offset))
        return best_score, best_offset + current_pos, best_implant
    print("Couldn't find a better offset")
    implant, _ = shapes.model_from_params(shapes.subject_params[subject], offset=current_pos)
    return initial_score, current_pos, implant
    

def update_params(implant, subject, model, df, max_iter=5): # we can just store 
    # dont search entire range, just area right around each param
    search_params = {param : (max(getattr(model, param) - 30, 10), getattr(model, param) + 30) for param in search_params.keys()}

    data = df[df['subject'] == subject]
    initial_vals = {param : getattr(model, param) for param in search_params.keys()}
    
    if type(model) == BiphasicAxonMapModel: 
        estimator = shapes.BiphasicAxonMapEstimator(implant=implant, model=model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
        initial_vals['a5'] = estimator.a5
        initial_vals['a6'] = estimator.a6
        estimator.fit_size_model(data['amp1'], data['image'])
    else:
        estimator = shapes.AxonMapEstimator(implant=implant, model=model, mse_params=mse_params, threshold=threshold, loss_fn=loss_fn, scale_features=scale_features, verbose=verbose)
    
    y_image = data['image']
    y = estimator.compute_moments(y_image)
    data[estimator._mse_params] = y
    y_averaged = data.groupby(['amp1', 'freq', 'pdur', 'electrode1'])[estimator._mse_params].mean()
    x = data[['amp1', 'freq', 'pdur', 'electrode1']].drop_duplicates()
    initial_score = estimator.score(x, y_averaged)
    
    for iteration in range(max_iter):
        
        opt = ParticleSwarmOptimizer(estimator, search_params, max_iter=5, swarm_size=1 * len(search_params), has_loss_function=True)
        opt.fit(x, y_averaged)
        score = estimator.score(x, y_averaged)
        if verbose:
            print("Iteration {} score: {}".format(iteration, score))
        if score < initial_score:
            # rho and lambda have already been changed in the model
            new_vals = {param: getattr(model, param) for param in initial_vals.keys()}
            if verbose:
                print("Better model found: {}".format(new_vals))
            return score, model
    print("Couldn't find a better model")
    for param, val in initial_vals.items():
        setattr(model, param, val)
    return initial_score, new_vals, model

In [4]:
df = shapes.load_shapes("../data/shapes.h5", implant="ArgusII", stim_class='SingleElectrode')

In [None]:
subjects = ['12-005', '51-009', '52-001']
implants = [shapes.model_from_params(shapes.subject_params[s], biphasic=False)[0] for s in subjects]
models = [shapes.model_from_params(shapes.subject_params[s], biphasic=False)[1] for s in subjects]
models[0].rho = 181.4
models[0].axlambda = 617.6
models[1].rho = 51
models[1].axlambda = 1191.5
models[2].rho = 198.1
models[2].axlambda = 1995.7

out_csv = "../results/implant_location/" + str(datetime.now().strftime("%m%d_%H%M")) + "_" + str(mse_params) + "_" + str(threshold) + "_"  + ".csv"
out_df = pd.DataFrame(columns=['subject', 'epoch', 'score', 'rho', 'axlambda', 'dx', 'dy', 'dz', 'final'])
for subject, starting_implant, model in zip(subjects, implants, models):
    print("Subject: {}".format(subject))
    row = [subject, 0, 999999, model.rho, model.axlambda, 0, 0, 0, False]
    out_df.loc[len(out_df)] = row
    out_df.to_csv(out_csv)
    
    offset = np.array([0, 0, 0])
    for epoch in range(n_epochs):
        print("Epoch: {}".format(epoch))
        # first update the implant
        
        print("Updating Implant Location")
        score, new_offset, implant = update_implant(starting_implant, subject, offset, model, df)
        print("New offset: {}, score:{}".format(new_offset, score))
        if np.all(new_offset == offset):
            print("Converged\n")
            row = [subject, epoch, score, model.rho, model.axlambda, offset[0], offset[1], offset[2], True]
            out_df.loc[len(out_df)] = row
            out_df.to_csv(out_csv)
            break
            
        offset = new_offset
        print("Updating model params")
        score, params, model = update_params(implant, subject, model, df)
        print("New params: {}, score:{}".format(params, score))
        print()
        
        row = [subject, epoch, score, params['rho'], params['axlambda'], offset[0], offset[1], offset[2], False]
        out_df.loc[len(out_df)] = row
        out_df.to_csv(out_csv)

Subject: 12-005
Epoch: 0
Updating Implant Location
score:7.280, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:2.2', 'orientation:3.4']
score:8.719, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:3.8', 'orientation:3.3']
score:7.997, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:2.9', 'orientation:3.4']
score:7.376, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:2.3', 'orientation:3.4']
score:7.532, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:2.3', 'orientation:3.5']
score:7.225, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:2.4', 'orientation:3.1']
score:7.255, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:2.4', 'orientation:3.2']
score:8.726, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:3.7', 'orientation:3.3']
score:8.330, rho:181.4, lambda:617.6, empty:0, scores:['area:1.7', 'eccentricity:3.3', 'orientation:3

In [None]:
out_df