In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
import torch.optim as optim
from torch.utils.data import Subset
from models.nn import FullyConnectedNetwork
from lib.loader import ProcessedCsvDataset, One2OneDataset, get_loader
from lib.utils import print_step

In [3]:
# Load data
dst = ProcessedCsvDataset(root_dir='data', normalized=True)

# Dataset and dataloader
train_dataset = One2OneDataset(dst.train_feature, dst.train_label)
test_dataset = One2OneDataset(dst.test_feature, dst.test_label)
train_loader = get_loader(train_dataset, batch_size=1024)
test_loader = get_loader(test_dataset, batch_size=1024, shuffle=False)

# Randomly sample val set
val_index = dst.make_val_from_test()
val_dataset = Subset(test_dataset, val_index)
val_loader = get_loader(val_dataset, batch_size=1024)

In [4]:
# Model
num_features = train_dataset.num_features
model = FullyConnectedNetwork(num_features, [100, 50, 100])
    
# optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)# , weight_decay=5e-4)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.fit(train_loader, optimizer, epochs=200,
          callback=print_step, val_loader=val_loader, scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True))

[00 - 00000] 0.82362 0.84566
[00 - 01000] 1.03433 0.82801
[01 - 00000] 0.68971 0.82669
[01 - 01000] 0.89010 0.82553
[02 - 00000] 0.88886 0.82502
[02 - 01000] 0.81339 0.82426
[03 - 00000] 0.78082 0.82392
[03 - 01000] 0.81845 0.82336
[04 - 00000] 0.86356 0.82314
[04 - 01000] 0.87496 0.82267
[05 - 00000] 0.81694 0.82256
[05 - 01000] 0.77149 0.82240
[06 - 00000] 0.81346 0.82221
[06 - 01000] 0.86731 0.82198
[07 - 00000] 0.79525 0.82194
[07 - 01000] 0.79482 0.82184
[08 - 00000] 0.81787 0.82188
[08 - 01000] 0.92714 0.82152
[09 - 00000] 0.80755 0.82143
[09 - 01000] 0.77296 0.82131
[10 - 00000] 0.86402 0.82131
[10 - 01000] 0.87852 0.82119
[11 - 00000] 1.06065 0.82114
[11 - 01000] 0.78386 0.82102
[12 - 00000] 0.72453 0.82096
[12 - 01000] 0.77000 0.82084
[13 - 00000] 0.85768 0.82096
[13 - 01000] 0.76167 0.82076
[14 - 00000] 0.77982 0.82085
[14 - 01000] 0.68001 0.82079
[15 - 00000] 0.80666 0.82062
[15 - 01000] 0.80909 0.82056
[16 - 00000] 0.83390 0.82056
[16 - 01000] 0.70178 0.82058
[17 - 00000] 0

KeyboardInterrupt: 

In [None]:
print(model.validate(test_loader))
print(model.validate(train_loader))

0.8096815317826419


| hidden units | batch size | optimizer |test acc  | train acc |
|--------------|------------|-----------|----------|-----------|
| 64        | 128      | sgd     |  0.8117  | 0.8106   |
| 100 (2/3)   | 128      | sgd     |  0.8079  | 0.8085   |
|  1024      | 128      | sgd     |  0.8342  | 1.2747   |
| 100       |  16      | sgd     |  0.8415  | 0.8546   |
| 100       | 1024     | sgd     |  0.8035  | 0.8005   |
| 100       | 8192     | sgd     |  0.8048  | 0.8018   |
| 100+200+100  | 1024     | sgd     |  0.8150  | 0.8117   |
|  100+50+100  | 1024     | sgd     |  0.8185  | 0.8153   |