In [None]:
from pathlib import Path
import json

import torch
import torch.nn as nn
from sklearn.metrics import mean_absolute_error, root_mean_squared_error, r2_score, mean_absolute_percentage_error

from models.vanilla_nn_bands_only import NNBandsOnly
from src.data_utils import LabelledTraitData

In [2]:
dpath = Path('/Users/campbelli/Documents/geofm-plant-traits/data')
var = 'N.Percent'

In [3]:
metadata_path = dpath / 'metadata/trait_stats.json'
with open(metadata_path, 'r') as f:
    trait_stats = json.load(f)

In [None]:
# Load model, data and trait stats.
model = NNBandsOnly(seed=42)
dataset = LabelledTraitData(dpath, var)
model.set_stats(trait_stats[var])

X_train, y_train = dataset.train_data, dataset.train_labels
X_test, y_test = dataset.test_data, dataset.test_labels
X_val, y_val = dataset.val_data, dataset.val_labels

X_train, y_train = model.configure_data(X_train, y_train)
X_test, y_test = model.configure_data(X_test, y_test)
X_val, y_val = model.configure_data(X_val, y_val)

In [7]:
model.fit(X_train, y_train, X_val, y_val)

Epoch 0, Loss: 0.2489585429430008, MAPE: 0.1954821394281944
Epoch 5, Loss: 0.136930912733078, MAPE: 0.13928760443183016
Epoch 10, Loss: 0.12038370221853256, MAPE: 0.1284997367096753
Epoch 15, Loss: 0.11500648781657219, MAPE: 0.1264718053622882
Epoch 20, Loss: 0.17681562155485153, MAPE: 0.15838245884798482
Epoch 25, Loss: 0.1285201497375965, MAPE: 0.13361563120004386
Epoch 30, Loss: 0.17632093280553818, MAPE: 0.15878916690372202
Epoch 35, Loss: 0.11975161358714104, MAPE: 0.12903667256621673
Epoch 40, Loss: 0.13223588466644287, MAPE: 0.1379974684679955
Epoch 45, Loss: 0.13649629056453705, MAPE: 0.1409771693639214
Epoch 50, Loss: 0.13162535429000854, MAPE: 0.13774876803060757
Epoch 55, Loss: 0.13980120420455933, MAPE: 0.1425987383728794
Epoch 60, Loss: 0.13768170028924942, MAPE: 0.14178685510381397
Epoch 65, Loss: 0.13893123716115952, MAPE: 0.14264014967487465
Epoch 70, Loss: 0.13910368829965591, MAPE: 0.14249677160583013
Epoch 75, Loss: 0.12980833649635315, MAPE: 0.13697901435493298
Epoc

In [8]:
def get_activations(model, X):
    activations = []
    hooks = []

    # Register hooks to capture activations
    for layer in model.model.children():
        if isinstance(layer, nn.Linear): # for fully-connected layers only
            # Register a forward hook to capture the output
            # of the layer
            hooks.append(layer.register_forward_hook(lambda m, i, o: activations.append(o.detach())))

    # Forward pass
    with torch.no_grad():
        model.model(X)

    # Remove hooks from layers
    for hook in hooks:
        hook.remove()

    return activations

# Retrieve activations for test data
activations = get_activations(model, X_test)

# Check for dead neurons
for i, activation in enumerate(activations):
    dead_neurons = (activation == 0).all(dim=0)
    print(f"Layer {i}: {dead_neurons.sum().item()} dead neurons")

Layer 0: 0 dead neurons
Layer 1: 0 dead neurons
Layer 2: 0 dead neurons
Layer 3: 0 dead neurons
Layer 4: 0 dead neurons
Layer 5: 0 dead neurons
Layer 6: 0 dead neurons
Layer 7: 0 dead neurons


In [9]:
# Evaluate model
y_pred = model.predict(X_test)

y_pred, y_test = model.unstandardise(y_pred, y_test)

metrics = {
    # Calculate metrics.
    "R_squared": r2_score(y_test, y_pred),
    "RMSE": root_mean_squared_error(y_test, y_pred),
    "MAE": mean_absolute_error(y_test, y_pred),
    "MAPE": mean_absolute_percentage_error(y_test, y_pred) * 100,
}

In [10]:
metrics

{'R_squared': 0.6098239421844482,
 'RMSE': 0.3210979998111725,
 'MAE': 0.25824713706970215,
 'MAPE': 13.816478848457336}