In [None]:
# Run the notebook as if it's in the PROJECT directory
%bookmark PROJ_ROOT /reg/data/ana03/scratch/cwang31/spi
%cd -b PROJ_ROOT

In [None]:
!pwd

In [None]:
# Load paths for using psana
%env SIT_ROOT=/reg/g/psdm/
%env SIT_DATA=/cds/group/psdm/data/
%env SIT_PSDM_DATA=/cds/data/psdm/

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import os
import logging
import torch
import socket
import pickle
import tqdm
import random

from speckleNN.datasets.lite    import SPIDataset         , SPIOnlineDataset, MultiwayQueryset
from speckleNN.model            import SiameseModelCompare, ConfigSiameseModel
from speckleNN.validator        import MultiwayQueryValidator, ConfigValidator
from speckleNN.encoders.convnet import Hirotaka0122       , ConfigEncoder
from speckleNN.utils            import EpochManager       , MetaLog, init_logger, split_dataset, set_seed, ConfusionMatrix
from datetime import datetime
from image_preprocess import DatasetPreprocess
# from image_preprocess_faulty import DatasetPreprocess
# from image_preprocess_half import DatasetPreprocess
# from image_preprocess_one_four import DatasetPreprocess
# from image_preprocess_three_four import DatasetPreprocess


# [[[ SEED ]]]
seed = 0
set_seed(seed)

# [[[ CONFIG ]]]
# timestamp = "2022_1203_1807_16"
timestamp = "2022_1130_2316_55"
frac_train = 0.5
frac_validate = 0.5
num_max_support = 5

lr = 1e-3

size_sample_test = 1000
size_sample_per_class = None
size_batch = 100
online_shuffle = True
trans = None


# Configure the location to run the job...## 
drc_cwd = os.getcwd()

init_logger(log_name = 'validate.query.test', timestamp = timestamp, returns_timestamp = False)


# Clarify the purpose of this experiment...
hostname = socket.gethostname()
comments = f"""
            Hostname: {hostname}.

            Online training.

            Sample size (test)     : {size_sample_test}
            Sample size (per class) : {size_sample_per_class}
            Batch  size             : {size_batch}
            Online shuffle          : {online_shuffle}
            lr                      : {lr}

            """


# [[[ DATASET ]]]
# Set up parameters for an experiment...
drc_dataset   = 'fastdata'
fl_dataset    = '0000.fastdata'    # Raw, just give it a try
path_dataset  = os.path.join(drc_dataset, fl_dataset)

# Load raw data...
with open(path_dataset, 'rb') as fh:
    dataset_list = pickle.load(fh)

# Split data...
data_train   , data_val_and_test = split_dataset(dataset_list     , frac_train   , seed = None)
data_validate, data_test         = split_dataset(data_val_and_test, frac_validate, seed = None)

In [None]:
# Define the test set
dataset_query = SPIOnlineDataset( dataset_list = data_test, 
                                  size_sample  = size_sample_test,
                                  size_sample_per_class = size_sample_per_class, 
                                  trans = trans, 
                                  seed  = None, )

In [None]:
# [[[ Form a support set ]]]
# Fetch all hit labels
hit_list = list(set( [ hit for _, hit, _ in data_train ] ))

# Form support set...
support_hit_to_idx_dict = { hit : [] for hit in hit_list }
for enum_data, (img, hit, metadata) in enumerate(data_train):
    support_hit_to_idx_dict[hit].append(enum_data)

for hit, idx_in_data_support in support_hit_to_idx_dict.items():
    if len(support_hit_to_idx_dict[hit]) > num_max_support:
        support_hit_to_idx_dict[hit] = random.sample(support_hit_to_idx_dict[hit], k = num_max_support)

In [None]:
# Preprocess dataset...
# Data preprocessing can be lengthy and defined in dataset_preprocess.py
img_orig            = dataset_query[0][0][0]   # idx, fetch img
dataset_preproc     = DatasetPreprocess(img_orig)
trans               = dataset_preproc.config_trans()
dataset_query.trans  = trans
img_trans           = dataset_query[0][0][0]

In [None]:
# [[[ IMAGE ENCODER ]]]
# Config the encoder...
dim_emb        = 128
size_y, size_x = img_trans.shape[-2:]
config_encoder = ConfigEncoder( dim_emb = dim_emb,
                                size_y  = size_y,
                                size_x  = size_x,
                                isbias  = True )
encoder = Hirotaka0122(config_encoder)

# Set up the model
config_siamese = ConfigSiameseModel( encoder = encoder, )
model = SiameseModelCompare(config_siamese)
model.init_params(from_timestamp = timestamp)


# Set up the right device for the computation...
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
model.to(device = device)

In [None]:
# [[[ EMBEDDING (SUPPORT) ]]]
support_batch_emb_dict = { hit : None for hit in hit_list }
for hit in hit_list:
    support_idx_list    = support_hit_to_idx_dict[hit]
    num_example_support = len(support_idx_list)
    for enum_support_idx, support_idx in enumerate(support_idx_list):
        # Fetch data from support...
        img = data_train[support_idx][0]
        img = trans(img)
        img = img[None,]

        # Normalize the image...
        img = (img - img.mean()) / img.std()

        # Preallocate tensor...
        if enum_support_idx == 0:
            size_c, size_y, size_x = img.shape
            batch_img = torch.zeros((num_example_support, size_c, size_y, size_x))

        # Save image as tensor...
        batch_img[enum_support_idx] = torch.tensor(img)

    with torch.no_grad():
        batch_img = batch_img.to(device = device)
        support_batch_emb_dict[hit] = model.encoder.encode(batch_img)

# [[[ EMBEDDING (QUERY) ]]]
num_test       = len(dataset_query)
query_idx_list = range(num_test)
for enum_query_idx, i in enumerate(query_idx_list):
    # Fetch data from query
    img = dataset_query[i][0]

    # Preallocate tensor...
    if enum_query_idx == 0:
        size_c, size_y, size_x = img.shape
        batch_img = torch.zeros((num_test, size_c, size_y, size_x))

    # Save image as tensor...
    batch_img[enum_query_idx] = torch.tensor(img)

with torch.no_grad():
    batch_img = batch_img.to(device = device)
    query_batch_emb = model.encoder.encode(batch_img)


# [[[ METRIC ]]] 
diff_query_support_dict = {}
for hit in hit_list:
    # Q: number of query examples.
    # S: number of support examples
    # E: dimension of an embedding
    # diff_query_support_dict[hit]: Q x S x E 
    # query_batch_emb[:, None]    : Q x 1 x E 
    # support_batch_emb_dict[hit] :     S x E 
    diff_query_support_dict[hit] = query_batch_emb[:,None] - support_batch_emb_dict[hit]

hit_to_dist_dict = {}
for hit in hit_list:
    # Q: number of query examples.
    # S: number of support examples
    # hit_to_dist_dict[hit]: Q x S 
    hit_to_dist_dict[hit] = torch.sum(diff_query_support_dict[hit] * diff_query_support_dict[hit], dim = -1) 

# Use enumeration as an intermediate to obtain predicted hits...
enum_to_hit_dict = {}

# Encode hit type with enum
# enum 0 : hit 1
# enum 1 : hit 2
for enum_hit, hit in enumerate(hit_list):
    enum_to_hit_dict[enum_hit] = hit 

    # Fetch the values and indices of the closet support for this hit type for the query...
    mean_support_val = hit_to_dist_dict[hit].mean(dim = -1) 
    if enum_hit == 0:
        # H: number of hit types (single vs multi)
        # N: number of examples
        # mean_support_tensor: H x N 
        mean_support_tensor = torch.zeros((len(hit_list), *mean_support_val.shape))

    mean_support_tensor[enum_hit] = mean_support_val

# Obtain the predicted hit...
# Obtain the min among examples across all hit type (dim = 0)
# [1] is to pick the indices from the result of a torch.min
pred_hit_as_enum_list = mean_support_tensor.min(dim = 0)[1]

# Obtain the predicted hit for each input example
pred_hit_list = [ enum_to_hit_dict[enum.item()] for enum in pred_hit_as_enum_list ]

# Obtain the real hit...
real_hit_list = [ dataset_query[idx][1] for idx in query_idx_list ]

# New container to store validation result (thus res_dict) for each label...
res_dict = {}
for hit in hit_list: res_dict[hit] = { i : [] for i in hit_list }

for pred_hit, real_hit in zip(pred_hit_list, real_hit_list):
    res_dict[pred_hit][real_hit].append( None )


In [None]:
# Get macro metrics...
cm = ConfusionMatrix(res_dict)

# Formating purpose...
disp_dict = { 0 : "not-sample",
              1 : "single-hit",
              2 : " multi-hit",
              9 : "background",
            }

# Report multiway classification...
msgs = []
for label_pred in sorted(hit_list):
    disp_text = disp_dict[label_pred]
    msg = f"{disp_text}  |"
    for label_real in sorted(hit_list):
        num = len(res_dict[label_pred][label_real])
        msg += f"{num:>12d}"

    metrics = cm.get_metrics(label_pred)
    for metric in metrics:
        msg += f"{metric:>12.2f}" if metric is not None else "{x:>14s}".format(x = "None")
    msgs.append(msg)

msg_header = " " * (msgs[0].find("|") + 1)
for label in sorted(hit_list): 
    disp_text = disp_dict[label]
    msg_header += f"{disp_text:>12s}"

for header in [ "accuracy", "precision", "recall", "specificity", "f1" ]:
    msg_header += f"{header:>12s}"
print(msg_header)

msg_headerbar = "-" * len(msgs[0])
print(msg_headerbar)
for msg in msgs:
    print(msg)

#### Equivalent confusion matrix

In [None]:
hit_mod_list = (0, 1)
res_mod_dict = {}
for hit in {0, 1}: res_mod_dict[hit] = { i : [] for i in hit_mod_list }

In [None]:
for pred_hit, real_hit in zip(pred_hit_list, real_hit_list):
    pred_hit = 0 if pred_hit == 2 else pred_hit
    real_hit = 0 if real_hit == 2 else real_hit
    res_mod_dict[pred_hit][real_hit].append( None )

In [None]:
# Get macro metrics...
cm = ConfusionMatrix(res_mod_dict)

# Formating purpose...
disp_dict = { 0 : "not-sample",
              1 : "single-hit",
              2 : " multi-hit",
              9 : "background",
            }

# Report multiway classification...
msgs = []
for label_pred in sorted(hit_mod_list):
    disp_text = disp_dict[label_pred]
    msg = f"{disp_text}  |"
    for label_real in sorted(hit_mod_list):
        num = len(res_mod_dict[label_pred][label_real])
        msg += f"{num:>12d}"

    metrics = cm.get_metrics(label_pred)
    for metric in metrics:
        msg += f"{metric:>12.2f}" if metric is not None else "{x:>14s}".format(x = "None")
    msgs.append(msg)

msg_header = " " * (msgs[0].find("|") + 1)
for label in sorted(hit_mod_list): 
    disp_text = disp_dict[label]
    msg_header += f"{disp_text:>12s}"

for header in [ "accuracy", "precision", "recall", "specificity", "f1" ]:
    msg_header += f"{header:>12s}"
print(msg_header)

msg_headerbar = "-" * len(msgs[0])
print(msg_headerbar)
for msg in msgs:
    print(msg)

#### Visual

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
for _ in range(3):
    plt.figure(figsize = (5, 3))
    #data = img_orig
    data = dataset_query[0][0][0]
    vmin = data.mean()
    vmax = vmin + 1 * data.std()
    plt.imshow(data, vmin = vmin, vmax = vmax)
    plt.colorbar()