In [1]:
import numpy as np
import argparse
import time
import datetime
import datautils
from utils import init_dl_program,dict2class
from infots import InfoTS as MetaInfoTS
# from baseline import InfoTS as baseInfoTS
from models.augclass import *

In [2]:
all_augs = [jitter(), scaling(), time_warp(), window_slice(), window_warp(),cutout(),subsequence()]

paras = {
    'dataset':'ETTh1', #electricity
    'archive':'forecast_csv_univar',
    'gpu':0,
    'seed':42,
    'max_threads':12,
    'log_file':'forecast_csv',
    'eval':True,
    'batch_size':128,
    'lr':0.001,
    'beta':0.5,
    'repr_dims':320,
    'max_train_length':2048,
    'iters':4000,
    'epochs':400,
    'dropout':0.1,
    'split_number':8,
    'label_ratio':1.0,
    'meta_beta':0.1,
    'aug':None,
    'aug_p1':0.7,
    'meta_lr':0.03
}

parser = argparse.ArgumentParser()
args = dict2class(**paras)


device = init_dl_program(args.gpu, seed=args.seed, max_threads=args.max_threads)

if args.archive == 'forecast_csv':
    task_type = 'forecasting'
    data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(args.dataset)
    train_data = data[:, train_slice]
elif args.archive == 'forecast_csv_univar':
    task_type = 'forecasting'
    data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(args.dataset, univar=True)
    train_data = data[:, train_slice]

valid_dataset = (data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols)

if train_data.shape[0] == 1:
    train_slice_number = int(train_data.shape[1] / args.max_train_length)
    if train_slice_number < args.batch_size:
        args.batch_size = train_slice_number
else:
    if train_data.shape[0] < args.batch_size:
        args.batch_size = train_data.shape[0]
print("Arguments:", str(args))

config = dict(
    batch_size=args.batch_size,
    lr=args.lr,
    meta_lr = args.meta_lr,
    output_dims=args.repr_dims,
    max_train_length=args.max_train_length,
    input_dims=train_data.shape[-1],
    device=device,
    num_cls =  args.batch_size,
    dropout = args.dropout,
)

t = time.time()


Arguments: <utils.dict2class object at 0x7f2d2ac3d040>


  dt.weekofyear.to_numpy(),


In [3]:
model = MetaInfoTS(
    aug_p1= args.aug_p1,
    eval_every_epoch =1,
    **config
)

In [4]:
res = model.fit(train_data,
     task_type = task_type,
     meta_beta=args.meta_beta,
     n_epochs=args.epochs,
     n_iters=args.iters,
     beta = args.beta,
     verbose=False,
     miverbose=True,
     split_number=args.split_number,
     valid_dataset = valid_dataset,
     train_labels= None
    )

v,f, mse, mae = res
mi_info = 'v %.5f ,f %.5f,mse %.5f  mae%.5f' % (v,f,mse[-1], mae[-1])

print(mi_info)

t = time.time() - t
print(f"\nTraining time: {datetime.timedelta(seconds=t)}\n")
print("Finished.")

train_data_label [[(tensor([[ 0.0000, -1.6613,  0.4913,  ...,  0.1344, -0.0335,  1.4606],
        [ 0.0000, -1.5169,  0.4913,  ...,  0.1344, -0.0335,  1.1615],
        [ 0.0000, -1.3724,  0.4913,  ...,  0.1344, -0.0335,  1.1615],
        ...,
        [ 0.0000,  1.3724, -0.5080,  ...,  0.7106,  0.8279, -0.2490],
        [ 0.0000,  1.5169, -0.5080,  ...,  0.7106,  0.8279, -0.2030],
        [ 0.0000,  1.6613, -0.5080,  ...,  0.7106,  0.8279, -0.2490]]),), (tensor(0, device='cuda:0'),)], [(tensor([[ 0.0000, -1.6613, -0.0083,  ...,  0.7106,  0.8279, -0.3640],
        [ 0.0000, -1.5169, -0.0083,  ...,  0.7106,  0.8279, -0.7396],
        [ 0.0000, -1.3724, -0.0083,  ...,  0.7106,  0.8279, -0.7626],
        ...,
        [ 0.0000,  1.3724, -1.0077,  ...,  1.5749,  1.6893, -1.0617],
        [ 0.0000,  1.5169, -1.0077,  ...,  1.5749,  1.6893, -0.9543],
        [ 0.0000,  1.6613, -1.0077,  ...,  1.5749,  1.6893, -0.9926]]),), (tensor(1, device='cuda:0'),)], [(tensor([[ 0.0000, -1.6613, -0.5080,  .

  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


[2022-11-24 09:26:53] [32mIntermediate result: 0.19142497797103258  (Index 0)[0m
{24: {'norm': {'MSE': 0.040736637257200975, 'MAE': 0.1506883407138316}}}
epoch_time 206.82215690612793


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


[2022-11-24 09:26:55] [32mIntermediate result: 0.19088859946156023  (Index 1)[0m
{24: {'norm': {'MSE': 0.04048341726338092, 'MAE': 0.1504051821981793}}}
epoch_time 131.40344619750977


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


[2022-11-24 09:26:56] [32mIntermediate result: 0.19409785891755754  (Index 2)[0m
{24: {'norm': {'MSE': 0.04169175831676438, 'MAE': 0.15240610060079315}}}
epoch_time 206.06708526611328


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


[2022-11-24 09:26:57] [32mIntermediate result: 0.1935836115578598  (Index 3)[0m
{24: {'norm': {'MSE': 0.041481061583559145, 'MAE': 0.15210254997430064}}}
epoch_time 136.63887977600098


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


[2022-11-24 09:26:58] [32mIntermediate result: 0.19327526186452815  (Index 4)[0m
{24: {'norm': {'MSE': 0.04134852214534293, 'MAE': 0.15192673971918522}}}
epoch_time 204.31780815124512


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


[2022-11-24 09:27:00] [32mIntermediate result: 0.19308651346479172  (Index 5)[0m
{24: {'norm': {'MSE': 0.04126016471612341, 'MAE': 0.15182634874866832}}}
epoch_time 138.31758499145508


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


[2022-11-24 09:27:01] [32mIntermediate result: 0.1930743682925463  (Index 6)[0m
{24: {'norm': {'MSE': 0.041231722545433634, 'MAE': 0.15184264574711265}}}
epoch_time 209.79547500610352


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


[2022-11-24 09:27:03] [32mIntermediate result: 0.19317335046479567  (Index 7)[0m
{24: {'norm': {'MSE': 0.0412410949869268, 'MAE': 0.15193225547786887}}}
epoch_time 144.36769485473633
[2022-11-24 09:27:04] [32mIntermediate result: 0.19952891722577243  (Index 8)[0m
{24: {'norm': {'MSE': 0.04353545400023141, 'MAE': 0.15599346322554103}}}
epoch_time 212.8610610961914
[2022-11-24 09:27:06] [32mIntermediate result: 0.1998212248088919  (Index 9)[0m
{24: {'norm': {'MSE': 0.043608031561139314, 'MAE': 0.15621319324775257}}}
epoch_time 144.36578750610352
[2022-11-24 09:27:07] [32mIntermediate result: 0.2001300536277684  (Index 10)[0m
{24: {'norm': {'MSE': 0.04368522815590506, 'MAE': 0.15644482547186334}}}
epoch_time 206.3009738922119
[2022-11-24 09:27:09] [32mIntermediate result: 0.20045916161260063  (Index 11)[0m
{24: {'norm': {'MSE': 0.04376808499192859, 'MAE': 0.15669107662067205}}}
epoch_time 136.13414764404297
[2022-11-24 09:27:10] [32mIntermediate result: 0.200756313869082  (Inde

[2022-11-24 09:27:59] [32mIntermediate result: 0.19601099828683785  (Index 51)[0m
{24: {'norm': {'MSE': 0.041696430981165815, 'MAE': 0.15431456730567203}}}
epoch_time 133.5000991821289
[2022-11-24 09:28:00] [32mIntermediate result: 0.19614035931352342  (Index 52)[0m
{24: {'norm': {'MSE': 0.04173362384147594, 'MAE': 0.15440673547204747}}}
epoch_time 193.19820404052734
[2022-11-24 09:28:02] [32mIntermediate result: 0.19626721446747658  (Index 53)[0m
{24: {'norm': {'MSE': 0.041770021045855796, 'MAE': 0.15449719342162077}}}
epoch_time 133.28027725219727
[2022-11-24 09:28:03] [32mIntermediate result: 0.19639259603911055  (Index 54)[0m
{24: {'norm': {'MSE': 0.04180601951880835, 'MAE': 0.1545865765203022}}}
epoch_time 193.67361068725586
[2022-11-24 09:28:04] [32mIntermediate result: 0.19651381802729853  (Index 55)[0m
{24: {'norm': {'MSE': 0.041840717573304216, 'MAE': 0.15467310045399432}}}
epoch_time 133.45813751220703
[2022-11-24 09:28:06] [32mIntermediate result: 0.19663163237699

{24: {'norm': {'MSE': 0.03980658812181716, 'MAE': 0.15111485893860171}}}
epoch_time 142.00735092163086
[2022-11-24 09:28:57] [32mIntermediate result: 0.19098613912130108  (Index 96)[0m
{24: {'norm': {'MSE': 0.03982586494026709, 'MAE': 0.15116027418103398}}}
epoch_time 211.28082275390625
[2022-11-24 09:28:59] [32mIntermediate result: 0.1910498169171501  (Index 97)[0m
{24: {'norm': {'MSE': 0.03984483448495355, 'MAE': 0.15120498243219654}}}
epoch_time 131.5004825592041
[2022-11-24 09:29:00] [32mIntermediate result: 0.19111211883142376  (Index 98)[0m
{24: {'norm': {'MSE': 0.03986336892234471, 'MAE': 0.15124874990907905}}}
epoch_time 204.62274551391602
[2022-11-24 09:29:01] [32mIntermediate result: 0.1911730942240945  (Index 99)[0m
{24: {'norm': {'MSE': 0.03988151348083012, 'MAE': 0.1512915807432644}}}
epoch_time 143.5525417327881
[2022-11-24 09:29:02] [32mIntermediate result: 0.19123369671862617  (Index 100)[0m
{24: {'norm': {'MSE': 0.03989958950927779, 'MAE': 0.15133410720934837

{24: {'norm': {'MSE': 0.04076666486714311, 'MAE': 0.15327746250757357}}}
epoch_time 122.4982738494873
[2022-11-24 09:29:54] [32mIntermediate result: 0.19412563775062494  (Index 140)[0m
{24: {'norm': {'MSE': 0.04079147883002485, 'MAE': 0.1533341589206001}}}
epoch_time 200.29711723327637
[2022-11-24 09:29:55] [32mIntermediate result: 0.19420702316350194  (Index 141)[0m
{24: {'norm': {'MSE': 0.04081616629119649, 'MAE': 0.15339085687230544}}}
epoch_time 133.02135467529297
[2022-11-24 09:29:56] [32mIntermediate result: 0.19428547639717705  (Index 142)[0m
{24: {'norm': {'MSE': 0.040840105814425175, 'MAE': 0.15344537058275187}}}
epoch_time 204.1325569152832
[2022-11-24 09:29:58] [32mIntermediate result: 0.1943605966660148  (Index 143)[0m
{24: {'norm': {'MSE': 0.04086314216804755, 'MAE': 0.15349745449796726}}}
epoch_time 148.43320846557617
[2022-11-24 09:29:59] [32mIntermediate result: 0.19443216486369624  (Index 144)[0m
{24: {'norm': {'MSE': 0.04088522374883444, 'MAE': 0.15354694111

{24: {'norm': {'MSE': 0.0414872044418687, 'MAE': 0.15489192294914705}}}
epoch_time 128.96490097045898
[2022-11-24 09:30:51] [32mIntermediate result: 0.19643198720520097  (Index 184)[0m
{24: {'norm': {'MSE': 0.04150391594899266, 'MAE': 0.1549280712562083}}}
epoch_time 194.83160972595215
[2022-11-24 09:30:52] [32mIntermediate result: 0.19648643355432382  (Index 185)[0m
{24: {'norm': {'MSE': 0.04152113895885855, 'MAE': 0.15496529459546526}}}
epoch_time 133.2550048828125
[2022-11-24 09:30:53] [32mIntermediate result: 0.19654243848683506  (Index 186)[0m
{24: {'norm': {'MSE': 0.04153888612567612, 'MAE': 0.15500355236115892}}}
epoch_time 203.23896408081055
[2022-11-24 09:30:54] [32mIntermediate result: 0.19659703865402003  (Index 187)[0m
{24: {'norm': {'MSE': 0.04155616262646437, 'MAE': 0.15504087602755567}}}
epoch_time 142.0588493347168
[2022-11-24 09:30:56] [32mIntermediate result: 0.19664981735063045  (Index 188)[0m
{24: {'norm': {'MSE': 0.0415728178860146, 'MAE': 0.1550769994646

{24: {'norm': {'MSE': 0.03998618989707565, 'MAE': 0.1517401195978757}}}
epoch_time 121.42181396484375
[2022-11-24 09:31:46] [32mIntermediate result: 0.19176607111381344  (Index 228)[0m
{24: {'norm': {'MSE': 0.039999424704477184, 'MAE': 0.15176664640933626}}}
epoch_time 196.98071479797363
[2022-11-24 09:31:48] [32mIntermediate result: 0.19180520242217916  (Index 229)[0m
{24: {'norm': {'MSE': 0.04001252430249736, 'MAE': 0.1517926781196818}}}
epoch_time 122.21646308898926
[2022-11-24 09:31:49] [32mIntermediate result: 0.19184313793904906  (Index 230)[0m
{24: {'norm': {'MSE': 0.04002533383622731, 'MAE': 0.15181780410282175}}}
epoch_time 194.90742683410645
[2022-11-24 09:31:50] [32mIntermediate result: 0.19188057817359672  (Index 231)[0m
{24: {'norm': {'MSE': 0.04003802793063364, 'MAE': 0.15184255024296306}}}
epoch_time 122.10345268249512
[2022-11-24 09:31:51] [32mIntermediate result: 0.19191696817832607  (Index 232)[0m
{24: {'norm': {'MSE': 0.04005045441454849, 'MAE': 0.151866513

{24: {'norm': {'MSE': 0.04059040334334551, 'MAE': 0.15295263919191204}}}
epoch_time 129.8542022705078
[2022-11-24 09:32:43] [32mIntermediate result: 0.19359635443344422  (Index 272)[0m
{24: {'norm': {'MSE': 0.040607249749759676, 'MAE': 0.15298910468368454}}}
epoch_time 216.01128578186035
[2022-11-24 09:32:44] [32mIntermediate result: 0.1936498593042724  (Index 273)[0m
{24: {'norm': {'MSE': 0.04062417170555138, 'MAE': 0.153025687598721}}}
epoch_time 138.6723518371582
[2022-11-24 09:32:45] [32mIntermediate result: 0.19370324249361195  (Index 274)[0m
{24: {'norm': {'MSE': 0.04064106757446903, 'MAE': 0.15306217491914292}}}
epoch_time 192.98458099365234
[2022-11-24 09:32:47] [32mIntermediate result: 0.19375677884026246  (Index 275)[0m
{24: {'norm': {'MSE': 0.040658046282265596, 'MAE': 0.15309873255799686}}}
epoch_time 135.63060760498047
[2022-11-24 09:32:48] [32mIntermediate result: 0.1938119066274101  (Index 276)[0m
{24: {'norm': {'MSE': 0.040675574335884944, 'MAE': 0.15313633229

{24: {'norm': {'MSE': 0.04129273116965588, 'MAE': 0.15441890594293084}}}
epoch_time 130.31339645385742
[2022-11-24 09:33:43] [32mIntermediate result: 0.1957542443828064  (Index 316)[0m
{24: {'norm': {'MSE': 0.04130640190139725, 'MAE': 0.15444784248140914}}}
epoch_time 183.69269371032715
[2022-11-24 09:33:44] [32mIntermediate result: 0.19579726229161032  (Index 317)[0m
{24: {'norm': {'MSE': 0.041320192132807826, 'MAE': 0.1544770701588025}}}
epoch_time 112.39790916442871
[2022-11-24 09:33:45] [32mIntermediate result: 0.19584067857018686  (Index 318)[0m
{24: {'norm': {'MSE': 0.04133409807730736, 'MAE': 0.1545065804928795}}}
epoch_time 207.8261375427246
[2022-11-24 09:33:47] [32mIntermediate result: 0.19588439927855317  (Index 319)[0m
{24: {'norm': {'MSE': 0.04134811284420846, 'MAE': 0.15453628643434472}}}
epoch_time 125.83041191101074
[2022-11-24 09:33:48] [32mIntermediate result: 0.19592821153800757  (Index 320)[0m
{24: {'norm': {'MSE': 0.04136212239648824, 'MAE': 0.15456608914

{24: {'norm': {'MSE': 0.04184768406350048, 'MAE': 0.15550792168720215}}}
epoch_time 127.44784355163574
[2022-11-24 09:34:40] [32mIntermediate result: 0.1974061182161577  (Index 360)[0m
{24: {'norm': {'MSE': 0.04186526646779241, 'MAE': 0.1555408517483653}}}
epoch_time 203.22465896606445
[2022-11-24 09:34:41] [32mIntermediate result: 0.19745747979965445  (Index 361)[0m
{24: {'norm': {'MSE': 0.041883120556835615, 'MAE': 0.15557435924281884}}}
epoch_time 128.37958335876465
[2022-11-24 09:34:42] [32mIntermediate result: 0.19750918362052028  (Index 362)[0m
{24: {'norm': {'MSE': 0.041901111533015706, 'MAE': 0.15560807208750457}}}
epoch_time 192.9178237915039
[2022-11-24 09:34:44] [32mIntermediate result: 0.19756137738216076  (Index 363)[0m
{24: {'norm': {'MSE': 0.041919286540746885, 'MAE': 0.15564209084141387}}}
epoch_time 136.0476016998291
[2022-11-24 09:34:45] [32mIntermediate result: 0.19761356526997836  (Index 364)[0m
{24: {'norm': {'MSE': 0.04193745985496046, 'MAE': 0.155676105

ValueError: not enough values to unpack (expected 4, got 2)