# Model: Pairwise DGCNN approach

Use DGCNN to classify edges from the pairwise trackster dataset:

- https://github.com/WangYueFt/dgcnn
- https://github.com/antao97/dgcnn.pytorch
- https://github.com/hqucms/ParticleNet


Classify cut-edges based on the node data. Node data is enriched using message passing.

Run the network, then derive a loss from the cut edges?

Or we do this in a point cloud segmentation style and see how many nodes in both tracksters are assigned the same class?

Either
- node based - point cloud segmentation
    - label per node or some advanced pooling
- edge based - link prediction
    - label per edge or some fancy loss function focusing only on selected edges

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import sklearn.metrics as metrics

from torch.utils.data import DataLoader, random_split


import networkx as nx

import matplotlib.pyplot as plt
from reco.dataset import PointCloudPairs
from reco.plotting import plot_graph_3D

# data_root = "/mnt/ceph/users/ecuba/processed"
# raw_dir = "/mnt/ceph/users/ecuba/multiparticle_10/"

data_root = "data"
ds_name = "CloseByTwoPion"
raw_dir = f"/Users/ecuba/data/{ds_name}"
file_name = f"{raw_dir}/new_ntuples_15101852_191.root"

In [None]:
ds = PointCloudPairs(
    ds_name,
    data_root,
    raw_dir,
    N_FILES=10,
    balanced=True,
    MAX_DISTANCE=10,
    ENERGY_THRESHOLD=10,
)
ds


In [None]:
# for the DGCNN, we need to get the points into shape:
# -> (batch_size, 9, num_points) 
# 9 seems random here

# need fo fix the labeling here
# 1 if nodes are from the same particle
# 0 if not

x, y = ds[2]
x.shape

In [None]:
class IOStream():
    def __init__(self, path):
        self.f = open(path, 'a')

    def cprint(self, text):
        print(text)
        self.f.write(text+'\n')
        self.f.flush()

    def close(self):
        self.f.close()

In [None]:
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (BS, NP, k)
    return idx

def get_graph_feature(x, device, k=20, idx=None, dim3=False):
    """
    Create a dynamic graph based on the k-neigbourgood
    """
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim3:
            # in the first iteration, only use the xyz coordinates for knn
            idx = knn(x[:, :3], k=k)
        else:
            idx = knn(x, k=k)   # (batch_size, NP, k)
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points
    idx = idx + idx_base
    idx = idx.view(-1) 
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (BS, NP, num_dims)  -> (BS*NP, num_dims) #   BS * NP * k + range(0, BS*NP)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
  
    # for each point in the batch, we got their features combined for k neighbours
    return feature      # (BS, 2*num_dims, NP, k)


class DGCNN_semseg(nn.Module):
    def __init__(self, k=3, emb_dims=1024, dropout=0.5, device=None):
        super(DGCNN_semseg, self).__init__()
        self.k = k

        self.device = device
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)
        self.bn6 = nn.BatchNorm1d(emb_dims)
        self.bn7 = nn.BatchNorm1d(512)
        self.bn8 = nn.BatchNorm1d(256)

        self.conv1 = nn.Sequential(
            nn.Conv2d(8, 64, kernel_size=1, bias=False),
            self.bn1,
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=1, bias=False),
            self.bn2,
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
            self.bn3,
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=1, bias=False),
            self.bn4,
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
            self.bn5,
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.conv6 = nn.Sequential(
            nn.Conv1d(192, emb_dims, kernel_size=1, bias=False),
            self.bn6,
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.conv7 = nn.Sequential(
            nn.Conv1d(202, 512, kernel_size=1, bias=False),
            self.bn7,
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.conv8 = nn.Sequential(
            nn.Conv1d(512, 256, kernel_size=1, bias=False),
            self.bn8,
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.dp1 = nn.Dropout(p=dropout)

        # output layer (BS, 1, NP) - num output features per node
        self.conv9 = nn.Conv1d(256, 1, kernel_size=1, bias=False)
        

    def forward(self, x):
        """
        Propagate the data throught the network
            
            Input dimensions: (BS, COORDS, NP)
                BS: batch size
                NP: number of points
        """

        batch_size = x.size(0)
        num_points = x.size(2)

        k = self.k
        dev = self.device

        x = get_graph_feature(x, dev, k=k, dim3=True)   # (BS, 3, NP) -> (BS, 3*2, NP, k)
        x = self.conv1(x)                               # (BS, 3*2, NP, k) -> (BS, 64, NP, k)
        x = self.conv2(x)                               # (BS, 64, NP, k) -> (BS, 64, NP, k)
        x1 = x.max(dim=-1, keepdim=False)[0]            # (BS, 64, NP, k) -> (BS, 64, NP)

        x = get_graph_feature(x1, dev, k=k)             # (BS, 64, NP) -> (BS, 64*2, NP, k)
        x = self.conv3(x)                               # (BS, 64*2, NP, k) -> (BS, 64, NP, k)
        x = self.conv4(x)                               # (BS, 64, NP, k) -> (BS, 64, NP, k)
        x2 = x.max(dim=-1, keepdim=False)[0]            # (BS, 64, NP, k) -> (BS, 64, NP)

        x = get_graph_feature(x2, dev, k=k)             # (BS, 64, NP) -> (BS, 64*2, NP, k)
        x = self.conv5(x)                               # (BS, 64*2, NP, k) -> (BS, 64, NP, k)
        x3 = x.max(dim=-1, keepdim=False)[0]            # (BS, 64, NP, k) -> (BS, 64, NP)

        x = torch.cat((x1, x2, x3), dim=1)              # (BS, 64*3, NP)

        x = self.conv6(x)                               # (BS, 64*3, NP) -> (BS, emb_dims, NP)
        x = x.max(dim=-1, keepdim=True)[0]              # (BS, emb_dims, NP) -> (BS, emb_dims, 1)

        x = x.repeat(1, 1, num_points)                  # (BS, 1024, NP)
        x = torch.cat((x, x1, x2, x3), dim=1)           # (BS, 1024+64*3, NP)

        x = self.conv7(x)                               # (BS, 1024+64*3, NP) -> (BS, 512, NP)
        x = self.conv8(x)                               # (BS, 512, NP) -> (BS, 256, NP)
        x = self.dp1(x)
        x = self.conv9(x)                               # (BS, 256, NP) -> (BS, 13, NP)
        
        return x

In [None]:
ds_size = len(ds)
test_set_size = ds_size // 10
train_set_size = ds_size - test_set_size
train_set, test_set = random_split(ds, [train_set_size, test_set_size])
print(f"Train samples: {len(train_set)}, Test samples: {len(test_set)}")

train_dl = DataLoader(train_set, batch_size=1, shuffle=True)
test_dl = DataLoader(test_set, batch_size=1, shuffle=True)

In [None]:
def train(device, io):

    epochs = 10

    #Try to load models
    model = DGCNN_semseg(k=3, emb_dims=10, dropout=0.1, device=device).to(device)
    # print(str(model))

    model = nn.DataParallel(model)
    print("Using", torch.cuda.device_count(), "GPUs")

    opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(opt, epochs, eta_min=1e-3)

    best_test_iou = 0
    for epoch in range(epochs):
        ####################
        # Train
        ####################
        train_loss = 0.0
        count = 0.0
        model.train()
        
        train_true_seg = []
        train_pred_seg = []

        for data, seg in train_dl:
            data, seg = data.to(device), seg.to(device)
            batch_size = data.size()[0]
            
            opt.zero_grad()
            seg_pred = model(data)
            loss = F.binary_cross_entropy_with_logits(seg_pred.view(-1, 1), seg.view(-1, 1))

            loss.backward()
            opt.step()
            
            count += batch_size
            train_loss += loss.item() * batch_size
            seg_np = seg.cpu().numpy()                  # (batch_size, num_points)
            pred_np = seg_pred.detach().cpu().numpy()       # (batch_size, num_points)

            train_true_seg.append(seg_np.reshape(-1))
            train_pred_seg.append(pred_np.reshape(-1))
            scheduler.step()
            
        train_true_cls = np.concatenate(train_true_seg).astype(int)
        train_pred_cls = (np.concatenate(train_pred_seg) > 0.5).astype(int)
        
        train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls)
        
        outstr = f'Train {epoch}:\tloss: {train_loss*1.0/count:.6f}, train acc: {train_acc:.6f}'
        io.cprint(outstr)

        ####################
        # Test
        ####################
        test_loss = 0.0
        count = 0.0
        model.eval()

        test_true_seg = []
        test_pred_seg = []
        for data, seg in test_dl:
            data, seg = data.to(device), seg.to(device)
            batch_size = data.size()[0]
            seg_pred = model(data)
            
            loss = F.binary_cross_entropy_with_logits(seg_pred.view(-1, 1), seg.view(-1, 1))

            count += batch_size
            test_loss += loss.item() * batch_size
            
            seg_np = seg.cpu().numpy()
            pred_np = seg_pred.detach().cpu().numpy()
            test_true_seg.append(seg_np.reshape(-1))
            test_pred_seg.append(pred_np.reshape(-1))

        test_true_cls = np.concatenate(test_true_seg).astype(int)
        test_pred_cls = (np.concatenate(test_pred_seg) > 0.5).astype(int)        
        test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)

        outstr = f'Test  {epoch}:\tloss: {test_loss*1.0/count:.6f}, test acc: {test_acc:.6f}'
        io.cprint(outstr)
        torch.save(model.state_dict(), 'model.t7')


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
io = IOStream('run.log')

print(f"Using device: {device}")
train(device, io)