In [1]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        '''
        nn.Linear(n,m) is a module that creates single layer feed forward network with n inputs and m output. 
        Mathematically, this module is designed to calculate the linear equation Ax = b where x is input, b is output, A is weight. 
        This is where the name 'Linear' came from. Essentially calculates weights for the model
        '''
        self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

X is the feature vector i.e training data. Each node in the graph has a feature vector and 

In [None]:
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)
        self.linear1 = torch.nn.Linear(100,1)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.linear1(x) # This probably isnt needed just return x
        return x

In [10]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import xarray as xr
import pandas as pd
import glob

data = xr.open_dataset('data\\tas_scenario_245\\tas_mon_mod_ssp245_192_000.nc')
data_df = data.to_dataframe().reset_index()
data_df
# plot = data_df.loc[(data_df['lat'] == -89.375) & (data_df['lon'] == 0), ['tas']]

# data_list = [Data(plot)]
# single = DataLoader(data_list)


Unnamed: 0,lat,lon,time,height,tas
0,-89.375,0.000,1850-01-16 12:00:00,2.0,252.504318
1,-89.375,0.000,1850-02-15 00:00:00,2.0,236.253250
2,-89.375,0.000,1850-03-16 12:00:00,2.0,217.723312
3,-89.375,0.000,1850-04-16 00:00:00,2.0,213.946365
4,-89.375,0.000,1850-05-16 12:00:00,2.0,210.434448
...,...,...,...,...,...
83275771,89.375,358.125,2100-08-16 12:00:00,2.0,274.949463
83275772,89.375,358.125,2100-09-16 00:00:00,2.0,274.228302
83275773,89.375,358.125,2100-10-16 12:00:00,2.0,272.529755
83275774,89.375,358.125,2100-11-16 00:00:00,2.0,271.059570


In [9]:
single

<torch_geometric.loader.dataloader.DataLoader at 0x1e29e027be0>

In [15]:
import torch
from torch_geometric.data import Data

edge_index = []
x = []

for i in range(40):
    for j in range(40):
        if i != j:
            edge_index.append([i, j])

edge_index = torch.tensor(edge_index, dtype=torch.long)

for filename in glob.glob('data\\tas_scenario_245\\tas_mon_mod_ssp245_192_*.nc'):
    print('Processing', filename)
    data = xr.open_dataset(filename)
    tas_vector = data.to_dataframe().reset_index()['tas']
    x.append(tas_vector)

x = torch.tensor(x, dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())

Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_000.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_001.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_002.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_003.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_004.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_005.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_006.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_007.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_008.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_009.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_010.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_011.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_012.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_013.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_014.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192

In [17]:
data.validate(raise_on_error=True)

ValueError: 'edge_index' contains larger indices than the number of nodes (39) in 'Data' (found 39)