# Example Starting Point for New Data Formats

In [5]:
# Are you using a special reservation for a workshop?
# If so, set it here:
nersc_reservation = "toast3"

# Load common tools for all lessons
import sys
sys.path.insert(0, "../lessons")
from lesson_tools import (
    check_nersc,
)
nersc_host, nersc_repo, nersc_resv = check_nersc(reservation=nersc_reservation)

# Capture C++ output in the jupyter cells
%reload_ext wurlitzer


Running on NERSC machine 'cori'
  with access to repos: mp107
Using default repo mp107
Reservation 'toast3' valid from 2019-10-16T09:00:00 to 2019-10-16T17:00:00
Current time is 2019-10-16T11:08:10.418371
Selecting reservation 'toast3'


In [3]:
import h5py
import numpy as np
f = h5py.File('test.h5')
f.create_dataset('channel_one',data=np.ones((10,10)))
f.close()

## TOD Class

This is the stub of a TOD class to read one observation of data.

In [20]:
import os

import toast
from toast.mpi import MPI
from toast.tod import TOD

class NewTOD(TOD):
    # You can override the default names of cache keys here.  They
    # are defined in the toast.TOD 
    BORESIGHT_NAME = "boresight"
    BORESIGHT_AZEL_NAME = "boresight_azel"
    """This class contains the timestream data.

    This loads data from a custom data format.  Add more documentation here
    about what it is doing...
    
    Add more constructor arguments to get all the info you need to be
    able to read the data.

    Args:
        path (str):  The path to an observation file.
        detquats (dict):  Dictionary of detector names and quaternion
            offsets from the boresight.
        mpicomm (mpi4py.MPI.Comm): the MPI communicator over which this
            observation data is distributed.
        detranks (int):  The dimension of the process grid in the detector
            direction.  The MPI communicator size must be evenly divisible
            by this number.

    """
    def __init__(self, path, detquats, mpicomm=None, detranks=1):
        self._path = path
        self._detquats = detquats
        
        # Figure out how many samples there are in this observation.  Also,
        # if there are any kind of "sub chunks" in the observation that should
        # not be split up between processes (e.g. left and right azimuth
        # scans), then compute them here.
        nsamp = 100
        
        # This is just a list of one element (the whole observation).  You
        # could specify the chunks in samples that should never be split up
        # between processes.
        sampsizes = [nsamp]
        
        # Here we assign unique IDs to every detector.  This is used for
        # reproducible simulations.  You can decide how to assign these for
        # your project.  Here they just assigned based on the sorted list
        # of detector names.
        
        detnames = list(sorted(detquats.keys()))
        
        detindx = {x[1]: x[0] for x in enumerate(detnames)}

        # Call base class constructor to distribute data
        super().__init__(
            mpicomm, detnames, nsamp,
            detindx=detindx, detranks=detranks,
            sampsizes=sampsizes, meta=dict()
        )
        
        # If we are caching some data (e.g. boresight pointing, auxilliary
        # files needed by any read operation, etc) then do it here.  Depending
        # on the data format, you may need to just load all data into the
        # self.cache object here.
        with h5py.File(path,'r') as f:
            self.cache.put('signal_channel_one',f['channel_one'])
        
        #nsamp = len(self.cache.reference('signal_channel_one'))
        return

    def detoffset(self):
        return dict(self._detquats)
    
    # The methods below assume that the data was cached during construction.
    # If not, then you can read the different data products inside each method.
    # You can customize the 

    def _get_boresight(self, start, n):
        # This assumes you cached the boresight pointing in RA/DEC
        # in the constructor.
        ref = self.cache.reference(self.BORESIGHT_NAME)[start:start+n, :]
        return ref

    def _put_boresight(self, start, data):
        ref = self.cache.reference(self.BORESIGHT_NAME)
        ref[start:(start+data.shape[0]), :] = data
        del ref
        return

#     def _get_boresight_azel(self, start, n):
#         ref = self.cache.reference(self.BORESIGHT_AZEL_NAME)[start:start+n, :]
#         return ref

#     def _put_boresight_azel(self, start, data):
#         ref = self.cache.reference(self.BORESIGHT_AZEL_NAME)
#         ref[start:(start+data.shape[0]), :] = data
#         del ref
#         return

    def _get(self, detector, start, n):
        name = "{}_{}".format(self.SIGNAL_NAME, detector)
        ref = self.cache.reference(name)[start:start+n]
        return ref

    def _put(self, detector, start, data):
        name = "{}_{}".format(self.SIGNAL_NAME, detector)
        ref = self.cache.reference(name)
        ref[start:(start+data.shape[0])] = data
        del ref
        return

    def _get_flags(self, detector, start, n):
        name = "{}_{}".format(self.FLAG_NAME, detector)
        ref = self.cache.reference(name)[start:start+n]
        return ref

    def _put_flags(self, detector, start, flags):
        name = "{}_{}".format(self.FLAG_NAME, detector)
        ref = self.cache.reference(name)
        ref[start:(start+flags.shape[0])] = flags
        del ref
        return

    def _get_common_flags(self, start, n):
        ref = self.cache.reference(self.COMMON_FLAG_NAME)[start:start+n]
        return ref

    def _put_common_flags(self, start, flags):
        ref = self.cache.reference(self.COMMON_FLAG_NAME)
        ref[start:(start+flags.shape[0])] = flags
        del ref
        return

    def _get_hwp_angle(self, start, n):
        if self.cache.exists(self.HWP_ANGLE_NAME):
            hwpang = self.cache.reference(self.HWP_ANGLE_NAME)[start:start+n]
        else:
            hwpang = None
        return hwpang

    def _put_hwp_angle(self, start, hwpang):
        ref = self.cache.reference(self.HWP_ANGLE_NAME)
        ref[start:(start + hwpang.shape[0])] = hwpang
        del ref
        return

    def _get_times(self, start, n):
        ref = self.cache.reference(self.TIMESTAMP_NAME)[start:start+n]
        tm = 1.0e-9 * ref.astype(np.float64)
        del ref
        return tm

    def _put_times(self, start, stamps):
        ref = self.cache.reference(self.TIMESTAMP_NAME)
        ref[start:(start+stamps.shape[0])] = np.array(1.0e9 * stamps,
                                                      dtype=np.int64)
        del ref
        return

    def _get_pntg(self, detector, start, n):
        # Get boresight pointing (from disk or cache)
        bore = self._get_boresight(start, n)
        # Apply detector quaternion and return
        return qa.mult(bore, self._detquats[detector])

    def _put_pntg(self, detector, start, data):
        raise RuntimeError("This class computes detector pointing on the fly")
        return

    def _get_position(self, start, n):
        ref = self.cache.reference(self.POSITION_NAME)[start:start+n, :]
        return ref

    def _put_position(self, start, pos):
        ref = self.cache.reference(self.POSITION_NAME)
        ref[start:(start+pos.shape[0]), :] = pos
        del ref
        return

    def _get_velocity(self, start, n):
        ref = self.cache.reference(self.VELOCITY_NAME)[start:start+n, :]
        return ref

    def _put_velocity(self, start, vel):
        ref = self.cache.reference(self.VELOCITY_NAME)
        ref[start:(start+vel.shape[0]), :] = vel
        del ref
        return

In [26]:
comm = toast.Comm()

NewTOD('test.h5', detquats={'channel_one':np.array([1,0,0,0])}, mpicomm=comm.comm_group, detranks=1)

<NewTOD
  1 total detectors and 100 total samples
  Using MPI communicator None
    In grid dimensions 1 sample ranks x 1 detranks
  Process at (0, 0) in grid has data for:
    Samples 0 - 99 (inclusive)
    Detectors:
      channel_one
    Cache contains 800 bytes
>

## Loading a Single Observation

This function creates one observation (i.e. a dictionary) with the TOD object and any other metadata.

In [None]:
def load_observation(path, mpicomm=None, detranks=1, **kwargs):
    """Create an observation.

    Extra keyword args are passed to the TOD constructor.

    Args:
        path (str):  The path to the observation.
        mpicomm (mpi4py.MPI.Comm): the MPI communicator over which this
            observation data is distributed.
        detranks (int):  The dimension of the process grid in the detector
            direction.  The MPI communicator size must be evenly divisible
            by this number.

    Returns:
        (dict):  The observation dictionary.

    """
    rank = 0
    if mpicomm is not None:
        rank = mpicomm.rank

    obs = dict()

    if rank == 0:
        # Rank zero should open up any files to get things needed to construct the TOD
        pass

    obs["tod"] = NewTOD(path, detquats, mpicomm=mpicomm, detranks=detranks, **kwargs)
    return obs

## Load Balancing Observations

This function computes a "weight" for each observation based on the same information that will be given to the TOD constructor.  Here we just return a weight based on the number of samples.  This can be used for an approximate load balancing below.

In [None]:
def obsweight(path):
    """Compute observation weight.

    Given a path to a "file", return the relative weight for this
    observation.

    Args:
        path (str):  Path to the observation

    Returns:
        (float):  Relative weight

    """
    return 1.0

## Loading a Dataset (Multiple Observations)

This function takes some parameters and distributes observations among process groups.  Then every group creates their assigned observations.

In [None]:
from toast.dist import distribute_discrete

def load_data(dir, obs=None, comm=None, **kwargs):
    """Loads data.

    This should take options for selecting observations based on some criteria.

    Additional keyword args are passed to the load_observation function.

    Args:
        dir (str):  Top directory of data.
        obs (list):  The list of observations to load.
        comm (toast.Comm): the toast Comm class for distributing the data.

    Returns:
        (toast.Data):  The distributed data object.

    """
    # the global communicator
    cworld = comm.comm_world
    # the communicator within the group
    cgroup = comm.comm_group

    # One process gets the list of observation directories
    obslist = list()
    weight = dict()

    worldrank = 0
    if cworld is not None:
        worldrank = cworld.rank

    if worldrank == 0:
#         for root, dirs, files in os.walk(dir, topdown=True):
#             for d in dirs:
#                 # Get a list of directory names as the "observations".  What you
#                 # do here depends on how your data is organized.
#                 obslist.append(d)
#                 weight[d] = obsweight(os.path.join(root, dir))
#             break
        obslist = ["foo", "bar", "blat", "obs_to_cut"]
        obslist = sorted(obslist)
        # Filter by the requested obs
        fobs = list()
        if obs is not None:
            for ob in obslist:
                if ob in obs:
                    fobs.append(ob)
            obslist = fobs

    # Communicate what observations we are using.
    if cworld is not None:
        obslist = cworld.bcast(obslist, root=0)
        weight = cworld.bcast(weight, root=0)

    # Distribute observations based on the relative weight.
    dweight = [weight[x] for x in obslist]
    distobs = distribute_discrete(dweight, comm.ngroups)

    # Distributed data
    data = Data(comm)

    # Now every group adds its observations to the list

    firstobs = distobs[comm.group][0]
    nobs = distobs[comm.group][1]
    for ob in range(firstobs, firstobs+nobs):
        opath = os.path.join(dir, obslist[ob])
        print("Loading {}".format(opath))
        # In case something goes wrong on one process, make sure the job
        # is killed.
        try:
            data.obs.append(
                load_observation(opath, mpicomm=cgroup, **kwargs)
            )
        except:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            lines = traceback.format_exception(exc_type, exc_value,
                                               exc_traceback)
            lines = ["Proc {}: {}".format(worldrank, x)
                     for x in lines]
            print("".join(lines), flush=True)
            if cworld is not None:
                cworld.Abort()

    return data

In [None]:
# Uncomment this when writing a file for MPI
# %%writefile data_formats_mpi.py

import toast
from toast.mpi import MPI

comm = toast.Comm()

data = load_data("data/directory", obs=["foo", "bar", "blat"], comm=comm)

print(data)


In [None]:
import subprocess as sp

command = "python data_formats_mpi.py"
runstr = None

if nersc_host is not None:
    runstr = "export OMP_NUM_THREADS=4; srun -N 2 -C haswell -n 32 -c 4 --cpu_bind=cores -t 00:05:00"
    if nersc_resv is not None:
        runstr = "{} --reservation {}".format(runstr, nersc_resv)
else:
    # Just use mpirun
    runstr = "mpirun -np 4"

runcom = "{} {}".format(runstr, command)
print(runcom, flush=True)

# Uncomment this line to actually submit the job
# sp.check_call(runcom, stderr=sp.STDOUT, shell=True)
