# Preprocessing SODA and Training GNNs

(Simplified, without Comments)

by Ding

For exploratory steps and comments, please see [this notebook](https://github.com/ding05/GNN_CNN_MHW_Forecasting_EEs/blob/main/preprocessing_c.ipynb).

In [1]:
!pip install geopandas

import numpy as np
from netCDF4 import Dataset
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
from geopandas import GeoDataFrame
from shapely.geometry import Point

Collecting geopandas
  Downloading geopandas-0.10.2-py2.py3-none-any.whl (1.0 MB)
[?25l[K     |▎                               | 10 kB 40.8 MB/s eta 0:00:01[K     |▋                               | 20 kB 18.4 MB/s eta 0:00:01[K     |█                               | 30 kB 15.1 MB/s eta 0:00:01[K     |█▎                              | 40 kB 14.0 MB/s eta 0:00:01[K     |█▋                              | 51 kB 7.6 MB/s eta 0:00:01[K     |██                              | 61 kB 7.6 MB/s eta 0:00:01[K     |██▎                             | 71 kB 8.4 MB/s eta 0:00:01[K     |██▌                             | 81 kB 9.3 MB/s eta 0:00:01[K     |██▉                             | 92 kB 9.4 MB/s eta 0:00:01[K     |███▏                            | 102 kB 7.5 MB/s eta 0:00:01[K     |███▌                            | 112 kB 7.5 MB/s eta 0:00:01[K     |███▉                            | 122 kB 7.5 MB/s eta 0:00:01[K     |████▏                           | 133 kB 7.5 MB/s eta 0

In [2]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

Mounted at /gdrive


In [3]:
%%bash
cp -a '/gdrive/MyDrive/soda_331_pt_l5.nc' '/content/'
cp -a '/gdrive/MyDrive/sst_anomaly.nc' '/content/'

In [4]:
soda = xr.open_dataset('soda_331_pt_l5.nc', decode_times=False)
soda

In [5]:
soda_array = soda.to_array(dim='temp')
soda_smaller = soda_array[:,:,:,::5,::5].to_dataset(dim="temp")
soda_smaller

In [6]:
start_year = 1980
end_year = 2016

In [7]:
start_month = (start_year - 1980) * 12
end_month = (end_year - 1980) * 12

soda_sst = np.zeros((end_month-start_month,1,66,144))
soda_sst[:,:,:,:] = soda_smaller.variables['temp'][0:end_month-start_month,:,:,:]

In [8]:
soda_sst = np.squeeze(soda_sst, axis=1)

soda_sst_list = soda_sst.tolist()

months = list(range(0, 432))
monthly_average_all = []

for i in range(12):
  individual_month = months[i + start_month : end_month : 12]
  average = np.zeros((66,144))
  for j in range(len(individual_month)):
    average += soda_sst[individual_month[j]]
    # average_map += np.array(individual_month[j])
  monthly_average = average / len(individual_month)
  monthly_average_all.append(monthly_average)
  print("Month " + str(i+1) + " is appended.")

Month 1 is appended.
Month 2 is appended.
Month 3 is appended.
Month 4 is appended.
Month 5 is appended.
Month 6 is appended.
Month 7 is appended.
Month 8 is appended.
Month 9 is appended.
Month 10 is appended.
Month 11 is appended.
Month 12 is appended.


In [9]:
monthly_average_all_432 = []
monthly_average_all_432 = monthly_average_all
print(len(monthly_average_all))
print(len(monthly_average_all_432))

for i in range(432 - 12):
  monthly_average_all_432.append(monthly_average_all_432[i])

print(len(monthly_average_all_432))

soda_sst_anomaly_list = []

for i in range(432):
  soda_sst_anomaly_list.append(soda_sst[i] - monthly_average_all_432[i])

12
12
432


In [10]:
soda_sst_anomaly = np.array(soda_sst_anomaly_list)

soda_sst_anomaly.shape

(432, 66, 144)

------------------

In [11]:
soda_sst_anomaly_transposed = soda_sst_anomaly.transpose(1,2,0)
soda_sst_anomaly_flattened = soda_sst_anomaly_transposed.reshape(soda_sst_anomaly.shape[1] * soda_sst_anomaly.shape[2],432)
soda_sst_anomaly_flattened.shape

(9504, 432)

In [12]:
def dropna(arr, *args, **kwarg):
    assert isinstance(arr, np.ndarray)
    dropped=pd.DataFrame(arr).dropna(*args, **kwarg).values
    if arr.ndim==1:
        dropped=dropped.flatten()
    return dropped

soda_sst_anomaly_ocean_flattened = dropna(soda_sst_anomaly_flattened)
soda_sst_anomaly_ocean_flattened.shape

(6924, 432)

In [13]:
feature_matrix = soda_sst_anomaly_ocean_flattened

In [14]:
lons, lats = np.meshgrid(soda_smaller.longitude.values, soda_smaller.latitude.values)

soda_time_1 = soda_smaller.temp.isel(depth=0,time=240)

soda_time_1_lons, soda_time_1_lats = np.meshgrid(soda_time_1.longitude.values, soda_time_1.latitude.values)

soda_masked = soda_time_1.where(abs(soda_time_1_lons) + abs(soda_time_1_lats) > 0)
soda_masked

In [15]:
soda_masked.values.flatten()[soda_masked.notnull().values.flatten()]

len(soda_masked.values.flatten()[soda_masked.notnull().values.flatten()])

6924

In [16]:
print(soda_time_1_lons.flatten()[soda_masked.notnull().values.flatten()])
print(soda_time_1_lats.flatten()[soda_masked.notnull().values.flatten()])

[162.75 165.25 167.75 ... 352.75 355.25 357.75]
[-74.75 -74.75 -74.75 ...  87.75  87.75  87.75]


In [17]:
from sklearn.metrics.pairwise import haversine_distances

lons_ocean = soda_time_1_lons.flatten()[soda_masked.notnull().values.flatten()]
lons_ocean = lons_ocean[::]

lats_ocean = soda_time_1_lats.flatten()[soda_masked.notnull().values.flatten()]
lats_ocean = lats_ocean[::]

lons_ocean *= np.pi/180
lats_ocean *= np.pi/180

points_ocean = np.concatenate([np.expand_dims(lats_ocean.flatten(),-1), np.expand_dims(lons_ocean.flatten(),-1)],-1)

distance_ocean = 6371*haversine_distances(points_ocean)

In [18]:
distance_ocean_diag = distance_ocean
distance_ocean_diag[distance_ocean_diag==0] = 1

distance_ocean_recip = np.reciprocal(distance_ocean_diag)

distance_ocean_recip.shape

(6924, 6924)

In [19]:
adjacency_matrix = distance_ocean_recip

In [20]:
lead_month = 0

feature_matrix = feature_matrix[:,:len(feature_matrix[0])-lead_month:]
adjacency_matrix = adjacency_matrix

In [21]:
import os
import json
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m
cuda:0


In [22]:
node_feats = torch.tensor(np.expand_dims(feature_matrix, axis=0)).float()
adj_matrix = torch.tensor(np.expand_dims(adjacency_matrix, axis=0)).float()

print(node_feats.shape)
print(adj_matrix.shape)

torch.Size([1, 6924, 432])
torch.Size([1, 6924, 6924])


In [23]:
class GCNLayer(nn.Module):

    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection = nn.Linear(c_in, c_out)

    def forward(self, node_feats, adj_matrix):
        """
        Inputs:
            node_feats - Tensor with node features of shape [batch_size, num_nodes, c_in]
            adj_matrix - Batch of adjacency matrices of the graph. If there is an edge from i to j, adj_matrix[b,i,j]=1 else 0.
                         Supports directed edges by non-symmetric matrices. Assumes to already have added the identity connections.
                         Shape: [batch_size, num_nodes, num_nodes]
        """
        # Num neighbours = number of incoming edges
        num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats)
        node_feats = node_feats / num_neighbours
        return node_feats

In [24]:
print("Node features:\n", node_feats)
print("\nAdjacency matrix:\n", adj_matrix)

Node features:
 tensor([[[-7.9552e-02, -7.8261e-02,  4.1625e-02,  ..., -3.3746e-03,
           8.1046e-03, -4.8462e-02],
         [ 1.1985e+00,  3.5482e-01,  1.2477e-01,  ..., -1.4804e-02,
          -9.4055e-03, -1.1167e-01],
         [ 1.7171e+00,  1.1768e+00,  1.3464e-01,  ..., -1.2377e-04,
           1.9625e-03, -8.1020e-02],
         ...,
         [-9.9387e-02, -9.5504e-02, -9.8592e-02,  ...,  8.6997e-02,
           6.8804e-02,  6.4023e-02],
         [-1.0047e-01, -9.3609e-02, -9.5575e-02,  ...,  8.7748e-02,
           7.1659e-02,  6.7239e-02],
         [-9.4762e-02, -8.9556e-02, -9.3895e-02,  ...,  8.8551e-02,
           7.4358e-02,  7.0678e-02]]])

Adjacency matrix:
 tensor([[[1.0000e+00, 1.3677e-02, 6.8402e-03,  ..., 5.3864e-05,
          5.3872e-05, 5.3880e-05],
         [1.3677e-02, 1.0000e+00, 1.3677e-02,  ..., 5.3859e-05,
          5.3864e-05, 5.3872e-05],
         [6.8402e-03, 1.3677e-02, 1.0000e+00,  ..., 5.3855e-05,
          5.3859e-05, 5.3864e-05],
         ...,
       

In [25]:
temp_list = []
for i in range(0, feature_matrix.shape[1]):
  element = [0] * feature_matrix.shape[1]
  element[i] = 1
  element = [float(item) for item in element] # Convert the type into float.
  temp_list.append(element)

layer = GCNLayer(c_in=feature_matrix.shape[1], c_out=feature_matrix.shape[1])
layer.projection.weight.data = torch.tensor(temp_list)

layer.projection.bias.data = torch.Tensor([0] * feature_matrix.shape[1])

In [26]:
with torch.no_grad():
    out_feats = layer(node_feats, adj_matrix)

print("Adjacency matrix")
print(np.round(np.array(adj_matrix), decimals=3))
print("Input features")
print(np.round(np.array(node_feats), decimals=3))
print("Output features")
print(np.round(np.array(out_feats), decimals=3))

Adjacency matrix
[[[1.    0.014 0.007 ... 0.    0.    0.   ]
  [0.014 1.    0.014 ... 0.    0.    0.   ]
  [0.007 0.014 1.    ... 0.    0.    0.   ]
  ...
  [0.    0.    0.    ... 1.    0.092 0.046]
  [0.    0.    0.    ... 0.092 1.    0.092]
  [0.    0.    0.    ... 0.046 0.092 1.   ]]]
Input features
[[[-0.08  -0.078  0.042 ... -0.003  0.008 -0.048]
  [ 1.198  0.355  0.125 ... -0.015 -0.009 -0.112]
  [ 1.717  1.177  0.135 ... -0.     0.002 -0.081]
  ...
  [-0.099 -0.096 -0.099 ...  0.087  0.069  0.064]
  [-0.1   -0.094 -0.096 ...  0.088  0.072  0.067]
  [-0.095 -0.09  -0.094 ...  0.089  0.074  0.071]]]
Output features
[[[ 0.296  0.161  0.13  ...  0.097  0.094  0.011]
  [ 0.9    0.37   0.17  ...  0.091  0.085 -0.019]
  [ 1.143  0.759  0.176 ...  0.097  0.09  -0.006]
  ...
  [-0.085 -0.065 -0.066 ...  0.17   0.134  0.142]
  [-0.085 -0.065 -0.065 ...  0.17   0.135  0.143]
  [-0.084 -0.064 -0.065 ...  0.171  0.136  0.144]]]


In [27]:
print(soda_time_1_lons.flatten()[soda_masked.notnull().values.flatten()])
print(soda_time_1_lats.flatten()[soda_masked.notnull().values.flatten()])

lons_smaller = soda_time_1_lons.flatten()[soda_masked.notnull().values.flatten()]
lats_smaller = soda_time_1_lats.flatten()[soda_masked.notnull().values.flatten()]

for i in range(len(lats_smaller)):
  if lats_smaller[i] > -39 and lats_smaller[i] < -34:
    print("The position: " + str(i) + "; the latitude: " + str(lats_smaller[i]) + "; the longitude: " + str(lons_smaller[i]))

[162.75 165.25 167.75 ... 352.75 355.25 357.75]
[-74.75 -74.75 -74.75 ...  87.75  87.75  87.75]
The position: 1880; the latitude: -37.25; the longitude: 0.25
The position: 1881; the latitude: -37.25; the longitude: 2.75
The position: 1882; the latitude: -37.25; the longitude: 5.25
The position: 1883; the latitude: -37.25; the longitude: 7.75
The position: 1884; the latitude: -37.25; the longitude: 10.25
The position: 1885; the latitude: -37.25; the longitude: 12.75
The position: 1886; the latitude: -37.25; the longitude: 15.25
The position: 1887; the latitude: -37.25; the longitude: 17.75
The position: 1888; the latitude: -37.25; the longitude: 20.25
The position: 1889; the latitude: -37.25; the longitude: 22.75
The position: 1890; the latitude: -37.25; the longitude: 25.25
The position: 1891; the latitude: -37.25; the longitude: 27.75
The position: 1892; the latitude: -37.25; the longitude: 30.25
The position: 1893; the latitude: -37.25; the longitude: 32.75
The position: 1894; the la

In [28]:
print("The position: 2079; the latitude: -34.75; the longitude: 177.75, which one point in Bay of Plenty")

The position: 2079; the latitude: -34.75; the longitude: 177.75, which one point in Bay of Plenty


In [50]:
gnn_output = np.round(np.array(out_feats), decimals=2)
gnn_input = np.round(np.array(node_feats), decimals=2)

In [51]:
gnn_output_bop = gnn_output[0][2079]
gnn_input_bop = gnn_input[0][2079]

In [52]:
from sklearn.metrics import mean_squared_error

gnn_mse_bop = mean_squared_error(gnn_output_bop, gnn_input_bop)
print(gnn_mse_bop)

0.08179954


------------------

In [53]:
soda_sst_anomaly.shape

(432, 66, 144)

In [54]:
gnn_input_bop

array([-0.15,  0.17, -0.49, -1.13, -0.54, -0.69, -1.48, -1.35, -0.81,
       -0.41, -0.27, -0.43,  0.64,  0.6 ,  0.45,  0.58, -0.02, -0.23,
        0.03,  0.27,  0.12,  0.18, -0.69,  0.12,  0.26, -0.65, -0.17,
       -0.24, -0.9 , -0.57, -0.02, -0.02, -0.1 , -0.5 , -0.32, -0.36,
       -0.72, -0.97,  0.17,  0.36, -0.25, -0.35, -0.07, -0.2 , -0.17,
       -0.14, -0.19, -1.02, -0.88, -0.76,  0.07,  0.67,  0.08,  0.28,
        0.81,  0.43,  0.78,  0.25,  1.07,  1.29,  1.11, -0.17,  0.12,
       -0.21,  0.07, -0.17,  0.44,  0.49,  0.28,  0.35, -0.16, -0.2 ,
        0.62,  0.27,  0.53,  0.23,  0.24, -0.18, -0.32, -0.25, -0.2 ,
        0.1 , -0.21, -0.83, -0.75, -0.17, -0.37, -0.73, -0.47,  0.16,
        0.38,  0.46,  0.34,  0.37,  0.31,  0.21, -0.2 ,  0.78,  0.6 ,
        0.07,  0.16,  0.76,  0.9 ,  0.6 ,  0.74,  0.74,  0.53,  1.12,
        0.65, -0.25,  0.02,  0.11,  0.56,  0.43,  0.66,  0.7 ,  0.7 ,
        0.49,  0.83,  0.03,  0.35,  0.96,  1.29,  0.58,  0.36,  0.1 ,
       -0.09,  0.24,

In [34]:
soda_smaller

In [78]:
soda_bop_sst = np.zeros((end_month-start_month,1))

soda_bop_sst[:,:] = soda_smaller.variables['temp'][0:end_month-start_month,:,15,71]

In [81]:
soda_bop_sst = np.squeeze(soda_bop_sst)

In [82]:
soda_bop_sst_monthly_average = []

for i in range(12):
  monthly_sst_anomaly = soda_bop_sst[i::12]
  soda_bop_sst_monthly_average.append(sum(monthly_sst_anomaly)/len(monthly_sst_anomaly))

soda_bop_sst_anomaly = []

for i in range(len(soda_bop_sst)):
  j = i % 12
  soda_bop_sst_anomaly.append(soda_bop_sst[i] - soda_bop_sst_monthly_average[j])

In [84]:
soda_bop_sst_anomaly = np.array(soda_bop_sst_anomaly)

In [85]:
np.round(soda_bop_sst_anomaly, decimals=2)

array([-0.35, -0.3 , -1.51, -1.31, -0.81, -0.92, -0.76, -0.86, -0.47,
       -0.1 , -0.22, -0.46,  0.61,  0.34,  0.63,  0.58,  0.38,  0.76,
        0.11, -0.38, -0.45, -0.6 , -0.94,  0.34,  0.2 , -0.77,  0.3 ,
        0.09, -0.28, -0.15, -0.33, -0.72,  0.19, -0.11, -0.38, -0.76,
       -1.86, -1.7 , -0.21,  0.18, -0.45, -0.96, -0.88, -0.61, -0.82,
       -0.43, -0.11, -0.48, -0.93, -1.14, -0.35,  0.27,  0.31,  0.85,
        0.6 ,  0.73,  0.33,  0.28,  0.68,  1.2 ,  1.18,  0.07,  0.26,
        0.06, -0.19, -0.19,  0.42,  0.21, -0.2 ,  0.6 , -0.3 ,  0.18,
        0.55, -0.1 ,  0.39, -0.45, -0.35, -0.17, -0.9 , -1.24, -1.09,
       -0.19, -0.51, -0.86, -0.46,  0.07, -0.58, -0.74, -1.37, -0.94,
       -0.89, -0.8 , -0.5 , -0.08,  0.3 , -0.13, -0.14,  0.26, -0.61,
       -1.21, -0.55, -0.03, -0.77, -1.01, -0.46, -0.37, -0.2 ,  0.85,
        0.79,  0.1 ,  0.48, -0.46,  0.11, -0.16, -0.21,  0.56,  0.9 ,
        0.53,  0.69, -0.32, -0.12,  1.11,  0.87,  0.33, -0.19,  0.25,
        0.1 ,  0.5 ,

In [108]:
soda_sst_anomaly.shape

soda_sst_anomaly = np.nan_to_num(soda_sst_anomaly, nan=0)

soda_sst_anomaly_CNN = np.expand_dims(soda_sst_anomaly, axis=1)

print(soda_sst_anomaly_CNN.shape)
print(soda_bop_sst_anomaly.shape)

(432, 1, 66, 144)
(432,)


In [110]:
train_data = []

for i in range(len(soda_sst_anomaly_CNN)):
  train_data.append((soda_sst_anomaly_CNN[i], soda_bop_sst_anomaly[i]))

In [114]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as Data
from torch.autograd import Variable

EPOCH = 20
BATCH_SIZE = 8
LR = 0.00001

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=False)
train_all_loader = Data.DataLoader(dataset=train_data, batch_size=len(train_data), shuffle=False)

In [112]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16, 3, 1, 1),
            nn.Tanh(),
        )
        self.out = nn.Linear(38016, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output, x

cnn = CNN().double()

print(cnn)

CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Tanh()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Tanh()
  )
  (out): Linear(in_features=38016, out_features=1, bias=True)
)


In [113]:
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.MSELoss() 

In [115]:
for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):
        b_x = Variable(x)
        b_y = Variable(y)

        output = cnn(b_x)[0]
        output = output.reshape(-1) # To avoid different sizes
        loss = loss_func(output, b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    for step, (x, y) in enumerate(train_all_loader):
        c_x = Variable(x)
        c_y = Variable(y)

    #for step, (x, y) in enumerate(test_loader):
        #d_x = Variable(x)
        #d_y = Variable(y)

        if step % 100 == 0:
            
            pred_train_y, last_layer = cnn(c_x)
            train_mse = loss_func(pred_train_y, c_y)

            #pred_test_y, last_layer = cnn(d_x)
            #test_mse = loss_func(pred_test_y, d_y)

            print('Epoch: ', epoch, '| trainig loss: %.4f' % loss.data, '| training MSE: %.4f' % train_mse)

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch:  0 | trainig loss: 0.0486 | training MSE: 0.4247
Epoch:  1 | trainig loss: 0.0627 | training MSE: 0.4700
Epoch:  2 | trainig loss: 0.0694 | training MSE: 0.5052
Epoch:  3 | trainig loss: 0.0698 | training MSE: 0.5256
Epoch:  4 | trainig loss: 0.0667 | training MSE: 0.5384
Epoch:  5 | trainig loss: 0.0625 | training MSE: 0.5490
Epoch:  6 | trainig loss: 0.0584 | training MSE: 0.5596
Epoch:  7 | trainig loss: 0.0546 | training MSE: 0.5703
Epoch:  8 | trainig loss: 0.0514 | training MSE: 0.5812
Epoch:  9 | trainig loss: 0.0487 | training MSE: 0.5921
Epoch:  10 | trainig loss: 0.0464 | training MSE: 0.6028
Epoch:  11 | trainig loss: 0.0444 | training MSE: 0.6131
Epoch:  12 | trainig loss: 0.0427 | training MSE: 0.6229
Epoch:  13 | trainig loss: 0.0411 | training MSE: 0.6321
Epoch:  14 | trainig loss: 0.0395 | training MSE: 0.6407
Epoch:  15 | trainig loss: 0.0380 | training MSE: 0.6487
Epoch:  16 | trainig loss: 0.0364 | training MSE: 0.6561
Epoch:  17 | trainig loss: 0.0348 | train