In [1]:
import torch.optim as optim
from models.nn import RecurrentNeuralNetwork
from lib.loader import ProcessedCsvDataset, ManyDataset, Many2OneDataset, get_loader
from lib.utils import print_step

In [2]:
dst = ProcessedCsvDataset()

dst.make_val_from_test()
train_dataset = Many2OneDataset(dst.train_feature, dst.train_label)
test_dataset = Many2OneDataset(dst.test_feature, dst.test_label)
val_dataset = Many2OneDataset(dst.val_feature, dst.val_label)

train_loader = get_loader(train_dataset, seq_first=True, batch_size=128)
test_loader = get_loader(test_dataset, seq_first=True, batch_size=1024, shuffle=False)
val_loader = get_loader(val_dataset, seq_first=True, batch_size=1024, shuffle=False)

In [None]:
model = RecurrentNeuralNetwork(dst.num_features, 50)
    
# optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
adam = optim.Adam(model.parameters(), lr=1e-2)

model.fit(train_loader, optimizer, callback=print_step,
          val_loader=val_loader, scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True))



[00 - 00000] 1.19439 0.85882
[00 - 01000] 0.52809 0.84351
[00 - 02000] 0.81401 0.84219
[00 - 03000] 0.92246 0.84166
[00 - 04000] 0.75788 0.84158
[00 - 05000] 0.76559 0.84131
[00 - 06000] 1.02900 0.84137
[00 - 07000] 0.68970 0.84117
[00 - 08000] 0.83326 0.84123
[00 - 09000] 0.40427 0.84137
[00 - 10000] 0.83581 0.84144
[01 - 00000] 0.53403 0.84135
[01 - 01000] 0.73312 0.84110
[01 - 02000] 0.92837 0.84107
[01 - 03000] 0.69749 0.84118
[01 - 04000] 0.71622 0.84138
[01 - 05000] 0.54107 0.84161
[01 - 06000] 0.74315 0.84147
[01 - 07000] 0.94082 0.84119
[01 - 08000] 1.16982 0.84152
[01 - 09000] 0.92298 0.84163
[01 - 10000] 0.80814 0.84124
[02 - 00000] 0.72532 0.84143
[02 - 01000] 0.78738 0.84139
[02 - 02000] 0.63868 0.84119
Epoch    24: reducing learning rate of group 0 to 1.0000e-04.


In [4]:
print(model.validate(test_loader))
print(model.validate(train_loader))

0.8294012955245592
0.8264085198218387


| No. | architcture | optimizer | test acc | train acc |
 -----|-------------|-----------|----------|-----------
 1 | 50 | sgd 1e-3 | 82.94 | 82.64
 2 | 50 | adam 1e-2