In [1]:
import pyprob
from pyprob.distributions import Normal
import pyprob.diagnostics
import torch
import pandas
import os

def parse_n_infect_from_output(file_name='/home/examples/output.txt'):
    data = pandas.read_csv(file_name, sep='\t', names = ["survey_number", "third_dimension", "measure", "value"])
    n_infect = torch.from_numpy(data['value'][:12].values)
    return n_infect

PATH = '/home/examples/test/'
def forward_before():
    exists = os.path.isfile(PATH + 'output.txt')
    if exists:
        print('Removing output...')
        os.remove(PATH+'output.txt')
    
def forward_after():
    n_infect = parse_n_infect_from_output()
    likelihood = Normal(n_infect, 0.01)
    pyprob.observe(likelihood, name='n_infect')

file_name = 'trace'
address_dict_file_name = 'address_dict'
use_address_base=True
model = pyprob.RemoteModel('ipc://@openmalaria_probprog', before_forward_func=forward_before ,after_forward_func=forward_after, address_dict_file_name=address_dict_file_name)

## Generating the Prior

In [None]:
prior = model.prior_traces(num_traces=5)

Time spent  | Time remain.| Progress             | Trace | Traces/sec
ppx (Python): zmq.REQ socket connecting to server ipc://@openmalaria_probprog
ppx (Python): This system        : [32mpyprob 0.13.2.dev18[0m
ppx (Python): Connected to system: [32mpyprob_cpp 0.1.9 (master:7290be2)[0m
ppx (Python): Model name         : [1m[32mOpenMalaria probprog[0m
Removing output...
Removing output...0:54:42 | ########------------ | 2/5 | 0.00       


## Generating the posterior: Conditioning on the number of infected

In [None]:
is_posterior_traces = model.posterior_traces(observe={'n_infect': ground_truth_prevalence}, num_traces=5)
is_posterior_population = is_posterior_traces.map(lambda trace: trace.named_variables['population'].value)
is_posterior_prevalence = is_posterior_population.map(get_prevalence)

In [None]:
infere_engine = 'RMH'
print('Saving posterior distribution with {} traces and inference engine {} to: {}'.format(20, infer_engine, file))
            create_path(opt.output_file)
            posterior_dist = model.posterior_traces(num_traces=20, inference_engine=inference_engine, observe={'calorimeter_n_deposits': observation}, initial_trace=initial_trace, file_name=opt.output_file, thinning_steps=opt.thinning_steps, likelihood_importance=opt.likelihood_importance)

pyprob.diagnostics.address_histograms(dists, plot=True, plot_show=True, file_name=file_name, ground_truth_trace=ground_truth_trace, use_address_base=use_address_base)

In [None]:
# ground_truth_trace = next(model._trace_generator())
prior.named_variables['n_infect'].value

In [8]:
# print(len(prior))
# trace=prior._values[0]
prior._values
trace=prior._values[0]
var = trace.variables
# len(var)
# print(var)
count  = 1
for k in trace.variables_dict_address_base:
    print(" Address {2} : {0} \n Value : {1} \n\n".format(k,trace.variables_dict_address_base[k],count))
    count += 1
# pyprob.diagnostics.graph(prior, file_name="mpenny_10")
print(prior.save_metadata)

 Address 1 : [forward()+0x1f3; OM::Simulator::start(scnXml::Monitoring const&)+0x28a; OM::Population::createInitialHumans()+0x94; OM::Population::newHuman(OM::SimTime)+0x5c; OM::Host::Human::Human(OM::SimTime)+0x12b; OM::WithinHost::WHInterface::createWithinHostModel(double)+0x72; OM::WithinHost::CommonWithinHost::CommonWithinHost(double)+0x3a; OM::WithinHost::WHFalciparum::WHFalciparum(double)+0xe6; OM::util::random::gauss(double, double)+0xb4]__Normal 
 Value : Variable(name:None, control:True, replace:False, observable:False, observed:False, tagged:False, address:[forward()+0x1f3; OM::Simulator::start(scnXml::Monitoring const&)+0x28a; OM::Population::createInitialHumans()+0x94; OM::Population::newHuman(OM::SimTime)+0x5c; OM::Host::Human::Human(OM::SimTime)+0x12b; OM::WithinHost::WHInterface::createWithinHostModel(double)+0x72; OM::WithinHost::CommonWithinHost::CommonWithinHost(double)+0x3a; OM::WithinHost::WHFalciparum::WHFalciparum(double)+0xe6; OM::util::random::gauss(double, doub

In [5]:
def plot_distribution(dist, obs_mode, ground_truth_trace=None, file_name=None):
    if dist.length > 0:
        # dist_mode = dist.mode
        # dist_mode_numpy = trace_to_numpy(dist_mode)
        dist_numpy = distribution_to_numpy(dist, obs_mode)
        num_traces = dist_numpy.shape[0]
        if ground_truth_trace is not None:
            ground = trace_to_numpy(ground_truth_trace, obs_mode)

        def nhad_nem_ninvis(pids,viz):
            n_em      = np.sum([1 if abs(p) in [22,11] else 0 for p in pids])
            n_had     = np.sum([1 if abs(p)>100 else 0 for p in pids])
            n_calovis = np.sum(viz)
            n_invis   = np.sum(1-viz)
            assert n_em + n_had + n_invis == len(pids)
            return [n_em,n_had,n_calovis,n_invis]


        mother           = dist_numpy[:,:3]
        channel          = dist_numpy[:,3]
        final            = dist_numpy[:,4:4+(30*8)].reshape(num_traces,30,8)
        finalfilt        = [np.array(sorted(f[f>-9999].reshape(-1,8), key=lambda x: -x[3])) for f in final]
        finalmult        = np.array([f.shape[0] for f in finalfilt])
        obs              = dist_numpy[:,4+(30*8):].reshape(num_traces,35,35,20)
        particle_types   = np.array([nhad_nem_ninvis(f[:,6],f[:,7]) for f in finalfilt])

        if ground_truth_trace is not None:
            g_mother         = ground[:3]
            g_channel        = ground[3]
            g_final          = ground[4:4+(30*8)].reshape(30,8)
            g_finalfilt      = np.array(sorted(g_final[g_final>-9999].reshape(-1,8), key=lambda x: -x[3]))
            g_finalmult      = g_finalfilt.shape[0]
            g_obs            = ground[4+(30*8):].reshape(35,35,20)
            g_particle_types = nhad_nem_ninvis(g_finalfilt[:,6],g_finalfilt[:,7])

        from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection


        particle_styles = {
            211: ('k', 'solid',  [0,15]), # pi
            16:  ('b', 'dashed', [0,15]), # tau lepton
            22:  ('y', 'solid',  [0, 9]), # photon
            11:  ('y', 'solid',  [0, 9]), # electron
        }

        def trajectory(momentum, zrange = [4,10]):
            t0 = zrange[0]/momentum['pz']
            t1 = zrange[1]/momentum['pz']
            return [momentum['px']*np.linspace(t0,t1), momentum['py']*np.linspace(t0,t1), momentum['pz']*np.linspace(t0,t1)]


        def plot_surf(ax,surfz = 4):
            colors = [mpl.colors.hex2color(mpl.colors.cnames['green']) + (0.1,)]
            faces  = [
                [[-3,-3,surfz], [-3,3,surfz], [3,3,surfz], [3,-3,surfz]]
            ]

            for face,color in zip(faces,colors):
                ax.add_collection3d (Poly3DCollection ([face], facecolor = color))

        def extend_obs(observation):
            minx = 4 - 11./(20.-1.) * 7
            np.linspace(minx,15,27)
            mgrid = ix,iy,iz = np.mgrid[-3:3:35j,-3:3:35j,minx:15:27j]
            extended = np.zeros((35,35,27))
        #    extended[:,:,7:] = np.nan
        #    observation[observation==0]=np.nan
        #    extended[:,:,0:7] = np.zeros((35,35,7))
            extended[:,:,7:] = observation
            return mgrid, extended

        def plot(ax,ix,iy,iz,observation, trajectories_trace = None):
            cutoff = np.mean(observation)
            ax.set_xlim(-3,3)
            ax.set_ylim(-3,3)
            ax.set_zlim(0,15)

            if trajectories_trace is not None:
                mother         = trajectories_trace[:3]
                final          = trajectories_trace[4:4+(30*8)].reshape(30,8)
                finalfilt      = np.array(sorted(final[final>-9999].reshape(-1,8), key=lambda x: -x[3]))
                mother_momentum = dict(zip(['px','py','pz'],mother))
                decay_particles = [dict(zip(['px', 'py', 'pz', 'E', 'theta', 'phi', 'pid', 'visible'],map(float,p))) for p in finalfilt]
                ax.plot(*trajectory(mother_momentum,[4,15]), c = 'r', linestyle = 'dotted')

                for p in decay_particles:
                    if abs(p['pid']) in particle_styles:
                        c, style, zlim = particle_styles[abs(p['pid'])]
                    else:
                        c, style, zlim = ('red', 'solid',  [0, 9])  # other (added by Gunes)
                #     ax.plot(*trajectory(p), linewidth=p['E']/5., c = c, linestyle = style)
                    ax.plot(*trajectory(p,zlim), linewidth=2.5, c = c, linestyle = style)


            # cutoff = np.mean(observation)
            cutoff = 0.1
            sizes = np.zeros(shape = observation.shape)
            sizes[observation > cutoff] = 100
            ax.scatter(ix,iy,iz, c = observation.ravel(), alpha = 0.1, s = sizes)

            plot_surf(ax)

            ax.set_xlabel('x')
            ax.set_ylabel('y')
            ax.set_zlabel('z')

            ax.view_init(10.,20.)


        f = plt.figure()
        f.set_size_inches(25,10)

        shape = (4,11)

        ax1 = plt.subplot2grid(shape, (0, 1))
        ax2 = plt.subplot2grid(shape, (0, 2))
        ax3 = plt.subplot2grid(shape, (0, 3))


        motheraxes = [ax1,ax2,ax3]


        emhadcomp   = plt.subplot2grid(shape, (0, 4))
        nfinalstate = plt.subplot2grid(shape, (1, 4))
        channelax   = plt.subplot2grid(shape, (1, 2), colspan=2)

        measimobsax = plt.subplot2grid(shape, (0, 5), rowspan=2, colspan=2, projection='3d')
        # modsimobsax = plt.subplot2grid(shape, (0, 7), rowspan=2, colspan=2, projection='3d')
        obsax       = plt.subplot2grid(shape, (0, 7), rowspan=2, colspan=2, projection='3d')

        finalstateaxes = {}

        for i in range(2):
            for j in range(2):
                if i < j:
                    continue
                ax = plt.subplot2grid(shape, (i,j))
                finalstateaxes.setdefault(i,{})[j] = ax



        # ## number of final state particles
        ax = nfinalstate

        h    = nfinalstate.hist(finalmult, bins = np.linspace(-0.5,10.5,12), density=True)
        ymax = np.max(h[0])*1.5
        if ground_truth_trace is not None:
            l    = ax.vlines(g_finalmult,0,ymax, linestyles='dashed')
        # ax.legend([l,h[2][0]],['ground truth','posterior'])
        ax.set_ylim(0,ymax)
        ax.set_xlim(xmin=0)
        ax.set_xlabel('Number of Final State Particles')
        ax.set_title('Decay Products')
        ax.xaxis.set_major_locator(MultipleLocator(1))


        # ## number of final state particles
        ax = emhadcomp
        em_had   = particle_types[:,:2]
        if ground_truth_trace is not None:
            g_em_had = g_particle_types[:2]

        nx, ny = 10,8
        h =ax.hist2d(em_had[:,0],em_had[:,1], [np.linspace(-0.5,0.5+nx,nx+2), np.linspace(-0.5,0.5+ny,ny+2)], normed=True, cmap = 'viridis')
        if ground_truth_trace is not None:
            l = ax.scatter(g_em_had[0],g_em_had[1],c = 'w', edgecolors='k')
        # ax.legend([l],['ground truth'])
        ax.set_xlabel('Number of EM Particles')
        ax.set_ylabel('Number of HAD Particles')
        ax.set_title('Event Composition')
        ax.xaxis.set_major_locator(MultipleLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(1))
        # # plt.colorbar(h[3], ax=ax)



        ##
        axarr = finalstateaxes
        energies = np.array([f[:2] for f in finalfilt])
        try:
            for i in range(2):
                for j in range(2):
                    if i > j:
                        ax = axarr[i][j]
                        ax.hist2d(energies[:,j,3],energies[:,i,3], [np.linspace(0,40,11),np.linspace(0,40,11)], cmap='viridis')
                        if ground_truth_trace is not None:
                            l = ax.scatter(g_finalfilt[:3][j,3],g_finalfilt[:3][i,3], c = 'w', edgecolors='k')
                        ax.set_title('FSP Energy Joint')
                    elif i==j:
                        ax = axarr[i][j]
                        h = ax.hist(energies[:,j,3],bins = np.linspace(0,40,11), facecolor = 'grey', density=True)
                        if ground_truth_trace is not None:
                            l = ax.vlines(g_finalfilt[:3][j,3],0,1, linestyles='dashed')
                        ax.set_ylim(0,1.5*np.max(h[0]))
                        ax.set_xlim(0,40)
                        ax.set_title('FSP Energy {}'.format(i + 1))
        except:
            print('Error with plotting FSP')

        # ## tau momentum
        axarr = motheraxes
        colors = mpl.cm.inferno(np.linspace(0,1,5))[1:-1]

        titles = ['τ px','τ py','τ pz']
        limits = [[-3,3],[-3,3],[43,47]]

        for i,(lim,ax,c,t) in enumerate(zip(limits,axarr,colors,titles)):
            n,_,h = ax.hist(mother[:,i], density=True, bins = np.linspace(*(lim+[11])), facecolor=c)
            ymax = 1.5*np.max(n)
            if ground_truth_trace is not None:
                l = ax.vlines(g_mother[i],0,ymax, linestyles='dashed')
            # ax.legend([l,h[0]],['ground truth','posterior'])
            ax.set_xlim(*lim)
            ax.set_title(t)
            ax.set_ylim(0,ymax)

        # ##

        ax = channelax
        ch = ax.hist(channel, np.linspace(-0.5,35.5,37), density=True, facecolor = 'grey')
        if ground_truth_trace is not None:
            l = ax.vlines(g_channel,0,1.0, linestyles='dashed')
        ax.set_ylim(0.,1.0)
        ax.set_xlim(-1, 38)
        # ax.legend([l,ch[2][0]],['ground truth','posterior'])
        ax.xaxis.set_minor_locator(MultipleLocator(1))
        ax.set_title('Decay Channel')

        if ground_truth_trace is not None:
            (ix, iy, iz), obs_extended = extend_obs(g_obs)
            obs_extended = (obs_extended - np.min(obs_extended)) / (np.max(obs_extended) - np.min(obs_extended))
            cutoff = 0.1
            # cutoff = 0
            # print(observation.shape)
            # print(ix.shape)
            obsax.set_title('Observed Calorimeter')
            ixa = np.append(ix[obs_extended>cutoff],[3, -3.])
            iya = np.append(iy[obs_extended>cutoff],[3, -3.])
            iza = np.append(iz[obs_extended>cutoff],[0, 14])
            # print(ixa.shape, iya.shape, iza.shape)
            # print(np.append(obs_extended[obs_extended>cutoff],[0,0]).shape)
            plot(obsax,ixa,iya,iza,np.append(obs_extended[obs_extended>cutoff],[0,0]), trajectories_trace=ground)
            #
            # print('cutoff', cutoff)
            # print('avgobs_extended min', np.min(avgobs_extended))
            # print('avgobs_extended max', np.max(avgobs_extended))

        avgobs = np.average(obs,0)
        # print(avgobs.shape)
        (ix, iy, iz), avgobs_extended = extend_obs(avgobs)
        # cutoff = np.mean(avgobs_extended)/10000
        # cutoff = 0
        avgobs_extended = (avgobs_extended - np.min(avgobs_extended)) / (np.max(avgobs_extended) - np.min(avgobs_extended))
        cutoff = 0.1

        measimobsax.set_title('Simulated Calorimeter (Mean)')
        ixa = np.append(ix[avgobs_extended>cutoff],[3, -3.])
        iya = np.append(iy[avgobs_extended>cutoff],[3, -3.])
        iza = np.append(iz[avgobs_extended>cutoff],[0, 14])
        # print(ixa.shape, iya.shape, iza.shape)
        # print(np.append(avgobs_extended[avgobs_extended>cutoff],[0,0]).shape)
        plot(measimobsax,ixa,iya,iza,np.append(avgobs_extended[avgobs_extended>cutoff],[0,0]), trajectories_trace=None)

        # modsimobsax.set_title('Simulated Calorimeter (Mode)')
        # modeobs = dist_mode_numpy[4+(30*8):].reshape(35, 35, 20)
        # (ix, iy, iz), modeobs_extended = extend_obs(modeobs)
        # cutoff = np.mean(modeobs_extended)/10000
        # cutoff = 0
        # modeobs_extended = (modeobs_extended - np.min(modeobs_extended)) / (np.max(modeobs_extended) - np.min(modeobs_extended))
        # cutoff = 0.1

        # ixa = np.append(ix[modeobs_extended>cutoff],[3, -3.])
        # iya = np.append(iy[modeobs_extended>cutoff],[3, -3.])
        # iza = np.append(iz[modeobs_extended>cutoff],[0, 14])
        # print(ixa.shape, iya.shape, iza.shape)
        # print(np.append(modeobs_extended[modeobs_extended>cutoff],[0,0]).shape)
        # plot(modsimobsax,ixa,iya,iza,np.append(modeobs_extended[modeobs_extended>cutoff],[0,0]), trajectories_trace=dist_mode_numpy)

        plt.suptitle(dist.name, x=0.0, y=.99, horizontalalignment='left', verticalalignment='top')
        plt.tight_layout(rect=[0, 0.03, 1, 0.98])

        if file_name is not None:
            plt.savefig(file_name + '.pdf', bbox_inches='tight')


def create_path(path, directory=False):
    if directory:
        dir = path
    else:
        dir = os.path.dirname(path)
    if not os.path.exists(dir):
        print('{} does not exist, creating'.format(dir))
        try:
            os.makedirs(dir)
        except Exception as e:
            print(e)
            print('Could not create path, potentiall created by another rank in multinode: {}'.format(path))


'/Users/bradley/Documents/Projects/OpenMalaria/code/notebooks'