In [1]:
import numpy as np

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader

from glob import glob
import re

from transformercvn.options import Options
from transformercvn.network.trainers.neutrino_full_dense_trainer import NeutrinoFullDenseTrainer, sparse_to_dense



In [5]:
CUDA = False
CUDA_DEVICE = 0

CHECKPOINT_PATH = "./tutorial_dense/version_0/checkpoints/epoch=0-step=28500.ckpt" # Model to export
NETWORK = NeutrinoFullDenseTrainer
TESTING_SOURCE = "training"

OUTPUT_PREFIX = "tutorial_dense" # Name for model

BASE_DIRECTORY = "./tutorial_dense/version_0/"

In [6]:
# Load checkpoint and add the test file location
options = Options.load(f"{BASE_DIRECTORY}/options.json")
options.testing_file = options.training_file.replace("training", TESTING_SOURCE)
options.num_dataloader_workers = 0
options.transformer_norm_first = bool(options.transformer_norm_first)

if CHECKPOINT_PATH is None:
    checkpoints = glob(f"{BASE_DIRECTORY}/checkpoints/epoch*.ckpt")
    last_checkpoint = np.argmax([int(re.search("step=(.*).ckpt", s)[1]) for s in checkpoints])
    checkpoint_path = checkpoints[last_checkpoint]
else:
    checkpoint_path = CHECKPOINT_PATH
    
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict = checkpoint["state_dict"]
print(f"Loading from: {checkpoint_path}")

network = NETWORK(options)
network.load_state_dict(state_dict)

network = network.eval()
for parameter in network.parameters():
    parameter.requires_grad_(False)
    
if CUDA:
    network = network.cuda(CUDA_DEVICE)

Loading from: ./tutorial_dense/version_0/checkpoints/epoch=0-step=28500.ckpt


In [7]:
(
    features,
    extra,
    event_coordinates,
    event_pixel_values,
    event_masks,
    prong_coordinates,
    prong_pixel_values,
    prong_masks,
    event_targets,
    prong_targets
) = next(iter(DataLoader(network.testing_dataset, batch_size=1, collate_fn=network.dataloader_options["collate_fn"])))

max_prongs_in_batch = prong_masks.sum(1).max()
features = features[:, :max_prongs_in_batch].contiguous()
prong_masks = prong_masks[:, :max_prongs_in_batch].contiguous()
prong_targets = prong_targets[:, :max_prongs_in_batch].contiguous()

In [8]:
event_pixels = 255 * network.preprocess_pixels(event_coordinates, event_pixel_values, network.training_dataset.pixel_shape).to_dense()
prong_pixels = 255 * network.preprocess_pixels(prong_coordinates, prong_pixel_values, network.training_dataset.pixel_shape).to_dense()
pixels = torch.cat((event_pixels, prong_pixels), dim=0)

In [9]:
event_preds, prong_preds = network(
    features,
    extra,
    event_coordinates,
    event_pixel_values,
    event_masks,
    prong_coordinates,
    prong_pixel_values,
    prong_masks
)

[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.


In [10]:
class DynamicSimplifedNetwork(nn.Module):
    __constants__ = ["pixel_features", "pixel_width", "pixel_height", "num_features", "num_extra"]
    def __init__(self, network):
        super().__init__()
        
        self.network = network.network
        
        self.num_features = network.training_dataset.num_features
        self.num_extra = network.training_dataset.extra.shape[1]
        
        self.pixel_features = network.training_dataset.pixel_features
        self.pixel_width = network.training_dataset.pixel_shape[0]
        self.pixel_height = network.training_dataset.pixel_shape[1]
        
        self.extra_mean = network.extra_mean
        self.extra_std = network.extra_std
        
        self.mean = network.mean
        self.std = network.std
        
        self.log_pixels = network.options.log_pixels
        
    def forward(self, pixels):
        if self.log_pixels:
            pixels = torch.log(pixels.float() + 1)
        else:
            pixels = pixels.float() / 255
            
        pixels = pixels.reshape(
            -1, 
            self.pixel_features,
            self.pixel_width,
            self.pixel_height
        )
                
        num_images = pixels.shape[0]
        
        # Create Artificial Data
        mask = torch.ones(num_images, device=pixels.device, dtype=torch.bool)
        features = torch.zeros(1, num_images - 1, self.num_features, device=pixels.device, dtype=pixels.dtype)
        extra = torch.zeros(1, self.num_extra, device=pixels.device, dtype=pixels.dtype)
        
        event_pixels, prong_pixels = pixels[:1], pixels[1:]
        event_mask, prong_mask = mask[:1], mask[1:]
        
        event, prongs = self.network(
            features, 
            extra, 
            event_pixels, 
            event_mask.unsqueeze(0), 
            prong_pixels, 
            prong_mask.unsqueeze(0)
        )
        
        event = torch.softmax(event[0], 0)
        prongs = torch.softmax(prongs[0], 1)
        
        if event.shape[-1] > 4:
            event = torch.stack((
                event[:4].sum(), 
                event[4:8].sum(), 
                event[8], 
                event[9], 
            ), dim=0)
        
        return event, prongs

In [11]:
class DynamicEmbeddingNetwork(nn.Module):
    __constants__ = ["pixel_features", "pixel_width", "pixel_height", "num_features", "num_extra"]
    def __init__(self, network):
        super().__init__()
        
        self.network = network.network
        
        self.num_features = network.training_dataset.num_features
        self.num_extra = network.training_dataset.extra.shape[1]
        
        self.pixel_features = network.training_dataset.pixel_features
        self.pixel_width = network.training_dataset.pixel_shape[0]
        self.pixel_height = network.training_dataset.pixel_shape[1]
        
        self.extra_mean = network.extra_mean
        self.extra_std = network.extra_std
        
        self.mean = network.mean
        self.std = network.std
        
        self.log_pixels = network.options.log_pixels
    
    def forward(self, pixels):
        if self.log_pixels:
            pixels = torch.log(pixels.float() + 1)
        else:
            pixels = pixels.float() / 255
            
        pixels = pixels.reshape(
            -1, 
            self.pixel_features,
            self.pixel_width,
            self.pixel_height
        )
                
        num_images = pixels.shape[0]
        
        # Create Artificial Data
        mask = torch.ones(num_images, device=pixels.device, dtype=torch.bool)
        features = torch.zeros(1, num_images - 1, self.num_features, device=pixels.device, dtype=pixels.dtype)
        extra = torch.zeros(1, self.num_extra, device=pixels.device, dtype=pixels.dtype)
        
        event_pixels, prong_pixels = pixels[:1], pixels[1:]
        event_mask, prong_mask = mask[:1], mask[1:]
        
        
        combined_embeddings, combined_mask = self.network.prong_embedding(
            features, 
            extra, 
            event_pixels, 
            event_mask.unsqueeze(0), 
            prong_pixels, 
            prong_mask.unsqueeze(0)
        )
        
        combined_embeddings, _, _ = self.network.encoder(combined_embeddings, combined_mask)

        event_features, prong_features = combined_embeddings[0, 0], combined_embeddings[1:, 0]        
        return event_features, prong_features

In [12]:
class DynamicCombinedNetwork(nn.Module):
    __constants__ = ["pixel_features", "pixel_width", "pixel_height", "num_features", "num_extra"]
    def __init__(self, network):
        super().__init__()
        
        self.network = network.network
        
        self.num_features = network.training_dataset.num_features
        self.num_extra = network.training_dataset.extra.shape[1]
        
        self.pixel_features = network.training_dataset.pixel_features
        self.pixel_width = network.training_dataset.pixel_shape[0]
        self.pixel_height = network.training_dataset.pixel_shape[1]
        
        self.extra_mean = network.extra_mean
        self.extra_std = network.extra_std
        
        self.mean = network.mean
        self.std = network.std
        
        self.log_pixels = network.options.log_pixels
    
    def forward(self, pixels):
        if self.log_pixels:
            pixels = torch.log(pixels.float() + 1)
        else:
            pixels = pixels.float() / 255
            
        pixels = pixels.reshape(
            -1, 
            self.pixel_features,
            self.pixel_width,
            self.pixel_height
        )
                
        num_images = pixels.shape[0]
        
        # Create Artificial Data
        mask = torch.ones(num_images, device=pixels.device, dtype=torch.bool)
        features = torch.zeros(1, num_images - 1, self.num_features, device=pixels.device, dtype=pixels.dtype)
        extra = torch.zeros(1, self.num_extra, device=pixels.device, dtype=pixels.dtype)
        
        event_pixels, prong_pixels = pixels[:1], pixels[1:]
        event_mask, prong_mask = mask[:1], mask[1:]
        
        
        combined_embeddings, combined_mask = self.network.prong_embedding(
            features, 
            extra, 
            event_pixels, 
            event_mask.unsqueeze(0), 
            prong_pixels, 
            prong_mask.unsqueeze(0)
        )
        
        hidden_features, padding_mask, sequence_mask = self.network.encoder(combined_embeddings, combined_mask)
        
        event_features, prong_features = hidden_features[0], hidden_features[1:]
        
        event = self.network.event_decoder(event_features)
        prongs = self.network.prong_decoder(prong_features).transpose(0, 1)
        
        event_features, prong_features = event_features[0], prong_features[:, 0]
        event = torch.softmax(event[0], 0)
        prongs = torch.softmax(prongs[0], 1)
        
        if event.shape[-1] > 4:
            event = torch.stack((
                event[:4].sum(), 
                event[4:8].sum(), 
                event[8], 
                event[9], 
            ), dim=0)
            
        return event, prongs, event_features, prong_features

In [13]:
dynamic_simplified = torch.jit.script(DynamicSimplifedNetwork(network))
dynamic_embeddings = torch.jit.script(DynamicEmbeddingNetwork(network))
dynamic_combined = torch.jit.script(DynamicCombinedNetwork(network))



In [14]:
# Check to make sure the traced models work

In [15]:
# This model will output two tensors:
#  1. The event classification probabilities
#  2. The prong classification probabilities for each input prong image

outs = dynamic_simplified(pixels)
print(pixels.shape, '->', outs[0].shape, ',', outs[1].shape)

torch.Size([7, 3, 400, 280]) -> torch.Size([4]) , torch.Size([6, 8])


In [16]:
# This model will output the intermediate feature representation hidden vectors 
# that serve as input to the final classification layers. Each are length 128.
# This model outputs two tensors:
#  1. The vector representing the event image
#  2. The vectors representing each input prong image
outs = dynamic_embeddings(pixels)
print(pixels.shape, '->', outs[0].shape, ',', outs[1].shape)

torch.Size([7, 3, 400, 280]) -> torch.Size([128]) , torch.Size([6, 128])


In [17]:
# This model outputs four tensors (all of the above):
#  1. The event classification probabilities
#  2. The prong classification probabilities for each input prong image
#  3. The vector representing the event image
#  4. The vectors representing each input prong image
outs = dynamic_combined(pixels)
print(pixels.shape, '->', ' , '.join([str(out.shape) for out in outs]))

torch.Size([7, 3, 400, 280]) -> torch.Size([4]) , torch.Size([6, 8]) , torch.Size([128]) , torch.Size([6, 128])


In [18]:
dynamic_simplified.save(f"{BASE_DIRECTORY}/{OUTPUT_PREFIX}_{checkpoint_path.split('/')[-1].split('.')[0]}_pid.torchscript")
dynamic_embeddings.save(f"{BASE_DIRECTORY}/{OUTPUT_PREFIX}_{checkpoint_path.split('/')[-1].split('.')[0]}_embeddings.torchscript")
dynamic_combined.save(f"{BASE_DIRECTORY}/{OUTPUT_PREFIX}_{checkpoint_path.split('/')[-1].split('.')[0]}_combined.torchscript")