# Saving and Loading IceCube Data as TFRecord 🚀

In [6]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import tensorflow as tf
from contextlib import ExitStack
import os
from datetime import datetime
import multiprocess

In [7]:
meta = pd.read_parquet('/kaggle/input/icecube-neutrinos-in-deep-ice/train_meta.parquet')
sensor_geometry = pd.read_csv("/kaggle/input/icecube-neutrinos-in-deep-ice/sensor_geometry.csv")
batch = pd.read_parquet("/kaggle/input/icecube-neutrinos-in-deep-ice/train/batch_1.parquet")
batch.head()

Unnamed: 0_level_0,sensor_id,time,charge,auxiliary
event_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
24,3918,5928,1.325,True
24,4157,6115,1.175,True
24,3520,6492,0.925,True
24,5041,6665,0.225,True
24,2948,8054,1.575,True


# Dataset Creation

In [8]:
def create_example_protobuff(event_pulses, azimuth, zenith):
    # convert to binary string format for Example protobuf
    event_data = tf.io.serialize_tensor(tf.cast(event_pulses, tf.float32))
    return tf.train.Example(
        features=tf.train.Features(
            feature={
                'event_pulses': tf.train.Feature(bytes_list=tf.train.BytesList(value=[event_data.numpy()])),
                'azimuth': tf.train.Feature(float_list=tf.train.FloatList(value=[azimuth])),
                'zenith': tf.train.Feature(float_list=tf.train.FloatList(value=[zenith])),
            }
        )
    )

def write_tfrecords(filename, dataset):
    with ExitStack() as stack:
        writer = stack.enter_context(tf.io.TFRecordWriter(filename))

        # create example protobuffs from instances
        for event, az_zen_pair in dataset:
            event = tf.Variable(event)
            azimuth, zenith = az_zen_pair
            example = create_example_protobuff(event, azimuth, zenith)
            writer.write(example.SerializeToString())
            
def save_to_tfrecord(X, Y, name):
    dataset = zip(X, Y)
    
    # write Dataset to files
    write_tfrecords(f"tfrecords/{name}", dataset)

In [9]:
NUM_BATCHES_TO_CONVERT = 10

train_batches = meta.batch_id.unique()
train_meta_baseline = meta[meta.batch_id.isin(train_batches[:NUM_BATCHES_TO_CONVERT])]
print(len(train_meta_baseline))
train_meta_baseline.head()

2000000


Unnamed: 0,batch_id,event_id,first_pulse_index,last_pulse_index,azimuth,zenith
0,1,24,0,60,5.029555,2.087498
1,1,41,61,111,0.417742,1.549686
2,1,59,112,147,1.160466,2.401942
3,1,67,148,289,5.845952,0.759054
4,1,72,290,351,0.653719,0.939117


### Now we go through the parquets and create the tfrecords after each parquet file is loaded into memory

In [10]:
open_batch_dict = dict()

def get_event_data(event_idx, batch_meta_df):
    # each meta row corresponds to an event
    batch_id, first_pulse_index, last_pulse_index, azimuth, zenith = batch_meta_df.iloc[event_idx][["batch_id", "first_pulse_index", "last_pulse_index", "azimuth", "zenith"]].astype("int")
    
    # close past batch df
    if batch_id - 1 in open_batch_dict.keys():
        del open_batch_dict[batch_id - 1]

    # open current batch df
    if batch_id not in open_batch_dict.keys():
        open_batch_dict.update({batch_id: pd.read_parquet(f"/kaggle/input/icecube-neutrinos-in-deep-ice/train/batch_{batch_id}.parquet")})
 

    train_batch = open_batch_dict[batch_id]
    event_data = train_batch.iloc[first_pulse_index:last_pulse_index + 1]
    
    event_data_with_pos = pd.merge(event_data, sensor_geometry, on='sensor_id', how='left')
    
    # Add rank and sort by rank
    # Find valid time window
    t_valid_length = 6200
    t_peak = event_data_with_pos["time"][event_data_with_pos["charge"].argmax()]
    t_valid_min = t_peak - t_valid_length
    t_valid_max = t_peak + t_valid_length
    t_valid = (event_data_with_pos["time"] > t_valid_min) * (event_data_with_pos["time"] < t_valid_max)

    # rank
    event_data_with_pos["rank"] = 2 * (1 - event_data_with_pos["auxiliary"]) + (t_valid)

    # sort by rank and charge (ascending order of importance)
    event_data_with_pos = event_data_with_pos.sort_values(['rank', 'charge'], ascending=[True, True])

    final_pulses = event_data_with_pos[['x', 'y', 'z', 'time', 'charge', 'auxiliary']]
    x = final_pulses.to_numpy().astype(np.float16)
    
    # get label
    y = [azimuth, zenith]
    
    return x, y

for batch_id in train_batches[:NUM_BATCHES_TO_CONVERT]:
    batch_meta_df = train_meta_baseline.loc[train_meta_baseline.batch_id == batch_id]
    X = []
    Y = []
    
    def read_event(event_idx):
        return get_event_data(event_idx, batch_meta_df)
    
    # Multiprocess Events
    iterator = range(len(batch_meta_df))
    with multiprocess.Pool() as pool:
        for event_x, event_y in pool.map(read_event, iterator):
            X.append(event_x)
            Y.append(event_y)
            
            
    save_to_tfrecord(X, Y, f"batch_{batch_id}.tfrecord")
    print(f"Batch {batch_id} done and saved")
    

    

Batch 1 done and saved
Batch 2 done and saved
Batch 3 done and saved
Batch 4 done and saved
Batch 5 done and saved
Batch 6 done and saved
Batch 7 done and saved
Batch 8 done and saved
Batch 9 done and saved
Batch 10 done and saved


# Dataset Loading

In [11]:
def cartesian_to_spherical(x, y, z):
    r = math.sqrt(x**2 + y**2 + z**2)
    zenith = math.acos(z / r)
    azimuth = math.atan2(y, x)
    return azimuth, zenith

def adjust_spherical(azimuth, zenith):
    if azimuth < 0:
        azimuth += math.pi * 2
    elif zenith < 0:
        zenith += math.pi
    return azimuth, zenith

def spherical_to_cartesian(azimuth, zenith):
    x = np.cos(azimuth) * np.sin(zenith)
    y = np.sin(azimuth) * np.sin(zenith)
    z = np.cos(zenith)
    return [x, y, z]

def spherical_to_cartesian_tf(azimuth, zenith):
    x = tf.math.cos(azimuth) * tf.math.sin(zenith)
    y = tf.math.sin(azimuth) * tf.math.sin(zenith)
    z = tf.math.cos(zenith)
    return [x, y, z]

In [15]:
max_pulse_count = 96
class TFDatasetManager:
    '''
        Create TFRecordDataset from filepaths
    '''
    def __init__(self, filepaths, shuffle=True, shuffle_buffer_size=1000, ragged=False, cartesian_labels=True, batch_size=128, sort_by_time=False):
        self.filepaths = filepaths
        self.shuffle_buffer_size = shuffle_buffer_size
        self.batch_size = batch_size
        self.ragged = ragged
        self.cartesian_labels = cartesian_labels
        self.num_features = 6
        self.sort_by_time = sort_by_time
        self.shuffle = shuffle

    # parse serialized Example protobuf
    def preprocess(self, tfrecord):
        # to parse we need the feature description of the protobuf
        feature_description = {
            'event_pulses': tf.io.RaggedFeature(value_key='event_pulses', dtype=tf.string),
            'azimuth': tf.io.FixedLenFeature([], tf.float32, default_value=0),
            'zenith': tf.io.FixedLenFeature([], tf.float32, default_value=0)
        }
        parsed_example = tf.io.parse_single_example(tfrecord, feature_description)
        event_pulses = tf.io.parse_tensor(tf.squeeze(parsed_example['event_pulses']), out_type=tf.float32)    

        if self.ragged == True:
            event_pulses = tf.ensure_shape(event_pulses, (None, self.num_features))
        else:
            # Truncate or pad to size
            end_size = max_pulse_count
            current_size = tf.shape(event_pulses)[0]
            if (current_size > end_size):
#                 event_pulses = event_pulses[:end_size]
                # these are ordered ascending by importance so we grab the last ones
                event_pulses = event_pulses[-end_size:]
                
            elif (current_size < end_size):
                diff = end_size - current_size
                zeros = tf.zeros((diff, self.num_features))
                event_pulses = tf.concat((event_pulses, zeros), axis=0)

            event_pulses = tf.reshape(event_pulses, (end_size, self.num_features))
            
        if self.sort_by_time == True:
            def sort_by_col(a, col):
                return tf.gather(a, tf.nn.top_k( -a[:, col], k=a.get_shape()[0]).indices)
            time_index = 3
            event_pulses = sort_by_col(event_pulses, time_index)

        azimuth = parsed_example['azimuth']
        zenith = parsed_example['zenith']

        if self.cartesian_labels == True:
            xyz = spherical_to_cartesian_tf(azimuth, zenith)
            return event_pulses, xyz
        elif self.cartesian_labels == False:
            return event_pulses, (azimuth, zenith)

    def create_dataset(self):
        # reading all filepaths in parallel
        dataset = tf.data.TFRecordDataset(self.filepaths, num_parallel_reads=len(self.filepaths))

        if self.shuffle == True:
            dataset = dataset.shuffle(self.shuffle_buffer_size)
        # parse serialized Dataset
        dataset = dataset.map(self.preprocess, num_parallel_calls=tf.data.AUTOTUNE)

        # batch size of -1 indicates that we want no batching. (In this case because we are using ragged tensors)
        if self.batch_size != -1:
            dataset = dataset.batch(self.batch_size, drop_remainder=True)
        # be 1 batch ahead
        return dataset.prefetch(tf.data.AUTOTUNE)
    

In [16]:
import glob
batch_size = 1024

filepaths = glob.glob("tfrecords/*")
neutrino_dataset = TFDatasetManager(filepaths, 
                                    shuffle=False,
                                    batch_size=batch_size, 
                                    sort_by_time=True, 
                                    cartesian_labels=True, 
                                    ragged=False
                                   )

full_set = neutrino_dataset.create_dataset()

In [None]:
full_set