In [1]:
import numpy as np
import random
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import time
from time import gmtime, strftime
import os
from data.data_split import cyclic_split
from data.dataset import get_dataset_class,CustomTensorDataset_GBA_seq_gap, CustomTensorDataset_GBA_seq
from data.transforms import ClassifyByThresholds
from trainer import NIMSTrainer_Germnay_Two
from model.swinunet_model import SwinUnet_CAM_Two
from model.conv_lstm import ConvLSTM,ConvLSTM_Two
from losses import *
from utils import *
import torch.optim as optim
import warnings
import sys
import datetime
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)

import wandb



In [2]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def main(args, wandb):
    device = set_device(args)
    fix_seed(args.seed)
    g = torch.Generator()
    g.manual_seed(args.seed)
    

    # Set experiment name and use it as process name if possible
    experiment_name = get_experiment_name(args)
    current_time = strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
    args.experiment_name = experiment_name+ "_" + current_time
    experiment_name = args.experiment_name
    
    # print('Running Experiment'.center(30).center(80, "="))
    # print(experiment_name)

    save_path = '/home/jianer/PostRainBench/3_GBA_wandb_ConvLSTM/GBA_dataset/experiment/'
    trn_x_1 = np.load(save_path + 'X_train_period1_time_filtered.npy', mmap_mode='r')
    trn_x_2= np.load(save_path + 'X_train_period2_time_filtered.npy', mmap_mode='r')
    trn_y_1 = np.load(save_path + 'y_train_period1_time_filtered.npy')
    trn_y_2 = np.load(save_path + 'y_train_period2_time_filtered.npy')

    tst_x = np.load(save_path + 'X_test_period_time_filtered.npy', mmap_mode='r')
    tst_y = np.load(save_path + 'y_test_period_time_filtered.npy')
    vld_x = np.load(save_path + 'X_valid_period_time_filtered.npy', mmap_mode='r')
    vld_y = np.load(save_path + 'y_valid_period_time_filtered.npy')

    print('Load datasets in CPU memory successfully!')
    print("#" * 80)

    batch_size = args.batch_size
    train_dataset = CustomTensorDataset_GBA_seq_gap(torch.from_numpy(trn_x_1),torch.from_numpy(trn_x_2), \
                                                    torch.from_numpy(trn_y_1), torch.from_numpy(trn_y_2), \
                                                    args.rain_thresholds, sequence_length=args.seq_length, downscaling_t=4)
    val_dataset = CustomTensorDataset_GBA_seq(torch.tensor(vld_x),torch.tensor(vld_y), args.rain_thresholds, sequence_length=args.seq_length, downscaling_t=4)
    test_dataset = CustomTensorDataset_GBA_seq(torch.tensor(tst_x),torch.tensor(tst_y), args.rain_thresholds, sequence_length=args.seq_length, downscaling_t=4)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
#     for x, y, z in train_dataset:
#         print(f'x shape: {x.shape}')
#         print(f'y shape: {y.shape}')
#         print(f'z shape: {z.shape}')
#         break  # 打印一次后跳出循环
#     for x, y, z in train_loader:
#         print(f'x shape: {x.shape}')
#         print(f'y shape: {y.shape}')
#         print(f'z shape: {z.shape}')
#         break  # 打印一次后跳出循环
        
    nwp_sample = torch.rand(1, 1, 84, 64, 64)
    # set model
    model = ConvLSTM_Two(input_data=args.input_data,
                            window_size=args.window_size,
                            input_dim=nwp_sample.shape[2],
                            hidden_dim=args.hidden_dim,
                            kernel_size=(args.kernel_size,args.kernel_size),  # hotfix: only supports single tuple of size 2
                            num_layers=args.num_layers,
                            num_classes=args.num_classes,
                            batch_first=True,
                            bias=True,
                            return_all_layers=False)
    
    criterion = CrossEntropyLoss_Two(args=args,
                                    device=device,
                                    num_classes=args.num_classes,
                                    experiment_name=experiment_name)
#     if wandb.config['finetune']:
#         checkpoint = torch.load(model_path)
#         model.load_state_dict(checkpoint['model'], strict=True)
        
    dice_criterion = None
    normalization = None
    
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                              weight_decay=args.wd, nesterov=args.nesterov)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr,
                                  alpha=0.9, eps=1e-6)
    elif args.optimizer == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    # scheduler = optim.lr_scheduler.StepLR(optimizer, args.wd_ep)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=wandb.config['lr_decay_rate'], patience=10, threshold=0.0001)
    
    nims_trainer = NIMSTrainer_Germnay_Two(wandb, model, criterion, dice_criterion, optimizer, scheduler, device,
                                train_loader, valid_loader, test_loader, experiment_name,
                                args, normalization=normalization)
    # Train model
    nims_trainer.train()

In [3]:
sweep_config = {
    'method': 'random'
    }

metric = {
    'name': 'loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

sweep_config['parameters'] = {}

# 常量型超参数
sweep_config['parameters'].update({
    'seed': {'value': 11611801},      
    'n_epochs': {'value': 1000},
})
    
# 离散型超参数
sweep_config['parameters'].update({
    'train_batch_size': {
        'values': [256]
    },
    'early_stop': {
        'values': [50]
    },
    'tolerate_loss':{
        'values':[1e-3]
    },
    'CELWeight': {
        'values': [0,3,4]
    },
    'seq_length':{
         'values': [1, 2]
    },
    'SFLoss': {
        'values': [0.4, 52]
    },
        
})

    
# 连续型超参数
sweep_config['parameters'].update({
    'learning_rate': {
        'distribution': 'uniform',
        'min': 1e-6,
        'max': 1e-1
      },
    'alpha': {
        'distribution': 'uniform',
        'min': 1e-1,
        'max': 1e2,
      },
    'lr_decay_rate': {
        'distribution': 'uniform',
        'min': 1e-1,
        'max': 8e-1,
    },
})

In [4]:
def training():
     # 初始化wandb
    ############################################################################################
    nowtime = datetime.datetime.now().strftime('%Y_%m_%d_%H%M%S')
    wandb.init(
      project='PRBenchTest_ConvLSTM_GBA_Tansfer_2015_2022', 
      name=nowtime, 
      )
    config = wandb.config
    ############################################################################################

    # 构建命令行参数
    sys.argv = [
    '--model', 'convlstm',
    '--device', '0',
    '--seed', str(config['seed']),
    '--input_data', 'gdaps_kim',
    '--num_epochs', str(config['n_epochs']),
    '--rain_thresholds', '0.4', '52.0', '100.0',
    '--log_dir', 'logs/logs_1106_China',
    '--batch_size', str(config['train_batch_size']),
    '--lr', str(config['learning_rate']),
    '--use_two',
    '--seq_length', str(config['seq_length']),
    '--loss', 'ce+mse',
    '--SFLoss', str(config['SFLoss']),
    '--alpha', str(config['alpha']),
    '--kernel_size', '3',
    '--weight_version', str(config['CELWeight']),
    '--wd_ep', '100',
    '--custom_name', 'PRBenchTest_ConvLSTM_GBA_TSF_2015_2022'
    ]
    # 模型训练
    args = parse_args(sys.argv)
    main(args, wandb)
    # best_model, best_loss = trainer(train_loader, valid_loader, model, wandb, device)
    
    # 保存模型
    # if best_loss<0.3:
    #     save_name = os.path.join(config['model_save_dir'], nowtime + '.ckpt')
    #     torch.save(best_model.state_dict(), save_name)
    #     arti_code = wandb.Artifact('ipynb', type='code')
    #     arti_code.add_file(os.path.join(config['root'], 'SURROGATE_TRAINING_WANDB.ipynb'))
    #     arti_code.add_file(os.path.join(config['root'], 'LSMDataset.py'))
    #     arti_code.add_file(os.path.join(config['root'], 'LSMLoss.py'))
    #     arti_code.add_file(os.path.join(config['root'], 'LSMTransformer.py'))
                                              
    #     # arti_model = wandb.Artifact('model', type='model')
    #     # arti_model.add_file(save_name)
    #     wandb.log_artifact(arti_code)
    #     wandb.log_artifact(arti_model)
    wandb.finish()

In [5]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpanj1018[0m ([33mpanj1018-hong-kong-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
sweep_id = wandb.sweep(sweep_config, project='PRBenchTest_ConvLSTM_GBA_Transfer_2015_2022')
print(sweep_id)

Create sweep with ID: jne96bg4
Sweep URL: https://wandb.ai/panj1018-hong-kong-university-of-science-and-technology/PRBenchTest_ConvLSTM_GBA_Transfer_2015_2022/sweeps/jne96bg4
jne96bg4


In [7]:
# wandb.agent(project='PRBenchTest_ConvLSTM_GBA_2015_2022', sweep_id='mli9ek0t', function=training, count=50)
wandb.agent(sweep_id, training, count=1)

[34m[1mwandb[0m: Agent Starting Run: p6xplvu3 with config:
[34m[1mwandb[0m: 	CELWeight: 3
[34m[1mwandb[0m: 	SFLoss: 0.4
[34m[1mwandb[0m: 	alpha: 74.6301536585969
[34m[1mwandb[0m: 	early_stop: 50
[34m[1mwandb[0m: 	learning_rate: 0.08271984100560713
[34m[1mwandb[0m: 	lr_decay_rate: 0.4592144601780157
[34m[1mwandb[0m: 	n_epochs: 1000
[34m[1mwandb[0m: 	seed: 11611801
[34m[1mwandb[0m: 	seq_length: 1
[34m[1mwandb[0m: 	tolerate_loss: 0.001
[34m[1mwandb[0m: 	train_batch_size: 256


Seed set to 11611801
  train_dataset = CustomTensorDataset_GBA_seq_gap(torch.from_numpy(trn_x_1),torch.from_numpy(trn_x_2), \


Load datasets in CPU memory successfully!
################################################################################
Omitting intermediate evaluation on test set


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
