In [1]:
# Import necessary libraries
import os
import random
import math
from math import radians, cos, sin, asin, sqrt

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import Parameter

# Set seaborn style for better aesthetics
sns.set(style="whitegrid")

In [2]:
# Ensure reproducibility
RANDOM_SEED = 123
def seed_torch(seed=RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [3]:
# Define hyperparameters
num_nodes = 7
num_features = 5  # ['new_confirmed', 'new_deceased', 'newAdmissions', 'hospitalCases', 'covidOccupiedMVBeds']
num_timesteps_input = 14
num_timesteps_output = 7
k = 8
hidA = 64
hidR = 40
hidP = 1
n_layer = 2
dropout = 0.5
learning_rate = 0.001
num_epochs = 100
batch_size = 32
threshold_distance = 300  # in km

# %%
# Reference dataset for correction
REFERENCE_COORDINATES = {
    "East of England": (52.1766, 0.425889),
    "Midlands": (52.7269, -1.458210),
    "London": (51.4923, -0.308660),
    "South East": (51.4341, -0.969570),
    "South West": (50.8112, -3.633430),
    "North West": (53.8981, -2.657550),
    "North East and Yorkshire": (54.5378, -2.180390),
}

In [4]:
# Define hyperparameters
num_nodes = 7
num_features = 5  # ['new_confirmed', 'new_deceased', 'newAdmissions', 'hospitalCases', 'covidOccupiedMVBeds']
num_timesteps_input = 14
num_timesteps_output = 7
k = 8
hidA = 64
hidR = 40
hidP = 1
n_layer = 2
dropout = 0.5
learning_rate = 0.001
num_epochs = 100
batch_size = 32
threshold_distance = 300  # in km

In [5]:
# Reference dataset for correction
REFERENCE_COORDINATES = {
    "East of England": (52.1766, 0.425889),
    "Midlands": (52.7269, -1.458210),
    "London": (51.4923, -0.308660),
    "South East": (51.4341, -0.969570),
    "South West": (50.8112, -3.633430),
    "North West": (53.8981, -2.657550),
    "North East and Yorkshire": (54.5378, -2.180390),
}

In [6]:
# Load and preprocess data
def load_and_correct_data(data, reference_coordinates):
    # Correct latitude and longitude based on reference coordinates
    for region, coords in reference_coordinates.items():
        data.loc[data['areaName'] == region, ['latitude', 'longitude']] = coords

    # Display unique latitudes and longitudes for verification
    unique_latitudes = data['latitude'].unique()
    unique_longitudes = data['longitude'].unique()
    print("Unique latitudes:", unique_latitudes)
    print("Unique longitudes:", unique_longitudes)

    return data

In [7]:
# NHSRegionDataset class
class NHSRegionDataset(Dataset):
    def __init__(self, data, num_timesteps_input, num_timesteps_output, transform=None):
        self.data = data.copy()
        self.num_timesteps_input = num_timesteps_input
        self.num_timesteps_output = num_timesteps_output
        self.transform = transform

        # Sort data by region and date
        self.data.sort_values(['areaName', 'date'], inplace=True)
        self.regions = self.data['areaName'].unique()
        self.num_nodes = len(self.regions)
        self.region_to_idx = {region: idx for idx, region in enumerate(self.regions)}
        self.data['region_idx'] = self.data['areaName'].map(self.region_to_idx)

        # Features to include
        self.features = ['new_confirmed', 'new_deceased', 'newAdmissions', 'hospitalCases', 'covidOccupiedMVBeds']

        # Pivot data to create a time-series matrix
        # This will create a multi-level column index: (feature, region_idx)
        self.pivot = self.data.pivot(index='date', columns='region_idx', values=self.features)

        # Fill missing values
        self.pivot.ffill(inplace=True)
        self.pivot.fillna(0, inplace=True)

        # Convert to numpy array and reshape
        # The pivot will have shape (num_dates, num_nodes * num_features)
        self.feature_array = self.pivot.values
        self.num_features = len(self.features)
        self.num_nodes = len(self.regions)
        self.num_dates = self.feature_array.shape[0]
        self.feature_array = self.feature_array.reshape(self.num_dates, self.num_nodes, self.num_features)

        # Validate population consistency across each region
        populations = self.data.groupby('areaName')['population'].unique()
        inconsistent_populations = populations[populations.apply(len) > 1]
        if not inconsistent_populations.empty:
            raise ValueError(f"Inconsistent population values found in regions: {inconsistent_populations.index.tolist()}")

    def __len__(self):
        return self.num_dates - self.num_timesteps_input - self.num_timesteps_output + 1

    def __getitem__(self, idx):
        X = self.feature_array[idx:idx + self.num_timesteps_input]  # Input sequence
        Y = self.feature_array[idx + self.num_timesteps_input:idx + self.num_timesteps_input + self.num_timesteps_output, :, 4]  # Target variable: 'covidOccupiedMVBeds'

        if self.transform:
            X = self.transform(X)
            Y = self.transform(Y)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(Y, dtype=torch.float32)


In [8]:
# Define adjacency matrix computation
def compute_geographic_adjacency(regions, latitudes, longitudes, threshold=threshold_distance):
    """
    Computes the geographic adjacency matrix based on the haversine distance between regions.
    
    Args:
        regions (list): List of region names.
        latitudes (list): List of latitudes corresponding to regions.
        longitudes (list): List of longitudes corresponding to regions.
        threshold (float): Distance threshold in kilometers to consider adjacency.

    Returns:
        torch.Tensor: Adjacency matrix of shape (num_nodes, num_nodes).
    """
    def haversine(lat1, lon1, lat2, lon2):
        # Convert decimal degrees to radians
        lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
        # Haversine formula
        dlat = lat2 - lat1 
        dlon = lon2 - lon1 
        a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
        c = 2 * asin(sqrt(a)) 
        r = 6371  # Radius of earth in kilometers
        return c * r

    num_nodes = len(regions)
    adj_matrix = np.zeros((num_nodes, num_nodes))

    for i in range(num_nodes):
        for j in range(num_nodes):
            if i == j:
                adj_matrix[i][j] = 1  # Self-loop
            elif adj_matrix[i][j] == 0:
                distance = haversine(latitudes[i], longitudes[i], latitudes[j], longitudes[j])
                if distance <= threshold:
                    adj_matrix[i][j] = 1
                    adj_matrix[j][i] = 1  # Symmetry

    return torch.tensor(adj_matrix, dtype=torch.float32)

# %%
# Define Laplace Matrix computation
def getLaplaceMat(batch_size, m, adj):
    """
    Computes the Laplacian matrix for the graph convolution.

    Args:
        batch_size (int): Number of graphs in the batch.
        m (int): Number of nodes in each graph.
        adj (torch.Tensor): Adjacency matrix of shape (batch_size, m, m).

    Returns:
        torch.Tensor: Laplacian matrix of shape (batch_size, m, m).
    """
    i_mat = torch.eye(m).to(adj.device).unsqueeze(0).expand(batch_size, m, m)
    o_mat = torch.ones(m).to(adj.device).unsqueeze(0).expand(batch_size, m, m)
    adj = torch.where(adj > 0, o_mat, adj)

    d_mat_in = torch.sum(adj, dim=1)
    d_mat_out = torch.sum(adj, dim=2)
    d_mat = d_mat_out.unsqueeze(2) + 1e-12
    d_mat = torch.pow(d_mat, -1)
    d_mat = i_mat * d_mat

    laplace_mat = torch.bmm(d_mat, adj)
    return laplace_mat

In [9]:
# Define model components

class GraphConvLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.act = nn.ELU()
        nn.init.xavier_uniform_(self.weight)

        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            stdv = 1.0 / math.sqrt(self.bias.size(0))
            self.bias.data.uniform_(-stdv, stdv)
        else:
            self.register_parameter('bias', None)

    def forward(self, feature, adj):
        support = torch.matmul(feature, self.weight)
        output = torch.matmul(adj, support)

        if self.bias is not None:
            return self.act(output + self.bias)
        else:
            return self.act(output)

class GraphLearner(nn.Module):
    def __init__(self, hidden_dim, tanhalpha=1):
        super(GraphLearner, self).__init__()
        self.hid = hidden_dim
        self.linear1 = nn.Linear(self.hid, self.hid)
        self.linear2 = nn.Linear(self.hid, self.hid)
        self.alpha = tanhalpha

    def forward(self, embedding):
        """
        Learns the adjacency matrix based on node embeddings.

        Args:
            embedding (torch.Tensor): Node embeddings of shape (batch_size, num_nodes, hidden_dim).

        Returns:
            torch.Tensor: Learned adjacency matrix of shape (batch_size, num_nodes, num_nodes).
        """
        nodevec1 = self.linear1(embedding)
        nodevec2 = self.linear2(embedding)
        nodevec1 = self.alpha * nodevec1
        nodevec2 = self.alpha * nodevec2
        nodevec1 = torch.tanh(nodevec1)
        nodevec2 = torch.tanh(nodevec2)

        adj = torch.bmm(nodevec1, nodevec2.transpose(1, 2)) - torch.bmm(nodevec2, nodevec1.transpose(1, 2))
        adj = self.alpha * adj
        adj = torch.relu(torch.tanh(adj))
        return adj

class ConvBranch(nn.Module):
    def __init__(self, m, in_channels, out_channels, kernel_size, dilation_factor=2, hidP=1, isPool=True):
        super(ConvBranch, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), dilation=(dilation_factor, 1))
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.isPool = isPool
        if self.isPool and hidP is not None:
            self.pooling = nn.AdaptiveMaxPool2d((hidP, m))
        self.activate = nn.Tanh()

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv(x)
        x = self.batchnorm(x)
        if self.isPool and hasattr(self, 'pooling'):
            x = self.pooling(x)
        x = x.view(batch_size, -1, x.size(-1))
        x = self.activate(x)
        return x

class RegionAwareConv(nn.Module):
    def __init__(self, nfeat, P, m, k, hidP, dilation_factor=2):
        super(RegionAwareConv, self).__init__()
        self.conv_l1 = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=3, dilation_factor=1, hidP=hidP)
        self.conv_l2 = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=5, dilation_factor=1, hidP=hidP)
        self.conv_p1 = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=3, dilation_factor=dilation_factor, hidP=hidP)
        self.conv_p2 = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=5, dilation_factor=dilation_factor, hidP=hidP)
        self.conv_g = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=P, dilation_factor=1, hidP=None, isPool=False)
        self.activate = nn.Tanh()

    def forward(self, x):
        """
        Applies multiple convolution branches to extract local, periodic, and global features.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, num_features, P, m).

        Returns:
            torch.Tensor: Output tensor after convolution and activation, shape (batch_size, k*4, m).
        """
        x_l1 = self.conv_l1(x)
        x_l2 = self.conv_l2(x)
        x_local = torch.cat([x_l1, x_l2], dim=1)

        x_p1 = self.conv_p1(x)
        x_p2 = self.conv_p2(x)
        x_period = torch.cat([x_p1, x_p2], dim=1)

        x_global = self.conv_g(x)

        x = torch.cat([x_local, x_period, x_global], dim=1)
        if x.size(1) > 0:
            x = x.permute(0, 2, 1)
        x = self.activate(x)
        return x

class EpiGNN(nn.Module):
    """
    Adapted Epidemiological Graph Neural Network (EpiGNN) for Hospitalization Prediction.

    Parameters
    ----------
    num_nodes : int
        Number of nodes in the graph (e.g., 7 NHS regions).
    num_features : int
        Number of features per node per timestep.
    num_timesteps_input : int
        Number of timesteps considered for each input sample.
    num_timesteps_output : int
        Number of output timesteps to predict.
    k : int, optional
        Number of local neighborhoods to consider in the graph learning layer. Default: 8.
    hidA : int, optional
        Dimension of attention in the model. Default: 64.
    hidR : int, optional
        Dimension of hidden layers in the recurrent neural network part. Default: 40.
    hidP : int, optional
        Dimension of positional encoding in the model. Default: 1.
    n_layer : int, optional
        Number of layers in the graph neural network. Default: 2.
    dropout : float, optional
        Dropout rate for regularization during training to prevent overfitting. Default: 0.5.
    device : str, optional
        The device (cpu or gpu) on which the model will be run. Default: 'cpu'.

    Returns
    -------
    torch.Tensor
        A tensor of shape (batch_size, num_timesteps_output, num_nodes), representing the predicted ICU bed usage for each node over future timesteps.
    """
    def __init__(self, 
                num_nodes, 
                num_features, 
                num_timesteps_input,
                num_timesteps_output, 
                k=8, 
                hidA=64, 
                hidR=40, 
                hidP=1, 
                n_layer=2, 
                dropout=0.5, 
                device='cpu'):
        super(EpiGNN, self).__init__()
        self.device = device
        self.nfeat = num_features
        self.m = num_nodes
        self.w = num_timesteps_input
        self.droprate = dropout
        self.hidR = hidR
        self.hidA = hidA
        self.hidP = hidP
        self.k = k
        self.n = n_layer
        self.dropout_layer = nn.Dropout(self.droprate)

        # Feature embedding
        self.backbone = RegionAwareConv(nfeat=num_features, P=self.w, m=self.m, k=self.k, hidP=self.hidP)

        # Global transmission risk encoding
        self.WQ = nn.Linear(self.hidR, self.hidA)
        self.WK = nn.Linear(self.hidR, self.hidA)
        self.leakyrelu = nn.LeakyReLU(inplace=True)
        self.t_enc = nn.Linear(1, self.hidR)

        # Local transmission risk encoding
        self.s_enc = nn.Linear(1, self.hidR)

        # External resources (if any, optional)
        self.external_parameter = nn.Parameter(torch.FloatTensor(self.m, self.m), requires_grad=True)
        nn.init.xavier_uniform_(self.external_parameter)

        # Graph Generator and GCN
        self.d_gate = nn.Parameter(torch.FloatTensor(self.m, self.m), requires_grad=True)
        nn.init.xavier_uniform_(self.d_gate)
        self.graphGen = GraphLearner(self.hidR)
        self.GNNBlocks = nn.ModuleList([GraphConvLayer(in_features=self.hidR, out_features=self.hidR) for _ in range(self.n)])

        # Prediction layer
        self.output = nn.Linear(self.hidR * 2, num_timesteps_output)

        self.init_weights()

    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)  # Best practice
            else:
                stdv = 1.0 / math.sqrt(p.size(0))
                p.data.uniform_(-stdv, stdv)

    def forward(self, X, adj, states=None, dynamic_adj=None, index=None):
        """
        Forward pass of the adapted EpiGNN model for Hospitalization Prediction.

        Parameters
        ----------
        X : torch.Tensor
            Input features tensor with shape (batch_size, num_timesteps_input, num_nodes, num_features).
        adj : torch.Tensor
            Static adjacency matrix with shape (num_nodes, num_nodes) or (batch_size, num_nodes, num_nodes).
        states : torch.Tensor, optional
            Current state variables tensor (if applicable). Default: None.
        dynamic_adj : torch.Tensor, optional
            Dynamic adjacency matrix (if applicable). Default: None.
        index : torch.Tensor, optional
            Indices for external resources (if applicable). Default: None.

        Returns
        -------
        torch.Tensor
            The output tensor of shape (batch_size, num_timesteps_output, num_nodes),
            representing the predicted ICU bed usage for each node over future timesteps.
        """
        adj = adj.bool().float()
        batch_size = X.size(0)  # batch_size, T, N, F

        # Ensure adj has batch dimension
        if adj.dim() == 2:
            adj = adj.unsqueeze(0).repeat(batch_size, 1, 1)  # Shape: (batch_size, m, m)

        # Step 1: Use multi-scale convolution to extract feature embedding (RegionAwareConv)
        # Reshape X to (batch_size, num_features, P, m) where P = num_timesteps_input, m = num_nodes
        X_reshaped = X.permute(0, 3, 1, 2)  # (batch_size, F, T, N)
        temp_emb = self.backbone(X_reshaped)  # Shape: (batch_size, hidR, m)

        # Step 2: Generate global transmission risk encoding
        query = self.WQ(temp_emb)  # Shape: (batch_size, m, hidA)
        query = self.dropout_layer(query)
        key = self.WK(temp_emb)    # Shape: (batch_size, m, hidA)
        key = self.dropout_layer(key)
        attn = torch.bmm(query, key.transpose(1, 2))  # Shape: (batch_size, m, m)
        attn = F.normalize(attn, dim=-1, p=2, eps=1e-12)  # Normalize
        attn = torch.sum(attn, dim=-1, keepdim=True)      # Shape: (batch_size, m, 1)
        
        # Optional: Verify the shape
        print(f"attn shape before t_enc: {attn.shape}")  # Should output: torch.Size([32, 7, 1])
        
        t_enc = self.t_enc(attn)                         # Shape: (batch_size, m, hidR)
        t_enc = self.dropout_layer(t_enc)

        # Step 3: Generate local transmission risk encoding
        d = torch.sum(adj, dim=1).unsqueeze(1)            # Shape: (batch_size, 1, m)
        s_enc = self.s_enc(d)                             # Shape: (batch_size, m, hidR)
        s_enc = self.dropout_layer(s_enc)

        # Step 4: Three embedding fusion
        feat_emb = temp_emb + t_enc + s_enc  

        # Step 5: Region-Aware Graph Learner
        # Load external resource if available (optional)
        if self.external_parameter is not None and index is not None:
            extra_adj_list = []
            zeros_mt = torch.zeros((self.m, self.m)).to(adj.device)
            for i in range(batch_size):
                offset = 20
                if i - offset >= 0:
                    idx = i - offset
                    extra_adj_list.append(self.external_parameter[index[i], :, :].unsqueeze(0))
                else:
                    extra_adj_list.append(zeros_mt.unsqueeze(0))
            extra_info = torch.cat(extra_adj_list, dim=0)  # Shape: (batch_size, m, m)
            external_info = torch.mul(self.external_parameter, extra_info)
            external_info = F.relu(external_info)
        else:
            external_info = 0

        # Apply Graph Learner to generate a graph
        # Adjusted to handle batched adjacency matrices
        d_mat = torch.bmm(torch.sum(adj, dim=1).unsqueeze(2), torch.sum(adj, dim=1).unsqueeze(1))  # Shape: (batch_size, 1, 1)
        d_mat = torch.mul(self.d_gate, d_mat)                                                # Shape: (batch_size, m, m)
        d_mat = torch.sigmoid(d_mat)                                                         # Shape: (batch_size, m, m)
        spatial_adj = torch.mul(d_mat, adj)                                                  # Shape: (batch_size, m, m)
        learned_adj = self.graphGen(feat_emb)                                                # Shape: (batch_size, m, m)

        # If additional information, fuse
        if external_info != 0:
            adj = learned_adj + spatial_adj + external_info
        else:
            adj = learned_adj + spatial_adj

        # Get Laplace adjacency matrix
        laplace_adj = getLaplaceMat(batch_size, self.m, adj)

        # Step 6: Graph Convolution Network
        node_state = feat_emb                                       # Shape: (batch_size, m, hidR)
        node_state_list = []
        for layer in self.GNNBlocks:
            node_state = layer(node_state, laplace_adj)           # Shape: (batch_size, m, hidR)
            node_state = self.dropout_layer(node_state)
            node_state_list.append(node_state)
        
        # Concatenate node states from all GNN layers
        node_state = torch.cat(node_state_list, dim=-1)           # Shape: (batch_size, m, hidR * n_layer)

        # Concatenate initial features and GNN outputs
        node_state = torch.cat([node_state, feat_emb], dim=-1)    # Shape: (batch_size, m, hidR * n_layer + hidR)

        # Step 7: Prediction
        res = self.output(node_state)                              # Shape: (batch_size, m, num_timesteps_output)
        res = res.transpose(1, 2)                                   # Shape: (batch_size, num_timesteps_output, m)
        
        print(f"Shape of X: {X.shape}")
        print(f"Shape of X_reshaped: {X_reshaped.shape}")
        print(f"Shape of temp_emb: {temp_emb.shape}")
        print(f"Shape of attn: {attn.shape}")


        return res  # Predicted covidOccupiedMVBeds

In [10]:
# Load data
csv_path = '../data/merged_nhs_covid_data.csv'  # Adjust this path if necessary
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"The specified CSV file does not exist: {csv_path}")

data = pd.read_csv(csv_path, parse_dates=['date'])
data = load_and_correct_data(data, REFERENCE_COORDINATES)

Unique latitudes: [52.1766 51.4923 52.7269 54.5378 53.8981 51.4341 50.8112]
Unique longitudes: [ 0.425889 -0.30866  -1.45821  -2.18039  -2.65755  -0.96957  -3.63343 ]


In [11]:
dataset = NHSRegionDataset(data, num_timesteps_input=num_timesteps_input, num_timesteps_output=num_timesteps_output)
print(f"Total samples in dataset: {len(dataset)}")


Total samples in dataset: 875


In [12]:
# Compute adjacency matrix
regions = dataset.regions.tolist()
latitudes = [data[data['areaName'] == region]['latitude'].iloc[0] for region in regions]
longitudes = [data[data['areaName'] == region]['longitude'].iloc[0] for region in regions]
adj = compute_geographic_adjacency(regions, latitudes, longitudes).to(device)
print("Adjacency matrix:")
print(adj)

Adjacency matrix:
tensor([[1., 1., 1., 0., 1., 1., 0.],
        [1., 1., 1., 0., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 1., 1., 0.],
        [1., 1., 1., 0., 1., 1., 1.],
        [0., 1., 1., 0., 0., 1., 1.]], device='cuda:0')


In [13]:
# Initialize the EpiGNN model
model = EpiGNN(
    num_nodes=num_nodes,
    num_features=num_features,
    num_timesteps_input=num_timesteps_input,
    num_timesteps_output=num_timesteps_output,
    k=k,
    hidA=hidA,
    hidR=hidR,
    hidP=hidP,
    n_layer=n_layer,
    dropout=dropout,
    device=device
).to(device)

In [14]:
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [15]:
# Split dataset into training, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(RANDOM_SEED)
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")


Training samples: 612
Validation samples: 131
Test samples: 132


In [16]:
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)


In [17]:
# Training loop with validation and checkpointing
best_val_loss = float('inf')
early_stopping_patience = 10
patience_counter = 0
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    # Training Phase
    model.train()
    epoch_train_loss = 0
    for batch_X, batch_Y in train_loader:
        batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
        optimizer.zero_grad()
        
        # Expand adjacency matrix to include batch dimension
        batch_size_current = batch_X.size(0)
        batch_adj = adj.unsqueeze(0).repeat(batch_size_current, 1, 1)  # Shape: (batch_size, m, m)
        
        pred = model(batch_X, batch_adj)
        loss = criterion(pred, batch_Y)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()

    avg_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation Phase
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for batch_X, batch_Y in val_loader:
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
            
            # Expand adjacency matrix to include batch dimension
            batch_size_current = batch_X.size(0)
            batch_adj = adj.unsqueeze(0).repeat(batch_size_current, 1, 1)  # Shape: (batch_size, m, m)
            
            pred = model(batch_X, batch_adj)
            loss = criterion(pred, batch_Y)
            epoch_val_loss += loss.item()

    avg_val_loss = epoch_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    # Checkpointing
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_epignn_adapted_model.pth')
        print("Model checkpoint saved.")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break

attn shape before t_enc: torch.Size([32, 7, 1])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x7 and 1x40)

In [1]:
import os
import random
import math
from math import radians, cos, sin, asin, sqrt

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import Parameter
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, r2_score

# Ensure reproducibility across runs
RANDOM_SEED = 123
def seed_torch(seed=RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch()

# Select device for computations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

########################################
# Hyperparameters
########################################
num_nodes = 7
num_features = 5  # features: [new_confirmed, new_deceased, newAdmissions, hospitalCases, covidOccupiedMVBeds]
num_timesteps_input = 14
num_timesteps_output = 7
k = 8
hidA = 64
hidR = 40
hidP = 1
n_layer = 2
dropout = 0.5
learning_rate = 0.001
num_epochs = 100
batch_size = 32
threshold_distance = 300  # km threshold for considering adjacency

########################################
# Reference coordinates for correction
########################################
REFERENCE_COORDINATES = {
    "East of England": (52.1766, 0.425889),
    "Midlands": (52.7269, -1.458210),
    "London": (51.4923, -0.308660),
    "South East": (51.4341, -0.969570),
    "South West": (50.8112, -3.633430),
    "North West": (53.8981, -2.657550),
    "North East and Yorkshire": (54.5378, -2.180390),
}

########################################
# Data Loading and Preprocessing
########################################
def load_and_correct_data(data, reference_coordinates):
    # Correct geographic coordinates to ensure consistency
    for region, coords in reference_coordinates.items():
        data.loc[data['areaName'] == region, ['latitude', 'longitude']] = coords

    # Diagnostic printouts for verification
    print("Unique latitudes:", data['latitude'].unique())
    print("Unique longitudes:", data['longitude'].unique())

    return data

class NHSRegionDataset(Dataset):
    def __init__(self, data, num_timesteps_input, num_timesteps_output, transform=None):
        self.data = data.copy()
        self.num_timesteps_input = num_timesteps_input
        self.num_timesteps_output = num_timesteps_output
        self.transform = transform

        # Sort data chronologically by region
        self.data.sort_values(['areaName', 'date'], inplace=True)
        self.regions = self.data['areaName'].unique()
        self.num_nodes = len(self.regions)
        self.region_to_idx = {region: idx for idx, region in enumerate(self.regions)}
        self.data['region_idx'] = self.data['areaName'].map(self.region_to_idx)

        self.features = ['new_confirmed', 'new_deceased', 'newAdmissions', 'hospitalCases', 'covidOccupiedMVBeds']

        # Pivot to create time-series: index=date, columns=(feature, region_idx)
        self.pivot = self.data.pivot(index='date', columns='region_idx', values=self.features)
        self.pivot.ffill(inplace=True)
        self.pivot.fillna(0, inplace=True)

        # Reshape into (num_dates, num_nodes, num_features)
        self.feature_array = self.pivot.values
        self.num_features = len(self.features)
        self.num_dates = self.feature_array.shape[0]
        self.feature_array = self.feature_array.reshape(self.num_dates, self.num_nodes, self.num_features)

        # Validate population consistency
        populations = self.data.groupby('areaName')['population'].unique()
        inconsistent_pop = populations[populations.apply(len) > 1]
        if not inconsistent_pop.empty:
            raise ValueError(f"Inconsistent population values in regions: {inconsistent_pop.index.tolist()}")

    def __len__(self):
        return self.num_dates - self.num_timesteps_input - self.num_timesteps_output + 1

    def __getitem__(self, idx):
        X = self.feature_array[idx:idx + self.num_timesteps_input]  # (num_timesteps_input, num_nodes, num_features)
        Y = self.feature_array[idx + self.num_timesteps_input:idx + self.num_timesteps_input + self.num_timesteps_output, :, 4]

        if self.transform:
            X = self.transform(X)
            Y = self.transform(Y)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(Y, dtype=torch.float32)

def compute_geographic_adjacency(regions, latitudes, longitudes, threshold=threshold_distance):
    # Compute adjacency based on geographical proximity using haversine distance
    def haversine(lat1, lon1, lat2, lon2):
        lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
        dlat = lat2 - lat1
        dlon = lon2 - lon1
        a = sin(dlat/2)**2 + cos(lat1)*cos(lat2)*sin(dlon/2)**2
        c = 2 * asin(sqrt(a))
        r = 6371  # Earth radius in km
        return c * r

    num_nodes = len(regions)
    adj_matrix = np.zeros((num_nodes, num_nodes))
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i == j:
                adj_matrix[i][j] = 1
            elif adj_matrix[i][j] == 0:
                distance = haversine(latitudes[i], longitudes[i], latitudes[j], longitudes[j])
                if distance <= threshold:
                    adj_matrix[i][j] = 1
                    adj_matrix[j][i] = 1
    return torch.tensor(adj_matrix, dtype=torch.float32)

def getLaplaceMat(batch_size, m, adj):
    # Compute Laplacian for graph convolution
    i_mat = torch.eye(m).to(adj.device).unsqueeze(0).expand(batch_size, m, m)
    o_mat = torch.ones(m).to(adj.device).unsqueeze(0).expand(batch_size, m, m)
    adj = torch.where(adj > 0, o_mat, adj)

    d_mat_out = torch.sum(adj, dim=2)
    d_mat = d_mat_out.unsqueeze(2) + 1e-12
    d_mat = torch.pow(d_mat, -1)
    d_mat = i_mat * d_mat

    laplace_mat = torch.bmm(d_mat, adj)
    return laplace_mat

########################################
# Model Definition
########################################
class GraphConvLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvLayer, self).__init__()
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.act = nn.ELU()
        nn.init.xavier_uniform_(self.weight)
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            stdv = 1.0 / math.sqrt(out_features)
            self.bias.data.uniform_(-stdv, stdv)
        else:
            self.register_parameter('bias', None)

    def forward(self, feature, adj):
        # feature: (batch_size, m, in_features)
        support = torch.matmul(feature, self.weight)      # (batch_size, m, out_features)
        output = torch.matmul(adj, support)               # (batch_size, m, out_features)
        if self.bias is not None:
            return self.act(output + self.bias)
        else:
            return self.act(output)

class GraphLearner(nn.Module):
    def __init__(self, hidden_dim, tanhalpha=1):
        super(GraphLearner, self).__init__()
        self.hid = hidden_dim
        self.linear1 = nn.Linear(self.hid, self.hid)
        self.linear2 = nn.Linear(self.hid, self.hid)
        self.alpha = tanhalpha

    def forward(self, embedding):
        # embedding: (batch_size, m, hidR)
        nodevec1 = torch.tanh(self.alpha * self.linear1(embedding))
        nodevec2 = torch.tanh(self.alpha * self.linear2(embedding))

        # Learn adjacency structure
        adj = torch.bmm(nodevec1, nodevec2.transpose(1, 2)) - torch.bmm(nodevec2, nodevec1.transpose(1, 2))
        adj = self.alpha * adj
        adj = torch.relu(torch.tanh(adj))
        return adj

class ConvBranch(nn.Module):
    def __init__(self, m, in_channels, out_channels, kernel_size, dilation_factor=2, hidP=1, isPool=True):
        super(ConvBranch, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), dilation=(dilation_factor, 1))
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.isPool = isPool
        if self.isPool and hidP is not None:
            self.pooling = nn.AdaptiveMaxPool2d((hidP, m))
        self.activate = nn.Tanh()

    def forward(self, x):
        # x: (batch_size, in_channels, T, m)
        x = self.conv(x)
        x = self.batchnorm(x)
        if self.isPool and hasattr(self, 'pooling'):
            x = self.pooling(x)
        batch_size = x.size(0)
        x = x.view(batch_size, -1, x.size(-1))  # (batch_size, out_channels * hidP?, m)
        x = self.activate(x)
        return x

class RegionAwareConv(nn.Module):
    def __init__(self, nfeat, P, m, k, hidP, dilation_factor=2):
        super(RegionAwareConv, self).__init__()
        self.conv_l1 = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=3, dilation_factor=1, hidP=hidP)
        self.conv_l2 = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=5, dilation_factor=1, hidP=hidP)
        self.conv_p1 = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=3, dilation_factor=dilation_factor, hidP=hidP)
        self.conv_p2 = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=5, dilation_factor=dilation_factor, hidP=hidP)
        self.conv_g = ConvBranch(m=m, in_channels=nfeat, out_channels=k, kernel_size=P, dilation_factor=1, hidP=None, isPool=False)
        self.activate = nn.Tanh()

    def forward(self, x):
        # x: (batch_size, num_features, P, m)
        # Extract local features
        x_l1 = self.conv_l1(x)
        x_l2 = self.conv_l2(x)
        x_local = torch.cat([x_l1, x_l2], dim=1)

        # Extract periodic features
        x_p1 = self.conv_p1(x)
        x_p2 = self.conv_p2(x)
        x_period = torch.cat([x_p1, x_p2], dim=1)

        # Extract global features
        x_global = self.conv_g(x)

        # Concatenate all features
        x = torch.cat([x_local, x_period, x_global], dim=1)
        x = self.activate(x)
        return x

class EpiGNN(nn.Module):
    """
    EpiGNN: A GNN-based model for epidemiological forecasting of hospital bed usage.
    """
    def __init__(self, 
                 num_nodes, 
                 num_features, 
                 num_timesteps_input,
                 num_timesteps_output, 
                 k=8, 
                 hidA=64, 
                 hidR=40, 
                 hidP=1, 
                 n_layer=2, 
                 dropout=0.5, 
                 device='cpu'):
        super(EpiGNN, self).__init__()
        self.device = device
        self.m = num_nodes
        self.w = num_timesteps_input
        self.hidR = hidR
        self.hidA = hidA
        self.hidP = hidP
        self.k = k
        self.n = n_layer
        self.dropout_layer = nn.Dropout(dropout)

        # Backbone for feature embedding
        self.backbone = RegionAwareConv(nfeat=num_features, P=self.w, m=self.m, k=self.k, hidP=self.hidP)

        # Global transmission risk encoding
        self.WQ = nn.Linear(self.hidR, self.hidA)
        self.WK = nn.Linear(self.hidR, self.hidA)
        self.t_enc = nn.Linear(1, self.hidR)

        # Local transmission risk encoding
        # s_enc expects (batch_size, m, 1) as input -> produce (batch_size, m, hidR)
        self.s_enc = nn.Linear(1, self.hidR)

        # External parameters (optional)
        self.external_parameter = nn.Parameter(torch.FloatTensor(self.m, self.m), requires_grad=True)
        nn.init.xavier_uniform_(self.external_parameter)

        # Graph Generation and GCN
        self.d_gate = nn.Parameter(torch.FloatTensor(self.m, self.m), requires_grad=True)
        nn.init.xavier_uniform_(self.d_gate)
        self.graphGen = GraphLearner(self.hidR)
        self.GNNBlocks = nn.ModuleList([GraphConvLayer(in_features=self.hidR, out_features=self.hidR) for _ in range(self.n)])

        # Prediction layer
        self.output = nn.Linear(self.hidR * 2, num_timesteps_output)
        self.init_weights()

    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                stdv = 1.0 / math.sqrt(p.size(0))
                p.data.uniform_(-stdv, stdv)

    def forward(self, X, adj, states=None, dynamic_adj=None, index=None):
        # X: (batch_size, T, m, F)
        # Permute X to (batch_size, F, T, m) for backbone
        adj = adj.bool().float()
        batch_size = X.size(0)

        if adj.dim() == 2:
            adj = adj.unsqueeze(0).expand(batch_size, self.m, self.m)

        X_reshaped = X.permute(0, 3, 1, 2)  # (batch_size, F, T, m)
        temp_emb = self.backbone(X_reshaped)  # (batch_size, k*..., m) after merging, final shape: (batch_size, hidR, m)

        # Ensure temp_emb is (batch_size, m, hidR)
        temp_emb = temp_emb.permute(0, 2, 1)  # (batch_size, m, hidR)

        # Global transmission encoding
        query = self.dropout_layer(self.WQ(temp_emb))  # (batch_size, m, hidA)
        key = self.dropout_layer(self.WK(temp_emb))    # (batch_size, m, hidA)

        attn = torch.bmm(query, key.transpose(1, 2))   # (batch_size, m, m)
        attn = F.normalize(attn, dim=-1, p=2, eps=1e-12)
        attn = torch.sum(attn, dim=-1, keepdim=True)   # (batch_size, m, 1)
        t_enc = self.dropout_layer(self.t_enc(attn))   # (batch_size, m, hidR)

        # Local transmission risk encoding
        # Ensure d is (batch_size, m, 1)
        d = torch.sum(adj, dim=1).unsqueeze(2)  # (batch_size, m, 1)
        s_enc = self.dropout_layer(self.s_enc(d))  # (batch_size, m, hidR)

        # Fusion of embeddings
        feat_emb = temp_emb + t_enc + s_enc  # (batch_size, m, hidR)

        # Optional external resource integration
        if self.external_parameter is not None and index is not None:
            batch_ext = []
            zeros_mt = torch.zeros((self.m, self.m)).to(adj.device)
            for i in range(batch_size):
                offset = 20
                if i - offset >= 0:
                    idx = i - offset
                    batch_ext.append(self.external_parameter[index[i], :, :].unsqueeze(0))
                else:
                    batch_ext.append(zeros_mt.unsqueeze(0))
            extra_info = torch.cat(batch_ext, dim=0)
            external_info = F.relu(torch.mul(self.external_parameter, extra_info))
        else:
            external_info = 0

        # Graph learning
        d_mat = torch.bmm(torch.sum(adj, dim=1).unsqueeze(2), torch.sum(adj, dim=1).unsqueeze(1)) # (batch_size, m, m)
        d_mat = torch.sigmoid(torch.mul(self.d_gate, d_mat))
        spatial_adj = torch.mul(d_mat, adj)
        learned_adj = self.graphGen(feat_emb)

        if external_info != 0:
            adj = learned_adj + spatial_adj + external_info
        else:
            adj = learned_adj + spatial_adj

        laplace_adj = getLaplaceMat(batch_size, self.m, adj)

        # GNN layers
        node_state = feat_emb
        node_state_list = []
        for layer in self.GNNBlocks:
            node_state = self.dropout_layer(layer(node_state, laplace_adj))
            node_state_list.append(node_state)

        node_state = torch.cat(node_state_list, dim=-1) # (batch_size, m, hidR * n_layer)
        node_state = torch.cat([node_state, feat_emb], dim=-1) # (batch_size, m, hidR*(n_layer+1))

        # Final prediction
        res = self.output(node_state) # (batch_size, m, num_timesteps_output)
        res = res.transpose(1, 2)    # (batch_size, num_timesteps_output, m)
        return res

########################################
# Data Loading and Training
########################################
csv_path = '../data/merged_nhs_covid_data.csv'
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"The specified CSV file does not exist: {csv_path}")

data = pd.read_csv(csv_path, parse_dates=['date'])
data = load_and_correct_data(data, REFERENCE_COORDINATES)

dataset = NHSRegionDataset(data, num_timesteps_input=num_timesteps_input, num_timesteps_output=num_timesteps_output)
print(f"Total samples in dataset: {len(dataset)}")

regions = dataset.regions.tolist()
latitudes = [data[data['areaName'] == region]['latitude'].iloc[0] for region in regions]
longitudes = [data[data['areaName'] == region]['longitude'].iloc[0] for region in regions]

adj = compute_geographic_adjacency(regions, latitudes, longitudes).to(device)
print("Adjacency matrix:")
print(adj)

model = EpiGNN(
    num_nodes=num_nodes,
    num_features=num_features,
    num_timesteps_input=num_timesteps_input,
    num_timesteps_output=num_timesteps_output,
    k=k,
    hidA=hidA,
    hidR=hidR,
    hidP=hidP,
    n_layer=n_layer,
    dropout=dropout,
    device=device
).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(RANDOM_SEED)
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

best_val_loss = float('inf')
early_stopping_patience = 10
patience_counter = 0
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    epoch_train_loss = 0
    for batch_X, batch_Y in train_loader:
        batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
        optimizer.zero_grad()

        batch_size_current = batch_X.size(0)
        batch_adj = adj.unsqueeze(0).repeat(batch_size_current, 1, 1)

        pred = model(batch_X, batch_adj)
        loss = criterion(pred, batch_Y)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()

    avg_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for batch_X, batch_Y in val_loader:
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
            batch_size_current = batch_X.size(0)
            batch_adj = adj.unsqueeze(0).repeat(batch_size_current, 1, 1)

            pred = model(batch_X, batch_adj)
            loss = criterion(pred, batch_Y)
            epoch_val_loss += loss.item()

    avg_val_loss = epoch_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_epignn_adapted_model.pth')
        print("Model checkpoint saved.")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break

# Plot training and validation loss
plt.figure(figsize=(10,6))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training and Validation Loss Curves')
plt.legend()
plt.show()

# Load the best model for evaluation
model.load_state_dict(torch.load('best_epignn_adapted_model.pth'))
model.eval()

test_loss = 0
all_preds = []
all_actuals = []

with torch.no_grad():
    for batch_X, batch_Y in test_loader:
        batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
        batch_size_current = batch_X.size(0)
        batch_adj = adj.unsqueeze(0).repeat(batch_size_current, 1, 1)

        pred = model(batch_X, batch_adj)
        loss = criterion(pred, batch_Y)
        test_loss += loss.item()
        all_preds.append(pred.cpu())
        all_actuals.append(batch_Y.cpu())

avg_test_loss = test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")

# Concatenate predictions and actual values
all_preds = torch.cat(all_preds, dim=0)
all_actuals = torch.cat(all_actuals, dim=0)

# Visualize predictions vs. actuals for a few samples
num_plots = 3
for i in range(min(num_plots, all_preds.size(0))):
    sample_pred = all_preds[i].numpy()
    sample_actual = all_actuals[i].numpy()

    plt.figure(figsize=(12,8))
    for node_idx, region in enumerate(regions):
        plt.plot(range(num_timesteps_output), sample_actual[:, node_idx], label=f'Actual - {region}')
        plt.plot(range(num_timesteps_output), sample_pred[:, node_idx], '--', label=f'Predicted - {region}')

    plt.xlabel('Future Timestep')
    plt.ylabel('COVID Occupied MV Beds')
    plt.title(f'Sample {i+1}: Predictions vs Actual')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

# Save final model
torch.save(model.state_dict(), 'epignn_adapted_model_final.pth')
print("Final model saved as 'epignn_adapted_model_final.pth'.")

# Compute additional metrics (MAE, R²)
preds_flat = all_preds.view(-1, num_nodes).numpy()
actuals_flat = all_actuals.view(-1, num_nodes).numpy()

mae_per_node = mean_absolute_error(actuals_flat, preds_flat, multioutput='raw_values')
r2_per_node = r2_score(actuals_flat, preds_flat, multioutput='raw_values')

for idx, region in enumerate(regions):
    print(f"Region: {region}, MAE: {mae_per_node[idx]:.4f}, R2 Score: {r2_per_node[idx]:.4f}")

# Visualize learned adjacency matrix for a test sample
example_X, _ = next(iter(test_loader))
example_X = example_X.to(device)
with torch.no_grad():
    batch_size_current = example_X.size(0)
    batch_adj = adj.unsqueeze(0).repeat(batch_size_current, 1, 1)
    example_pred = model(example_X, batch_adj)

# Compute learned adjacency from the backbone embeddings
learned_adj = model.graphGen(model.backbone(example_X.permute(0,3,1,2))).cpu().numpy()[0]

plt.figure(figsize=(8,6))
sns.heatmap(learned_adj, annot=True, fmt=".2f", cmap='viridis', xticklabels=regions, yticklabels=regions)
plt.title('Learned Adjacency Matrix (Test Sample)')
plt.xlabel('Regions')
plt.ylabel('Regions')
plt.show()


Using device: cuda
Unique latitudes: [52.1766 51.4923 52.7269 54.5378 53.8981 51.4341 50.8112]
Unique longitudes: [ 0.425889 -0.30866  -1.45821  -2.18039  -2.65755  -0.96957  -3.63343 ]
Total samples in dataset: 875
Adjacency matrix:
tensor([[1., 1., 1., 0., 1., 1., 0.],
        [1., 1., 1., 0., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 1., 1., 0.],
        [1., 1., 1., 0., 1., 1., 1.],
        [0., 1., 1., 0., 0., 1., 1.]], device='cuda:0')
Training samples: 612
Validation samples: 131
Test samples: 132


RuntimeError: mat1 and mat2 shapes cannot be multiplied (224x120 and 80x7)