In [1]:
# Import Dependencies

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#%matplotlib notebook

import sys
sys.path.append("../new_notebooks/ipynb/dlp_opendata_api")
from osf.image_api import image_reader_3d
from osf.particle_api import *
from osf.cluster_api import *

from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sparseconvnet as scn
import glob
import os.path as osp
import numpy as np

import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

Welcome to JupyROOT 6.14/04


In [2]:
use_cuda = torch.cuda.is_available()

In [3]:
use_cuda

True

In [4]:
class ClusteringAEData(Dataset):
    """
    A customized data loader for clustering.
    """
    def __init__(self, root, numPixels=192, filenames=None):
        """
        Initialize Clustering Dataset

        Inputs:
            - root: root directory of dataset
            - preload: if preload dataset into memory.
        """
        self.cluster_filenames = []
        self.energy_filenames = []
        self.root = root
        self.numPixels = str(numPixels)
        
        if filenames:
            self.energy_filenames = filenames[0]
            self.cluster_filenames = filenames[1]
            print(self.energy_filenames)

        self.energy_filenames.sort()
        self.cluster_filenames.sort()
        self.cluster_reader = cluster_reader(*self.cluster_filenames)
        self.energy_reader = image_reader_3d(*self.energy_filenames)
        self.len = self.energy_reader.entry_count()
        assert self.len == self.cluster_reader.entry_count()

    def __getitem__(self, index):
        """
        Get a sample from dataset.
        """
        voxel, label = self.cluster_reader.get_image(index)
        _, energy, _ = self.energy_reader.get_image(index)
        voxel, label = torch.from_numpy(voxel), torch.from_numpy(label)
        energy = torch.from_numpy(energy)
        energy = torch.unsqueeze(energy, dim=1)
        label = torch.unsqueeze(label, dim=1).type(torch.LongTensor)
        return (voxel, energy), label

    def __len__(self):
        """
        Total number of sampels in dataset.
        """
        return self.len

In [5]:
def ae_collate(batch):
    """
    Custom collate_fn for Autoencoder.
    """
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    return [data, target]

In [6]:
root = '/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10' #replace with your own path to root folder. 
trainset_cluster = [root + '/cluster/dlprod_cluster_192px_0{}.root'.format(i) for i in range(8)]
devset_cluster = [root + '/cluster/dlprod_cluster_192px_0{}.root'.format(8)]
#testset_cluster = [root + '/cluster/dlprod_cluster_192px_0{}.root'.format(9)]

trainset_energy = [root + '/dlprod_192px_0{}.root'.format(i) for i in range(8)]
devset_energy = [root + '/dlprod_192px_0{}.root'.format(8)]
#testset_energy = [root + '/dlprod_192px_0{}.root'.format(9)]

for i, f in enumerate(trainset_cluster):
    print(f)
    print(trainset_energy[i])
    
for i, f in enumerate(devset_cluster):
    print(f)
    print(devset_energy[i])
    
#for i, f in enumerate(testset_cluster):
#    print(f)
#    print(testset_energy[i])

trainset = ClusteringAEData(root, 192, filenames=[trainset_energy, trainset_cluster])
devset = ClusteringAEData(root, 192, filenames=[devset_energy, devset_cluster])
#testset = ClusteringAEData(root, 192, filenames=[testset_energy, testset_cluster])
print('Number of entries in training set: {}'.format(len(trainset)))
print('Number of entries in validation set: {}'.format(len(devset)))
#print('Number of entries in test set: {}'.format(len(testset)))

/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/cluster/dlprod_cluster_192px_00.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/dlprod_192px_00.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/cluster/dlprod_cluster_192px_01.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/dlprod_192px_01.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/cluster/dlprod_cluster_192px_02.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/dlprod_192px_02.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/cluster/dlprod_cluster_192px_03.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/dlprod_192px_03.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/cluster/dlprod_cluster_192px_04.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/dlprod_192px_04.root
/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/cluster/dlprod_cluster_192px_05.root
/gpfs/slac/staas/fs1/g/n

In [7]:
trainloader = DataLoader(trainset, batch_size=16, shuffle=True, collate_fn=ae_collate, num_workers=1, pin_memory=False)
devloader = DataLoader(devset, batch_size=16, shuffle=True, collate_fn=ae_collate, num_workers=1, pin_memory=False)

In [8]:
class UResNet(torch.nn.Module):
    def __init__(self, dim=3, size=192, nFeatures=16, depth=5, nClasses=5):
        import sparseconvnet as scn
        super(UResNet, self).__init__()
        #self._flags = flags
        dimension = dim
        reps = 2  # Conv block repetition factor
        kernel_size = 2  # Use input_spatial_size method for other values?
        m = nFeatures  # Unet number of features
        nPlanes = [i*m for i in range(1, depth+1)]  # UNet number of features per level
        # nPlanes = [(2**i) * m for i in range(1, num_strides+1)]  # UNet number of features per level
        nInputFeatures = 1
        self.sparseModel = scn.Sequential().add(
           scn.InputLayer(dimension, size, mode=3)).add(
           scn.SubmanifoldConvolution(dimension, nInputFeatures, m, 3, False)).add( # Kernel size 3, no bias
           scn.UNet(dimension, reps, nPlanes, residual_blocks=True, downsample=[kernel_size, 2])).add(  # downsample = [filter size, filter stride]
           scn.BatchNormReLU(m)).add(
           scn.OutputLayer(dimension))
        self.linear = torch.nn.Linear(m, nClasses)

    def forward(self, point_cloud):
        """
        point_cloud is a list of length minibatch size (assumes mbs = 1)
        point_cloud[0] has 3 spatial coordinates + 1 batch coordinate + 1 feature
        shape of point_cloud[0] = (N, 4)
        """
        coords = point_cloud[:, 0:-1].float()
        features = point_cloud[:, -1][:, None].float()
        x = self.sparseModel((coords, features))
        x = self.linear(x)
        return [x]

In [9]:
def get_unet(fname, dimension=3, size=192, nFeatures=16, depth=5, nClasses=5):
    model = UResNet(dim=dimension, size=size, nFeatures=nFeatures, depth=depth, nClasses=nClasses)
    model = nn.DataParallel(model)
    #print(model.state_dict().keys())
    checkpoint = torch.load(fname, map_location='cpu')
    #print()
    #print(checkpoint['state_dict'].keys())
    model.load_state_dict(checkpoint['state_dict'], strict=True)
    # just return the pre-trained unet
    return model.module.sparseModel

In [10]:
fname = '/gpfs/slac/staas/fs1/g/neutrino/.scn_paper/new/sparse_is192_uns5_uf16_bs64/weights3/snapshot-29999.ckpt'
unet = get_unet(fname)
unet.eval()

Sequential(
  (0): InputLayer()
  (1): SubmanifoldConvolution 1->16 C3
  (2): Sequential(
    (0): ConcatTable(
      (0): Identity()
      (1): Sequential(
        (0): BatchNormReLU(16,eps=0.0001,momentum=0.9,affine=True)
        (1): SubmanifoldConvolution 16->16 C3
        (2): BatchNormReLU(16,eps=0.0001,momentum=0.9,affine=True)
        (3): SubmanifoldConvolution 16->16 C3
      )
    )
    (1): AddTable()
    (2): ConcatTable(
      (0): Identity()
      (1): Sequential(
        (0): BatchNormReLU(16,eps=0.0001,momentum=0.9,affine=True)
        (1): SubmanifoldConvolution 16->16 C3
        (2): BatchNormReLU(16,eps=0.0001,momentum=0.9,affine=True)
        (3): SubmanifoldConvolution 16->16 C3
      )
    )
    (3): AddTable()
    (4): ConcatTable(
      (0): Identity()
      (1): Sequential(
        (0): BatchNormReLU(16,eps=0.0001,momentum=0.9,affine=True)
        (1): Convolution 16->32 C2/2
        (2): Sequential(
          (0): ConcatTable(
            (0): Identity()
    

### See Data

In [11]:
trainiter = iter(trainloader)
data, labels = trainiter.next()

In [12]:
print("Batch Size = {}".format(len(data)))
for i in range(min(len(data), 5)):
    print("labels.shape = {}".format(labels[i].shape))
    print("coords.shape = {}".format(data[i][0].shape))
    print("energy.shape = {}".format(data[i][1].shape))
    print("-"*20)

Batch Size = 16
labels.shape = torch.Size([4860, 1])
coords.shape = torch.Size([4860, 3])
energy.shape = torch.Size([4860, 1])
--------------------
labels.shape = torch.Size([2676, 1])
coords.shape = torch.Size([2676, 3])
energy.shape = torch.Size([2676, 1])
--------------------
labels.shape = torch.Size([2096, 1])
coords.shape = torch.Size([2096, 3])
energy.shape = torch.Size([2096, 1])
--------------------
labels.shape = torch.Size([2516, 1])
coords.shape = torch.Size([2516, 3])
energy.shape = torch.Size([2516, 1])
--------------------
labels.shape = torch.Size([3247, 1])
coords.shape = torch.Size([3247, 3])
energy.shape = torch.Size([3247, 1])
--------------------


In [13]:
coord, energy = data[0]
label = labels[0]
coord, energy = coord.cuda(), energy.cuda()
out = unet((coord, energy))

In [14]:
out = out.cpu()

In [15]:
n_cluster = label.unique().size()
print("Number of Clusters = {}".format(n_cluster))
cluster_labels = label.unique(sorted=True)
print(list(cluster_labels.numpy()))
print(cluster_labels)

Number of Clusters = torch.Size([12])
[0, 3, 4, 5, 6, 7, 8, 9, 10, 14, 24, 48]
tensor([ 0,  3,  4,  5,  6,  7,  8,  9, 10, 14, 24, 48])


In [16]:
# Compute average of cluster
#(label == 0).nonzero().squeeze(1)
index = (label == 14).squeeze(1).nonzero()
index = index.squeeze(1)
mu_c = out[index].mean(0)
print("Cluster Mean = {}".format(mu_c))

Cluster Mean = tensor([0.0392, 0.0302, 0.0005, 0.0697, 0.5309, 0.0738, 0.0050, 0.0164, 6.2792,
        0.1473, 1.2774, 4.8091, 1.1092, 0.0260, 0.2112, 5.2653],
       grad_fn=<MeanBackward0>)


### Function for computing centroids

In [17]:
def find_cluster_means(features, label):
    '''
    For a given image, compute the mean clustering point mu_c for each
    cluster label in the feature dimension.
    '''
    n_clusters = label.unique().size()
    cluster_labels = list(label.unique(sorted=True).numpy())
    # Ordering of the cluster means are crucial.
    cluster_means = []
    for c in cluster_labels:
        index = (label == c).squeeze(1).nonzero()
        index = index.squeeze(1)
        mu_c = features[index].mean(0)
        cluster_means.append(mu_c)
    cluster_means = torch.stack(cluster_means)
    return cluster_means

In [18]:
c_means = find_cluster_means(out, label)
print(c_means)

tensor([[2.4266, 2.6478, 1.8963, 0.0000, 0.2725, 0.0150, 0.0118, 0.4292, 0.4624,
         0.0122, 1.6712, 0.0000, 2.5536, 0.0003, 2.2377, 2.9073],
        [2.0479, 2.0550, 1.1468, 0.0000, 0.3504, 0.0278, 0.0098, 0.3071, 1.1611,
         0.0341, 1.7235, 0.0450, 2.3116, 0.0007, 2.0758, 3.0324],
        [1.9780, 0.0322, 1.4784, 0.0000, 0.0001, 0.0200, 2.4804, 0.0051, 1.6261,
         1.8203, 0.0006, 0.0010, 0.3292, 0.8389, 3.1172, 0.0036],
        [1.8207, 0.9258, 1.3938, 0.0000, 0.3216, 0.0939, 0.0515, 0.9003, 0.2954,
         0.1475, 0.6243, 0.0000, 1.7631, 0.0030, 2.1722, 2.0616],
        [2.3648, 0.0000, 0.1677, 0.0000, 0.0000, 0.0230, 0.2310, 0.1347, 0.2688,
         0.4499, 0.0000, 0.0000, 1.0779, 0.0000, 2.6447, 0.0225],
        [2.5145, 0.0274, 0.1270, 0.0000, 0.0000, 0.0255, 0.1967, 0.2036, 0.2963,
         0.3715, 0.0000, 0.0000, 1.2444, 0.0000, 2.4407, 0.4039],
        [2.3238, 0.0190, 0.0825, 0.0000, 0.0000, 0.0043, 0.0598, 0.3208, 0.0742,
         0.1626, 0.0000, 0.0000, 1.81

### Regularization

In [22]:
def regularization(cluster_means, norm=2):
    reg = 0
    n_clusters, feature_dim = cluster_means.shape
    for i in range(n_clusters):
        #print(torch.norm(cluster_means[i, :], norm))
        reg += torch.norm(cluster_means[i, :], norm)
    #print(reg)
    reg /= n_clusters
    return reg

In [None]:
regularization(c_means)

### Variance Term

In [None]:
index = (label == 0).squeeze(1).nonzero()
index = index.squeeze(1)
#print(out[index])
#print(c_means[0])
#print((out[index] - c_means[0]).shape)
#print(torch.norm(out[index] - c_means[0], dim=1))
dists = torch.norm(out[index] - c_means[0], dim=1)
hinge = torch.clamp(dists - 1, 0)
print(hinge.shape)
l = torch.mean(torch.pow(hinge, 2))
print(l)

In [None]:
def variance_loss(features, label, cluster_means, margin=1):
    var_loss = 0
    n_clusters = len(cluster_means)
    cluster_labels = list(label.unique(sorted=True).numpy())
    for i, c in enumerate(cluster_labels):
        index = (label == c).squeeze(1).nonzero()
        index = index.squeeze(1)
        dists = torch.norm(features[index] - cluster_means[i], dim=1)
        hinge = torch.clamp(dists-1, min=0)
        l = torch.mean(torch.pow(hinge, 2))
        var_loss += l
    var_loss /= n_clusters
    return var_loss

In [None]:
variance_loss(out, label, c_means)

### Mean Distance Loss

In [None]:
c1 = c_means[0]
c2 = c_means[1]
#print("c1 = {}, c2 = {}".format(c1, c2))
print(c1 - c2)
print(torch.norm(c1 - c2))

In [None]:
def mean_distance_loss(cluster_means, margin=2):
    mean_loss = 0
    n_clusters = len(cluster_means)
    for i, c1 in enumerate(cluster_means):
        for j, c2 in enumerate(cluster_means):
            if i != j:
                dist = torch.norm(c1 - c2)
                hinge = torch.clamp(2.0 * margin - dist, min=0)
                mean_loss += torch.pow(hinge, 2)
    mean_loss /= (n_clusters - 1) * n_clusters
    return mean_loss

In [None]:
mean_distance_loss(c_means)

In [None]:
from torch.nn.modules.loss import _Loss
from torch.autograd import Variable

In [19]:
class DiscriminativeLoss(torch.nn.Module):
    
    def __init__(self, delta_var=0.5, delta_dist=1.5, norm=2, 
                 alpha=1.0, beta=1.0, gamma=0.001,
                 use_gpu=False):
        super(DiscriminativeLoss, self).__init__()
        self.delta_var = delta_var
        self.delta_dist = delta_dist
        self.norm = norm
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.use_gpu = use_gpu
        
    def find_cluster_means(self, features, label):
        '''
        For a given image, compute the mean clustering point mu_c for each
        cluster label in the feature dimension.
        '''
        n_clusters = label.unique().size()
        cluster_labels = list(label.unique(sorted=True).numpy())
        # Ordering of the cluster means are crucial.
        cluster_means = []
        for c in cluster_labels:
            index = (label == c).squeeze(1).nonzero()
            index = index.squeeze(1)
            mu_c = features[index].mean(0)
            cluster_means.append(mu_c)
        cluster_means = torch.stack(cluster_means)
        print(cluster_means)
        return cluster_means
        
    def variance_loss(self, features, label, cluster_means, margin=1):
        var_loss = 0
        n_clusters = len(cluster_means)
        cluster_labels = list(label.unique(sorted=True).numpy())
        for i, c in enumerate(cluster_labels):
            index = (label == c).squeeze(1).nonzero()
            index = index.squeeze(1)
            dists = torch.norm(features[index] - cluster_means[i], dim=1)
            hinge = torch.clamp(dists-1, min=0)
            l = torch.mean(torch.pow(hinge, 2))
            var_loss += l
        var_loss /= n_clusters
        print(var_loss)
        return var_loss
    
    def mean_distance_loss(self, cluster_means, margin=2):
        mean_loss = 0
        n_clusters = len(cluster_means)
        for i, c1 in enumerate(cluster_means):
            for j, c2 in enumerate(cluster_means):
                if i != j:
                    dist = torch.norm(c1 - c2)
                    hinge = torch.clamp(2.0 * margin - dist, min=0)
                    mean_loss += torch.pow(hinge, 2)
        if n_clusters > 1:
            mean_loss /= (n_clusters - 1) * n_clusters
        print(mean_loss)
        return mean_loss
    
    def regularization(self, cluster_means, norm=2):
        reg = 0
        n_clusters, feature_dim = cluster_means.shape
        for i in range(n_clusters):
            #print(torch.norm(cluster_means[i, :], norm))
            reg += torch.norm(cluster_means[i, :], norm)
        #print(reg)
        reg /= n_clusters
        print(reg)
        return reg
    
    def combine(self, features, label):
        
        c_means = self.find_cluster_means(features, label)
        loss_dist = self.mean_distance_loss(c_means, margin=self.delta_dist)
        loss_var = self.variance_loss(features, label, c_means, margin=self.delta_var)
        loss_reg = self.regularization(c_means, norm=self.norm)
        
        loss = self.alpha * loss_var + self.beta * loss_dist + self.gamma * loss_reg
        
        return loss
    
    def forward(self, x, y):
        
        return self.combine(x, y)
        

In [20]:
loss = DiscriminativeLoss()

In [21]:
loss.forward(out, label)

tensor([[2.4266, 2.6478, 1.8963, 0.0000, 0.2725, 0.0150, 0.0118, 0.4292, 0.4624,
         0.0122, 1.6712, 0.0000, 2.5536, 0.0003, 2.2377, 2.9073],
        [2.0479, 2.0550, 1.1468, 0.0000, 0.3504, 0.0278, 0.0098, 0.3071, 1.1611,
         0.0341, 1.7235, 0.0450, 2.3116, 0.0007, 2.0758, 3.0324],
        [1.9780, 0.0322, 1.4784, 0.0000, 0.0001, 0.0200, 2.4804, 0.0051, 1.6261,
         1.8203, 0.0006, 0.0010, 0.3292, 0.8389, 3.1172, 0.0036],
        [1.8207, 0.9258, 1.3938, 0.0000, 0.3216, 0.0939, 0.0515, 0.9003, 0.2954,
         0.1475, 0.6243, 0.0000, 1.7631, 0.0030, 2.1722, 2.0616],
        [2.3648, 0.0000, 0.1677, 0.0000, 0.0000, 0.0230, 0.2310, 0.1347, 0.2688,
         0.4499, 0.0000, 0.0000, 1.0779, 0.0000, 2.6447, 0.0225],
        [2.5145, 0.0274, 0.1270, 0.0000, 0.0000, 0.0255, 0.1967, 0.2036, 0.2963,
         0.3715, 0.0000, 0.0000, 1.2444, 0.0000, 2.4407, 0.4039],
        [2.3238, 0.0190, 0.0825, 0.0000, 0.0000, 0.0043, 0.0598, 0.3208, 0.0742,
         0.1626, 0.0000, 0.0000, 1.81

tensor(1.8750, grad_fn=<AddBackward0>)

In [22]:
mean_distance_loss(c_means, margin=1.5)

NameError: name 'mean_distance_loss' is not defined