In [81]:
import torch
import torch_geometric
import torch_geometric.transforms as T
from glob import glob
import xarray as xr
import pandas as pd

In [82]:
torch.cuda.is_available()

False

In [83]:
class ssp_data():
    def __init__(self, n=39) -> None:
        self.n = n
        self.init_edge_list(n)
        self.y_file = 'data\\tas_scenario_245\\tas_mon_mod_ssp245_192_000.nc'
        self.x_file_list = [item for item in glob('data\\tas_scenario_245\\tas_mon_mod_ssp245_192_*.nc') if item not in [self.y_file]][0 : self.n]
        self.create_df()
        self.x_tensor = self.create_tensors(self.x).T
        self.y_tensor = self.create_tensors(self.y)
        self.data = torch_geometric.data.Data(x=self.x_tensor, edge_index=self.edge_index.t().contiguous(), y=self.y_tensor)
        self.split_data()
        # self.mini_graphs()

    def init_edge_list(self, n):
        self.edge_index = []
        for i in range(n):
            for j in range(n):
                if i != j:
                    self.edge_index.append([i, j])
        self.edge_index = torch.tensor(self.edge_index, dtype=torch.long)

    def create_df(self):
        self.x = pd.DataFrame()
        i = 1
        for filename in self.x_file_list:
            print('Processing', filename)
            self.x[f'model_{i}'] = self.create_vector(filename)
            i += 1
        
        self.y = self.create_vector(self.y_file)

    def create_vector(self, filename):
        data = xr.open_dataset(filename)
        tas_vector = data.to_dataframe().reset_index()['tas'][0:5000]
        return tas_vector
    
    def get_device(self):
        if torch.cuda.is_available():
            device = torch.device('cuda:0')
        else:
            device = torch.device('cpu') # don't have GPU 
        return device

    def create_tensors(self, df):
        device = self.get_device()
        return torch.from_numpy(df.values).float().to(device)
    
    def mini_graphs(self):
        df = self.x
        df['x_tensor'] = df.apply(lambda row: torch.tensor(row.values.flatten()), axis=1)
        df['y'] = self.y
        df['y_tensor'] = df['y'].apply(lambda y: torch.tensor(y))
        df['data_obj'] = df.apply(lambda row: torch_geometric.data.Data(x=df['x_tensor'], edge_index=self.edge_index.t().contiguous(), y=df['y_tensor']), axis=1)
        self.batch_graphs = df['data_obj']
    
    def split_data(self):
        transform = T.Compose([T.RandomNodeSplit(num_test=1000, num_val=1000)])
        self.data = transform(self.data) 


In [84]:
ssp_data = ssp_data(n=39)

Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_001.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_002.nc


In [None]:
ssp_data.data.validate()

True

In [None]:
ssp_data.x_tensor.T

tensor([[243.8454, 241.6632, 247.5146],
        [235.3529, 237.1550, 234.1314],
        [223.9142, 221.3933, 226.5480],
        ...,
        [217.1660, 219.4037, 223.5101],
        [217.9901, 214.1802, 225.3063],
        [220.2681, 217.7450, 218.0767]])

In [None]:
ssp_data.data

Data(x=[3, 5000], edge_index=[2, 6], y=[5000], train_mask=[3], val_mask=[3], test_mask=[3])

In [None]:
ssp_data.data.num_node_features

39

In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # self.edge_weight = torch.nn.Parameter(torch.ones(ssp_data.data.num_edges))
        self.conv1 = GCNConv(ssp_data.data.num_node_features, 16)
        self.conv2 = GCNConv(16, 5000)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)#, torch.minimum(self.edge_weight.abs(),torch.ones(data.num_edges)))
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        # print(x, edge_index, torch.minimum(self.edge_weight.abs(),torch.ones(data.num_edges)))
        x = self.conv2(x, edge_index)#, torch.minimum(self.edge_weight.abs(),torch.ones(data.num_edges)))

        return x


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = ssp_data.data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(100):
    optimizer.zero_grad()
    out = model(data)
    loss = F.mse_loss(out, data.y)
    print(epoch, loss)
    loss.backward()
    optimizer.step()

  loss = F.mse_loss(out, data.y)


0 tensor(51451.1133, grad_fn=<MseLossBackward0>)
1 tensor(51244.1250, grad_fn=<MseLossBackward0>)
2 tensor(51248.2383, grad_fn=<MseLossBackward0>)
3 tensor(51229.3945, grad_fn=<MseLossBackward0>)
4 tensor(10014.2461, grad_fn=<MseLossBackward0>)
5 tensor(165138.4531, grad_fn=<MseLossBackward0>)
6 tensor(33700.7617, grad_fn=<MseLossBackward0>)
7 tensor(37168.6914, grad_fn=<MseLossBackward0>)
8 tensor(51220.0078, grad_fn=<MseLossBackward0>)
9 tensor(51217.5195, grad_fn=<MseLossBackward0>)
10 tensor(51214.7383, grad_fn=<MseLossBackward0>)
11 tensor(51211.7148, grad_fn=<MseLossBackward0>)
12 tensor(51208.5078, grad_fn=<MseLossBackward0>)
13 tensor(51205.1367, grad_fn=<MseLossBackward0>)
14 tensor(51201.6172, grad_fn=<MseLossBackward0>)
15 tensor(51197.9922, grad_fn=<MseLossBackward0>)
16 tensor(51194.2617, grad_fn=<MseLossBackward0>)
17 tensor(51190.4414, grad_fn=<MseLossBackward0>)
18 tensor(51186.5508, grad_fn=<MseLossBackward0>)
19 tensor(51182.5938, grad_fn=<MseLossBackward0>)
20 tensor

In [None]:
model.eval()
out = model(data)
mse = F.mse_loss(out, data.y)
print(f'MSE: {mse:.4f}')

MSE: 50824.9766


  mse = F.mse_loss(out, data.y)
