# Import packages

In [1]:
# Reading/Writing Data
import os
import glob
import numpy as np
import math 
import matplotlib.pyplot as plt
import datetime
import pandas as pd
from io import StringIO
import re
import pickle
import json

# Pytorch
import torch 
import torch.nn as nn
from torch.autograd import gradcheck
from torch.utils.data import DataLoader, random_split

# Self-Defined Package
from PGDataset import PGDataset
from PGNetwork import DynamicFCNetwork
from LSMTransformer import LSMLSTM

# 忽略 ParserWarning 警告
import warnings
warnings.filterwarnings("ignore", category=pd.errors.ParserWarning)
import wandb

# Configurations
`config` contains hyper-parameters for training and the path to save your model.

In [2]:
constant_config = {
    'seed': 11611801,      
    'test_ratio': 0,
    'valid_ratio': 0.2,   
    'n_epochs': 1000,                
    'train_batch_size': 128, 
    'valid_batch_size': 128,
    
    'learning_rate': 35e-4,
    'step_size': 20,
    'gamma': 0.2,
    'weight_decay': 0.0025,
    'warm_step': 11,
    'early_stop': 50, 
    'hidden_size': 64,
    'loss_decrease_threshold': 1e-6,
    'param_factor': 0.001,
    
    # 目前只训练near sites的cmfd点
    'standardization': True,
    'soil_layer_num': [0, 4, 5, 6, 7],
    'evaluate_layer': [0, 4, 5],
    'var_param_list': ['MAXSMC'],
    'seq_length': 365,
    'output': 1,
    'root': 'YOUR_ROOT\\Parameter_Generator',
    'model_save_dir': 'YOUR_ROOT\\Parameter_Generator\\GENERATOR'
}

In [3]:
sweep_config = {
    'method': 'random'
    }

metric = {
    'name': 'loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

sweep_config['parameters'] = {}

# 常量型超参数
sweep_config['parameters'].update({
    'valid_batch_size': {'value': 128},
    'seed': {'value': 11611801},      
    'valid_ratio': {'value': 0.2}, 
    'n_epochs': {'value': 1000},
    'soil_layer_num': {'value': [0, 4, 5, 6, 7]},
    'evaluate_layer': {'value': [0, 4, 5]},
    'seq_length': {'value': 365},
    'output': {'value': 1},
    'var_param_list': {'value': ['MAXSMC']},
    'root': {'value': 'YOUR_ROOT\\Parameter_Generator'},
    'model_save_dir': {'value': 'YOUR_ROOT\\Parameter_Generator\\GENERATOR'},
})
    
# 离散型超参数
sweep_config['parameters'].update({
    'standardization': {
        'values': [True, False]
    },
    'step_size': {
        'values': [10, 20, 40]
    },
    'train_batch_size': {
        'values': [32, 64, 128, 256]
    },
    'early_stop': {
        'values': [50, 100, 150, 200]
    },
    'start_hidden_size': {
        'values': [32, 64, 128, 256]
    },
    'end_hidden_size': {
        'values': [8, 16, 32]
    },
})

    
# 连续型超参数
sweep_config['parameters'].update({
    'learning_rate': {
        'distribution': 'uniform',
        'min': 1e-6,
        'max': 5e-2
      },
    'weight_decay': {
        'distribution': 'uniform',
        'min': 1e-4,
        'max': 1e-1
      },
    'gamma': {
        'distribution': 'uniform',
        'min': 1e-1,
        'max': 8e-1,
      },
    'loss_decrease_threshold': {
        'distribution': 'uniform',
        'min': 1e-6,
        'max': 1e-4,
    },
    'param_factor': {
        'distribution': 'uniform',
        'min': 1e-6,
        'max': 1,
    }
})

# Some Utility Functions

In [4]:
def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

# Dataloader

In [5]:
def collate_fn(batch):
    gen_pkgs = np.asarray([gen_pkg for gen_pkg, _, _, _ in batch])
    surr_pkgs = np.asarray([surr_pkg for _, surr_pkg, _, _ in batch])
    label_pkgs = np.asarray([label_pkg for _, _, label_pkg, _ in batch])
    meta_pkgs =  np.asarray([meta_pkg for _, _, _, meta_pkg in batch])
    return torch.tensor(gen_pkgs, dtype=torch.double), torch.tensor(surr_pkgs, dtype=torch.double), torch.tensor(label_pkgs, dtype=torch.double), torch.tensor(meta_pkgs, dtype=torch.double)

In [6]:
def create_dataloader(config):    
    # 创建数据集
    dataset = PGDataset(config)
    
    # 计算训练集和测试集的长度
    train_len = int(len(dataset) * (1-config['valid_ratio']))
    valid_len = len(dataset) - train_len

    # 使用 random_split 函数进行划分
    train_dataset, valid_dataset = random_split(dataset, [train_len, valid_len])

    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=config['train_batch_size'], drop_last=True, shuffle=True, pin_memory=False, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=config['valid_batch_size'], drop_last=False, shuffle=True, pin_memory=False, collate_fn=collate_fn)

    print(len(dataset))
    print('training size', len(train_dataset))
    print('validing size', len(valid_dataset))
    
    return dataset, train_loader, valid_loader

In [7]:
dataset, train_loader, valid_loader = create_dataloader(constant_config)
print(dataset[0][0].shape)
print(dataset[0][1].shape)
print(dataset[0][2].shape)

612
training size 489
validing size 123
torch.Size([24])
torch.Size([365, 12])
torch.Size([365, 3])


# Trainer and Tester

In [8]:
def trainer(train_loader, valid_loader, surrogate, generator, device, config=None, wandb=None):
    if wandb != None:
        # 初始化config和criterion
        config = wandb.config
        
    criterion =  nn.MSELoss() 

    #初始化optimizer和lr_scheduler
    optimizer = torch.optim.Adam(generator.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) 
    schedulerRP = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=config['step_size'], factor=config['gamma'], min_lr=1e-8, threshold=config['loss_decrease_threshold'])
#     schedulerEXP = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['gamma'])
#     schedulerSTEP = torch.optim.lr_scheduler.StepLR(optimizer, gamma=config['gamma'], step_size=config['step_size'])
      
    # table = wandb.Table(columns=["epoch", "epoch_loss"])
    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0
    surrogate = surrogate.double()
    generator = generator.double()
    
    # print("***************************start to train*****************************")
    columns = ['epoch', 'param']
    param_table = wandb.Table(columns=columns)
    for epoch in range(n_epochs):
        generator.train()
        train_each_batch_loss = []
        train_each_batch_loss_total = 0
        for data_pkg in train_loader:
            gen_input = data_pkg[0].double().to(device)
            surr_input = data_pkg[1].double().to(device)
            label = data_pkg[2].double().to(device)
#             print(gen_input.shape, surr_input.shape, label.shape)
            optimizer.zero_grad()   
            
            param = generator(gen_input) * config['param_factor']
            mask = torch.zeros_like(surr_input)
            mask[:, :, 1] += param
            surr_input = torch.add(surr_input, mask)
            
            pred = surrogate(surr_input)
            loss = criterion(pred[:, :, :3], label)
        
            loss.backward()                     
            optimizer.step()
            
            train_each_batch_loss.append(loss.detach().item())
            train_each_batch_loss_total += loss.detach().item()
    
        generator.eval() 
        valid_each_batch_loss = []
        valid_each_batch_loss_total = 0
        for data_pkg in valid_loader:
            gen_input = data_pkg[0].double().to(device)
            surr_input = data_pkg[1].double().to(device)
            label = data_pkg[2].double().to(device)
            
            with torch.no_grad():
                gen_input = data_pkg[0].double().to(device)
                surr_input = data_pkg[1].double().to(device)
                label = data_pkg[2].double().to(device)

                param = generator(gen_input) * config['param_factor']
                mask = torch.zeros_like(surr_input)
                mask[:, :, 1] += param
                surr_input = torch.add(surr_input, mask)
            
                pred = surrogate(surr_input)
                loss = criterion(pred[:, :, :3], label)
            
            if epoch%1==0:
                for batch in range(param.shape[0]):
                    param_table.add_data(epoch, param[batch].detach().item())       
            valid_each_batch_loss.append(loss.detach().item())
            valid_each_batch_loss_total += loss.detach().item()
        
        current_lr = (optimizer.param_groups[0])['lr']
        train_indicator = train_each_batch_loss_total/len(train_loader)
        valid_indicator = valid_each_batch_loss_total/len(valid_loader)
        schedulerRP.step(valid_indicator)
#         schedulerEXP.step()
#         schedulerSTEP.step()

        
#         print("Epoch {}: LR: {:.8f}, Train Loss: {:.8f}, Valid Loss: {:.8f} for one layers".format(epoch, current_lr, train_indicator, valid_indicator))
        
        if np.abs(best_loss-valid_indicator)<1e-4:
            early_stop_count += 1
        else:
            best_loss = valid_indicator
            early_stop_count = 0
            best_generator = generator
        if early_stop_count >= config['early_stop']:
#             wandb.log({"termination_reason": "early_stop_count >= config['early_stop']"})
            break
        if epoch>50 and (valid_indicator-train_indicator)>0.15:  # 防止过拟合
#             wandb.log({"termination_reason": 'epoch>50 and (valid_indicator-train_indicator)>0.15'})
            break
        if epoch>50 and train_indicator>1.5:  # 防止收敛过慢
#             wandb.log({"termination_reason": 'epoch>50 and train_indicator>1.5'})
            break
    
        # wandb logging
        ############################################################################################
        wandb.log({'epoch': epoch, 
                   'lr': current_lr, 
                   'train_indicator': train_indicator, 
                   'valid_indicator': valid_indicator, 
                   'param_prediction': param_table
                  })
        ###########################################################################################

    return best_generator, generator, best_loss, valid_indicator

# Training process

In [9]:
def training(config=None):
    if config==None:
        # 初始化wandb
        ############################################################################################
        nowtime = datetime.datetime.now().strftime('%Y_%m_%d_%H%M%S')
        wandb.init(
          project='YOUR_PROJECT', 
          name=nowtime, 
          )
        config = wandb.config
        ############################################################################################
   
    # 初始化device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    
    # 加载数据集
    dataset, train_loader, valid_loader = create_dataloader(config)
    
    # 创建模型保存目录
    if os.path.exists(config['model_save_dir'])==False:
        os.makedirs(config['model_save_dir'], exist_ok=True)

    # 创建模型
    surrogate = LSMLSTM(config['seq_length'], 12, 64, len(config['soil_layer_num'])).to(device) 
    surrogate.load_state_dict(torch.load('YOUR_SURROGATE.ckpt'))
    
    hidden_sizes = [config['start_hidden_size'] // (2 ** i) for i in range(1+int(math.log2(config['start_hidden_size'] / config['end_hidden_size'])))]
    print(hidden_sizes)
    generator = DynamicFCNetwork(dataset[0][0].shape[0], hidden_sizes, 1).to(device) 
    print(generator)
    
    # 模型训练
    best_generator, last_generator, best_loss, last_loss = trainer(train_loader, valid_loader, surrogate, generator, device, wandb=wandb, config=config)
    
    # 保存模型
    if best_loss<0.015:
        best_save_name = os.path.join(config['model_save_dir'], nowtime + '_best.ckpt')
        last_save_name = os.path.join(config['model_save_dir'], nowtime + '_last.ckpt')
        
        torch.save(best_generator.state_dict(), best_save_name)
        torch.save(last_generator.state_dict(), last_save_name)
        
        arti_code = wandb.Artifact('ipynb', type='code')
        arti_code.add_file(os.path.join(config['root'], 'PG_TRAINING_UWC_WANDB.ipynb'))
        arti_code.add_file(os.path.join(config['root'], 'PGDataset.py'))
        arti_code.add_file(os.path.join(config['root'], 'PGNetwork.py'))
                                              
        # arti_model = wandb.Artifact('model', type='model')
        # arti_model.add_file(save_name)
        wandb.log_artifact(arti_code)
        # wandb.log_artifact(arti_model)
    wandb.finish()

In [10]:
# training(constant_config)

In [None]:
wandb.login()

In [None]:
sweep_id = wandb.sweep(sweep_config, project='YOUR_PROJECT')
print(sweep_id)

In [None]:
# # wandb.agent(project='YOUR_PROJECT', sweep_id='l3kbhtzo', function=training, count=1)
wandb.agent(sweep_id, training, count=100)