In [1]:
import functions
import matplotlib.pyplot as plt
import pandas as pd
import torch

In [2]:
DATASET_PATH = "/scratch/kirschstein/LamaH-CE"
CHECKPOINT_PATH = "./checkpoint/topology"

In [3]:
adj_df = pd.read_csv(f"{DATASET_PATH}/processed/adjacency.csv")
weight_cols = adj_df[["dist_hdn", "elev_diff", "strm_slope"]].values
stream_length = torch.tensor(weight_cols[:, 0], dtype=torch.float)
elevation_difference = torch.tensor(weight_cols[:, 1], dtype=torch.float)
average_slope = torch.tensor(weight_cols[:, 2], dtype=torch.float)

In [4]:
for architecture in ["ResGCN", "GCNII"]:
    for edge_orientation in ["downstream", "upstream", "bidirectional"]:
        print(architecture, edge_orientation)
        corrmats = []
        for fold in range(6):
            chkpt = torch.load(f"{CHECKPOINT_PATH}/{architecture}_{edge_orientation}_learned_{fold}.run")
            learned_weights = chkpt["history"]["best_model_params"]["edge_weights"].nan_to_num().cpu()
            corrmats.append(torch.corrcoef(torch.stack([learned_weights, stream_length, elevation_difference, average_slope]))) 
        print("correlation matrix mean:")
        print(torch.stack(corrmats).mean(dim=0))
        print("correlation matrix std:")
        print(torch.stack(corrmats).std(dim=0))
        print()

ResGCN downstream
correlation matrix mean:
tensor([[ 1.0000, -0.2213,  0.1000,  0.1683],
        [-0.2213,  1.0000,  0.3331, -0.1741],
        [ 0.1000,  0.3331,  1.0000,  0.5844],
        [ 0.1683, -0.1741,  0.5844,  1.0000]])
correlation matrix std:
tensor([[6.1559e-08, 9.8565e-02, 2.1019e-02, 3.8168e-02],
        [9.8565e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.1019e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.8168e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00]])

ResGCN upstream
correlation matrix mean:
tensor([[ 1.0000,  0.0422, -0.3080, -0.2930],
        [ 0.0422,  1.0000,  0.3331, -0.1741],
        [-0.3080,  0.3331,  1.0000,  0.5844],
        [-0.2930, -0.1741,  0.5844,  1.0000]])
correlation matrix std:
tensor([[0.0000, 0.0075, 0.0145, 0.0092],
        [0.0075, 0.0000, 0.0000, 0.0000],
        [0.0145, 0.0000, 0.0000, 0.0000],
        [0.0092, 0.0000, 0.0000, 0.0000]])

ResGCN bidirectional
correlation matrix mean:
tensor([[ 1.0000, -0.0023, -0.2353, -0.2400],
  

In [5]:
weight_stats_df = pd.DataFrame()
for edge_orientation in ["downstream", "upstream", "bidirectional"]:
    for architecture in ["ResGCN", "GCNII"]:
        print(architecture, edge_orientation)
        stats = []
        descriptors = ["mean", "std", "min", "25%", "median", "75%", "max"]
        for fold in range(6):
            chkpt = torch.load(f"{CHECKPOINT_PATH}/{architecture}_{edge_orientation}_learned_{fold}.run")
            learned_weights = chkpt["history"]["best_model_params"]["edge_weights"].nan_to_num().cpu().clamp(min=0)
            stats.append([learned_weights.mean(), 
                          learned_weights.std(), 
                          learned_weights.min(),
                          learned_weights.quantile(0.25),
                          learned_weights.median(),
                          learned_weights.quantile(0.75),
                          learned_weights.max()])
        stats = torch.tensor(stats)
        for i, descriptor in enumerate(descriptors):
            stat_mean = stats[:, i].mean()
            stat_std = stats[:, i].std()
            weight_stats_df.loc[descriptor, f"{edge_orientation}_{architecture}"] = f"{stat_mean:.3f} ± {stat_std:.3f}"
weight_stats_df

ResGCN downstream
GCNII downstream
ResGCN upstream
GCNII upstream
ResGCN bidirectional
GCNII bidirectional


Unnamed: 0,downstream_ResGCN,downstream_GCNII,upstream_ResGCN,upstream_GCNII,bidirectional_ResGCN,bidirectional_GCNII
mean,0.989 ± 0.013,0.768 ± 0.002,0.666 ± 0.011,0.793 ± 0.008,0.917 ± 0.006,0.955 ± 0.008
std,0.511 ± 0.212,0.665 ± 0.025,0.537 ± 0.006,0.825 ± 0.022,0.635 ± 0.036,0.630 ± 0.026
min,0.109 ± 0.268,0.000 ± 0.000,0.000 ± 0.000,0.000 ± 0.000,0.000 ± 0.000,0.000 ± 0.000
25%,0.624 ± 0.160,0.279 ± 0.028,0.201 ± 0.022,0.227 ± 0.022,0.451 ± 0.032,0.473 ± 0.021
median,1.042 ± 0.019,0.599 ± 0.021,0.588 ± 0.026,0.570 ± 0.015,0.851 ± 0.024,0.919 ± 0.017
75%,1.365 ± 0.151,1.172 ± 0.031,1.049 ± 0.027,1.134 ± 0.037,1.298 ± 0.036,1.306 ± 0.027
max,3.257 ± 0.983,5.463 ± 0.895,2.217 ± 0.052,6.772 ± 0.489,3.197 ± 0.256,3.515 ± 0.286
