-
Notifications
You must be signed in to change notification settings - Fork 33
/
hierarchical.py
1102 lines (874 loc) · 39 KB
/
hierarchical.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/python
from copy import copy
import pickle
import sys
import numpy as np
from scipy.optimize import minimize, basinhopping
from collections import OrderedDict, defaultdict
import pandas as pd
import pymc as pm
import warnings
from kabuki.utils import flatten
from . import analyze
class LnProb(object):
def __init__(self, model):
self.model = model
def lnprob(self, vals): # vals is a vector of parameter values to try
# Set each random variable of the pymc model to the value
# suggested by emcee
try:
for val, (name, stoch) in zip(vals, self.model.iter_stochastics()):
stoch['node'].set_value(val)
logp = self.model.mc.logp
return logp
except pm.ZeroProbability:
return -np.inf
def __call__(self, *args, **kwargs):
return self.lnprob(*args, **kwargs)
class Knode(object):
def __init__(self, pymc_node, name, depends=(), col_name='',
subj=False, hidden=False, pass_dataframe=True, **kwargs):
self.pymc_node = pymc_node
self.name = name
self.kwargs = kwargs
self.subj = subj
if isinstance(col_name, str):
col_name = [col_name]
self.col_name = col_name
self.nodes = OrderedDict()
self.hidden = hidden
self.pass_dataframe = pass_dataframe
#create self.parents
self.parents = {}
for (name, value) in self.kwargs.items():
if isinstance(value, Knode):
self.parents[name] = value
# Create depends set and update based on parents' depends
depends = set(depends)
if self.subj:
depends.add('subj_idx')
depends.update(self.get_parent_depends())
self.depends = sorted(list(depends))
self.observed = 'observed' in kwargs
def __repr__(self):
return self.name
def set_data(self, data):
self.data = data
def get_parent_depends(self):
"""returns the depends of the parents"""
union_parent_depends = set()
for name, parent in self.parents.items():
union_parent_depends.update(set(parent.depends))
return union_parent_depends
def init_nodes_db(self):
data_col_names = list(self.data.columns)
node_descriptors = ['knode_name', 'stochastic', 'observed', 'subj', 'node', 'tag', 'depends', 'hidden']
stats = ['mean', 'std', '2.5q', '25q', '50q', '75q', '97.5q', 'mc err']
columns = node_descriptors + data_col_names + stats
# create central dataframe
self.nodes_db = pd.DataFrame(columns=columns)
def append_node_to_db(self, node, uniq_elem):
#create db entry for knode
row = {}
row['knode_name'] = self.name
row['observed'] = self.observed
row['stochastic'] = isinstance(node, pm.Stochastic) and not self.observed
row['subj'] = self.subj
row['node'] = node
row['tag'] = self.create_tag_and_subj_idx(self.depends, uniq_elem)[0]
row['depends'] = self.depends
row['hidden'] = self.hidden
row = pd.DataFrame(data=[row], columns=self.nodes_db.columns, index=[node.__name__])
for dep, elem in zip(self.depends, uniq_elem):
row[dep] = elem
self.nodes_db = self.nodes_db.append(row)
def create(self):
"""create the pymc nodes"""
self.init_nodes_db()
#group data
if len(self.depends) == 0:
grouped = [((), self.data)]
else:
grouped = self.data.groupby(self.depends)
#create all the pymc nodes
for uniq_elem, grouped_data in grouped:
if not isinstance(uniq_elem, tuple):
uniq_elem = (uniq_elem,)
# create new kwargs to pass to the new pymc node
kwargs = self.kwargs.copy()
# update kwarg with the right parent
for name, parent in self.parents.items():
kwargs[name] = parent.get_node(self.depends, uniq_elem)
#get node name
tag, subj_idx = self.create_tag_and_subj_idx(self.depends, uniq_elem)
node_name = self.create_node_name(tag, subj_idx=subj_idx)
#get value for observed node
if self.observed:
if self.pass_dataframe:
kwargs['value'] = grouped_data[self.col_name] #.to_records(index=False)
else:
kwargs['value'] = grouped_data[self.col_name].values #.to_records(index=False)
# Deterministic nodes require a parent argument that is a
# dict mapping parent names to parent nodes. Knode wraps
# this; so here we have to fish out the parent nodes from
# kwargs, put them into a parent dict and put that back
# into kwargs, which will make pm.Determinstic() get a
# parent dict as an argument.
if self.pymc_node is pm.Deterministic:
parents_dict = {}
for name, parent in self.parents.items():
parents_dict[name] = parent.get_node(self.depends, uniq_elem)
kwargs.pop(name)
kwargs['parents'] = parents_dict
if self.observed:
kwargs['parents']['value'] = kwargs['value']
# Deterministic nodes require a doc kwarg, we don't really
# need that so if its not supplied, just use the name
if self.pymc_node is pm.Deterministic and 'doc' not in kwargs:
kwargs['doc'] = node_name
node = self.create_node(node_name, kwargs, grouped_data)
if node is not None:
self.nodes[uniq_elem] = node
self.append_node_to_db(node, uniq_elem)
def create_node(self, node_name, kwargs, data):
#actually create the node
return self.pymc_node(name=node_name, **kwargs)
def create_tag_and_subj_idx(self, cols, uniq_elem):
uniq_elem = pd.Series(uniq_elem, index=cols)
if 'subj_idx' in cols:
subj_idx = uniq_elem['subj_idx']
tag = uniq_elem.drop(['subj_idx']).values
else:
tag = uniq_elem.values
subj_idx = None
return tuple(tag), subj_idx
def create_node_name(self, tag, subj_idx=None):
# construct string that will become the node name
s = self.name
if len(tag) > 0:
elems_str = '.'.join([str(elem) for elem in tag])
s += "({elems})".format(elems=elems_str)
if subj_idx is not None:
s += ".{subj_idx}".format(subj_idx=subj_idx)
return s
def get_node(self, cols, elems):
"""Return the node that depends on the same elements.
Called by the child to receive specific parent node.
:Arguments:
col_to_elem : dict
Maps column names to elements.
e.g. {'col1': 'elem1', 'col2': 'elem2', 'col3': 'elem3'}
"""
col_to_elem = {}
for col, elem in zip(cols, elems):
col_to_elem[col] = elem
# Find the column names that overlap with the ones we have
overlapping_cols = intersect(cols, self.depends)
# Create new tag for the specific elements we are looking for (that overlap)
deps_on_elems = tuple([col_to_elem[col] for col in overlapping_cols])
return self.nodes[deps_on_elems]
def intersect(t1, t2):
# Preserves order, unlike set.
return tuple([i for i in t2 if i in t1])
def test_subset_tuple():
assert intersect(('a', 'b' , 'c'), ('a',)) == ('a',)
assert intersect(('a', 'b' , 'c'), ('a', 'b')) == ('a', 'b')
assert intersect(('a', 'b' , 'c'), ('a', 'c')) == ('a', 'c')
assert intersect(('a', 'b' , 'c'), ('b', 'c')) == ('b', 'c')
assert intersect(('c', 'b', 'a'), ('b', 'c')) == ('b', 'c')
class Hierarchical(object):
"""Creation of hierarchical Bayesian models in which each subject
has a set of parameters that are constrained by a group distribution.
:Arguments:
data : numpy.recarray
Input data with a row for each trial.
Must contain the following columns:
* 'rt': Reaction time of trial in seconds.
* 'response': Binary response (e.g. 0->error, 1->correct)
May contain:
* 'subj_idx': A unique ID (int) of the subject.
* Other user-defined columns that can be used in depends_on
keyword.
:Optional:
is_group_model : bool
If True, this results in a hierarchical
model with separate parameter distributions for each
subject. The subject parameter distributions are
themselves distributed according to a group parameter
distribution.
depends_on : dict
Specifies which parameter depends on data
of a column in data. For each unique element in that
column, a separate set of parameter distributions will be
created and applied. Multiple columns can be specified in
a sequential container (e.g. list)
:Example:
>>> depends_on={'param1':['column1']}
Suppose column1 has the elements 'element1' and
'element2', then parameters 'param1('element1',)' and
'param1('element2',)' will be created and the
corresponding parameter distribution and data will be
provided to the user-specified method get_liklihood().
trace_subjs : bool
Save trace for subjs (needed for many
statistics so probably a good idea.)
plot_var : bool
Plot group variability parameters
In addition, the variable self.params must be defined as a
list of Paramater().
"""
def __init__(self, data, is_group_model=None, depends_on=None, trace_subjs=True,
plot_subjs=False, plot_var=False, group_only_nodes=()):
# Init
self.plot_subjs = plot_subjs
self.depends_on = depends_on
self.mc = None
self.data = pd.DataFrame(data)
self.group_only_nodes = group_only_nodes
if not depends_on:
depends_on = {}
else:
assert isinstance(depends_on, dict), "depends_on must be a dictionary."
# Support for supplying columns as a single string
# -> transform to list
for key in depends_on:
if isinstance(depends_on[key], str):
depends_on[key] = [depends_on[key]]
# Check if column names exist in data
for depend_on in depends_on.values():
for elem in depend_on:
if elem not in self.data.columns:
raise KeyError("Column named %s not found in data." % elem)
self.depends = defaultdict(lambda: ())
for key, value in depends_on.items():
self.depends[key] = value
# Determine if group model
if is_group_model is None:
if 'subj_idx' in self.data.columns:
if len(np.unique(data['subj_idx'])) != 1:
self.is_group_model = True
else:
self.is_group_model = False
else:
self.is_group_model = False
else:
if is_group_model:
if 'subj_idx' not in data.columns:
raise ValueError("Group models require 'subj_idx' column in input data.")
self.is_group_model = is_group_model
# Should the model incorporate multiple subjects
if self.is_group_model:
self._subjs = np.unique(data['subj_idx'])
self._num_subjs = self._subjs.shape[0]
else:
self._num_subjs = 1
self.num_subjs = self._num_subjs
self.sampled = False
self.dbname = 'ram'
self.db = None
self._setup_model()
def _setup_model(self):
# create knodes (does not build according pymc nodes)
self.knodes = self.create_knodes()
#add data to knodes
for knode in self.knodes:
knode.set_data(self.data)
# constructs pymc nodes etc and connects them appropriately
self.create_model()
def __getstate__(self):
from copy import deepcopy
d = copy(self.__dict__)
d['nodes_db'] = deepcopy(d['nodes_db'].drop('node', axis=1))
d['depends'] = dict(d['depends'])
#d['model_type'] = self.__class__
if self.sampled:
d['db'] = self.mc.db.__name__
dbname = d['mc'].db.__name__
if (dbname == 'ram'):
raise ValueError("db is 'ram'. Saving a model requires a database on disk.")
elif (dbname == 'pickle'):
d['dbname'] = d['mc'].db.filename
elif (dbname == 'txt'):
d['dbname'] = d['mc'].db._directory
else: # hdf5, sqlite
d['dbname'] = d['mc'].db.dbname
del d['mc']
del d['knodes']
return d
def __setstate__(self, d):
self.__dict__.update(d)
self._setup_model()
self.create_model()
# backwards compat
if not hasattr(self, 'sampled'):
self.sampled = True
if self.sampled:
self.load_db(d['dbname'], db=d['db'])
self.gen_stats()
else:
self.mcmc()
def save(self, fname):
"""Save model to file.
:Arguments:
fname : str
filename to save to
:Notes:
* Load models using kabuki.utils.load(fname).
* You have to save traces to db, not RAM.
* Uses the pickle protocol internally.
"""
pickle.dump(self, open(fname, 'wb'))
def create_knodes(self):
raise NotImplementedError("create_knodes has to be overwritten")
def create_model(self, max_retries=8):
"""Set group level distributions. One distribution for each
parameter.
:Arguments:
retry : int
How often to retry when model creation
failed (due to bad starting values).
"""
def _create():
for knode in self.knodes:
knode.create()
for tries in range(max_retries):
try:
_create()
except (pm.ZeroProbability, ValueError):
continue
break
else:
print("After %f retries, still no good fit found." %(tries))
_create()
# create node container
self.create_nodes_db()
# Check whether all user specified column names (via depends_on) where used by the depends_on.
assert set(flatten(list(self.depends.values()))).issubset(set(flatten(self.nodes_db.depends))), "One of the column names specified via depends_on was not picked up. Check whether you specified the correct parameter value."
def create_nodes_db(self):
self.nodes_db = pd.concat([knode.nodes_db for knode in self.knodes])
def draw_from_prior(self, update=False):
if not update:
values = self.values
non_zero = True
while non_zero:
try:
self.mc.draw_from_prior()
self.mc.logp
draw = copy(self.values)
non_zero = False
except pm.ZeroProbability:
non_zero = True
if not update:
# restore original values
self.set_values(values)
return draw
def map(self, runs=2, warn_crit=5, method='fmin_powell', **kwargs):
"""
Find MAP and set optimized values to nodes.
:Arguments:
runs : int
How many runs to make with different starting values
warn_crit: float
How far must the two best fitting values be apart in order to print a warning message
:Returns:
pymc.MAP object of model.
:Note:
Forwards additional keyword arguments to pymc.MAP().
"""
from operator import attrgetter
# I.S: when using MAP with Hierarchical model the subjects nodes should be
# integrated out before the computation of the MAP (see Pinheiro JC, Bates DM., 1995, 2000).
# since we are not integrating we get a point estimation for each
# subject which is not what we want.
if self.is_group_model:
raise NotImplementedError("""Sorry, This method is not yet implemented for group models.
you might consider using the approximate_map method""")
maps = []
for i in range(runs):
# (re)create nodes to get new initival values.
#nodes are not created for the first iteration if they already exist
self.mc = pm.MAP(self.nodes_db.node)
if i != 0:
self.draw_from_prior()
self.mc.fit(method, **kwargs)
print(self.mc.logp)
maps.append(self.mc)
self.mc = None
# We want to use values of the best fitting model
sorted_maps = sorted(maps, key=attrgetter('logp'))
max_map = sorted_maps[-1]
# If maximum logp values are not in the same range, there
# could be a problem with the model.
if runs >= 2:
abs_err = np.abs(sorted_maps[-1].logp - sorted_maps[-2].logp)
if abs_err > warn_crit:
print("Warning! Two best fitting MAP estimates are %f apart. Consider using more runs to avoid local minima." % abs_err)
# Set values of nodes
for max_node in max_map.stochastics:
self.nodes_db.node.ix[max_node.__name__].set_value(max_node.value)
return max_map
def mcmc(self, assign_step_methods=True, *args, **kwargs):
"""
Returns pymc.MCMC object of model.
Input:
assign_step_metheds <bool> : assign the step methods in params to the nodes
The rest of the arguments are forwards to pymc.MCMC
"""
self.mc = pm.MCMC(self.nodes_db.node.values, *args, **kwargs)
self.pre_sample()
return self.mc
def pre_sample(self):
pass
def sample_emcee(self, nwalkers=500, samples=10, dispersion=.1, burn=5, thin=1, stretch_width=2., anneal_stretch=True, pool=None):
import emcee
import pymc.progressbar as pbar
# This is the likelihood function for emcee
lnprob = LnProb(self)
# init
self.mcmc()
# get current values
stochs = self.get_stochastics()
start = [node_descr['node'].value for name, node_descr in stochs.iterrows()]
ndim = len(start)
def init_from_priors():
p0 = np.empty((nwalkers, ndim))
i = 0
while i != nwalkers:
self.mc.draw_from_prior()
try:
self.mc.logp
p0[i, :] = [node_descr['node'].value for name, node_descr in stochs.iterrows()]
i += 1
except pm.ZeroProbability:
continue
return p0
if hasattr(self, 'emcee_dispersions'):
scale = np.empty_like(start)
for i, (name, node_descr) in enumerate(stochs.iterrows()):
knode_name = node_descr['knode_name'].replace('_subj', '')
scale[i] = self.emcee_dispersions.get(knode_name, 0.1)
else:
scale = 0.1
p0 = np.random.randn(ndim * nwalkers).reshape((nwalkers, ndim)) * scale * dispersion + start
#p0 = init_from_priors()
# instantiate sampler passing in the pymc likelihood function
sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, a=stretch_width, pool=pool)
bar = pbar.progress_bar(burn + samples)
i = 0
annealing = np.linspace(stretch_width, 2, burn)
sys.stdout.flush()
for pos, prob, state in sampler.sample(p0, iterations=burn):
if anneal_stretch:
sampler.a = annealing[i]
i += 1
bar.update(i)
#print("\nMean acceptance fraction during burn-in: {}".format(np.mean(sampler.acceptance_fraction)))
sampler.reset()
# sample
try:
for p, lnprob, lnlike in sampler.sample(pos,
iterations=samples,
thin=thin):
i += 1
bar.update(i)
except KeyboardInterrupt:
pass
finally:
print(("\nMean acceptance fraction during sampling: {}".format(np.mean(sampler.acceptance_fraction))))
# restore state
for val, (name, node_descr) in zip(start, stochs.iterrows()):
node_descr['node'].set_value(val)
# Save samples back to pymc model
self.mc.sample(1, progress_bar=False) # This call is to set up the chains
for pos, (name, node) in enumerate(stochs.iterrows()):
node['node'].trace._trace[0] = sampler.flatchain[:, pos]
return sampler
def sample(self, *args, **kwargs):
"""Sample from posterior.
:Note:
Forwards arguments to pymc.MCMC.sample().
"""
# Fetch out arguments for db backend
db = kwargs.pop('db', 'ram')
dbname = kwargs.pop('dbname', None)
# init mc if needed
if self.mc == None:
self.mcmc(db=db, dbname=dbname)
# suppress annoying warnings
if ('hdf5' in dir(pm.database)) and \
isinstance(self.mc.db, pm.database.hdf5.Database):
warnings.simplefilter('ignore', pm.database.hdf5.tables.NaturalNameWarning)
# sample
self.mc.sample(*args, **kwargs)
self.sampled = True
self.gen_stats()
return self.mc
@property
def logp(self):
if self.mc is None:
raise AttributeError('self.mc not set. Call mcmc().')
return self.mc.logp
@property
def dic_info(self):
"""returns information about the model DIC."""
info = {}
try:
info['DIC'] = self.mc.DIC
info['deviance'] = np.mean(self.mc.db.trace('deviance')(), axis=0)
info['pD'] = info['DIC'] - info['deviance']
except pm.ZeroProbability:
info['DIC'] = np.nan
info['deviance'] = np.nan
info['pD'] = np.nan
return info
@property
def dic(self):
"""Deviance Information Criterion.
"""
return self.dic_info['DIC']
@property
def aic(self):
"""Akaike Information Criterion.
"""
if self.is_group_model:
raise NotImplementedError('AIC can only be computed for non-hierarchical models. See dic.')
k = len(self.get_stochastics())
logp = sum([x.logp for x in self.get_observeds()['node']])
return 2 * k - 2 * logp
@property
def bic(self):
"""Bayesian Information Criterion.
"""
if self.is_group_model:
raise NotImplementedError('BIC can only be computed for non-hierarchical models. See dic.')
k = len(self.get_stochastics())
n = len(self.data)
logp = sum([x.logp for x in self.get_observeds()['node']])
return -2 * logp + k * np.log(n)
def _output_stats(self, stats_str, fname=None):
"""
used by print_stats and print_group_stats to print the stats to the screen
or to file
"""
info = self.dic_info
if fname is None:
print(stats_str)
print("DIC: %f" % info['DIC'])
print("deviance: %f" % info['deviance'])
print("pD: %f" % info['pD'])
else:
with open(fname, 'w') as fd:
fd.write(stats_str)
fd.write("\nDIC: %f\n" % info['DIC'])
fd.write("deviance: %f\n" % info['deviance'])
fd.write("pD: %f" % info['pD'])
def gen_stats(self, fname=None, print_hidden=False, **kwargs):
"""print statistics of all variables
Input (optional)
fname <string> - the output will be written to a file named fname
print_hidden <bool> - print statistics of hidden nodes
"""
self.append_stats_to_nodes_db()
sliced_db = self.nodes_db.copy()
# only print stats of stochastic, non-observed nodes
if not print_hidden:
sliced_db = sliced_db[(sliced_db['observed'] == False) & (sliced_db['hidden'] == False)]
else:
sliced_db = sliced_db[(sliced_db['observed'] == False)]
stat_cols = ['mean', 'std', '2.5q', '25q', '50q', '75q', '97.5q', 'mc err']
for node_property, value in kwargs.items():
sliced_db = sliced_db[sliced_db[node_property] == value]
sliced_db = sliced_db[stat_cols]
return sliced_db
def print_stats(self, fname=None, print_hidden=False, **kwargs):
"""print statistics of all variables
Input (optional)
fname <string> - the output will be written to a file named fname
print_hidden <bool> - print statistics of hidden nodes
"""
sliced_db = self.gen_stats(fname=fname, print_hidden=print_hidden, **kwargs)
self._output_stats(sliced_db.to_string(), fname)
def append_stats_to_nodes_db(self, *args, **kwargs):
"""
smart call of MCMC.stats() for the model
"""
try:
nchains = self.mc.db.chains
except AttributeError:
raise ValueError("No model found.")
#check which chain is going to be "stat"
if 'chain' in kwargs:
i_chain = kwargs['chain']
else:
i_chain = nchains
#update self._stats
self._stats = self.mc.stats(*args, **kwargs)
self._stats_chain = i_chain
#add/overwrite stats to nodes_db
for name, i_stats in self._stats.items():
if self.nodes_db.loc[name, 'hidden']:
continue
self.nodes_db.loc[name, 'mean'] = i_stats['mean']
self.nodes_db.loc[name, 'std'] = i_stats['standard deviation']
self.nodes_db.loc[name, '2.5q'] = i_stats['quantiles'][2.5]
self.nodes_db.loc[name, '25q'] = i_stats['quantiles'][25]
self.nodes_db.loc[name, '50q'] = i_stats['quantiles'][50]
self.nodes_db.loc[name, '75q'] = i_stats['quantiles'][75]
self.nodes_db.loc[name, '97.5q'] = i_stats['quantiles'][97.5]
self.nodes_db.loc[name, 'mc err'] = i_stats['mc error']
def load_db(self, dbname, verbose=0, db='sqlite'):
"""Load samples from a database created by an earlier model
run (e.g. by calling .mcmc(dbname='test'))
:Arguments:
dbname : str
File name of database
verbose : int <default=0>
Verbosity level
db : str <default='sqlite'>
Which database backend to use, can be
sqlite, pickle, hdf5, txt.
"""
if db == 'sqlite':
db_loader = pm.database.sqlite.load
elif db == 'pickle':
db_loader = pm.database.pickle.load
elif db == 'hdf5':
db_loader = pm.database.hdf5.load
elif db == 'txt':
db_loader = pm.database.txt.load
# Ignore annoying sqlite warnings
warnings.simplefilter('ignore', UserWarning)
# Open database
db = db_loader(dbname)
# Create mcmc instance reading from the opened database
self.mc = pm.MCMC(self.nodes_db.node, db=db, verbose=verbose)
# Not sure if this does anything useful, but calling for good luck
self.mc.restore_sampler_state()
return self
def plot_posteriors(self, params=None, plot_subjs=False, save=False, **kwargs):
"""
plot the nodes posteriors
Input:
params (optional) - a list of parameters to plot.
plot_subj (optional) - plot subjs nodes
kwargs (optional) - optional keywords to pass to pm.Matplot.plot
TODO: add attributes plot_subjs and plot_var to kabuki
which will change the plot attribute in the relevant nodes
"""
# should we save the figures
kwargs.pop('last', None)
if isinstance(params, str):
params = [params]
# loop over nodes and for each node if it
for (name, node) in self.iter_non_observeds():
if (params is None) or (node['knode_name'] in params): # plot params if its name was mentioned
if not node['hidden']: # plot it if it is not hidden
plot_value = node['node'].plot
if (plot_subjs and node['subj']): # plot if it is a subj node and plot_subjs==True
node['node'].plot = True
if (params is not None) and (node['knode_name'] in params): # plot if it was sepecficily mentioned
node['node'].plot = True
pm.Matplot.plot(node['node'], last=save, **kwargs)
node['node'].plot = plot_value
def plot_posteriors_conditions(self, *args, **kwargs):
"""
Plot all group posteriors listed in depends_on on individual graphs.
Forwards arguments to kabuki.analyze.plot_posterior_nodes.
"""
group_nodes = self.get_group_nodes()
for dep in self.depends_on.keys():
nodes = group_nodes.ix[group_nodes.knode_name == dep]
if all(nodes.hidden == True):
continue
analyze.plot_posterior_nodes(nodes['node'], *args, **kwargs)
def get_observeds(self):
return self.nodes_db[self.nodes_db.observed == True]
def iter_observeds(self):
nodes = self.get_observeds()
for node in nodes.iterrows():
yield node
def get_non_observeds(self):
return self.nodes_db[self.nodes_db.observed == False]
def iter_non_observeds(self):
nodes = self.get_non_observeds()
for node in nodes.iterrows():
yield node
def iter_stochastics(self):
nodes = self.get_stochastics()
for node in nodes.iterrows():
yield node
def get_stochastics(self):
return self.nodes_db[self.nodes_db.stochastic == True]
def get_subj_nodes(self, stochastic=True):
select = (self.nodes_db['subj'] == True) & \
(self.nodes_db['stochastic'] == stochastic)
return self.nodes_db[select]
def iter_subj_nodes(self, **kwargs):
nodes = self.get_subj_nodes(**kwargs)
for node in nodes.iterrows():
yield node
def get_group_nodes(self, stochastic=True):
select = (self.nodes_db['subj'] == False) & \
(self.nodes_db['stochastic'] == stochastic)
return self.nodes_db[select]
def iter_group_nodes(self, **kwargs):
nodes = self.get_group_nodes(**kwargs)
for node in nodes.iterrows():
yield node
def get_group_traces(self):
"""Returns a DataFrame containing traces of all stochastic
group nodes in the model.
"""
return pd.DataFrame({i.__name__: i.trace() for i in self.get_group_nodes().node})
def get_traces(self):
"""Returns a DataFrame containing traces of all stochastic
nodes in the model.
:Note: It is quite easy to then save this trace to csv by
calling model.get_traces().to_csv('samples.csv')
"""
return pd.DataFrame({i.__name__: i.trace() for i in self.get_stochastics().node})
def get_data_nodes(self, idx):
data_nodes = []
for name, node_descr in self.iter_observeds():
node = node_descr['node']
if set(idx).issubset(set(node.value.index)):
data_nodes.append(node)
if len(data_nodes) != 1:
raise NotImplementedError("Supply a grouping so that at most 1 observed node codes for each group.")
return data_nodes[0]
def __getitem__(self, name):
return self.nodes_db.ix[name]['node']
@property
def values(self):
values = OrderedDict()
for (name, node) in self.iter_non_observeds():
if node['node'].value.shape == ():
values[name] = node['node'].value[()]
return values
def set_values(self, new_values):
"""
set values of nodes according to new_values
Input:
new_values <dict> - dictionary of the format {'node_name1': new_value1, ...}
"""
for (name, value) in new_values.items():
self.nodes_db.ix[name]['node'].set_value(value)
def find_starting_values(self, *args, **kwargs):
"""Find good starting values for the different parameters by
optimization.
For more options see approximate_map and map. Arguments are forwarded.
"""
if self.is_group_model:
self.approximate_map(*args, **kwargs)
else:
self.map(*args, **kwargs)
def _partial_optimize(self, optimize_nodes, evaluate_nodes, fall_to_simplex=True, minimizer='Powell', use_basin=False, debug=False, minimizer_kwargs=None, basin_kwargs=None):
"""Optimize part of the model.
:Arguments:
nodes : iterable
list nodes to optimize.
"""
if minimizer_kwargs is None:
minimizer_kwargs = {}
if basin_kwargs is None:
basin_kwargs = {}
non_observeds = [x for x in optimize_nodes if not x.observed]
init_vals = [node.value for node in non_observeds]
# define function to be optimized
def opt(values):
if debug: print(values)
for value, node in zip(values, optimize_nodes):
node.set_value(value)