In [1]:
%load_ext autoreload
%matplotlib inline
%autoreload 2

In [2]:
import sys
import os
import h5py
import glob
import copy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

In [3]:
if '../WatChMaL' not in sys.path:
    sys.path.append('../WatChMaL')

from watchmal.dataset.cnn_mpmt.cnn_mpmt_dataset import CNNmPMTDataset
#from analysis.gan_plot_utils import disp_gan_learn_hist

In [4]:
def channel_to_position(channel):
    channel = channel % 19 
    theta = (channel<12)*2*np.pi*channel/12 + ((channel >= 12) & (channel<18))*2*np.pi*(channel-12)/6
    radius = 0.2*(channel<18)+0.2*(channel<12)
    position = [radius*np.cos(theta), radius*np.sin(theta)] # note this is [y, x] or [row, column]
    return position

## Testing with short tank data

In [None]:
dataset_short = CNNmPMTDataset(h5file='/fast_scratch/WatChMaL/data/IWCD_mPMT_Short_emg_E0to1000MeV_digihits.h5',
                         mpmt_positions_file='/data/WatChMaL/data/IWCDshort_mPMT_image_positions.npz',
                         is_distributed=False)

In [None]:
data_short = dataset_short.__getitem__(8)['data']

In [None]:
print(data_short.shape)

In [None]:
def plot_event(data, mpmt_pos, old_convention=False, **plot_args):
    #cmap.set_under(color='black')

    fig = plt.figure(figsize=(20,12))
    ax = fig.add_subplot(111)
    
    #ax = plt.gca()
    
    mpmts = ax.scatter(mpmt_pos[:, 1], mpmt_pos[:, 0], s=380, facecolors='none', edgecolors='0.2')
    indices = np.indices(data.shape)
    channels = indices[0].flatten()
    positions = indices[1:].reshape(2,-1).astype(np.float64)
    positions += channel_to_position(channels)
    
    if old_convention:
        positions[1] = max(mpmt_pos[:, 1])-positions[1]
    
    cmap = copy.copy(plt.cm.jet)
    ax.set_facecolor('black')
    cmap.set_under(color='black')
    
    pmts = ax.scatter(positions[1], positions[0], c=data.flatten(), cmap=cmap, vmin=0.0000001, s=3, **plot_args)
    plt.colorbar(pmts)

In [None]:
# Import test events from h5 file
data_path = "/fast_scratch/WatChMaL/data/IWCD_mPMT_Short_emg_E0to1000MeV_digihits.h5"
data_file = h5py.File(data_path, "r")

print(data_file.keys())

energies   = np.array(data_file['energies'])
angles     = np.array(data_file['angles'])
positions  = np.array(data_file['positions'])
labels     = np.array(data_file['labels'])

In [None]:
print(angles[8])

In [None]:
plot_event(data_short, dataset_short.mpmt_positions)#, cmap=plt.cm.gist_heat_r)

In [None]:
## Long Tank Plotting

In [None]:
dataset_long = CNNmPMTDataset(h5file='/fast_scratch/WatChMaL/data/IWCDmPMT_4pi_full_tank_pointnet.h5',
                         mpmt_positions_file='/data/WatChMaL/data/IWCD_mPMT_image_positions.npz',
                         is_distributed=False)

In [None]:
data_long = dataset_long.__getitem__(8)['data']

In [None]:
print(data_long.shape)

In [None]:
plot_event(data_long, dataset_long.mpmt_positions, cmap=plt.cm.gist_heat_r)

In [None]:
## GAN data loading

In [None]:
path = '/home/jtindall/WatChMaL/outputs/2021-02-28/gan/outputs'

In [None]:
fig = disp_gan_learn_hist(path)

In [None]:
image_batches = [np.load(fname,allow_pickle=True)['gen_imgs'] for fname in glob.glob(os.path.join(path,'imgs/*'))]
print(len(image_batches))

test_batch = image_batches[-1]
print(test_batch.shape)

test_image = test_batch[5]
print(test_image.shape)

In [None]:
## GAN event plotting

In [None]:
print(len(image_batches))
print(len(image_batches[0]))

In [None]:
#idx = 10
batch_idx = 5
for idx in range(len(test_batch) - 1):
    batch = image_batches[batch_idx]
    image_data = batch[idx]
    
    norm = mcolors.DivergingNorm(vmin=image_data.min(), vmax = image_data.max(), vcenter=0)
    plot_event(image_data, dataset_long.mpmt_positions, cmap=plt.cm.BrBG, norm=norm)

In [None]:
## GAN data loading

In [None]:
path = '/home/jtindall/WatChMaL/outputs/2021-03-01/gan_wasserstein/outputs'

In [None]:
fig = disp_gan_learn_hist(path)

In [None]:
image_batches = [np.load(fname,allow_pickle=True)['gen_imgs'] for fname in glob.glob(os.path.join(path,'imgs/*'))]
print(len(image_batches))

test_batch = image_batches[-1]
print(test_batch.shape)

test_image = test_batch[5]
print(test_image.shape)

In [None]:
## GAN event plotting

In [None]:
print(len(image_batches))
print(len(image_batches[0]))

In [None]:
#idx = 10
batch_idx = 5
for idx in range(len(test_batch) - 1):
    batch = image_batches[batch_idx]
    image_data = batch[idx]
    
    norm = mcolors.DivergingNorm(vmin=image_data.min(), vmax = image_data.max(), vcenter=0)
    plot_event(image_data, dataset_long.mpmt_positions, cmap=plt.cm.BrBG, norm=norm)