Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
marcellacornia committed Aug 3, 2017
1 parent 07ed2c5 commit bba23cd
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from config import *
from utilities import preprocess_images, preprocess_maps, preprocess_fixmaps, postprocess_predictions
from models import sam_vgg, sam_resnet, schedule_vgg, schedule_resnet, kl_divergence, correlation_coefficient, nss
from models import sam_vgg, sam_resnet, kl_divergence, correlation_coefficient, nss


def generator(b_s, phase_gen='train'):
Expand Down Expand Up @@ -76,15 +76,13 @@ def generator_test(b_s, imgs_test_path):
m.fit_generator(generator(b_s=b_s), nb_imgs_train, nb_epoch=nb_epoch,
validation_data=generator(b_s=b_s, phase_gen='val'), nb_val_samples=nb_imgs_val,
callbacks=[EarlyStopping(patience=3),
ModelCheckpoint('weights.sam-vgg.{epoch:02d}-{val_loss:.4f}.pkl', save_best_only=True),
LearningRateScheduler(schedule=schedule_vgg)])
ModelCheckpoint('weights.sam-vgg.{epoch:02d}-{val_loss:.4f}.pkl', save_best_only=True)])
elif version == 1:
print("Training SAM-ResNet")
m.fit_generator(generator(b_s=b_s), nb_imgs_train, nb_epoch=nb_epoch,
validation_data=generator(b_s=b_s, phase_gen='val'), nb_val_samples=nb_imgs_val,
callbacks=[EarlyStopping(patience=3),
ModelCheckpoint('weights.sam-resnet.{epoch:02d}-{val_loss:.4f}.pkl', save_best_only=True),
LearningRateScheduler(schedule=schedule_resnet)])
ModelCheckpoint('weights.sam-resnet.{epoch:02d}-{val_loss:.4f}.pkl', save_best_only=True)])

elif phase == "test":
# Output Folder Path
Expand Down

0 comments on commit bba23cd

Please sign in to comment.