Skip to content

Commit

Permalink
Merge pull request #16 from bjherger/example
Browse files Browse the repository at this point in the history
Examples
  • Loading branch information
bjherger committed May 25, 2018
2 parents 4249823 + b9ec220 commit 82b46c9
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 1 deletion.
33 changes: 33 additions & 0 deletions examples/mushrooms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging

from keras import Model
from keras.layers import Dense

from keras_pandas.Automater import Automater
from keras_pandas.lib import load_mushrooms, load_titanic


def main():
logging.getLogger().setLevel(logging.DEBUG)

observations = load_mushrooms()

# Transform the data set, using keras_pandas
auto = Automater(categorical_vars=observations.columns, response_var='class')
X, y = auto.fit_transform(observations)

# Create model
x = auto.input_nub
x = Dense(30)(x)
x = auto.output_nub(x)

model = Model(inputs=auto.input_layers, outputs=x)
model.compile(optimizer='Adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train model
model.fit(X, y, epochs=10, validation_split=.5)

pass

if __name__ == '__main__':
main()
36 changes: 36 additions & 0 deletions examples/titanic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging

from keras import Model
from keras.layers import Dense

from keras_pandas.Automater import Automater
from keras_pandas.lib import load_mushrooms, load_titanic


def main():
logging.getLogger().setLevel(logging.DEBUG)

observations = load_titanic()

# Transform the data set, using keras_pandas
categorical_vars = ['pclass', 'sex', 'survived']
numerical_vars = ['age', 'siblings_spouses_aboard', 'parents_children_aboard', 'fare']

auto = Automater(categorical_vars=categorical_vars, numerical_vars=numerical_vars, response_var='survived')
X, y = auto.fit_transform(observations)

# Create model
x = auto.input_nub
x = Dense(30)(x)
x = auto.output_nub(x)

model = Model(inputs=auto.input_layers, outputs=x)
model.compile(optimizer='Adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train model
model.fit(X, y, epochs=10, validation_split=.5)

pass

if __name__ == '__main__':
main()
5 changes: 4 additions & 1 deletion keras_pandas/Automater.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@ def fit(self, input_dataframe):

def transform(self, dataframe):

# Check if we have a response variable, and if it is available
# Check if fitted yet
if not self.fitted:
raise ValueError('Cannot transform without being fitted first. Call fit() method before transform() method')

# Check if we have a response variable, and if it is available
if self.response_var is not None and self.response_var in dataframe.columns:
y_available = True
else:
Expand Down
71 changes: 71 additions & 0 deletions keras_pandas/lib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import logging
import os
import pandas

import requests


def check_variable_list_are_valid(variable_type_dict):
"""
Expand All @@ -19,3 +26,67 @@ def check_variable_list_are_valid(variable_type_dict):

return True

def download_file(url, local_file_path, filename):
"""
Download the file at `url` in chunks, to the location at `local_file_path`
:param url: URL to a file to be downloaded
:type url: str
:param local_file_path: Path to download the file to
:type local_file_path: str
:return: The path to the file on the local machine (same as input `local_file_path`)
:rtype: str
"""
logging.info('Downloading file from url: {}, to path: {}'.format(url, local_file_path))
# Reference variables
chunk_count = 0
local_file_path = os.path.expanduser(local_file_path)
if not os.path.exists(local_file_path):
os.makedirs(local_file_path)

local_file_path = os.path.join(local_file_path, filename)

# Create connection to the stream
r = requests.get(url, stream=True)

# Open output file
if not os.path.exists(local_file_path):
with open(local_file_path, 'wb') as f:

# Iterate through chunks of file
for chunk in r.iter_content(chunk_size=1048576):

logging.debug('Downloading chunk: {} for file: {}'.format(chunk_count, local_file_path))

# If there is a chunk to write to file, write it
if chunk:
f.write(chunk)

# Increase chunk counter
chunk_count = chunk_count + 1

r.close()
return local_file_path

def load_mushrooms():
# Extract the data
file_path = download_file(
'https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data',
'~/.keras-pandas/example_datasets/', filename='agaricus-lepiota.data')

observations = pandas.read_csv(filepath_or_buffer=file_path,
names=['class', 'cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor',
'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color',
'stalk-shape', 'stalk-root', 'stalk-surface-above-ring',
'stalk-surface-below-ring', 'stalk-color-above-ring',
'stalk-color-below-ring', 'veil-type', 'veil-color', 'ring-number',
'ring-type', 'spore-print-color', 'population', 'habitat'])
return observations

def load_titanic():
file_path = download_file('http://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv',
'~/.keras-pandas/example_datasets/', filename='titanic.csv')

observations = pandas.read_csv(file_path)
observations.columns = map(lambda x: x.lower().replace(' ', '_').replace('/', '_'), observations.columns)

return observations
27 changes: 27 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
absl-py==0.2.0
astor==0.6.2
attrs==17.4.0
backports.weakref==1.0.post1
bleach==1.5.0
certifi==2018.4.16
chardet==3.0.4
click==6.7
enum34==1.1.6
Flask==0.12.2
funcsigs==1.0.2
futures==3.2.0
gast==0.2.0
grpcio==1.11.0
h5py==2.7.1
html5lib==0.9999999
idna==2.6
itsdangerous==0.24
Jinja2==2.10
Keras==2.1.5
Markdown==2.6.11
MarkupSafe==1.0
mock==2.0.0
more-itertools==4.1.0
numpy==1.14.2
pandas==0.22.0
pbr==4.0.2
pluggy==0.6.0
protobuf==3.5.2.post1
py==1.5.3
python-dateutil==2.7.2
pytz==2018.4
PyYAML==3.12
requests==2.18.4
scikit-learn==0.19.1
scipy==1.0.1
six==1.11.0
sklearn==0.0
sklearn-pandas==1.6.0
tensorboard==1.7.0
tensorflow==1.7.0
termcolor==1.1.0
urllib3==1.22
Werkzeug==0.14.1

0 comments on commit 82b46c9

Please sign in to comment.