In [1]:
from torch.utils.data import Dataset as TorchDataset
import h5py
import torch
from torch_geometric.data import Dataset as PygDataset, Data
from torch_geometric.loader import DataLoader
import numpy as np
import os.path as osp
import os
import multiprocessing as mp
import torch.nn.functional as F
from torch_geometric.nn import MLP, DynamicEdgeConv, global_max_pool, Linear
import torch_geometric.transforms as T
from tqdm import tqdm
from pytorch_metric_learning.losses import NTXentLoss
cpu_count = mp.cpu_count()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
quark_gluon_path = "../Data/hdf5/processed/quark-gluon-dataset.hdf5"

In [3]:
def subset_dataset(raw_path, processed_path, subset_len = 6000, starter = 0):
    with h5py.File(raw_path, 'r') as f, h5py.File(processed_path, 'w') as p:
        keys = list(f.keys())
        total_events = f[keys[1]].shape[0]
        for key in keys:
            shape = (subset_len,)
            if len(f[key].shape) > 1:
                shape = (subset_len, 125, 125, 3)
            p.create_dataset(key, shape=shape)
        quark_count = 0
        gluon_count = 0
        idx = 0
        for i in range(starter, starter + subset_len):
            if quark_count < subset_len // 2:
                for key in keys:
                    p[key][idx] = f[key][i]
                quark_count += 1
                idx += 1
            elif gluon_count < subset_len // 2:
                for key in keys:
                    p[key][idx] = f[key][i]
                gluon_count += 1
                idx+=1


In [4]:
train_path = "../Data/hdf5/processed/train.hdf5"
val_path = "../Data/hdf5/processed/val.hdf5"
test_path = "../Data/hdf5/processed/test.hdf5"

In [5]:
subset_dataset(quark_gluon_path, train_path, 6000)
subset_dataset(quark_gluon_path, val_path, 1200, 6000)
subset_dataset(quark_gluon_path, test_path, 1200, 7200)

In [6]:
def get_pillow(x):
    return x.transpose((2,1,0))
def get_k_nearest(indices, k = 10):
    edges = None
    for i in range(indices.shape[0]):
        k_nearest = np.sum((indices - indices[i])**2, axis=1).argsort()
        k_nearest_edges = np.array([[i, j] for j in k_nearest[1:k]])
        if edges is None:
            edges = k_nearest_edges
        else:
            edges = np.vstack((edges, k_nearest_edges))
    return edges
def create_graph(idx,quark_gluon_path ,outpath ):
    data = Data()
    with h5py.File(quark_gluon_path, 'r') as f:
        y = f['y'][idx]
        x = f['X_jets'][idx]
        non_zero_indices = np.argwhere(np.sum(x, axis=2))
        non_zero_fetures = x[non_zero_indices[:, 0], non_zero_indices[:, 1]]
        data.x = torch.from_numpy(non_zero_fetures)
        edges = get_k_nearest(non_zero_indices)
        data.edge_index = torch.from_numpy(edges).t().contiguous().to(torch.float64)
        data.y = torch.from_numpy(np.asarray([y]))
        data.pos = torch.from_numpy(non_zero_indices).to(torch.float64)
        torch.save(data, osp.join(outpath, f"{idx}.pt"))

In [7]:
def grapher(root_dir = "../Data/hdf5/processed"):
    files = ["train.hdf5", "val.hdf5", "test.hdf5"]
    for file in files:
        path = osp.join(root_dir , file)
        with h5py.File(path, 'r') as f:
            event_count = len(f["X_jets"])
        data = file.split(".")[0]
        if len(os.listdir("../Data/Graphs/{}/raw".format(data))) < 1:
            for i in range(event_count):
                # print(data)
                create_graph(i, path , "../Data/Graphs/{}/raw".format(data))

In [8]:
grapher()

In [11]:
class QuarkGluonGraphs(PygDataset):
    def __init__(self, root = None, transform = None, pre_transform = None, pre_filter = None, log = True):
        super().__init__(root, transform, pre_transform, pre_filter, log)
        
    @property
    def raw_file_names(self):
        return os.listdir(osp.join(self.root, "raw"))
    
    @property
    def processed_file_names(self):
        return os.listdir(osp.join(self.root, "raw"))
    def download(self):
        pass

    def process(self):
        for raw_path in self.raw_file_names:
            data = torch.load(osp.join(self.raw_dir, raw_path))
            data.y = F.one_hot(data.y.to(torch.int64), 2)
            torch.save(data, osp.join(self.processed_dir, raw_path))
    def len(self):
        return len(self.processed_file_names)
    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f"{idx}.pt"))
        if self.transform is not None:
            data = self.transform(data)
        return data


In [12]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"
augmentation = T.Compose([
    T.NormalizeFeatures(),
    T.RandomJitter(0.03),
    T.RandomFlip(1),
    T.RandomShear(0.2),
    T.ToDevice(device),
])
train_data = QuarkGluonGraphs("../Data/Graphs/train/")
val_data = QuarkGluonGraphs("../Data/Graphs/val/")
test_data = QuarkGluonGraphs("../Data/Graphs/test/")

Processing...
Done!
Processing...
Done!
Processing...
Done!


In [18]:
class Model(torch.nn.Module):
    def __init__(self, k = 2, aggr = "max"):
        super().__init__()
        num_classes = 2
        embedding_size = 1024

        # Feature extraction
        self.conv1 = DynamicEdgeConv(MLP([2*3, 64, 64]), k, aggr)
        self.conv2 = DynamicEdgeConv(MLP([2*64, 128]), k, aggr)

        # Encoder head
        self.lin1 = Linear(128+64, 128)

        # Projection head 
        self.mlp = MLP([128, 256,32], norm=None)

    def forward(self, data, train = True):
        if train:
            # get 2 augmentations of the batch
            augm_1 = augmentation(data)
            augm_2 = augmentation(data)
            
            # extract properties
            pos_1, batch_1 = augm_1.pos, augm_1.batch
            pos_2, batch_2 = augm_2.pos, augm_2.batch

            # Get representation for the first augmented view
            x1 = self.conv1(pos_1, batch_1)
            x2 = self.conv2(x1, batch_1)
            
            h_points_1 = self.lin1(torch.cat([x1, x2], dim=1))

            # Get representation for the second augmented view
            x1 = self.conv1(pos_2, batch_2)
            x2 = self.conv2(x1, batch_2)
            h_points_2 = self.lin1(torch.cat([x1, x2], dim=1))
            

            # Global representation
            h_1 = global_max_pool(h_points_1, batch_1)
            h_2 = global_max_pool(h_points_2, batch_2)

        else:
            x1 = self.conv1(data.pos, data.batch)
            x2 = self.conv2(x1, data.batch)
            h_points = self.lin1(torch.cat([x1, x2], dim=1))
            return global_max_pool(h_points)
        

        # Transormation for loss function
        compact_h_1 = self.mlp(h_1)
        compact_h_2 = self.mlp(h_2)
        return (h_1, h_2, compact_h_1, compact_h_2)


In [19]:
loss_func = NTXentLoss(temperature=0.10)

In [20]:
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

data_loader = DataLoader(train_data, batch_size=32, shuffle=True)

In [21]:
import tqdm

def train():
    model.train()
    total_loss = 0
    for _, data in enumerate(tqdm.tqdm(data_loader)):
        data = data.to(device)
        optimizer.zero_grad()
        # Get data representations
        h_1, h_2, compact_h_1, compact_h_2 = model(data)
        
        # Prepare for loss
        embeddings = torch.cat((compact_h_1, compact_h_2))
        # The same index corresponds to a positive pair
        indices = torch.arange(0, compact_h_1.size(0), device=compact_h_2.device)
        labels = torch.cat((indices, indices))
        loss = loss_func(embeddings, labels)
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_data)


In [22]:

for epoch in range(1, 4):
    loss = train()
    print(f'Epoch {epoch:03d}, Loss: {loss:.4f}')
    scheduler.step()

  0%|          | 0/188 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 must have the same dtype

In [24]:
for batch in train_data:
    data = batch
    break

In [25]:
batch

Data(x=[884, 3], edge_index=[2, 7956], y=[1, 2], pos=[884, 2])