In [6]:
import torch
import torch_geometric
import torch_geometric.transforms as T
from glob import glob
import xarray as xr
import pandas as pd
import matplotlib as plt
import cftime

In [101]:
class ssp_data():
    def __init__(self, n=39) -> None:
        self.n = n
        self.raw = {}
        self.init_edge_list(n)
        self.y_file = 'data\\tas_scenario_245\\tas_mon_mod_ssp245_192_000.nc'
        self.x_file_list = [item for item in glob('data\\tas_scenario_245\\tas_mon_mod_ssp245_192_*.nc') if item not in [self.y_file]][0 : self.n]
        self.create_df()
        self.x = self.x.drop(columns=['time', 'lat', 'lon', 'tas_9', 'tas_8'])
        self.x_tensor = self.create_tensors(self.x).T
        self.y_tensor = self.create_tensors(self.y)
        self.data = torch_geometric.data.Data(x=self.x_tensor, edge_index=self.edge_index.t().contiguous(), y=self.y_tensor)
        # self.split_data()
        # self.mini_graphs()

    def init_edge_list(self, n):
        self.edge_index = []
        for i in range(8):
            for j in range(8):
                if i != j:
                    self.edge_index.append([i, j])
        self.edge_index = torch.tensor(self.edge_index, dtype=torch.long)

    def create_df(self):
        self.x = pd.DataFrame()
        i = 1
        for filename in self.x_file_list:
            print('Processing', filename)
            if self.x.empty:
                self.x = self.create_vector(filename)
            else:
                self.x[f'tas_{i}'] = self.create_vector(filename)['tas']
                # self.x = self.x.merge(self.create_vector(filename), how='inner', on=['time', 'lat', 'lon'], suffixes=(None, f'_{i}'))
            
            # print(self.x)
            i += 1
        
        self.y = self.create_vector(self.y_file)['tas']

    def create_vector(self, filename):
        data = xr.open_dataset(filename)
        try:
            datetimeindex = data.indexes['time'].to_datetimeindex()
            data['time'] = datetimeindex
        except AttributeError:
            pass
        self.raw[filename] = data
        df = self.raw[filename].to_dataframe().reset_index()
        # for col in ['lat', 'lon', 'tas']:
            # df[col] = df[col].round(2)
        # self.raw_df[filename] = df
        # cftime_1960 = cftime.DatetimeNoLeap(1960, 1, 1, 12, 0, 0, 0, has_year_zero=True)
        # cftime_1970 = cftime.DatetimeNoLeap(1970, 12, 30, 12, 0, 0, 0, has_year_zero=True)
        # print(df)
        df = df.query('lat >= -44 & lat <= -12 & lon >= 288 & lon <= 336')
        ret = df.loc[(df['time'].dt.year > 1960) & (df['time'].dt.year < 1970), ['time', 'lat', 'lon', 'tas']]

        # print(ret, ret.dtypes)
        return ret

    
    def get_device(self):
        if torch.cuda.is_available():
            device = torch.device('cuda:0')
        else:
            device = torch.device('cpu') # don't have GPU 
        return device

    def create_tensors(self, df):
        device = self.get_device()
        return torch.from_numpy(df.values).float().to(device)
    
    def mini_graphs(self):
        df = self.x
        df['x_tensor'] = df.apply(lambda row: torch.tensor(row.values.flatten()), axis=1)
        df['y'] = self.y
        df['y_tensor'] = df['y'].apply(lambda y: torch.tensor(y))
        df['data_obj'] = df.apply(lambda row: torch_geometric.data.Data(x=df['x_tensor'], edge_index=self.edge_index.t().contiguous(), y=df['y_tensor']), axis=1)
        self.batch_graphs = df['data_obj']
    
    def split_data(self):
        transform = T.Compose([T.RandomNodeSplit(num_test=1000, num_val=1000)])
        self.data = transform(self.data) 


In [102]:
ssp_data = ssp_data(n=10)

Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_001.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_002.nc
Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_003.nc


  datetimeindex = data.indexes['time'].to_datetimeindex()


Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_004.nc


  datetimeindex = data.indexes['time'].to_datetimeindex()


Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_005.nc


  datetimeindex = data.indexes['time'].to_datetimeindex()


Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_006.nc


  datetimeindex = data.indexes['time'].to_datetimeindex()


Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_007.nc


  datetimeindex = data.indexes['time'].to_datetimeindex()


Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_008.nc


  datetimeindex = data.indexes['time'].to_datetimeindex()


Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_009.nc


  datetimeindex = data.indexes['time'].to_datetimeindex()


Processing data\tas_scenario_245\tas_mon_mod_ssp245_192_010.nc


  datetimeindex = data.indexes['time'].to_datetimeindex()


In [103]:
ssp_data.x

Unnamed: 0,tas,tas_2,tas_3,tas_4,tas_5,tas_6,tas_7,tas_10
21862428,286.400269,280.862762,286.771332,285.407745,283.856293,284.691101,284.222473,282.871368
21862429,285.905518,281.284271,281.414459,284.798828,283.674713,284.747620,284.870758,283.852448
21862430,284.728638,283.605560,281.929413,282.059021,281.492584,282.372467,282.833374,283.768555
21862431,280.983887,276.310425,279.923767,279.065002,278.729950,278.684601,278.864471,278.012451
21862432,279.186432,272.765076,278.642639,275.679047,275.733490,275.873016,275.867615,278.417572
...,...,...,...,...,...,...,...,...
35817127,295.364685,295.758026,297.652039,296.523560,295.708649,295.954987,296.025940,295.984528
35817128,295.254761,295.588348,297.691101,296.417755,295.422699,295.788300,295.935730,295.922882
35817129,295.446686,296.078278,298.125854,297.322083,295.799286,296.264893,296.387390,296.601593
35817130,296.196991,296.965240,298.843842,298.587280,296.752350,297.104401,297.196533,297.200256


In [110]:
ssp_data.data

Data(x=[8, 70200], edge_index=[2, 56], y=[70200])

In [88]:
for col in ssp_data.x.columns:
    print(ssp_data.x[col].isna().sum())

0
0
0
0
0
0
0
0
0
0
70200
70200
0


In [57]:
ssp_data.x[ssp_data.x.isna().any(axis=1)]

Unnamed: 0,time,lat,lon,tas,tas_2,tas_3,tas_4,tas_5,tas_6,tas_7
37,1964-02-15 12:00:00,-43.125,288.750,286.610809,279.847229,,,,,
85,1968-02-15 12:00:00,-43.125,288.750,286.647003,281.439758,,,,,
145,1964-02-15 12:00:00,-43.125,290.625,289.057434,284.443146,,,,,
193,1968-02-15 12:00:00,-43.125,290.625,289.101013,286.768890,,,,,
253,1964-02-15 12:00:00,-43.125,292.500,292.429688,288.975159,,,,,
...,...,...,...,...,...,...,...,...,...,...
69961,1968-02-15 12:00:00,-13.125,331.875,297.861481,300.113892,,,,,
70021,1964-02-15 12:00:00,-13.125,333.750,297.736603,299.800446,,,,,
70069,1968-02-15 12:00:00,-13.125,333.750,297.684784,300.062622,,,,,
70129,1964-02-15 12:00:00,-13.125,335.625,297.670563,299.548645,,,,,


In [106]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.edge_weight = torch.nn.Parameter(torch.ones(ssp_data.data.num_edges))
        self.conv1 = GCNConv(ssp_data.data.num_node_features, 16)
        self.conv2 = GCNConv(16, ssp_data.data.num_node_features)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index, torch.minimum(self.edge_weight.abs(),torch.ones(data.num_edges)))
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        # print(x, edge_index, torch.minimum(self.edge_weight.abs(),torch.ones(data.num_edges)))
        x = self.conv2(x, edge_index, torch.minimum(self.edge_weight.abs(),torch.ones(data.num_edges)))

        return x


In [111]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = ssp_data.data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(1000):
    optimizer.zero_grad()
    out = model(data)
    loss = F.mse_loss(out, data.y)
    print(epoch, loss)
    loss.backward()
    optimizer.step()

  loss = F.mse_loss(out, data.y)


0 tensor(85091.0156, grad_fn=<MseLossBackward0>)
1 tensor(7053266.5000, grad_fn=<MseLossBackward0>)
2 tensor(344167.1875, grad_fn=<MseLossBackward0>)
3 tensor(752990.7500, grad_fn=<MseLossBackward0>)
4 tensor(59927.9883, grad_fn=<MseLossBackward0>)
5 tensor(85049.5469, grad_fn=<MseLossBackward0>)
6 tensor(85049.6875, grad_fn=<MseLossBackward0>)
7 tensor(85049.4141, grad_fn=<MseLossBackward0>)
8 tensor(85048.8047, grad_fn=<MseLossBackward0>)
9 tensor(85047.9141, grad_fn=<MseLossBackward0>)
10 tensor(85046.7734, grad_fn=<MseLossBackward0>)
11 tensor(85045.3984, grad_fn=<MseLossBackward0>)
12 tensor(85043.8438, grad_fn=<MseLossBackward0>)
13 tensor(85042.1016, grad_fn=<MseLossBackward0>)
14 tensor(85040.2031, grad_fn=<MseLossBackward0>)
15 tensor(85038.1641, grad_fn=<MseLossBackward0>)
16 tensor(85036.0078, grad_fn=<MseLossBackward0>)
17 tensor(85033.7188, grad_fn=<MseLossBackward0>)
18 tensor(85031.3281, grad_fn=<MseLossBackward0>)
19 tensor(85028.8281, grad_fn=<MseLossBackward0>)
20 ten

In [112]:
model.edge_weight

Parameter containing:
tensor([1.0473, 1.0473, 1.0473, 1.0473, 1.0473, 1.0473, 1.0473, 0.6880, 0.6878,
        0.6856, 0.6853, 0.6861, 0.6859, 0.6874, 0.5059, 0.5015, 0.5029, 0.4975,
        0.5038, 0.4981, 0.5012, 0.0572, 0.0477, 0.0594, 0.0574, 0.0490, 0.0608,
        0.0670, 0.2679, 0.2581, 0.2678, 0.2587, 0.2601, 0.2630, 0.2696, 0.0978,
        0.0873, 0.0991, 0.0876, 0.0894, 0.0926, 0.0994, 0.9416, 0.9417, 0.9417,
        0.9416, 0.9416, 0.9417, 0.9415, 0.9291, 0.9291, 0.9292, 0.9291, 0.9291,
        0.9291, 0.9290], requires_grad=True)

In [108]:
model.eval()
out = model(data)
mse = F.mse_loss(out, data.y)
print(f'MSE: {mse:.4f}')

MSE: 84940.0234


  mse = F.mse_loss(out, data.y)


In [120]:
import networkx as nx
import

edge_index = ssp_data.data.edge_index
x = ssp_data.data.x

data = torch_geometric.data.Data(x=x, edge_index=edge_index)
g = torch_geometric.utils.to_networkx(data, to_undirected=True)
nx.draw(g)


TypeError: '_AxesStack' object is not callable

<Figure size 640x480 with 0 Axes>

In [117]:
ssp_data.data.edge_index

tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,
         3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
         6, 7, 7, 7, 7, 7, 7, 7],
        [1, 2, 3, 4, 5, 6, 7, 0, 2, 3, 4, 5, 6, 7, 0, 1, 3, 4, 5, 6, 7, 0, 1, 2,
         4, 5, 6, 7, 0, 1, 2, 3, 5, 6, 7, 0, 1, 2, 3, 4, 6, 7, 0, 1, 2, 3, 4, 5,
         7, 0, 1, 2, 3, 4, 5, 6]])