In [1]:
import os
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import polars
from sklearn.model_selection import train_test_split

In [2]:
def normalize_targets(y):
    y_min = np.array([0.0, 0.0], dtype=np.float32)  # azimuth: [0, 2*pi], zenith: [0, pi]
    y_max = np.array([2 * np.pi, np.pi], dtype=np.float32)
    return (y - y_min) / (y_max - y_min)

def denormalize_targets(y_norm):
    y_min = np.array([0.0, 0.0], dtype=np.float32)
    y_max = np.array([2 * np.pi, np.pi], dtype=np.float32)
    return y_norm * (y_max - y_min) + y_min

def normalize_features(df, features):
    df = df.copy()
    cols_min = df[features].min()
    cols_max = df[features].max()
    df[features] = (df[features] - cols_min) / (cols_max - cols_min)
    return df

In [3]:
def angular_dist_score(az_true, zen_true, az_pred, zen_pred):
   
    
    if not (np.all(np.isfinite(az_true)) and
            np.all(np.isfinite(zen_true)) and
            np.all(np.isfinite(az_pred)) and
            np.all(np.isfinite(zen_pred))):
        raise ValueError("All arguments must be finite")
    
    # pre-compute all sine and cosine values
    sa1 = np.sin(az_true)
    ca1 = np.cos(az_true)
    sz1 = np.sin(zen_true)
    cz1 = np.cos(zen_true)
    
    sa2 = np.sin(az_pred)
    ca2 = np.cos(az_pred)
    sz2 = np.sin(zen_pred)
    cz2 = np.cos(zen_pred)
    
    # scalar product of the two cartesian vectors (x = sz*ca, y = sz*sa, z = cz)
    scalar_prod = sz1*sz2*(ca1*ca2 + sa1*sa2) + (cz1*cz2)
    
    # scalar product of two unit vectors is always between -1 and 1, this is against nummerical instability
    # that might otherwise occure from the finite precision of the sine and cosine functions
    scalar_prod =  np.clip(scalar_prod, -1, 1)
    
    # convert back to an angle (in radian)
    return np.average(np.abs(np.arccos(scalar_prod)))

In [4]:
def generating_some_features(batch_file, metadata, sensor_geometry):
    
    
    def join_tables_all(df_meta, df_batch, df_sensor):
        return df_meta.join(df_batch, on='event_id').join(df_sensor, on='sensor_id').with_columns([
            (polars.col('time') - polars.col('time').min()).over('event_id')
        ])

    def generate_features_grouped(dataf):
        return dataf.groupby('event_id').agg([
        polars.col('x').mean().alias('x_mean'),
        polars.col('x').median().alias('x_median'),
        polars.col('y').mean().alias('y_mean'),
        polars.col('y').median().alias('y_median'),
        polars.col('z').mean().alias('z_mean'),
        polars.col('z').median().alias('z_median'),    
        polars.col('time').mean().alias('event_mean_time'),
        polars.col('time').max().alias('event_max_time'),
        polars.col('charge').min().alias('event_min_charge'),
        polars.col('charge').mean().alias('event_mean_charge'),
        polars.col('charge').max().alias('event_max_charge'),
        polars.col('charge').count().alias('overall_count'),
        polars.col('auxiliary').sum().alias('overall_aux_sum'),
        polars.col('charge').sum().alias('sum_charge'),
        (polars.col('auxiliary').sum() / polars.col('auxiliary').count()).alias('aux_ratio'),
        polars.col('sensor_id').n_unique().alias('sensor_count'),
    ])
    
    def add_ranks(dataf):
        return dataf.with_columns(
[
    polars.col('time').rank('ordinal').over('event_id').alias('time_rank_asc'),
    polars.col('time').rank('ordinal', descending=True).over('event_id').alias('time_rank_des'),
    polars.col('charge').rank('ordinal').over('event_id').alias('charge_rank_asc'),
    polars.col('charge').rank('ordinal').over('event_id').alias('charge_rank_des')
])

    def make_geometrical_features(dataf):
        geometrical_features = dataf.select('event_id').unique()
        for direction in ['time_rank_asc','time_rank_des', 'charge_rank_asc', 'charge_rank_des']:
            for direction_axis in ['x', 'y', 'z']:    
                temp_col_1 = dataf.filter(polars.col(direction) == 1).select([
                    polars.col('event_id'),
                    polars.col(direction_axis).over('event_id')
                ]).with_columns([
                    polars.col(direction_axis).alias(direction_axis+'_'+direction+'_1')]
                ).select(polars.col('event_id'), polars.col(direction_axis+'_'+direction+'_1'))

                temp_col_2 = dataf.filter(polars.col("time_rank_asc") == 2).select([
                    polars.col('event_id'),
                    polars.col(direction_axis).over('event_id')
                ]).with_columns([
                    polars.col(direction_axis).alias(direction_axis+'_'+direction+'_2')]
                ).select(polars.col('event_id'), polars.col(direction_axis+'_'+direction+'_2'))

                temp_col_3 = dataf.filter(polars.col("time_rank_asc") == 3).select([
                    polars.col('event_id'),
                    polars.col(direction_axis).over('event_id')
                ]).with_columns([
                    polars.col(direction_axis).alias(direction_axis+'_'+direction+'_3')]
                ).select(polars.col('event_id'), polars.col(direction_axis+'_'+direction+'_3'))

                geometrical_features = geometrical_features.join(temp_col_1, on='event_id', how='left'
                               ).join(temp_col_2, on='event_id', how='left'
                               ).join(temp_col_3, on='event_id', how='left'
                               )
        return geometrical_features.fill_null(1000)

    train_batch = polars.scan_parquet(batch_file).lazy()
    df_train_meta = polars.DataFrame(metadata).lazy()
    df_sensor_geometry = polars.DataFrame(sensor_geometry).with_columns(polars.col('sensor_id').cast(polars.Int16)).lazy()
    
        #Not accounting for aux
    features_grouped_metrics = df_train_meta.pipe(join_tables_all, train_batch, df_sensor_geometry
                      ).pipe(generate_features_grouped).collect()

    geometrical_features = df_train_meta.pipe(join_tables_all, train_batch, df_sensor_geometry
                      ).pipe(add_ranks
                      ).collect().pipe(make_geometrical_features)

    temp_1 = features_grouped_metrics.join(geometrical_features, on='event_id', how='left')


    #AUX = FALSE

    features_grouped_metrics = df_train_meta.pipe(join_tables_all, train_batch, df_sensor_geometry
                      ).filter(polars.col('auxiliary') == False).pipe(generate_features_grouped).collect()

    geometrical_features = df_train_meta.pipe(join_tables_all, train_batch, df_sensor_geometry
                      ).filter(polars.col('auxiliary') == False).pipe(add_ranks
                      ).collect().pipe(make_geometrical_features)

    temp_2 = features_grouped_metrics.join(geometrical_features, on='event_id', how='left')
    
    temp_3 = temp_1.join(temp_2, on = 'event_id', how='left').fill_null(0)
    del temp_1, temp_2, features_grouped_metrics, geometrical_features
    
    temp_3 = temp_3.to_pandas().set_index('event_id')
    
    temp_3 = (temp_3-temp_3.mean())/temp_3.std()
    
    return temp_3

In [5]:
class NeutrinoDataset(Dataset):
    def __init__(self, metadata, sensor_geometry, batch_file, mode="train", max_nodes=100):
        self.metadata = metadata
        self.sensor_geometry = sensor_geometry
        self.sensor_geometry = normalize_features(sensor_geometry, ['x', 'y', 'z']) 
        batch_df = pd.read_parquet(batch_file)
        self.mode = mode
        self.max_nodes = max_nodes
        
        self.batch_features = generating_some_features(batch_file, metadata, sensor_geometry)
        
        # Normalize time and charge features
        batch_df['time'] = (batch_df['time'] - batch_df['time'].min()) / (batch_df['time'].max() - batch_df['time'].min())
        batch_df['charge'] = (batch_df['charge'] - batch_df['charge'].min()) / (batch_df['charge'].max() - batch_df['charge'].min())
        
        self.batch_data = batch_df
        
    def get_event_features(self, event_data, idx):
#         pca_features = process_event_pca(event_data)
        batch_features = self.batch_features.loc[self.metadata.iloc[idx]['event_id']].fillna(0).values

        combined_features = np.hstack([
#             pca_features,
            batch_features
        ])
        assert np.all(~np.isnan(combined_features))
        return combined_features.astype(np.float32)
        
    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        event = self.metadata.iloc[idx]
        event_df = self.batch_data.loc[event.event_id]
        if len(event_df) > 100:
            event_df = event_df.sample(100)
        
        nodes_groupby = event_df.groupby('sensor_id').agg({
            'sensor_id': 'size',
            'time': ['min', 'max', 'mean'],
            'charge': ['min', 'max', 'mean'],
            'auxiliary': 'mean',
        })

        nodes = list(nodes_groupby.index)
        nodes_features = np.concatenate((nodes_groupby.values, self.sensor_geometry.loc[nodes, ['x', 'y', 'z']].values), axis=1)
        nodes_features = torch.tensor(nodes_features, dtype=torch.float)
        # Pad node features
        nodes_features_padded = np.zeros((self.max_nodes, nodes_features.shape[1]))
        nodes_features_padded[:nodes_features.shape[0], :] = nodes_features
        nodes_features_padded = torch.tensor(nodes_features_padded, dtype=torch.float)

        
        nodes_dict = {n: i for i, n in enumerate(nodes)}

        seq = event_df.sort_values(by='time').sensor_id.values
        edges = []
        for i in range(1, len(seq)):
            edges.append([nodes_dict[seq[i-1]], nodes_dict[seq[i]]])
            
            
        adjacency_matrix = torch.zeros((self.max_nodes, self.max_nodes))
        for i, j in edges:
            adjacency_matrix[i, j] = 1
            adjacency_matrix[j, i] = 1
            
        event_features = self.get_event_features(event_df, idx)

        if self.mode == "train":
            y = torch.tensor([event['azimuth'], event['zenith']], dtype=torch.float32)
            y_norm = normalize_targets(y)
            return nodes_features_padded, adjacency_matrix, event_features, y_norm
        else:
            return nodes_features_padded, adjacency_matrix, event_features, event.event_id


In [6]:
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, adjacency_matrix, node_features):
        # Normalize adjacency matrix
        epsilon = 1e-8
        degree_matrix_inv_sqrt = torch.diag_embed(torch.pow(adjacency_matrix.sum(dim=-1) + epsilon, -0.5))
        normalized_adjacency_matrix = degree_matrix_inv_sqrt @ adjacency_matrix @ degree_matrix_inv_sqrt
        
        # Perform graph convolution
        output = self.linear(torch.matmul(normalized_adjacency_matrix, node_features))
        return output

class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.graph_conv1 = GraphConvolution(input_dim, hidden_dim)
        self.graph_conv2 = GraphConvolution(hidden_dim, hidden_dim)
        self.graph_conv3 = GraphConvolution(hidden_dim, hidden_dim)
        self.graph_conv4 = GraphConvolution(hidden_dim, hidden_dim)
        self.graph_conv5 = GraphConvolution(hidden_dim, output_dim)

    def forward(self, adjacency_matrix, node_features):
        x = F.relu(self.graph_conv1(adjacency_matrix, node_features))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.graph_conv2(adjacency_matrix, x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.graph_conv3(adjacency_matrix, x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.graph_conv4(adjacency_matrix, x))
        x = F.dropout(x, training=self.training)
        x = self.graph_conv5(adjacency_matrix, x)
        return x



In [7]:
class NeutrinoDirectionModel(pl.LightningModule):
    def __init__(self, input_features, hidden_features, output_gnn_features, num_fc_layers, input_fc_features, hidden_fc_units, output_features, learning_rate=1e-3):
        super().__init__()
        self.gcn = GCN(input_features, hidden_features, output_gnn_features)
        
        layers = []
        layers.append(nn.BatchNorm1d(input_fc_features))

        for i in range(num_fc_layers):
            in_dim = input_fc_features if i == 0 else hidden_fc_units
            out_dim = output_features if i == num_fc_layers - 1 else hidden_fc_units
            layers.append(nn.Linear(in_dim, out_dim))

            if i < num_fc_layers - 1:
                layers.append(nn.ReLU())
            else:
                layers.append(nn.Sigmoid())

        self.fc_layers = nn.Sequential(*layers)
        
        
        self.lr = learning_rate
        
    def forward(self, node_features, adjacency_matrix, event_features):
        x = self.gcn(adjacency_matrix, node_features)
        x = torch.mean(x, dim=1)
        x = torch.cat((x, event_features), 1)
        x = self.fc_layers(x)
        return x

    def training_step(self, batch, batch_idx):
        node_features, adjacency_matrix, event_features, y = batch
        y_hat = self(node_features, adjacency_matrix, event_features)
        loss = F.l1_loss(y_hat, y, reduction='mean')
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        node_features, adjacency_matrix, event_features, y_norm = batch
        y_hat_norm = self(node_features, adjacency_matrix, event_features)
        loss = F.l1_loss(y_hat_norm, y_norm, reduction='mean')

        # Denormalize target values and predictions
        y = denormalize_targets(y_norm.cpu().numpy())
        y_hat = denormalize_targets(y_hat_norm.detach().cpu().numpy())

        # Calculate angular distance score
        az_true, zen_true = y[:, 0], y[:, 1]
        az_pred, zen_pred = y_hat[:, 0], y_hat[:, 1]
        ang_dist = angular_dist_score(az_true, zen_true, az_pred, zen_pred)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("angular_dist_score", ang_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


In [8]:
def train_one_epoch_on_batch(model, metadata, sensor_geometry, batch_file, val_split=0.2):
    batch_id = int(batch_file.split('_')[-1].split('.')[0])
    batch_metadata = metadata[metadata['batch_id'] == batch_id]
    
    train_metadata, val_metadata = train_test_split(batch_metadata, test_size=val_split, random_state=42)
    
    train_dataset = NeutrinoDataset(train_metadata, sensor_geometry, batch_file, mode="train")
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    
    val_dataset = NeutrinoDataset(val_metadata, sensor_geometry, batch_file, mode="train")
    val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    trainer = pl.Trainer(max_epochs=1)
    trainer.fit(model, train_dataloader, val_dataloader)
    
    del train_dataset, train_dataloader, val_dataset, val_dataloader
    
    return model

In [9]:
data_dir = '/icecube-neutrinos-in-deep-ice/'
metadata_file = f'{data_dir}/train_meta.parquet'
batch_files = [f'{data_dir}train/batch_{i}.parquet' for i in range(1, 2)] # You can set up to 660 files

sensor_geometry = pd.read_csv(os.path.join(data_dir, 'sensor_geometry.csv'))
metadata = pd.read_parquet(metadata_file)

In [10]:
config = {
    'input_features': 11,
    'hidden_features': 64,
    'output_gnn_features': 32,
    'num_fc_layers': 4,
    'input_fc_features': 104 + 32,
    'hidden_fc_units': 512,
    'output_features': 2,
    'learning_rate': 0.001
}


model = NeutrinoDirectionModel(**config)

In [11]:
for epoch in range(1): # You can set any number of epochs but first use all of 660 files
    for batch_file in batch_files:
        model = train_one_epoch_on_batch(model, metadata, sensor_geometry, batch_file)

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]