In [1]:
!pip install torch torchvision torchaudio
!pip install torch-geometric

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
from torch.nn import LayerNorm
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split


seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


df = pd.read_csv("surface_n2o_compilation.csv")
required_cols = ["latitude", "longitude", "depth", "atmPressure", "temperature", "salinity", "month", "dn2o_ppb"]
df = df.dropna(subset=required_cols).reset_index(drop=True)

# feature engineering
df["temp_sal"] = df["temperature"] * df["salinity"]
df["lat_lon"] = df["latitude"] * df["longitude"]
df["month_sin"] = np.sin(2 * np.pi * df["month"] / 12)
df["month_cos"] = np.cos(2 * np.pi * df["month"] / 12)

features = ["latitude", "longitude", "depth", "atmPressure", "temperature",
            "salinity", "month_sin", "month_cos", "temp_sal", "lat_lon"]
X = df[features].values
y = df["dn2o_ppb"].values.reshape(-1, 1)

# scaling
feature_scaler = StandardScaler()
target_scaler = StandardScaler()
X_scaled = feature_scaler.fit_transform(X)
y_scaled = target_scaler.fit_transform(y)

#build graph edges
coords = df[["latitude", "longitude", "depth"]].values

def build_edge_index(k=8):
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm="auto").fit(coords)
    _, indices = nbrs.kneighbors(coords)
    edges = [[i, j] for i in range(len(df)) for j in indices[i] if i != j]
    return torch.tensor(edges, dtype=torch.long).t().contiguous()

edge_index = build_edge_index(k=8).to(device)

x = torch.tensor(X_scaled, dtype=torch.float)
y = torch.tensor(y_scaled, dtype=torch.float)

data = Data(x=x.to(device), edge_index=edge_index, y=y.to(device))

# train/test masks
idx = np.arange(len(df))
train_idx, test_idx = train_test_split(idx, test_size=0.15, random_state=seed)
train_idx, val_idx = train_test_split(train_idx, test_size=0.1765, random_state=seed)

masks = {
    "train": torch.zeros(len(df), dtype=torch.bool),
    "val": torch.zeros(len(df), dtype=torch.bool),
    "test": torch.zeros(len(df), dtype=torch.bool)
}
masks["train"][train_idx] = True
masks["val"][val_idx] = True
masks["test"][test_idx] = True

data.train_mask = (masks["train"] | masks["val"]).to(device)
data.test_mask = masks["test"].to(device)


class GAT(torch.nn.Module):
    def __init__(self, in_channels, heads=4):
        super().__init__()
        hidden_dim = 32
        self.dropout = 0.1

        self.att1 = GATv2Conv(in_channels, hidden_dim, heads=heads, dropout=self.dropout)
        self.norm1 = LayerNorm(hidden_dim * heads)

        self.att2 = GATv2Conv(hidden_dim * heads, hidden_dim, heads=1, dropout=self.dropout)
        self.norm2 = LayerNorm(hidden_dim)

        self.res_proj = torch.nn.Linear(hidden_dim * heads, hidden_dim)

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, 16),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Dropout(self.dropout),
            torch.nn.Linear(16, 1)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x1 = F.dropout(F.leaky_relu(self.norm1(self.att1(x, edge_index))), p=self.dropout, training=self.training)
        x2 = self.att2(x1, edge_index)
        x2 = F.dropout(F.leaky_relu(self.norm2(x2 + self.res_proj(x1))), p=self.dropout, training=self.training)
        return self.mlp(x2)


model = GAT(in_channels=x.shape[1]).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.05, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=150)
criterion = torch.nn.MSELoss()

best_loss = float("inf")
counter, patience = 30, 0
epochs = 250
grad_clip = 1.0

for epoch in range(1, epochs + 1):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()
    scheduler.step()

    if loss.item() < best_loss:
        best_loss = loss.item()
        torch.save(model.state_dict(), "best_gat_model.pth")
        patience = 0
    else:
        patience += 1

    if epoch % 20 == 0 or patience == 0:
        print(f"Epoch {epoch} | Loss: {loss.item():.5f} | LR: {scheduler.get_last_lr()[0]:.6f}")
    if patience > counter:
        print("Early stopping.")
        break


model.load_state_dict(torch.load("best_gat_model.pth"))
model.eval()
with torch.no_grad():
    preds = model(data)[data.test_mask].cpu().numpy()
    y_true = data.y[data.test_mask].cpu().numpy()

preds_rescaled = target_scaler.inverse_transform(preds)
y_true_rescaled = target_scaler.inverse_transform(y_true)

rmse = np.sqrt(np.mean((preds_rescaled - y_true_rescaled) ** 2))
r2 = 1 - np.sum((preds_rescaled - y_true_rescaled) ** 2) / np.sum((y_true_rescaled - np.mean(y_true_rescaled)) ** 2)

print(f"\nFinal GAT Test RMSE: {rmse:.2f} ppb")
print(f"Final GAT Test R²: {r2:.4f}")


Epoch 1 | Loss: 1.20407 | LR: 0.049995
Epoch 3 | Loss: 1.05458 | LR: 0.049951
Epoch 6 | Loss: 1.01242 | LR: 0.049803
Epoch 8 | Loss: 1.01169 | LR: 0.049650
Epoch 9 | Loss: 0.97567 | LR: 0.049557
Epoch 10 | Loss: 0.94394 | LR: 0.049454
Epoch 11 | Loss: 0.93142 | LR: 0.049339
Epoch 12 | Loss: 0.88465 | LR: 0.049215
Epoch 13 | Loss: 0.81876 | LR: 0.049079
Epoch 14 | Loss: 0.76753 | LR: 0.048933
Epoch 15 | Loss: 0.75441 | LR: 0.048776
Epoch 17 | Loss: 0.72551 | LR: 0.048432
Epoch 18 | Loss: 0.66936 | LR: 0.048244
Epoch 19 | Loss: 0.64938 | LR: 0.048047
Epoch 20 | Loss: 0.62819 | LR: 0.047839
Epoch 21 | Loss: 0.61100 | LR: 0.047621
Epoch 22 | Loss: 0.55049 | LR: 0.047393
Epoch 23 | Loss: 0.54308 | LR: 0.047155
Epoch 25 | Loss: 0.50350 | LR: 0.046651
Epoch 26 | Loss: 0.48120 | LR: 0.046384
Epoch 27 | Loss: 0.44181 | LR: 0.046108
Epoch 29 | Loss: 0.40784 | LR: 0.045529
Epoch 30 | Loss: 0.39779 | LR: 0.045225
Epoch 32 | Loss: 0.37302 | LR: 0.044592
Epoch 34 | Loss: 0.35675 | LR: 0.043925
Epoch