In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import DataLoader
from datasets import traffic_dataset
from utils import *
import argparse
import yaml
import time
from maml import STMAML
from tqdm import tqdm

## Set the predict target dataset

In [None]:
predict_target_dataset = 'pems-bay'

# the predict_target_dataset can be 1 of the 6 datasets mentioned in our report
#'metr-la', 'pems-bay', 'shenzhen', 'chengdu_m', 'pems04', 'pems08'

# if you want to change the reference source datasets, you have to change the 'data_keys' in config.yaml 2nd row

In [None]:
parser = argparse.ArgumentParser(description='MAML-based')
parser.add_argument('--config_filename', default='config.yaml', type=str,
                        help='Configuration filename for restoring the model.')
parser.add_argument('--test_dataset', default='metr-la', type=str)
parser.add_argument('--source_epochs', default=200, type=int)
parser.add_argument('--source_lr', default=1e-2, type=float)
parser.add_argument('--target_epochs', default=120, type=int)
parser.add_argument('--target_lr', default=1e-2, type=float)
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--meta_dim', default=16, type=int)
parser.add_argument('--target_days', default=3, type=int)
parser.add_argument('--model', default='GRU', type=str)
parser.add_argument('--loss_lambda', default=1.5, type=float)
parser.add_argument('--memo', default='revise', type=str)

# 手動設定 args
args = parser.parse_args(args=['--test_dataset', predict_target_dataset,  '--model', 'GRU'])

print(time.strftime('%Y-%m-%d %H:%M:%S'), "meta_dim = ", args.meta_dim,"target_days = ", args.target_days)


## Main Process

In [2]:
if torch.cuda.is_available():
    args.device = torch.device('cuda')
    print("INFO: GPU")
else:
    args.device = torch.device('cpu')
    print("INFO: CPU")

with open(args.config_filename) as f:
    config = yaml.load(f,Loader=yaml.FullLoader)

torch.manual_seed(7)

data_args, task_args, model_args = config['data'], config['task'], config['model']
    
model_args['meta_dim'] = args.meta_dim
model_args['loss_lambda'] = args.loss_lambda
    
source_dataset = traffic_dataset(data_args, task_args, "source", test_data=args.test_dataset)

model = STMAML(data_args, task_args, model_args, model=args.model).to(device=args.device)
    
optimizer = torch.optim.Adam(model.parameters(), lr=args.source_lr)
loss_criterion = nn.MSELoss()

source_training_losses, target_training_losses = [], []
best_result = ''
min_MAE = 10000000

for epoch in tqdm(range(args.source_epochs)):
    # Meta-Train
    start_time = time.time()
    spt_task_data, spt_task_A, qry_task_data, qry_task_A = source_dataset.get_maml_task_batch(task_args['task_num'])
    loss = model.meta_train_revise(spt_task_data, spt_task_A, qry_task_data, qry_task_A)

    # loss = model.meta_train(spt_task_data, spt_task_A, qry_task_data, qry_task_A)
    end_time = time.time()
    if epoch % 20 == 0:
        print("[Source Train] epoch #{}/{}: loss is {}, training time is {}".format(epoch+1, args.source_epochs, loss, end_time-start_time))

print("Source dataset meta-train finish.")

target_dataset = traffic_dataset(data_args, task_args, "target", test_data=args.test_dataset, target_days=args.target_days)
target_dataloader = DataLoader(target_dataset, batch_size=task_args['batch_size'], shuffle=True, num_workers=8, pin_memory=True)
test_dataset = traffic_dataset(data_args, task_args, "test", test_data=args.test_dataset)
test_dataloader = DataLoader(test_dataset, batch_size=task_args['test_batch_size'], shuffle=True, num_workers=8, pin_memory=True)

model.finetuning(target_dataloader, args.target_epochs)

2023-12-04 17:21:09 meta_dim =  16 target_days =  3
INFO: GPU
[INFO] source dataset: ['metr-la']
[INFO] Dataset init finished!
loss_lambda =  1.5
tp is True.
sp is True.
MAML Model: GRU
model params:  17598


  0%|          | 1/200 [00:01<03:27,  1.04s/it]

[Source Train] epoch #1/200: loss is 4357.0009765625, training time is 1.0448005199432373


 10%|█         | 21/200 [00:10<01:27,  2.04it/s]

[Source Train] epoch #21/200: loss is 17.17517852783203, training time is 0.48419809341430664


 20%|██        | 41/200 [00:20<01:15,  2.09it/s]

[Source Train] epoch #41/200: loss is 25.741512298583984, training time is 0.48328304290771484


 30%|███       | 61/200 [00:29<01:04,  2.14it/s]

[Source Train] epoch #61/200: loss is 6.981149673461914, training time is 0.4395253658294678


 40%|████      | 81/200 [00:38<00:53,  2.22it/s]

[Source Train] epoch #81/200: loss is 41.792945861816406, training time is 0.4430272579193115


 50%|█████     | 101/200 [00:47<00:44,  2.22it/s]

[Source Train] epoch #101/200: loss is 9.547203063964844, training time is 0.4718587398529053


 60%|██████    | 121/200 [00:56<00:37,  2.13it/s]

[Source Train] epoch #121/200: loss is 6.5052947998046875, training time is 0.49070167541503906


 70%|███████   | 141/200 [01:05<00:26,  2.23it/s]

[Source Train] epoch #141/200: loss is 32.62919998168945, training time is 0.44253039360046387


 80%|████████  | 161/200 [01:15<00:17,  2.22it/s]

[Source Train] epoch #161/200: loss is 7.119198799133301, training time is 0.4568650722503662


 90%|█████████ | 181/200 [01:24<00:08,  2.16it/s]

[Source Train] epoch #181/200: loss is 51.25939178466797, training time is 0.454970121383667


100%|██████████| 200/200 [01:33<00:00,  2.15it/s]


Source dataset meta-train finish.
[INFO] target dataset: ['pems-bay']
[INFO] Dataset init finished!
[INFO] test dataset: ['pems-bay']




[INFO] Dataset init finished!


  1%|          | 1/120 [00:12<25:40, 12.95s/it]

[Target Fine-tune] epoch #1/120: loss is 8.487161307475146, fine-tuning time is 12.945801496505737


  9%|▉         | 11/120 [02:12<21:43, 11.96s/it]

[Target Fine-tune] epoch #11/120: loss is 6.7193237441427565, fine-tuning time is 11.895766735076904


 18%|█▊        | 21/120 [04:11<19:33, 11.85s/it]

[Target Fine-tune] epoch #21/120: loss is 6.677469818732318, fine-tuning time is 11.648711204528809


 26%|██▌       | 31/120 [06:09<17:32, 11.82s/it]

[Target Fine-tune] epoch #31/120: loss is 6.8497141890666065, fine-tuning time is 11.679858684539795


 34%|███▍      | 41/120 [08:11<16:00, 12.16s/it]

[Target Fine-tune] epoch #41/120: loss is 6.6977438751389, fine-tuning time is 12.056032180786133


 42%|████▎     | 51/120 [10:12<13:57, 12.13s/it]

[Target Fine-tune] epoch #51/120: loss is 6.897759929825278, fine-tuning time is 12.258360147476196


 51%|█████     | 61/120 [12:16<12:01, 12.22s/it]

[Target Fine-tune] epoch #61/120: loss is 6.715371214642245, fine-tuning time is 12.063392877578735


 59%|█████▉    | 71/120 [14:17<09:57, 12.18s/it]

[Target Fine-tune] epoch #71/120: loss is 6.570450925125796, fine-tuning time is 12.217651605606079


 68%|██████▊   | 81/120 [16:18<07:52, 12.12s/it]

[Target Fine-tune] epoch #81/120: loss is 9.778453651245902, fine-tuning time is 12.179498672485352


 76%|███████▌  | 91/120 [18:19<05:47, 11.99s/it]

[Target Fine-tune] epoch #91/120: loss is 6.561017151089276, fine-tuning time is 11.620487928390503


 84%|████████▍ | 101/120 [20:18<03:45, 11.88s/it]

[Target Fine-tune] epoch #101/120: loss is 8.590373269950643, fine-tuning time is 11.76212215423584


 92%|█████████▎| 111/120 [22:17<01:46, 11.81s/it]

[Target Fine-tune] epoch #111/120: loss is 6.409476311417187, fine-tuning time is 11.709235429763794


100%|██████████| 120/120 [24:03<00:00, 12.03s/it]


## Forward Predict

In [3]:
from utils import *

outputs, y_label = model.evaluate(test_dataloader)

step 0 outputs shape = torch.Size([128, 325, 6])
step 1 outputs shape = torch.Size([256, 325, 6])
step 2 outputs shape = torch.Size([384, 325, 6])
step 3 outputs shape = torch.Size([512, 325, 6])
step 4 outputs shape = torch.Size([640, 325, 6])
step 5 outputs shape = torch.Size([768, 325, 6])
step 6 outputs shape = torch.Size([896, 325, 6])
step 7 outputs shape = torch.Size([1024, 325, 6])
step 8 outputs shape = torch.Size([1152, 325, 6])
step 9 outputs shape = torch.Size([1280, 325, 6])
step 10 outputs shape = torch.Size([1408, 325, 6])
step 11 outputs shape = torch.Size([1536, 325, 6])
step 12 outputs shape = torch.Size([1664, 325, 6])
step 13 outputs shape = torch.Size([1792, 325, 6])
step 14 outputs shape = torch.Size([1920, 325, 6])
step 15 outputs shape = torch.Size([2048, 325, 6])
step 16 outputs shape = torch.Size([2176, 325, 6])
step 17 outputs shape = torch.Size([2304, 325, 6])
step 18 outputs shape = torch.Size([2432, 325, 6])
step 19 outputs shape = torch.Size([2560, 325, 6

## Evaluate Predict Result

In [4]:
from utils import *
result = metric_func(pred=outputs, y=y_label, times=6)

result_print(result, info_name='Evaluate')

print(args.memo)

Help
metric | pred shape: (10407, 6, 325)  y shape: (10407, 6, 325)
 MAE: 1.311/ 1.535/ 1.776/ 1.906/ 2.025/ 2.156
MAPE: 2.565/ 3.089/ 3.640/ 4.017/ 4.373/ 4.741
RMSE: 2.023/ 2.732/ 3.385/ 3.905/ 4.349/ 4.741
---------------------------------------
revise


## Previous Experoment Result