In [None]:
!wget https://raw.githubusercontent.com/mwbaj/MachineLearning/ZPS_2023_winter/WAWTPC/io_functions.py
!pip install uproot

In [None]:
import numpy as np
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from google.colab import drive
import tensorflow as tf
from tensorflow.python.ops.numpy_ops import np_config
import tensorflow_datasets as tfds
import io_functions as io
from tensorflow.data import Dataset, TFRecordDataset
from tensorflow.io import TFRecordWriter, TFRecordOptions
from tensorflow.train import BytesList, FloatList, Int64List
from tensorflow.train import Example, Features, Feature
from multiprocessing import Process, Queue
from os.path import isfile


np_config.enable_numpy_behavior()
drive.mount('/content/drive')
dataPath = 'drive/MyDrive/ZPS/simulated_data/'

In [None]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize(charge_array, target):
  feature = {'myChargeArray' : _bytes_feature(tf.io.serialize_tensor(charge_array)),
             'target' : _bytes_feature(tf.io.serialize_tensor(target))}
  example = tf.train.Example(features=tf.train.Features(feature=feature))
  return example.SerializeToString()

def XYZtoUVWT(data):
    referencePoint = tf.constant([-138.9971, 98.25])
    phi = np.pi/6.0
    stripPitch = 1.5
    f = 1.0/25*6.46
    u = -(data[:, 1]-99.75)
    v = (data[:, 0]-referencePoint[0]) * np.cos(phi) - (data[:, 1]-referencePoint[1]) * np.sin(phi)
    w = (data[:, 0]-referencePoint[0]) * np.cos(-phi) - (data[:, 1]-referencePoint[1]) * np.sin(-phi) + 98.75
    t = data[:, 2]/f + 256
    u/=stripPitch
    v/=stripPitch
    w/=stripPitch
    return tf.stack([u,v,w,t], axis=0).T

def conversion(filename, queue):
    options = TFRecordOptions(compression_type='GZIP')
    writer = TFRecordWriter(filename, options=options)
    scale = 100
    n_projections = 3
    while True:
        item = queue.get()
        if item == None:
            break
        myChargeArray, target = item
        charge_array= io.proc_features(myChargeArray)
        charge_array = tf.transpose(charge_array, perm = [0, 3, 1, 2])
        uvwt_1 = XYZtoUVWT(scale*target[:, 0:3])
        uvwt_2 = XYZtoUVWT(scale*target[:, 3:6])
        uvwt_3 = XYZtoUVWT(scale*target[:, 6:9])

        points = []
        for i in range(n_projections):
          points.append([
              uvwt_1[:, 3], uvwt_1[:, i],
              uvwt_2[:, 3], uvwt_2[:, i],
              uvwt_3[:, 3], uvwt_3[:, i]
          ])
        points = np.stack(points, axis = 1).T
        for index in range(points.shape[0]):
          example = serialize(charge_array[index], points[index])
          writer.write(example)


In [None]:
def process_file(output_files, datasetGenerator):
  nFiles = len(output_files)
  for file in output_files:
      if isfile(file):
          raise Exception('output file already exists')
  if __name__ == '__main__':
      processes = []
      q = Queue(2*nFiles)
      for name in output_files:
          p = Process(target=conversion, args=(name, q))
          processes.append(p)
          p.start()
          print(p.name + ' started')

      counter = 0
      for item in datasetGenerator:
          q.put(item)
          counter+=1
          if counter%100 == 0:
              print(f'read {counter} batches')

      for _ in range(nFiles):
          q.put(None)

      for p in processes:
          p.join()
          print(p.name + ' done')

In [None]:
input_files = [dataPath+'out_random_sigma-001.root:TPCData']
batchSize = 5
nFiles = 5
output_files = [dataPath + 'test/' + f"out_random_sigma-001-part-{i}.tfrecord" for i in range(nFiles)]
datasetGenerator = io.minimal_generator(files=input_files, batchSize=batchSize)
process_file(output_files, datasetGenerator)

# Read the data

In [None]:
filenames = [dataPath + 'test/' + f"out_random_sigma2k2mm-part-{i}.tfrecord" for i in range(nFiles)]
train_dataset = tf.data.TFRecordDataset(filenames, compression_type='GZIP', num_parallel_reads=5)
# Create a description of the features.
feature_description = {
    'myChargeArray': tf.io.FixedLenFeature([], tf.string),
    'target': tf.io.FixedLenFeature([], tf.string),

}

def _parse_function(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)
    charge, target = parsed_features['myChargeArray'], parsed_features['target']
    # decode from bytes
    charge = tf.io.parse_tensor(charge, tf.float64)
    target = tf.io.parse_tensor(target, tf.float64)

    return charge, target


train_dataset = train_dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
for image, target in train_dataset.take(1):
  print(image.shape)
  points = target.reshape(3, 3, 2)
  plt.imshow(image[2, :, :])
  plt.scatter(points[2, :, 0], points[2, :, 1])