# Plot all Nakazato 2013 models together

In [1]:
from os import listdir
from os.path import isfile

import pandas as pd
import plotly.express as px

from sspike.targets import Target
from sspike.beer import sort_channels

### Get file paths to tabulated results

In [2]:
# Path to sspike snowball directory.
snowball_dir = '/Users/joe/src/gitjoe/sspike/snowballs/'

# Simulation list and plot labels.
sn_names = ['N13-13-04-100',
             'N13-13-04-200',
             'N13-13-04-300',
             'N13-13-20-100',
             'N13-13-20-200',
             'N13-13-20-300',
             'N13-20-04-100',
             'N13-20-04-200',
             'N13-20-04-300',
             'N13-20-20-100',
             'N13-20-20-200',
             'N13-20-20-300',
             'N13-30-20-100',
             'N13-30-20-200',
             'N13-30-20-300',
             'N13-50-04-100',
             'N13-50-04-200',
             'N13-50-04-300',
             'N13-50-20-100',
             'N13-50-20-200',
             'N13-50-20-300']

distance = 10.

# List for filepaths to totals files tabulated using sspike.beer.tab()
beer_tabs = []

# Folder containing simulation results for each model.
for name in sn_names:
    sn_dir = f"{snowball_dir}{name}/"
    # Results for distance of interest.
    for d in listdir(sn_dir):
        if f"{distance}kpc" in d:
            file_name = f"{sn_dir}{d}/totals.txt"
            if isfile(file_name):
                beer_tabs.append(file_name)
            else:
                print(f'File not found!\n{file_name}')

### Set target for channel lists.

In [3]:
target = Target('kamland')
combos = {**sort_channels(target.snow_channels),
          **sort_channels(target.nc_channels)}

### Read each file into a data frame

In [4]:
# Columns for simulation type and events per channel.
# TODO: add other simulation properties for sorting.
column_names = ['sn_name', 'Channel', 'Mass',
                'Metal', 'Revival [ms]', 'Counts']

events = pd.DataFrame(columns=column_names)
for i, sn in enumerate(sn_names):
    # One row for each simulation file.
    'N13-20-20-300'
    vals = sn.split('-')
    mass = vals[1]
    z = vals[2]
    t = vals[3]

    if z == '20':
        Z = 'Solar'
    else:
        Z = 'SMC'

    row = {'sn_name': sn, 'Mass': mass, 'Metal': Z, 'Revival [ms]': t}

    # Counters for combining flavors by target.
    counts = {}
    for channel in combos:
        counts[channel] = 0.

    # Get results from sspike.beer.tab() for each file.
    with open(beer_tabs[i], 'r') as f:
        lines = f.readlines()

    # Sort by data type/processing method.
    skip = 0
    data_type = None
    nc_flav_count = 0

    # Check each line in file.
    for line in lines:
        # Skip blank line separators.
        if not line.strip():
            continue

        # Skip lines that are for checking, but not really of interest here.
        if skip:
            skip -= 1
            continue

        # Data types of interest: snow-smeared, sspike-nc.
        if len(line.split('-')) == 2:
            data_type = line.strip()

            # TODO: make this check target channel lengths.
            if data_type == 'snow-smeared':
                skip = 1
            if data_type == 'snow-unsmeared':
                skip = 23
            if data_type == 'sspike-basic':
                skip = 5
            if data_type == 'sspike-nc':
                skip = 21
            continue
        
        # Get channel name and number of events.
        try:
            channel, count = line.split(':')
        except:
            print(f"ERROR: {line}")
            continue

        # Add into counts dictionary for this file.
        for column in combos.keys():
            if channel in combos[column]:
                counts[column] += float(count)
                if column == 'p-nc':
                    nc_flav_count += 1
                break
        
        # Only looking at 200 keV energy cut for now (skipping 300 keV cut).
        if nc_flav_count == 4:
            break

    for channel in counts:
        row['Channel'] = channel
        row['Counts'] = counts[channel]
        events = events.append(row, ignore_index=True)

In [5]:
bars = px.bar(events, x='Revival [ms]', y='Counts', 
              color='Channel', barmode='group',
              facet_row='Metal', facet_col='Mass',
              category_orders={"Mass": ["13", "20", "30", "50"],
                               "Metal": ["Solar", "SMC"]},
              log_y=True)
bars.layout.bargroupgap = 0.2
# bars.layout.title = f'Nakazato 2013 event rates at 10 kpc'
#                     f'using snewpy and SNOwGLobES'
bars.layout.font = dict(size=18, family="Times New Roman")
bars

In [10]:
bars.write_image('../plots/nakazato.png', width=1100, height=400, scale=3)