In [None]:
from easydict import EasyDict

from model.models import TuneModel
from keras.applications.resnet50 import ResNet50
from keras.applications.densenet import DenseNet121

In [None]:
# BASE MODEL PARAMETERS
HEIGHT = 224
WIDTH  = 224
CHANNELS = 3
WEIGHTS = 'imagenet'
INCLUDE_TOP = False
CLASSES = 2

base_models = {'resnet50': ResNet50, 'densenet121': DenseNet121}

In [None]:
model_name = 'resnet50'
model = TuneModel(base_models[model_name], name=model_name, 
                  height=HEIGHT, width=WIDTH, channels=CHANNELS,
                  INCLUDE_TOP=False, WEIGHTS='imagenet',
                  classes=CLASSES)

In [None]:
model.build()

In [12]:
gen_params = EasyDict(
    {
        'train':
            {
                'datagen':
                    {
                        'rescale': 1. / 255,
                        'rotation_range': 45,
                        'width_shift_range': 0.2,
                        'height_shift_range': 0.2,
                        'zoom_range': 0.2,
                        'horizontal_flip': True,
                    },
                'generator':
                    {
                        'directory': 'data/train',
                        'shuffle': True,
                        "target_size": (224, 224),
                        "class_mode": 'binary',
                        "batch_size": 64,
                    },

            },
        'val':
            {
                'datagen':
                    {
                        'rescale': 1. / 255,
                    },
                'generator':
                    {
                        'directory': 'data/val',
                        'shuffle': True,
                        "target_size": (224, 224),
                        "class_mode": 'binary',
                        "batch_size": 64,
                    },
            }
    })


train_params = EasyDict({'epoch': 2, 'verbose': 1})


In [None]:
model.train(train_params=train_params, gen_params=gen_params)

In [14]:
eval_params = EasyDict(
    {
        'datagen':
                {
                    'rescale': 1. / 255,
                },
        'generator':
                {
                    'directory': 'data/val',
                    'shuffle': False,
                    'target_size': (224, 224),
                    'class_mode': 'binary',
                },

    })

In [None]:
metrics, metrics_by_encounter, metrics_by_study_type = model.eval(eval_params)