Skip to content

Commit

Permalink
fixed error
Browse files Browse the repository at this point in the history
  • Loading branch information
densechen committed Jul 22, 2020
1 parent aafdaba commit 9ee0ead
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
7 changes: 4 additions & 3 deletions main.py
Expand Up @@ -109,7 +109,7 @@ def test(model, dataloader):
return correct


def forward_epoch(model, optimizer, state_keeper, time, epochs):
def forward_epoch(model, train_dataloader, test_dataloader, optimizer, state_keeper, time, epochs):
for epoch in range(1, epochs + 1):
loss_dict = train(model, optimizer, train_dataloader)
with torch.no_grad():
Expand All @@ -131,11 +131,12 @@ def forward_epoch(model, optimizer, state_keeper, time, epochs):
for time in range(args.times):
model = utils.get_model(args)
optimizer = utils.get_optimizer(args.optim, args.lr, model)
forward_epoch(model, optimizer, state_keeper, time, args.epochs)
forward_epoch(model, train_dataloader, test_dataloader,
optimizer, state_keeper, time, args.epochs)
if args.exname == "TransferLearning":
optimizer_aux = utils.get_optimizer(
args.optim, args.lr_aux, model)
forward_epoch(model, optimizer_aux, state_keeper_aux,
forward_epoch(model, train_dataloader_aux, test_dataloader_aux, optimizer_aux, state_keeper_aux,
time, args.epochs_aux)

state_keeper.save()
Expand Down
8 changes: 6 additions & 2 deletions train.sh
Expand Up @@ -9,5 +9,9 @@ python main.py --batch_size 128 --lr 1e-4 --epochs 20 --times 5 --data_root data

# Transfer Learning
# MNIST -> SVHN
python main.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 20 --epochs_aux 200 --times 5 --data_root data --dataset MNIST --dataset_aux SVHN --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning
python main.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 20 --epochs_aux 200 --times 5 --data_root data --dataset SVHN --dataset_aux MNIST --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning
python main.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 5 --epochs_aux 100 --times 5 --data_root data --dataset MNIST --dataset_aux SVHN --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning
python main.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 5 --epochs_aux 100 --times 5 --data_root data --dataset SVHN --dataset_aux MNIST --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning


python main.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 10 --epochs_aux 100 --times 5 --data_root data --dataset MNIST --dataset_aux SVHN --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning
python main.py --batch_size 128 --lr 1e-2 --lr_aux 1e-5 --epochs 10 --epochs_aux 100 --times 5 --data_root data --dataset SVHN --dataset_aux MNIST --num_workers 2 --net ConvMNIST --af all --optim SGD --exname TransferLearning

0 comments on commit 9ee0ead

Please sign in to comment.