Skip to content
Permalink
Browse files

Whitespace and adding pydot for graph visualization

  • Loading branch information...
ReeceStevens committed Feb 5, 2018
1 parent e24225e commit fd37561d470b75789216e993c35b82dc0d73a968
Showing with 23 additions and 26 deletions.
  1. +0 −3 cam_animation.py
  2. +23 −23 train.py
@@ -12,9 +12,6 @@
from train import DataGenerator
from visualize import plot_row_item

# Debug purposes only
# from pympler import muppy, summary

def get_model_predictions_for_npz(model, data_generator, character_name, npz_name):
npz_file_path = os.path.join(data_generator.data_path, character_name, npz_name)
pixels = np.load(npz_file_path)['pixels']
@@ -1,60 +1,60 @@
'''Builds a model, organizes and loads data, and runs model training.'''
import argparse
from collections import defaultdict

from keras.layers import Input
from keras.layers.core import Dense, Flatten, Dropout
from keras.layers.merge import Concatenate
from keras.layers.normalization import BatchNormalization

import keras
import numpy as np
import os
import glob
import random

from keras.layers.pooling import GlobalAveragePooling2D
import keras
import numpy as np

from keras.layers import Input, Average
from keras.layers.core import Dense, Flatten, Dropout
from keras.layers.merge import Concatenate
from keras.layers.normalization import BatchNormalization
from keras.layers.pooling import GlobalAveragePooling2D, GlobalMaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model

def get_model(pretrained_model, all_character_names):
if pretrained_model == 'inception':
pretrained_model = keras.applications.inception_v3.InceptionV3(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
model_base = keras.applications.inception_v3.InceptionV3(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
elif pretrained_model == 'xception':
pretrained_model = keras.applications.xception.Xception(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
model_base = keras.applications.xception.Xception(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
elif pretrained_model == 'resnet50':
pretrained_model = keras.applications.resnet50.ResNet50(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
model_base = keras.applications.resnet50.ResNet50(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
elif pretrained_model == 'vgg19':
pretrained_model = keras.applications.vgg19.VGG19(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
model_base = keras.applications.vgg19.VGG19(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
elif pretrained_model == 'all':
input = Input(shape=(*IMG_SIZE, 3))
inception_model = keras.applications.inception_v3.InceptionV3(include_top=False, input_tensor=input, weights='imagenet')
xception_model = keras.applications.xception.Xception(include_top=False, input_tensor=input, weights='imagenet')
resnet_model = keras.applications.resnet50.ResNet50(include_top=False, input_tensor=input, weights='imagenet')

flattened_outputs = [Flatten()(inception_model.output),
Flatten()(xception_model.output),
Flatten()(resnet_model.output)]
output = Concatenate()(flattened_outputs)
pretrained_model = Model(input, output)

print(pretrained_model.output.shape.ndims)

if pretrained_model.output.shape.ndims > 2:
output = Flatten()(pretrained_model.output)
else:
output = pretrained_model.output
model_base = Model(input, output)

output = model_base.output
output = BatchNormalization()(output)
output = Dropout(0.5)(output)
output = Dense(128, activation='relu')(output)
output = BatchNormalization()(output)
output = Dropout(0.5)(output)
output = Dense(len(all_character_names), activation='softmax')(output)
model = Model(pretrained_model.input, output)
for layer in pretrained_model.layers:
model = Model(model_base.input, output)
for layer in model_base.layers:
layer.trainable = False
model.summary(line_length=200)

# Generate a plot of a model
import pydot
pydot.find_graphviz = lambda: True
from keras.utils import plot_model
plot_model(model, show_shapes=True, to_file='../model_pdfs/{}.pdf'.format(pretrained_model))

model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])

0 comments on commit fd37561

Please sign in to comment.
You can’t perform that action at this time.