In [5]:
%run config.py
%run dataset.py
%run models.py

In [7]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from models import GNN
from dataset import MyGraphDataset, load_wave_data, normalise_data
from config import Config
import torch
import math
import pickle
from scipy.io import savemat

cfg = Config(
    snapshots=3,
    data_dir="", 
    neighbourhood_size=5,
    normalization="norm_01",
    train_period=(0, 1000),
    test_period=(1000, 1400),
    convolution_kernels = (64, 128),
    node_var_observ= ['hs', 'ub_bot', 'wlen', 'pwave_bot', 'tpeak', 'dirm'],
    node_var_target=['hs', 'ub_bot', 'wlen', 'pwave_bot', 'tpeak', 'dirm'],
    train_batch_size=4,
    forward_time=1,
    max_epoches=100,
    test_batch_size=3,
)

def save_data(data, filename):
    with open(filename, 'wb') as file:
        pickle.dump(data, file)   
    
def main(config:Config):
    trn_dl, test_dl = prepare_train_test_dataloaders(cfg)
    # setup the model and optimiser
    model = GNN(cfg)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
    loss_fn = torch.nn.MSELoss()
    
    r2_scores = []
    mse =[]
 
    # Create lists to store training and testing data
    #train_data = []
    test_data = []
    
    for epoch in range(cfg.max_epoches):
        for i, (data_x, y) in enumerate(trn_dl):
            optimizer.zero_grad()            #clar gradients from previous iteration

            # compute the loss for this batch
            # TODO: to wrap in a proper loss function
            tmp = model(data_x.x, data_x.edge_index)
            batch_pred = unbatch(tmp, data_x.batch)
            
            loss = 0
            for pi, yi in zip(batch_pred, y):
                loss += loss_fn(pi, yi)

            
            loss.backward()                   #gradients of the model's parameters are computed with respect to the loss using backpropagation
            optimizer.step()                  #The optimizer updates the model's parameters based on the computed gradients using the chosen optimization algorithm (e.g., Adam)
            
            # Calculate R-squared
            y_true_int = y.detach().numpy()
            y_true = y_true_int.reshape(-1, 6)
            y_pred = torch.cat(batch_pred).detach().numpy()
            ss_res = np.sum((y_true - y_pred)**2)
            ss_tot = np.sum((y_true - np.mean(y_true))**2)
            r2 = 1 - (ss_res / ss_tot)
            r2_scores.append(r2)
            
    
            print(f"epoch {epoch}, batch {i}: loss {loss.item():.5f}, R2 {r2:.5f}")
    
    # Test the model using test data
    model.eval()  # Set model to evaluation mode
    test_r2 = 0
    test_loss = 0
    for i, (data_x, y) in enumerate(test_dl):
        test_data_out = model(data_x.x, data_x.edge_index)
        batch_pred = unbatch(test_data_out, data_x.batch)
        
        y_true_int = y.detach().numpy()
        y_true = y_true_int.reshape(-1, 6)
        y_pred = torch.cat(batch_pred).detach().numpy()
        ss_res = np.sum((y_true - y_pred)**2)
        ss_tot = np.sum((y_true - np.mean(y_true))**2)
        r2 = 1 - (ss_res / ss_tot)
        test_r2 += r2
        mse = np.mean((y_true - y_pred)**2)
        test_loss += mse
        
         # Create dictionaries for lon, lat, time, and variable data
        data = load_wave_data(Config)
        lon = data['lon']
        lat = data['lat']
        time= data['t']
        lon_lat_time_data = {
            'lon': lon,  # Use data_x.node_index as an index to extract lon values
            'lat': lat,  # Use data_x.node_index as an index to extract lat values
            'time': time,  # Use data_x.time as an index to extract time values
        }
        
        variable_pred_data = {
            'hs': y_pred[:, 0],  # Adjust the index as needed
            'ub_bot': y_pred[:, 1],  # Adjust the index as needed
            'wlen': y_pred[:, 2],  # Adjust the index as needed
            'pwave_bot': y_pred[:, 3],  # Adjust the index as needed
            'tpeak': y_pred[:, 4],  # Adjust the index as needed
            'dirm' : y_pred[:, 5],
        }
        
        variable_true_data = {
            'hs': y_true[:, 0],  # Adjust the index as needed
            'ub_bot': y_true[:, 1],  # Adjust the index as needed
            'wlen': y_true[:, 2],  # Adjust the index as needed
            'pwave_bot': y_true[:, 3],  # Adjust the index as needed
            'tpeak': y_true[:, 4],  # Adjust the index as needed
            'dirm' : y_true[:, 5],
        }
        
        test_data.append({
            'lon_lat_time_data': lon_lat_time_data,
            'variable_pred_data': variable_pred_data,
            'variable_true_data': variable_true_data
        })
                
    test_r2 /= len(test_dl)    
    # Calculate average test loss (mean squared error)
    average_mse = test_loss / len(test_dl)
    # Calculate RMSE from average MSE
    rmse = math.sqrt(average_mse)

    print(f"Test R2: {test_r2:.5f}, loss {rmse:.5f}")     
    
    # Convert the lists to tensors if needed
    test_data = np.array(test_data)
    
    # Save training and testing data to separate files
    save_data(test_data, 'test.pkl')

    print("testing data saved.")
  
if __name__ == "__main__":
    main(cfg)



epoch 0, batch 0: loss 0.31040, R2 -0.33919
epoch 0, batch 1: loss 0.28469, R2 -0.20515
epoch 0, batch 2: loss 0.13278, R2 -0.14074
epoch 1, batch 0: loss 0.24465, R2 -0.03911
epoch 1, batch 1: loss 0.21357, R2 0.07718
epoch 1, batch 2: loss 0.11397, R2 0.03028
epoch 2, batch 0: loss 0.21774, R2 0.11463
epoch 2, batch 1: loss 0.15349, R2 0.32378
epoch 2, batch 2: loss 0.06740, R2 0.39455
epoch 3, batch 0: loss 0.13557, R2 0.40690
epoch 3, batch 1: loss 0.12073, R2 0.48895
epoch 3, batch 2: loss 0.07551, R2 0.36839
epoch 4, batch 0: loss 0.08789, R2 0.60945
epoch 4, batch 1: loss 0.11446, R2 0.52712
epoch 4, batch 2: loss 0.04817, R2 0.58897
epoch 5, batch 0: loss 0.07902, R2 0.65593
epoch 5, batch 1: loss 0.06749, R2 0.70473
epoch 5, batch 2: loss 0.05639, R2 0.55257
epoch 6, batch 0: loss 0.08315, R2 0.64410
epoch 6, batch 1: loss 0.06944, R2 0.69771
epoch 6, batch 2: loss 0.03500, R2 0.71087
epoch 7, batch 0: loss 0.06986, R2 0.70528
epoch 7, batch 1: loss 0.06552, R2 0.71231
epoch 7

epoch 9, batch 149: loss 0.03064, R2 0.92531
epoch 9, batch 150: loss 0.04074, R2 0.89079
epoch 9, batch 151: loss 0.04196, R2 0.89657
epoch 9, batch 152: loss 0.08150, R2 0.80253
epoch 9, batch 153: loss 0.08225, R2 0.84428
epoch 9, batch 154: loss 0.10709, R2 0.77475
epoch 9, batch 155: loss 0.05314, R2 0.85694
epoch 9, batch 156: loss 0.07940, R2 0.82284
epoch 9, batch 157: loss 0.04347, R2 0.89351
epoch 9, batch 158: loss 0.03977, R2 0.90049
epoch 9, batch 159: loss 0.03787, R2 0.90833
epoch 9, batch 160: loss 0.02090, R2 0.93696
epoch 9, batch 161: loss 0.07874, R2 0.80416
epoch 9, batch 162: loss 0.06430, R2 0.84322
epoch 9, batch 163: loss 0.04545, R2 0.88557
epoch 9, batch 164: loss 0.03405, R2 0.89709
epoch 9, batch 165: loss 0.02060, R2 0.94029
epoch 9, batch 166: loss 0.03426, R2 0.92610
epoch 9, batch 167: loss 0.04926, R2 0.89180
epoch 9, batch 168: loss 0.04607, R2 0.87959
epoch 9, batch 169: loss 0.04586, R2 0.88853
epoch 9, batch 170: loss 0.04610, R2 0.86293
epoch 9, b

epoch 10, batch 82: loss 0.06474, R2 0.82870
epoch 10, batch 83: loss 0.03837, R2 0.91325
epoch 10, batch 84: loss 0.07331, R2 0.86077
epoch 10, batch 85: loss 0.06936, R2 0.85607
epoch 10, batch 86: loss 0.01662, R2 0.95832
epoch 10, batch 87: loss 0.09331, R2 0.80457
epoch 10, batch 88: loss 0.04868, R2 0.88214
epoch 10, batch 89: loss 0.02303, R2 0.93099
epoch 10, batch 90: loss 0.02457, R2 0.91466
epoch 10, batch 91: loss 0.04049, R2 0.89730
epoch 10, batch 92: loss 0.07729, R2 0.82915
epoch 10, batch 93: loss 0.09579, R2 0.78506
epoch 10, batch 94: loss 0.04308, R2 0.87098
epoch 10, batch 95: loss 0.09551, R2 0.80184
epoch 10, batch 96: loss 0.04177, R2 0.88829
epoch 10, batch 97: loss 0.02115, R2 0.93381
epoch 10, batch 98: loss 0.04612, R2 0.89142
epoch 10, batch 99: loss 0.03908, R2 0.90397
epoch 10, batch 100: loss 0.02538, R2 0.93582
epoch 10, batch 101: loss 0.02688, R2 0.93399
epoch 10, batch 102: loss 0.06229, R2 0.82226
epoch 10, batch 103: loss 0.05452, R2 0.88122
epoch 

epoch 11, batch 11: loss 0.04884, R2 0.84530
epoch 11, batch 12: loss 0.05878, R2 0.86553
epoch 11, batch 13: loss 0.07667, R2 0.83382
epoch 11, batch 14: loss 0.04006, R2 0.88130
epoch 11, batch 15: loss 0.02955, R2 0.92197
epoch 11, batch 16: loss 0.03383, R2 0.92609
epoch 11, batch 17: loss 0.03070, R2 0.90542
epoch 11, batch 18: loss 0.02267, R2 0.91237
epoch 11, batch 19: loss 0.10257, R2 0.79141
epoch 11, batch 20: loss 0.05721, R2 0.83196
epoch 11, batch 21: loss 0.07436, R2 0.83231
epoch 11, batch 22: loss 0.10273, R2 0.81011
epoch 11, batch 23: loss 0.06862, R2 0.85613
epoch 11, batch 24: loss 0.04983, R2 0.88062
epoch 11, batch 25: loss 0.06566, R2 0.84194
epoch 11, batch 26: loss 0.10635, R2 0.76756
epoch 11, batch 27: loss 0.03600, R2 0.90047
epoch 11, batch 28: loss 0.07905, R2 0.84566
epoch 11, batch 29: loss 0.04256, R2 0.89460
epoch 11, batch 30: loss 0.05874, R2 0.87703
epoch 11, batch 31: loss 0.02541, R2 0.92228
epoch 11, batch 32: loss 0.02378, R2 0.91518
epoch 11, 

epoch 11, batch 192: loss 0.02737, R2 0.91727
epoch 11, batch 193: loss 0.04528, R2 0.89307
epoch 11, batch 194: loss 0.05013, R2 0.85957
epoch 11, batch 195: loss 0.01458, R2 0.95365
epoch 11, batch 196: loss 0.05424, R2 0.87957
epoch 11, batch 197: loss 0.04550, R2 0.90558
epoch 11, batch 198: loss 0.02275, R2 0.93266
epoch 11, batch 199: loss 0.02116, R2 0.92927
epoch 11, batch 200: loss 0.01998, R2 0.93961
epoch 11, batch 201: loss 0.07530, R2 0.82190
epoch 11, batch 202: loss 0.05704, R2 0.87181
epoch 11, batch 203: loss 0.06060, R2 0.86640
epoch 11, batch 204: loss 0.03649, R2 0.89561
epoch 11, batch 205: loss 0.12215, R2 0.71730
epoch 11, batch 206: loss 0.04164, R2 0.90702
epoch 11, batch 207: loss 0.02590, R2 0.93694
epoch 11, batch 208: loss 0.03691, R2 0.89567
epoch 11, batch 209: loss 0.03195, R2 0.91501
epoch 11, batch 210: loss 0.06710, R2 0.79597
epoch 11, batch 211: loss 0.08000, R2 0.81451
epoch 11, batch 212: loss 0.02413, R2 0.92362
epoch 11, batch 213: loss 0.03267,

epoch 12, batch 123: loss 0.06129, R2 0.86687
epoch 12, batch 124: loss 0.04097, R2 0.91190
epoch 12, batch 125: loss 0.05239, R2 0.87671
epoch 12, batch 126: loss 0.06304, R2 0.84294
epoch 12, batch 127: loss 0.04136, R2 0.90240
epoch 12, batch 128: loss 0.06450, R2 0.84641
epoch 12, batch 129: loss 0.05088, R2 0.88561
epoch 12, batch 130: loss 0.03045, R2 0.89666
epoch 12, batch 131: loss 0.03345, R2 0.93077
epoch 12, batch 132: loss 0.02940, R2 0.92018
epoch 12, batch 133: loss 0.04341, R2 0.89179
epoch 12, batch 134: loss 0.05958, R2 0.85619
epoch 12, batch 135: loss 0.04045, R2 0.89591
epoch 12, batch 136: loss 0.02224, R2 0.94019
epoch 12, batch 137: loss 0.03156, R2 0.91402
epoch 12, batch 138: loss 0.06401, R2 0.85140
epoch 12, batch 139: loss 0.01718, R2 0.95040
epoch 12, batch 140: loss 0.07265, R2 0.84124
epoch 12, batch 141: loss 0.03238, R2 0.89742
epoch 12, batch 142: loss 0.08862, R2 0.81039
epoch 12, batch 143: loss 0.03888, R2 0.90332
epoch 12, batch 144: loss 0.02763,

epoch 13, batch 53: loss 0.04212, R2 0.89705
epoch 13, batch 54: loss 0.08793, R2 0.83635
epoch 13, batch 55: loss 0.04839, R2 0.87831
epoch 13, batch 56: loss 0.02865, R2 0.92810
epoch 13, batch 57: loss 0.07113, R2 0.83357
epoch 13, batch 58: loss 0.03602, R2 0.89049
epoch 13, batch 59: loss 0.01633, R2 0.94550
epoch 13, batch 60: loss 0.04472, R2 0.87734
epoch 13, batch 61: loss 0.04058, R2 0.90125
epoch 13, batch 62: loss 0.03139, R2 0.88152
epoch 13, batch 63: loss 0.02195, R2 0.93942
epoch 13, batch 64: loss 0.02155, R2 0.93965
epoch 13, batch 65: loss 0.02727, R2 0.92441
epoch 13, batch 66: loss 0.03527, R2 0.90605
epoch 13, batch 67: loss 0.02419, R2 0.94373
epoch 13, batch 68: loss 0.04260, R2 0.89996
epoch 13, batch 69: loss 0.02405, R2 0.93838
epoch 13, batch 70: loss 0.12455, R2 0.73890
epoch 13, batch 71: loss 0.12795, R2 0.74445
epoch 13, batch 72: loss 0.04236, R2 0.90414
epoch 13, batch 73: loss 0.05534, R2 0.88447
epoch 13, batch 74: loss 0.09640, R2 0.78426
epoch 13, 

epoch 13, batch 233: loss 0.07227, R2 0.84289
epoch 13, batch 234: loss 0.06269, R2 0.86752
epoch 13, batch 235: loss 0.04506, R2 0.88352
epoch 13, batch 236: loss 0.09885, R2 0.78734
epoch 13, batch 237: loss 0.08574, R2 0.75443
epoch 13, batch 238: loss 0.08356, R2 0.80648
epoch 13, batch 239: loss 0.06305, R2 0.82028
epoch 13, batch 240: loss 0.02512, R2 0.92139
epoch 13, batch 241: loss 0.03303, R2 0.90724
epoch 13, batch 242: loss 0.05342, R2 0.88933
epoch 13, batch 243: loss 0.06439, R2 0.82166
epoch 13, batch 244: loss 0.02693, R2 0.92718
epoch 13, batch 245: loss 0.04538, R2 0.90051
epoch 13, batch 246: loss 0.06447, R2 0.85333
epoch 13, batch 247: loss 0.04450, R2 0.89783
epoch 13, batch 248: loss 0.04871, R2 0.86439
epoch 13, batch 249: loss 0.02592, R2 0.93115
epoch 14, batch 0: loss 0.04786, R2 0.87187
epoch 14, batch 1: loss 0.02663, R2 0.94319
epoch 14, batch 2: loss 0.02977, R2 0.91937
epoch 14, batch 3: loss 0.04526, R2 0.90085
epoch 14, batch 4: loss 0.03257, R2 0.9204

epoch 14, batch 164: loss 0.03437, R2 0.89640
epoch 14, batch 165: loss 0.07148, R2 0.83682
epoch 14, batch 166: loss 0.05038, R2 0.87553
epoch 14, batch 167: loss 0.05685, R2 0.88813
epoch 14, batch 168: loss 0.04860, R2 0.89819
epoch 14, batch 169: loss 0.08768, R2 0.81105
epoch 14, batch 170: loss 0.05752, R2 0.87925
epoch 14, batch 171: loss 0.05794, R2 0.83995
epoch 14, batch 172: loss 0.05037, R2 0.87792
epoch 14, batch 173: loss 0.11572, R2 0.76019
epoch 14, batch 174: loss 0.04752, R2 0.90234
epoch 14, batch 175: loss 0.04557, R2 0.87469
epoch 14, batch 176: loss 0.06780, R2 0.86581
epoch 14, batch 177: loss 0.02679, R2 0.93124
epoch 14, batch 178: loss 0.05447, R2 0.87935
epoch 14, batch 179: loss 0.03823, R2 0.90068
epoch 14, batch 180: loss 0.06347, R2 0.84041
epoch 14, batch 181: loss 0.02616, R2 0.94008
epoch 14, batch 182: loss 0.04922, R2 0.89336
epoch 14, batch 183: loss 0.06537, R2 0.85159
epoch 14, batch 184: loss 0.02896, R2 0.93365
epoch 14, batch 185: loss 0.10437,

epoch 15, batch 95: loss 0.05608, R2 0.85861
epoch 15, batch 96: loss 0.02412, R2 0.94470
epoch 15, batch 97: loss 0.09746, R2 0.79725
epoch 15, batch 98: loss 0.03431, R2 0.91113
epoch 15, batch 99: loss 0.05425, R2 0.88318
epoch 15, batch 100: loss 0.04949, R2 0.86084
epoch 15, batch 101: loss 0.04718, R2 0.89409
epoch 15, batch 102: loss 0.04210, R2 0.90588
epoch 15, batch 103: loss 0.04503, R2 0.88945
epoch 15, batch 104: loss 0.02662, R2 0.93210
epoch 15, batch 105: loss 0.03050, R2 0.91593
epoch 15, batch 106: loss 0.04540, R2 0.88738
epoch 15, batch 107: loss 0.07320, R2 0.81642
epoch 15, batch 108: loss 0.05570, R2 0.87840
epoch 15, batch 109: loss 0.03023, R2 0.91258
epoch 15, batch 110: loss 0.08543, R2 0.75744
epoch 15, batch 111: loss 0.04978, R2 0.88106
epoch 15, batch 112: loss 0.04919, R2 0.90040
epoch 15, batch 113: loss 0.03446, R2 0.92246
epoch 15, batch 114: loss 0.09232, R2 0.84776
epoch 15, batch 115: loss 0.04383, R2 0.89192
epoch 15, batch 116: loss 0.07623, R2 0

epoch 16, batch 24: loss 0.10809, R2 0.74935
epoch 16, batch 25: loss 0.05428, R2 0.85642
epoch 16, batch 26: loss 0.05976, R2 0.85736
epoch 16, batch 27: loss 0.05862, R2 0.86209
epoch 16, batch 28: loss 0.03781, R2 0.91155
epoch 16, batch 29: loss 0.07676, R2 0.81953
epoch 16, batch 30: loss 0.06866, R2 0.81684
epoch 16, batch 31: loss 0.01152, R2 0.95117
epoch 16, batch 32: loss 0.04964, R2 0.90108
epoch 16, batch 33: loss 0.04045, R2 0.87634
epoch 16, batch 34: loss 0.04758, R2 0.90310
epoch 16, batch 35: loss 0.01457, R2 0.94726
epoch 16, batch 36: loss 0.07439, R2 0.84066
epoch 16, batch 37: loss 0.04409, R2 0.89548
epoch 16, batch 38: loss 0.07717, R2 0.82282
epoch 16, batch 39: loss 0.02478, R2 0.91309
epoch 16, batch 40: loss 0.03334, R2 0.91331
epoch 16, batch 41: loss 0.07158, R2 0.83031
epoch 16, batch 42: loss 0.03729, R2 0.87985
epoch 16, batch 43: loss 0.07254, R2 0.81229
epoch 16, batch 44: loss 0.02074, R2 0.93620
epoch 16, batch 45: loss 0.06818, R2 0.86982
epoch 16, 

epoch 16, batch 204: loss 0.06340, R2 0.87013
epoch 16, batch 205: loss 0.05504, R2 0.86364
epoch 16, batch 206: loss 0.07064, R2 0.83546
epoch 16, batch 207: loss 0.05487, R2 0.88464
epoch 16, batch 208: loss 0.04580, R2 0.90247
epoch 16, batch 209: loss 0.01850, R2 0.93878
epoch 16, batch 210: loss 0.03168, R2 0.92278
epoch 16, batch 211: loss 0.05056, R2 0.89699
epoch 16, batch 212: loss 0.04505, R2 0.89136
epoch 16, batch 213: loss 0.02386, R2 0.92577
epoch 16, batch 214: loss 0.09948, R2 0.79540
epoch 16, batch 215: loss 0.05582, R2 0.85746
epoch 16, batch 216: loss 0.01727, R2 0.94871
epoch 16, batch 217: loss 0.02168, R2 0.94777
epoch 16, batch 218: loss 0.02380, R2 0.94430
epoch 16, batch 219: loss 0.03531, R2 0.92961
epoch 16, batch 220: loss 0.04178, R2 0.90645
epoch 16, batch 221: loss 0.04193, R2 0.89874
epoch 16, batch 222: loss 0.04441, R2 0.87348
epoch 16, batch 223: loss 0.03046, R2 0.89864
epoch 16, batch 224: loss 0.06723, R2 0.83443
epoch 16, batch 225: loss 0.04072,

epoch 17, batch 135: loss 0.05594, R2 0.88959
epoch 17, batch 136: loss 0.09950, R2 0.79009
epoch 17, batch 137: loss 0.04248, R2 0.88283
epoch 17, batch 138: loss 0.07193, R2 0.80801
epoch 17, batch 139: loss 0.04656, R2 0.90141
epoch 17, batch 140: loss 0.06592, R2 0.87191
epoch 17, batch 141: loss 0.04836, R2 0.88818
epoch 17, batch 142: loss 0.02379, R2 0.93222
epoch 17, batch 143: loss 0.03523, R2 0.90542
epoch 17, batch 144: loss 0.05072, R2 0.88211
epoch 17, batch 145: loss 0.02144, R2 0.91886
epoch 17, batch 146: loss 0.05433, R2 0.87816
epoch 17, batch 147: loss 0.04141, R2 0.88891
epoch 17, batch 148: loss 0.02264, R2 0.93291
epoch 17, batch 149: loss 0.04689, R2 0.86955
epoch 17, batch 150: loss 0.05365, R2 0.87615
epoch 17, batch 151: loss 0.04171, R2 0.88240
epoch 17, batch 152: loss 0.03007, R2 0.91406
epoch 17, batch 153: loss 0.03141, R2 0.91596
epoch 17, batch 154: loss 0.08669, R2 0.81374
epoch 17, batch 155: loss 0.03852, R2 0.87826
epoch 17, batch 156: loss 0.03394,

epoch 18, batch 65: loss 0.02525, R2 0.92943
epoch 18, batch 66: loss 0.04266, R2 0.90126
epoch 18, batch 67: loss 0.05552, R2 0.85146
epoch 18, batch 68: loss 0.06236, R2 0.82756
epoch 18, batch 69: loss 0.05917, R2 0.86126
epoch 18, batch 70: loss 0.09166, R2 0.76200
epoch 18, batch 71: loss 0.04167, R2 0.89249
epoch 18, batch 72: loss 0.02437, R2 0.92984
epoch 18, batch 73: loss 0.03971, R2 0.90838
epoch 18, batch 74: loss 0.03927, R2 0.88326
epoch 18, batch 75: loss 0.02728, R2 0.93010
epoch 18, batch 76: loss 0.02849, R2 0.93765
epoch 18, batch 77: loss 0.02653, R2 0.92596
epoch 18, batch 78: loss 0.06323, R2 0.85478
epoch 18, batch 79: loss 0.08372, R2 0.84000
epoch 18, batch 80: loss 0.03697, R2 0.90198
epoch 18, batch 81: loss 0.03623, R2 0.90810
epoch 18, batch 82: loss 0.02678, R2 0.92462
epoch 18, batch 83: loss 0.02667, R2 0.93696
epoch 18, batch 84: loss 0.02610, R2 0.94077
epoch 18, batch 85: loss 0.02866, R2 0.93125
epoch 18, batch 86: loss 0.03264, R2 0.91488
epoch 18, 