<a href="https://colab.research.google.com/github/harishk30/CamelsHetroGNN/blob/main/AstridLHTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Loading Data

In [None]:
pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl.metadata (64 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/64.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.2/64.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3


In [None]:
import numpy as np
import h5py

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
catalog = '/content/drive/MyDrive/groups_090.hdf5'
f = h5py.File(catalog, 'r')

In [None]:
M_star = f['Subhalo/SubhaloMassType'][:,4]*1e10
pos  = f['Subhalo/SubhaloPos'][:]/1e3
vel = f['Subhalo/SubhaloVel'][:]
met = f['Subhalo/SubhaloStarMetallicity'][:]

In [None]:
def load_and_filter_data(file, mass_threshold=2e8):
    with h5py.File(file, 'r') as f:
        positions = f['Subhalo/SubhaloPos'][:]/1e3  # Convert to Mpc/h
        vel = f['Subhalo/SubhaloVel'][:]
        metallicities = f['Subhalo/SubhaloStarMetallicity'][:]
        masses = f['Subhalo/SubhaloMassType'][:,4]*1e10  # Stellar mass
        omega_m = f['Header'].attrs['Omega0']

    # Filter galaxies based on the stellar mass threshold
    mask = masses > mass_threshold
    positions = positions[mask]
    vel = vel[mask]
    metallicities = metallicities[mask]
    masses = masses[mask]

    return positions, vel, metallicities, masses, omega_m

In [None]:
def apply_periodic_boundary_conditions(positions, box_size):
    # Wrap positions to the box size
    positions = positions % box_size
    return positions

In [None]:
def minimum_image_distance(pos1, pos2, box_size):
    # Calculate the minimum image distance between two points
    delta = np.abs(pos1 - pos2)
    delta = np.where(delta > 0.5 * box_size, box_size - delta, delta)
    return np.sqrt((delta ** 2).sum(axis=-1))

In [None]:
from scipy.spatial import KDTree

def distance(point1, point2):
    return np.linalg.norm(point1 - point2)

def create_edges_knn(points, k=6):
    edges = []
    edge_value = []

    # Create a KDTree for efficient nearest neighbor search
    point_tree = KDTree(points)

    for i in range(len(points)):
        # Query the k nearest neighbors for each point
        _, neighbors = point_tree.query(points[i], k=k+1)

        for j in neighbors[1:]:  # Skip the first neighbor because it's the point itself
            # Add an edge between the point and its neighbor
            edges.append([i, j])

            # Compute the distance between the points as the edge value
            edge_value.append(distance(points[i], points[j]))

    return [edges, edge_value]

In [None]:
from tqdm import tqdm
def min_distance(positions, box_size = 25):
    min_distance = np.inf
    max_distance = 0

    # Iterate over all pairs of galaxies
    for i in tqdm(range(len(positions))):
        for j in range(i + 1, len(positions)):
            dist = minimum_image_distance(positions[i], positions[j], box_size)
            if dist < min_distance:
                min_distance = dist
            if dist > max_distance:
                max_distance = dist

    # Print the results
    print(f"Minimum distance: {min_distance} Mpc/h")
    print(f"Maximum distance: {max_distance} Mpc/h")

In [None]:
def minimum_image_distance_vectorized(positions, box_size = 25):
    num_galaxies = positions.shape[0]

    # Compute pairwise differences in each dimension
    diff = positions[:, np.newaxis, :] - positions[np.newaxis, :, :]

    # Apply periodic boundary conditions
    diff = np.abs(diff)
    diff = np.where(diff > 0.5 * box_size, box_size - diff, diff)

    # Compute the Euclidean distance
    dist = np.sqrt(np.sum(diff ** 2, axis=-1))

    return dist

In [None]:
from tqdm import tqdm
from scipy.spatial import cKDTree
def create_edges_knn_pbc(points, box_size = 25, k=6):
    tree = KDTree(points, boxsize=box_size)

    edges = []
    edge_values = []
    '''
    distances = minimum_image_distance_vectorized(points, box_size)
    # Mask the diagonal (self-distances which are zero)
    np.fill_diagonal(distances, np.inf)
    # Get the minimum and maximum distances
    min_distance = np.min(distances)
    max_distance = np.max(np.triu(distances, k=1))
    print(min_distance, max_distance)
    '''

    min_distance = np.inf
    max_distance = 0
    large_distance_count = 0

    for i in tqdm(range(len(points)), desc="Processing points"):
        distances, neighbors = tree.query(points[i], k=k+1)
        for j, tree_dist in zip(neighbors[1:], distances[1:]):
            if j != i and j < len(points):
                actual_distance = minimum_image_distance(points[i], points[j], box_size)
                edges.append([i, j])
                edge_values.append(actual_distance)
                min_distance = min(min_distance, actual_distance)
                max_distance = max(max_distance, actual_distance)
    print(min_distance, max_distance)
    return np.array(edges), np.array(edge_values)


In [None]:
def create_points(positions, masses, vel, met):
    point_features = []
    for i, pos in enumerate(positions):
        #point_features.append(list(pos) + list(vel[i]) + [masses[i]] + [met[i]])
        point_features.append(list(pos) + [masses[i]] + [met[i]])
    return point_features

In [None]:
from torch_geometric.data import Data
import torch
def create_graph(file_path, k_val=6):
    positions, velocity, metallicities, masses, omega_m = load_and_filter_data(file_path)
    edges, edge_values = create_edges_knn_pbc(positions, 25, k_val)
    point_values = create_points(positions, masses, velocity, metallicities)

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    point_values = torch.tensor(point_values, dtype=torch.float)
    edge_value = torch.tensor(edge_values, dtype=torch.float)

    return [point_values, edge_index, edge_value, omega_m]

In [None]:
def turn_data(graph):
    graph_data = Data(x=graph[0], edge_index=graph[1], edge_attr=graph[2], y = graph[3])
    return graph_data

In [None]:
def create_data(file_path, k_val=6):
    graph = create_graph(file_path, k_val)
    return turn_data(graph)

In [None]:
create_data('/content/drive/MyDrive/groups_090.hdf5', 6)

Processing points: 100%|██████████| 709/709 [00:00<00:00, 3142.77it/s]

0.016440034 7.6689534





Data(x=[709, 5], edge_index=[2, 4254], edge_attr=[4254], y=0.3862)

In [None]:
def calculate_normalization_params(data_list):
    # Concatenate all node features and edge attributes
    all_x = torch.cat([data.x for data in data_list], dim=0)
    all_edge_attr = torch.cat([data.edge_attr for data in data_list], dim=0)

    # Calculate mean and std for node features and edge attributes
    x_mean, x_std = all_x.mean(dim=0), all_x.std(dim=0)
    edge_attr_mean, edge_attr_std = all_edge_attr.mean(dim=0), all_edge_attr.std(dim=0)

    return (x_mean, x_std), (edge_attr_mean, edge_attr_std)

def normalize_dataset(data_list, x_params, edge_attr_params):
    x_mean, x_std = x_params
    edge_attr_mean, edge_attr_std = edge_attr_params

    normalized_data_list = []
    for data in data_list:
        normalized_x = (data.x - x_mean) / (x_std + 1e-8)
        normalized_edge_attr = (data.edge_attr - edge_attr_mean) / (edge_attr_std + 1e-8)

        # Create a new Data object with normalized features and original y value
        normalized_data = Data(x=normalized_x,
                               edge_index=data.edge_index,
                               edge_attr=normalized_edge_attr,
                               y=data.y)  # Preserve the original y value

        normalized_data_list.append(normalized_data)

    return normalized_data_list

In [None]:
import torch
from torch.utils.data import random_split
from torch_geometric.data import DataLoader
directory = ''
def load_all_graphs(directory, k_val=6, box_size=25):
    file_list = os.listdir(directory)
    data_list = []
    for file_name in tqdm(file_list, desc="Loading HDF5 files"):
        file_path = os.path.join(directory, file_name)
        graph_data = create_data(file_path, k_val)
        data_list.append(graph_data)
    return data_list

# Load all graphs
data_list = load_all_graphs(directory)

# Calculate normalization parameters based on all graphs
x_params, edge_attr_params = calculate_normalization_params(data_list)

# Normalize the dataset using the calculated parameters
normalized_data_list = normalize_dataset(data_list, x_params, edge_attr_params)

# Calculate the lengths for the 70-15-15 split
total_len = len(normalized_data_list)
train_len = int(0.7 * total_len)
val_len = int(0.15 * total_len)
test_len = total_len - train_len - val_len  # Ensures all data is used

# Perform the split
train_data, val_data, test_data = random_split(normalized_data_list, [train_len, val_len, test_len])

# Create DataLoaders for each split
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Training Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool

class ComplexGAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers=4, heads=4, dropout_rate=0.1):
        super(ComplexGAT, self).__init__()
        self.dropout_rate = dropout_rate

        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True, edge_dim=1)

        self.convs = nn.ModuleList([
            GATConv(hidden_channels * heads, hidden_channels, heads=heads, concat=True, edge_dim=1)
            for _ in range(num_layers - 2)
        ])

        self.conv_last = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=False, edge_dim=1)

        self.fc1 = nn.Linear(hidden_channels, hidden_channels)
        self.fc2 = nn.Linear(hidden_channels, 1)  # Output a single value for omega_m

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        x = F.relu(self.conv1(x, edge_index, edge_attr=edge_attr))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)

        for conv in self.convs:
            x = F.relu(conv(x, edge_index, edge_attr=edge_attr))
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        x = self.conv_last(x, edge_index, edge_attr=edge_attr)

        x = global_mean_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x.squeeze()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np

def train_model(model, train_loader, val_loader, device, num_epochs=100, lr=0.001):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    best_val_loss = float('inf')
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        val_loss = 0
        val_predictions = []
        val_true = []
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                out = model(data)
                val_loss += criterion(out, data.y).item()
                val_predictions.extend(out.cpu().numpy())
                val_true.extend(data.y.cpu().numpy())

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        val_mse = mean_squared_error(val_true, val_predictions)
        val_r2 = r2_score(val_true, val_predictions)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val MSE: {val_mse:.4f}, Val R2: {val_r2:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model.state_dict()

    model.load_state_dict(best_model)
    return model

def evaluate_model(model, test_loader, device):
    model.eval()
    test_predictions = []
    test_true = []
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            out = model(data)
            test_predictions.extend(out.cpu().numpy())
            test_true.extend(data.y.cpu().numpy())

    test_mse = mean_squared_error(test_true, test_predictions)
    test_r2 = r2_score(test_true, test_predictions)

    print(f'Test MSE: {test_mse:.4f}, Test R2: {test_r2:.4f}')
    return test_predictions, test_true

In [None]:
import torch
from torch_geometric.loader import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# Initialize model
in_channels = normalized_dataset[0].num_node_features
hidden_channels = 64
num_layers = 4
heads = 4
dropout_rate = 0.1

model = ComplexGAT(in_channels, hidden_channels, num_layers, heads, dropout_rate)

# Train model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trained_model = train_model(model, train_loader, val_loader, device, num_epochs=100, lr=0.001)

# Evaluating Model

In [None]:
# Evaluate model
test_predictions, test_true = evaluate_model(trained_model, test_loader, device)

# Plot results
plt.figure(figsize=(10, 6))
plt.scatter(test_true, test_predictions, alpha=0.5)
plt.plot([min(test_true), max(test_true)], [min(test_true), max(test_true)], 'r--', lw=2)
plt.xlabel('True $\Omega_{m}$')
plt.ylabel('Predicted $\Omega_{m}$')
plt.title('True vs Predicted $\Omega_{m}$')
plt.show()