# Import packages

In [1]:
# Reading/Writing Data
import os
import glob
import numpy as np
import math 
import matplotlib.pyplot as plt
import datetime
import pandas as pd
from io import StringIO
import re
import pickle
import json

# Pytorch
import torch 
import torch.nn as nn
from torch.autograd import gradcheck
from torch.utils.data import DataLoader, random_split

# Self-Defined Package
from LSMDataset import LSMDataset
from LSMTransformer import LSMLSTM

# 忽略 ParserWarning 警告
import warnings
warnings.filterwarnings("ignore", category=pd.errors.ParserWarning)

import wandb

# Configurations
`config` contains hyper-parameters for training and the path to save your model.

In [2]:
constant_config = {
    'seed': 11611801,      
    'test_ratio': 0.2,
    'valid_ratio': 0.2,   
    'n_epochs': 1000,                
    'train_batch_size': 128, 
    'valid_batch_size': 128,
    'test_batch_size': 256,
    
    'learning_rate': 35e-4,
    'step_size': 20,
    'gamma': 0.2,
    'weight_decay': 0.0025,
    'warm_step': 11,
    'early_stop': 50, 
    'hidden_size': 64,
    'loss_decrease_threshold': 1e-4,
    
    # 目前只训练near sites的cmfd点
    'soil_layer_num': 8,
    'near_sites': [1, 2, 3, 4, 5, 6, 7, 8, 9],         # 文件夹下有LSM跑出来的数据和气象站点的数据
    'var_param_list': ['MAXSMC'],
    'seq_length': 365,
    'peroid': [2010, 2015], # 右开区间
    'output': 0,
    'root': 'YOUR_ROOT\\Transformer_TEST',
    'model_save_dir': 'YOUR_ROOT\\Transformer_TEST\\SURROGATE'
}

In [3]:
sweep_config = {
    'method': 'random'
    }

metric = {
    'name': 'loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

sweep_config['parameters'] = {}

# 常量型超参数
sweep_config['parameters'].update({
    'valid_batch_size': {'value': 900},
    'test_batch_size': {'value': 720},
    'seed': {'value': 11611801},      
    'test_ratio': {'value': 0.2}, 
    'valid_ratio': {'value': 0.2}, 
    'n_epochs': {'value': 2000},
    'soil_layer_num': {'value': 8},
    'seq_length': {'value': 365},
    'output': {'value': 0},
    'near_sites': {'value': [1, 2, 3, 4, 5, 6, 7, 8, 9]},         # 文件夹下有LSM跑出来的数据和气象站点的数据
    'var_param_list': {'value': ['MAXSMC']},
    'peroid':{'value':  [2010, 2015]}, # 右开区间
    'root': {'value':'YOUR_ROOT\\Transformer_TEST'},
    'model_save_dir': {'value':'YOUR_ROOT\\Transformer_TEST\\SURROGATE'},
})
    
# 离散型超参数
sweep_config['parameters'].update({
    'step_size': {
        'values': [10, 20, 40]
    },
    'train_batch_size': {
        'values': [64, 128, 256]
    },
    'early_stop': {
        'values': [50]
    },
    'hidden_size': {
        'values': [32, 64]
    },
})

    
# 连续型超参数
sweep_config['parameters'].update({
    'learning_rate': {
        'distribution': 'uniform',
        'min': 1e-6,
        'max': 4e-3
      },
    'weight_decay': {
        'distribution': 'uniform',
        'min': 1e-4,
        'max': 1e-2
      },
    'gamma': {
        'distribution': 'uniform',
        'min': 5e-1,
        'max': 8e-1,
      },
    'loss_decrease_threshold': {
        'distribution': 'uniform',
        'min': 1e-4,
        'max': 1e-3,
    },
})

# Some Utility Functions

In [4]:
def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

# Dataloader

In [5]:
def create_dataloader(config):    
    # 创建数据集
    dataset = LSMDataset(config)
    
    # 计算训练集和测试集的长度
    train_len = int(len(dataset) * (1-config['test_ratio']))
    test_len = len(dataset) - train_len

    # 使用 random_split 函数进行划分
    train_dataset, test_dataset = random_split(dataset, [train_len, test_len])

    # 计算训练集和验证集的长度
    valid_len = int(train_len * (config['valid_ratio']))
    train_len = train_len - valid_len

    # 使用 random_split 函数进行划分
    train_dataset, valid_dataset = random_split(train_dataset, [train_len, valid_len])

    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=config['train_batch_size'], drop_last=True, shuffle=True, pin_memory=False)#, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=config['valid_batch_size'], drop_last=False, shuffle=True, pin_memory=False)#, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=config['test_batch_size'], drop_last=False, shuffle=False)#, collate_fn=collate_fn)

    print(len(dataset))
    print('training size', len(train_dataset))
    print('validing size', len(valid_dataset))
    print('testing size', len(test_dataset))
    
    return dataset, train_loader, valid_loader, test_loader

In [6]:
# create_dataloader(config)

# Trainer and Tester

In [7]:
def trainer(train_loader, valid_loader, model, wandb, device):
    # 初始化config和criterion
    config = wandb.config
    criterion =  nn.MSELoss() 

    #初始化optimizer和lr_scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) 
    schedulerRP = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=config['step_size'], factor=config['gamma'], min_lr=1e-8, threshold=config['loss_decrease_threshold'])
#     schedulerEXP = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['gamma'])
#     schedulerSTEP = torch.optim.lr_scheduler.StepLR(optimizer, gamma=config['gamma'], step_size=config['step_size'])
      
    # table = wandb.Table(columns=["epoch", "epoch_loss"])
    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0
    model = model.double()
    # print("***************************start to train*****************************")
    
    for epoch in range(n_epochs):
        model.train()
        
        train_each_batch_loss = []
        train_each_batch_loss_total = 0
        for data_pkg in train_loader:
            x = data_pkg[0].double().to(device)
            y = data_pkg[1]
            y = y.double().to(device)
            
            optimizer.zero_grad()   
            
            pred = model(x)
            loss = criterion(pred, y)
        
            loss.backward()                     
            optimizer.step()
            
            train_each_batch_loss.append(loss.detach().item())
            train_each_batch_loss_total += loss.detach().item()
    
        model.eval() 
        
        valid_each_batch_loss = []
        valid_each_batch_loss_total = 0
        for data_pkg in valid_loader:
            x = data_pkg[0].double().to(device)
            y = data_pkg[1]
            y = y.double().to(device)
            
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, y)
            # if epoch==10:
            #     plt.plot(y[0][:, 1].detach().cpu(), label='y')
            #     plt.plot(pred[0][:, 1].detach().cpu(), label='pred')
            #     plt.legend()
            #     plt.show()
                
            valid_each_batch_loss.append(loss.detach().item())
            valid_each_batch_loss_total += loss.detach().item()
            
        current_lr = (optimizer.param_groups[0])['lr']
        train_indicator = train_each_batch_loss_total/len(train_loader)
        valid_indicator = valid_each_batch_loss_total/len(valid_loader)
        schedulerRP.step(valid_indicator)
#         schedulerEXP.step()
#         schedulerSTEP.step()

        
                
        # print("Epoch {}: LR: {:.8f}, Train Loss: {:.8f}, Valid Loss: {:.8f} for one layers".format(epoch, current_lr, train_indicator, valid_indicator))
        
        if best_loss<valid_indicator: # loss不降反增
            early_stop_count += 1
        else:
            if np.abs(best_loss-valid_indicator)>1e-4: # loss下降达到指标
                best_loss = valid_indicator
                early_stop_count = 0
            else:
                early_stop_count += 1 # loss下降但没有达到指标
                
        if early_stop_count >= config['early_stop']:
            break
        if epoch>100 and (valid_indicator-train_indicator)>0.5:  # 防止过拟合
            break
        if epoch>100 and train_indicator>1:  # 防止收敛过慢
            break

        # wandb logging
        ############################################################################################
        wandb.log({'epoch':epoch, 
                   'lr':current_lr, 
                   'train_indicator':train_indicator, 
                   'valid_indicator':valid_indicator, 
                  })
        ###########################################################################################
    
    return model, best_loss

# Training process

In [8]:
def training():
    # 初始化wandb
    ############################################################################################
    nowtime = datetime.datetime.now().strftime('%Y_%m_%d_%H%M%S')
    wandb.init(
      project='YOUR_PROJECT', 
      name=nowtime, 
      )
    config = wandb.config
    ############################################################################################
    
    # 初始化device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # 加载数据集
    dataset, train_loader, valid_loader, test_loader = create_dataloader(config)
    
    # 创建模型保存目录
    if os.path.exists(config['model_save_dir'])==False:
        os.makedirs(config['model_save_dir'], exist_ok=True)

    # 创建模型
    model = LSMLSTM(config['seq_length'], 12, config['hidden_size'], config['soil_layer_num']).to(device) 

    # 模型训练
    best_model, best_loss = trainer(train_loader, valid_loader, model, wandb, device)
    
    # 保存模型
    if best_loss<0.5:
        save_name = os.path.join(config['model_save_dir'], nowtime + '_STC.ckpt')
        torch.save(best_model.state_dict(), save_name)
        arti_code = wandb.Artifact('ipynb', type='code')
        arti_code.add_file(os.path.join(config['root'], 'SURROGATE_TRAINING_WANDB.ipynb'))
        arti_code.add_file(os.path.join(config['root'], 'LSMDataset.py'))
        arti_code.add_file(os.path.join(config['root'], 'LSMLoss.py'))
        arti_code.add_file(os.path.join(config['root'], 'LSMTransformer.py'))
                                              
        # arti_model = wandb.Artifact('model', type='model')
        # arti_model.add_file(save_name)
        wandb.log_artifact(arti_code)
        # wandb.log_artifact(arti_model)
    wandb.finish()

In [9]:
# training(constant_config)

In [None]:
wandb.login()

In [None]:
sweep_id = wandb.sweep(sweep_config, project='YOUR_PROJECT')
print(sweep_id)

In [None]:
# wandb.agent(project='YOUR_PROJECT', sweep_id='l3kbhtzo', function=training, count=1)
wandb.agent(sweep_id, training, count=100)

# 