Skip to content

Commit

Permalink
move best fit processing to new module, simpler to batch
Browse files Browse the repository at this point in the history
  • Loading branch information
joesilber committed Apr 25, 2020
1 parent 726257a commit 5a8ad4a
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 261 deletions.
278 changes: 23 additions & 255 deletions bin/fit_posparams
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,27 @@ parser.add_argument('-dw', '--data_window', type=int, default=100,
parser.add_argument('-pd', '--period_days', type=int, default=1,
help='spacing of datum dates at which to run the fit')
parser.add_argument('-np', '--n_processes_max', type=int, default=None,
help='max number of processors to use')
help='max number of processors to use. Note: can argue 1 to avoid multiprocessing pool (useful for debugging)')
parser.add_argument('-sq', '--static_quantile', type=float, default=0.05,
help='sort static param fits by fit error, take this fraction with least error, and median them to decide the single set of best values')
parser.add_argument('-p', '--profile', action='store_true',
help='profile timing of a single analyis job')
parser.add_argument('-q', '--quiet', action='store_true',
help='reduce verbosity of print outs at terminal')
parser.add_argument('-l', '--logging', action='store_true',
help='log terminal print outs to file')
args = parser.parse_args()

# proceed with the rest
import os
import sys
import cProfile, pstats
import math
import numpy as np
import time
import multiprocessing
from pkg_resources import resource_filename
from astropy.table import Table
from astropy.time import Time
import glob

# imports below require <path to desimeter>/py' to be added to system PYTHONPATH.
import desimeter.posparams.fitter as fitter
import desimeter.transform.ptl2fp as ptl2fp
import desimeter.posparams.fithandler as fithandler

# paths
desimeter_data_dir = resource_filename("desimeter", "data")
Expand All @@ -81,42 +75,12 @@ for s in args.infiles:
these = glob.glob(s)
infiles.extend(these)
infiles = [os.path.realpath(p) for p in infiles]
save_dir = args.outdir
if not os.path.isdir(save_dir):
os.path.os.makedirs(save_dir)

# other options which (for now) are not made accessible at command line
params_to_keep_fixed = [] # see comments in pos_params_bestfit function
param_overrides = {} # any key:val pairs to override the default params in bestfit module

# renaming some of the args
profile_one_job = args.profile
static_params_quantile = args.static_quantile
static_params_group_func = np.median

# parameter overrides
param_nominals = fitter.default_values
param_nominals.update(param_overrides)

# log and output file handling
def now_str_for_data():
'''Returns a string verson of current time with format like yyyy-mm-dd HH:MM:SS.fff.'''
return Time.now().iso

if not os.path.isdir(args.outdir):
os.path.os.makedirs(args.outdir)
def now_str_for_file():
'''Returns a string verson of current time with format like yyyymmddTHHMMSS.'''
return Time.now().isot.replace('-','').replace(':','')[:-4]

file_timestamp_str = now_str_for_file()
log_name = file_timestamp_str + '_log.txt'
log_path = os.path.join(save_dir, log_name)
def printlog(s):
s = f'{now_str_for_data()} {s}'
if not args.quiet:
print(s)
if args.logging:
with open(log_path,'a') as file:
print(s, file=file)

# parse out the posids from filenames
csv_files = {}
Expand All @@ -136,14 +100,14 @@ if not csv_files:
tables = {}
required_keys = {'DATE', 'POS_P', 'POS_T', 'X_FP', 'Y_FP', 'CTRL_ENABLED',
'TOTAL_MOVE_SEQUENCES'}
printlog(f'Now reading {len(csv_files)} csv data files...')
print(f'Now reading {len(csv_files)} csv data files...')
for posid, path in csv_files.items():
table = Table.read(path)
table.sort('DATE')
key_search = [key in table.columns for key in required_keys]
if all(key_search):
tables[posid] = table
# printlog(f'{posid}: Read from {path}') # something buggy about this particular print, ghosts of it keep appearing in stdout later on
print(f'{posid}: Read from {path}')
posids = sorted(csv_files.keys())

def _drop_empty_tables():
Expand All @@ -152,7 +116,7 @@ def _drop_empty_tables():
if len(tables[posid]) == 0:
del posids[posids.index(posid)]
del tables[posid]
printlog(f'{posid}: dropped from analysis (no data)')
print(f'{posid}: dropped from analysis (no data)')
_drop_empty_tables()

# fix types, since astropy seems to kind of suck at this
Expand All @@ -171,7 +135,7 @@ for posid in tables:
z_fp=None)
tables[posid]['X_PTL'] = ptlXYZ[0]
tables[posid]['Y_PTL'] = ptlXYZ[1]
printlog(f'{posid}: (X_FP, Y_FP) converted to (X_PTL, Y_PTL)')
print(f'{posid}: (X_FP, Y_FP) converted to (X_PTL, Y_PTL)')

# filter functions for identifying bad rows
def _np_diff_with_prepend(array, prepend):
Expand Down Expand Up @@ -215,49 +179,14 @@ for posid, table in tables.items():
if len(table) == 0:
break
final_len = len(table)
printlog(f'{posid}: dropped {initial_len - final_len} of {initial_len} non-conforming data rows')
print(f'{posid}: dropped {initial_len - final_len} of {initial_len} non-conforming data rows')
_drop_empty_tables()

# add column for seconds-since-epoch version of date to data frames
for posid, table in tables.items():
dates = table['DATE']
table['DATE_SEC'] = Time(dates, format='iso').unix
printlog(f'{posid}: generated seconds-since-epoch column (\'DATE_SEC\')')

# data selectors
def select_by_index(table, start=0, final=-1):
'''Returns a subset of data formatted for the best-fitting function.'''
data = {}
subtable = table[start:final]
data['posintT'] = subtable['POS_T'].tolist()
data['posintP'] = subtable['POS_P'].tolist()
data['ptlX'] = subtable['X_PTL'].tolist()
data['ptlY'] = subtable['Y_PTL'].tolist()
return data, subtable

def _row_idx_for_time(presorted_table, t):
'''Returns index of first row in table which matches the argued time t.
Value t should be in same scale as 'DATE_SEC' column. I.e. the date in
seconds-since-epoch. If t is outside the time range of the table, the
function will return either index 0 or max index.
Note: For speed, this function assumes the provided table has already been
pre-sorted, ascending by time!
'''
if t < presorted_table['DATE_SEC'][0]:
return 0
if t > presorted_table['DATE_SEC'][-1]:
return len(presorted_table) - 1
return int(np.argwhere(table['DATE_SEC'] >= t)[0])

def dict_str(d):
'''Return a string displaying formatted values in dict d.'''
s = '{'
for key,val in d.items():
s += f"'{key}':{val:.3f}, "
s = s[:-2]
s += '}'
return s
print(f'{posid}: generated seconds-since-epoch column (\'DATE_SEC\')')

# find date range of all data and assign datum dates
period_sec = args.period_days * 24 * 60 * 60
Expand All @@ -268,179 +197,18 @@ for table in tables.values():
last = max(last, table['DATE_SEC'][-1])
datum_dates = np.arange(first, last, period_sec).tolist()

# function for multiprocessing of many analyis cases
case_enum = {'posid':0,
'start_idx':1,
'final_idx':2}
def process(mode='static', cases=[], nominals=param_nominals, savedir='.'):
'''Process many analysis cases.
Inputs:
mode ... 'static' or 'dynamic', see notes in fitter module
cases ... list of values matching sequence defined in case_enum
nominals ... if mode == 'static' --> single dict with keys = parameter names
if mode =='dynamic' --> dict with keys = posids and values = subdicts giving each positioner's particular set of parameter names/values
savedir ... where to save analysis results as they are generated
Outputs:
output_frame ... single pandas dataframe containing all results
output_path ... path to saved csv data file containing all results
'''
assert mode in {'static', 'dynamic'}
mp_results = {} # container for result objects from the multiprocessing pool
output_columns = ['ANALYSIS_DATE', 'POS_ID', 'NUM_POINTS', 'FIT_ERROR',
'DATA_START_DATE', 'DATA_END_DATE',
'DATA_START_DATE_SEC', 'DATA_END_DATE_SEC']
output_columns += fitter.all_keys
output = {col:[] for col in output_columns}
if __name__ == '__main__':
with multiprocessing.Pool(processes=args.n_processes_max) as pool:

# send out jobs
for case in cases:
posid = case[case_enum['posid']]
m = case[case_enum['start_idx']]
n = case[case_enum['final_idx']]
table = tables[posid]
xytp_data, subtable = select_by_index(table, start=m, final=n+1)
job_desc = f'{mode} params analysis of {posid} over data period:\n'
job_desc += f' start idx = {m:5d}, date = {subtable["DATE"][0]}\n'
job_desc += f' final idx = {n:5d}, date = {subtable["DATE"][-1]}\n'
job_desc += f' num points = {n-m+1:5d}'
printlog(f'job {cases.index(case)} of {len(cases)} added:\n{job_desc}\n')
if mode == 'static':
these_nominals = nominals
for posid, table in tables.items():
kwargs = {'table': table,
'datum_dates': datum_dates,
'data_window': args.data_window,
'savedir': args.outdir,
'static_quantile': args.static_quantile,
'verbose': not(args.quiet)
}
if args.n_processes_max == 1:
fithandler.run_best_fits(**kwargs)
else:
these_nominals = nominals[posid]
argsdict = xytp_data
argsdict.update({'mode':mode,
'nominals':these_nominals,
'keep_fixed':params_to_keep_fixed,
'description':job_desc})
if profile_one_job:
prof = cProfile.Profile()
params, fit_err, _ = prof.runcall(fitter.fit_params, **argsdict)
printlog(f'Results:\n{params}\nfit error = {fit_err}')
statsfile = 'stats'
prof.dump_stats(statsfile)
p = pstats.Stats(statsfile)
for sorter in ['cumtime', 'tottime']:
p.strip_dirs()
p.sort_stats(sorter)
p.print_stats(15)
os.remove(statsfile)
return {posid:params}, fit_err
mp_results[case] = pool.apply_async(fitter.fit_params, kwds=argsdict)

# retrieve results as they come in
while mp_results:
completed = set()
for case, result in mp_results.items():
if result.ready():
completed.add(case)
params, fit_err, job_desc = result.get()
posid = case[case_enum['posid']]
m = case[case_enum['start_idx']]
n = case[case_enum['final_idx']]
output['ANALYSIS_DATE'].append(now_str_for_data())
output['POS_ID'].append(posid)
for suffix in {'', '_SEC'}:
d = f'DATE{suffix}'
output[f'DATA_START_{d}'].append(tables[posid][d][m])
output[f'DATA_END_{d}'].append(tables[posid][d][n])
output['NUM_POINTS'].append(n - m + 1)
output['FIT_ERROR'].append(fit_err)
for key in params:
output[key].append(params[key])
job_desc += '\n'
job_desc += f'{posid}: best params = {dict_str(params)}\n'
job_desc += f'{posid}: fit error = {fit_err:.3f}'
printlog(f'job {cases.index(case)} completed:\n{job_desc}\n')
for case in completed:
del mp_results[case]
printlog(f'Jobs remaining: {len(mp_results)}')
time.sleep(0.05)
output_table = Table(output)
output_table.sort(['POS_ID','DATA_END_DATE'])
output_name = f'{file_timestamp_str}_results_{mode}.csv'
output_path = os.path.join(savedir, output_name)
output_table.write(output_path)
return output_table, output_path

if __name__ == '__main__':

# define analysis cases spanning the data history
cases = []
num_cases = {posid: 0 for posid in posids}
formatted_xytp_data = {}
for posid, table in tables.items():
widths = []
start_idxs = []
final_idxs = []
for j in range(1, len(datum_dates)):
start_date = datum_dates[j - 1]
final_date = datum_dates[j]
start_idxs.append(_row_idx_for_time(table, start_date))
final_idxs.append(_row_idx_for_time(table, final_date))
widths.append(final_idxs[-1] - start_idxs[-1])
for J in range(len(final_idxs)):
backwards_sum = lambda i: sum(w for w in widths[i:J+1])
satisfies_min_width = lambda i: backwards_sum(i) > args.data_window
should_expand_backwards = lambda i: not(satisfies_min_width(i)) and i > 0
I = J
while should_expand_backwards(I):
I -= 1
can_expand_forwards = J < len(widths) - 1
if not satisfies_min_width(I) and can_expand_forwards:
continue # skip ahead and try the next datum date
case = [None]*len(case_enum)
case[case_enum['posid']] = posid
case[case_enum['start_idx']] = start_idxs[I]
case[case_enum['final_idx']] = final_idxs[J]
cases.append(tuple(case)) # tuple is hashable, so can use case as a dict key later
num_cases[posid] += 1

printlog(f'{posid}: {num_cases[posid]:5d} analysis cases defined')
printlog(f'Total analysis cases: {len(cases)}')
cases = list(set(cases))
printlog(f'Reduced to: {len(cases)} unique cases')

# FIRST-PASS: STATIC PARAMETERS
static_out, static_file = process(mode='static',
cases=cases,
nominals=param_nominals.copy(),
savedir=save_dir)

# DECIDE ON BEST STATIC PARAMS
if profile_one_job:
best_static = static_out
else:
best_static = {}
for posid in posids:
best_static[posid] = param_nominals.copy()
subtable = static_out[static_out['POS_ID'] == posid]
errors = subtable['FIT_ERROR']
quantile = np.percentile(errors, static_params_quantile * 100)
selection = subtable[errors <= quantile]
these_best = {key:static_params_group_func(selection[key]) for key in fitter.static_keys}
best_static[posid].update(these_best)
printlog(f'{posid}: Selected best static params = {dict_str(best_static[posid])}')

# SECOND-PASS: DYNAMIC PARAMETERS
dynamic_out, dynamic_file = process(mode='dynamic',
cases=cases,
nominals=best_static,
savedir=save_dir)

if profile_one_job:
sys.exit() # in this case our work is done

# MERGED STATIC + DYNAMIC
# Note how merge here assumes both static and dynamic are provided pre-sorted into same order.
merged = static_out.copy()
for key in fitter.dynamic_keys:
merged[key] = dynamic_out[key]
for key in ['ANALYSIS_DATE', 'FIT_ERROR']:
merged.rename_column(key, key + '_STATIC')
merged[key + '_DYNAMIC'] = dynamic_out[key]
merged_file = static_file.split('static')[0] + 'merged.csv'
merged.write(merged_file)
pool.apply_async(fithandler.run_best_fits, kwds=kwargs)
print('Complete.')
Loading

0 comments on commit 5a8ad4a

Please sign in to comment.