### 1 Initialization

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torch.utils.data import ConcatDataset
from torchinfo import summary

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, median_absolute_error
import numpy as np
import pandas as pd
from tqdm import tqdm

import scipy.special
import scipy

import time
import pickle
import warnings
import os
import glob
import random
from itertools import chain, combinations

cm = 1/2.54

In [2]:
import my_model as mm
import my_data_preprocess as mdp

In [3]:
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

### 2 Data processing

In [4]:
valid_site, labels, durations, lon_lat = mdp.get_valid_site(threshold = 1)
var_name = mdp.get_var_name()

In [5]:
maxlen = 1
for key in durations:
    if len(durations[key])>maxlen:
        maxlen = len(durations[key])
        name = key

print(name, maxlen)

US-PFa 6805


In [6]:
durations['BE-Vie']

Unnamed: 0,TIMESTAMP
199,1996-07-18
200,1996-07-19
201,1996-07-20
202,1996-07-21
203,1996-07-22
...,...
6934,2014-12-26
6935,2014-12-27
6936,2014-12-28
6937,2014-12-29


#### Extract era5 data

In [7]:
mean_log, std_log = mdp.get_statistics()

In [8]:
restart = True
while restart:
    restart = False
    for site_name in valid_site:
        t000=time.time()
        # print(site_name, len(labels[site_name]))
        try:
            mdp.get_nc_one_site(site_name, grid = 10)
            # print('***site %s saved, consuming : %.3f mins****************' % (site_name, (time.time() - t000) / 60))
            # print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime()))
        except Exception as e:
            print("failed saving", site_name)
            time.sleep(10) 
            restart = True
            break
print('Completed')

Completed


#### CCM calculation

In [9]:
cell_states = mdp.get_ccm()

#### Torch dataset & dataloader

In [10]:
### one site test
site_name = valid_site[0]
x, y = mdp.data_read(site_name)
train_dataset = mdp.TimeseriesDataset(x, y, seq_len=15, pre_hor = 7)
train_loader = mdp.DataLoader(train_dataset, batch_size = 128, shuffle = True)

for i, [inp_data, out_fea] in enumerate(train_loader):
    print(i, inp_data.shape, out_fea.shape)
    break

0 torch.Size([128, 15, 12, 10, 10]) torch.Size([128])


In [11]:
### all site set
### here we extract the earliest 70% observations of all sites as train set with shuffle
### sebsequently, we extract the next 15% observations as validation set and the last 15% observations as test set
### seperating train, validatation, test set one site by one site and grouping them together.

trainset_list = []
valid_list = []
testset_list = []
for i, site_name in enumerate(valid_site):
    x, y = mdp.data_read(site_name) 
    cell_state = cell_states[site_name]
    train_len = int(len(y) * 0.7)

    val_len = int(len(y) * 0.15)
    
    site_dataset = mdp.TimeseriesDataset(x, y, cell_state, seq_len=15, pre_hor = 7)

    trainset_list.append( Subset(site_dataset, range(train_len)))
    valid_list.append( Subset(site_dataset, range(train_len, train_len+val_len)))
    testset_list.append( Subset(site_dataset, range(train_len+val_len, len(y))))
    

train_set =  ConcatDataset(trainset_list)
valid_set =  ConcatDataset(valid_list)
test_set =  ConcatDataset(testset_list)
all_set = ConcatDataset([train_set, valid_set, test_set])


In [12]:
len(train_set)

260489

In [13]:
train_loader = DataLoader(train_set, batch_size = 512, shuffle = True, drop_last = False)
valid_loader = DataLoader(valid_set, batch_size = 512, shuffle = True, drop_last = False)
test_loader = DataLoader(test_set, batch_size = 512, shuffle = True, drop_last = False)

### 3 Model performance

In [None]:
def my_train(model, learn_rate=5e-4, batch_size=64, r=1, max_epoch = 50):
    from torch.utils.tensorboard import SummaryWriter
    import numpy as np

    criteria = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

    r = f'{type(model).__name__}-{str(r)}'
    writer = SummaryWriter(f'log/{r}')

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False)
    valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, drop_last=False)

    t0 = time.time()
    print(f'Training starts {time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}, model = {r}')


    best_valid_loss = np.inf
    patience_counter = 0
    best_model_state = None

    for e in range(max_epoch):  # Set a high maximum epoch count
        model.train()
        loss_log = 0
        for i, [x, cell_state, y] in enumerate(train_loader):
            # Move data to device
            x = x.to(device)
            cell_state = cell_state.to(device)
            y = y.to(device)

            # Clear gradients
            optimizer.zero_grad()

            # Forward pass and compute loss
            out = model(x, cell_state).squeeze()
            loss = criteria(out, y)
            loss_log += loss.item()

            # Backward pass and update parameters
            loss.backward()
            optimizer.step()

        # Validation phase
        valid_loss_log = 0
        model.eval()
        with torch.no_grad():
            for j, [x, cell_state, y] in enumerate(valid_loader):
                valid_out = model(x.to(device), cell_state.to(device)).squeeze()
                valid_loss = criteria(valid_out, y.to(device))
                valid_loss_log += valid_loss.item()

        avg_train_loss = loss_log / len(train_loader)
        avg_valid_loss = valid_loss_log / len(valid_loader)

        writer.add_scalar("Loss/valid", avg_valid_loss, e + 1)
        writer.add_scalar("Loss/train", avg_train_loss, e + 1)

        # Early stopping check
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            patience_counter = 0
            # Save the best model
            best_model_state = model.state_dict()
        else:
            patience_counter += 1


        # Print progress
        elapsed_time = time.time() - t0
        remaining_time = elapsed_time / (e + 1) * (1000 - (e + 1))  # Estimate remaining time
        print(f'epoch {e + 1:02d} [{int(elapsed_time // 60):02d}:{int(elapsed_time % 60):02d} < {int(remaining_time // 60):02d}:{int(remaining_time % 60):02d}, {elapsed_time / (e + 1):.2f} s/it]')

    # Save the final model
    ckpt = f'checkpoint/{r}.ckpt'
    torch.save(best_model_state, ckpt)
    writer.flush()
    writer.close()

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [19]:
torch.cuda.empty_cache()

In [None]:
batch_size = 512 
learn_rate = 1e-4 

In [None]:
for m in range(10):
    model = mm.ResRec().to(device)
    # model.load_state_dict(torch.load('checkpoint/ResAttCauRec512-1e-4.ckpt'))
    my_train(model = model, learn_rate = learn_rate, batch_size = batch_size, r = 'retest-'+str(m), max_epoch = 50, )

Training starts 2025-04-10 17:03:11, model = ResRec-retest-0
epoch 01 [02:32 < 124:31, 152.47 s/it]
epoch 02 [05:04 < 121:44, 152.17 s/it]
epoch 03 [07:36 < 119:10, 152.13 s/it]
epoch 04 [10:08 < 116:33, 152.03 s/it]
epoch 05 [12:40 < 114:00, 152.01 s/it]
epoch 06 [15:11 < 111:26, 151.96 s/it]
epoch 07 [17:43 < 108:54, 151.96 s/it]
epoch 08 [20:15 < 106:21, 151.94 s/it]
epoch 09 [22:47 < 103:48, 151.91 s/it]
epoch 10 [25:18 < 101:14, 151.87 s/it]
epoch 11 [27:50 < 98:41, 151.83 s/it]
epoch 12 [30:21 < 96:08, 151.80 s/it]
epoch 13 [32:53 < 93:36, 151.79 s/it]
epoch 14 [35:25 < 91:05, 151.81 s/it]
epoch 15 [37:57 < 88:35, 151.86 s/it]
epoch 16 [40:30 < 86:03, 151.88 s/it]
epoch 17 [43:00 < 83:29, 151.81 s/it]
epoch 18 [45:31 < 80:56, 151.75 s/it]
epoch 19 [48:02 < 78:22, 151.70 s/it]
epoch 20 [50:32 < 75:49, 151.64 s/it]
epoch 21 [53:03 < 73:16, 151.60 s/it]
epoch 22 [55:34 < 70:43, 151.55 s/it]
epoch 23 [58:04 < 68:10, 151.49 s/it]
epoch 24 [60:34 < 65:37, 151.45 s/it]
epoch 25 [63:05 <

In [None]:
model.cell_state_in(cell_state.to(device))
summary(model, input_data = b.to(device), depth = 1)

Layer (type:depth-idx)                                  Output Shape              Param #
ResAttCauRec                                            [512, 1]                  --
├─ResAttNet: 1-1                                        [7680, 12]                693,720
├─CauRecNet: 1-2                                        [512, 1]                  146,241
Total params: 839,961
Trainable params: 839,961
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 327.43
Input size (MB): 36.86
Forward/backward pass size (MB): 5980.27
Params size (MB): 3.36
Estimated Total Size (MB): 6020.49