Skip to content

Commit

Permalink
CSV tricks and all tests moved to a package
Browse files Browse the repository at this point in the history
  • Loading branch information
jaor committed Apr 8, 2015
1 parent 6ac0b30 commit 018fba8
Show file tree
Hide file tree
Showing 75 changed files with 188 additions and 90 deletions.
7 changes: 3 additions & 4 deletions bigml/cluster.py
Expand Up @@ -42,7 +42,6 @@
LOGGER = logging.getLogger('BigML')

import sys
import csv
import math
import re

Expand All @@ -56,6 +55,7 @@
from bigml.model import STORAGE
from bigml.predicate import TM_TOKENS, TM_FULL_TERM
from bigml.modelfields import ModelFields
from bigml.io import UnicodeWriter


OPTIONAL_FIELDS = ['categorical', 'text']
Expand Down Expand Up @@ -298,7 +298,7 @@ def statistics_CSV(self, file_name=None):
row.append(result)
intercentroids = True
for measure, result in centroid.distance.items():
if measure in CSV_STATISTICS:
if measure in CSV_STATISTICS:
if not header_complete:
headers.append(u"Data %s" %
measure.lower().replace("_", " "))
Expand All @@ -310,8 +310,7 @@ def statistics_CSV(self, file_name=None):

if file_name is None:
return rows
with open(file_name, "w") as file_handler:
writer = csv.writer(file_handler)
with UnicodeWriter(file_name) as writer:
for row in rows:
writer.writerow([item if not isinstance(item, basestring)
else item.encode("utf-8") for item in row])
Expand Down
92 changes: 92 additions & 0 deletions bigml/io.py
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 BigML, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.


"""Python 2/3 compatibility for I/O functions.
:author: jao <jao@bigml.com>
:date: Wed Apr 08, 2015 17:52
"""

import sys, csv

PY3 = sys.version > '3'

class UnicodeReader:
def __init__(self, filename, dialect=csv.excel,
encoding="utf-8", **kw):
self.filename = filename
self.dialect = dialect
self.encoding = encoding
self.kw = kw

def __enter__(self):
if PY3:
self.f = open(self.filename, 'rt',
encoding=self.encoding, newline='')
else:
self.f = open(self.filename, 'rb')
self.reader = csv.reader(self.f, dialect=self.dialect,
**self.kw)
return self

def __exit__(self, type, value, traceback):
self.f.close()

def next(self):
row = next(self.reader)
if PY3:
return row
return [s.decode("utf-8") for s in row]

def __iter__(self):
return self

class UnicodeWriter:
def __init__(self, filename, dialect=csv.excel,
encoding="utf-8", **kw):
self.filename = filename
self.dialect = dialect
self.encoding = encoding
self.kw = kw

def open_writer(self):
if PY3:
self.f = open(self.filename, 'wt',
encoding=self.encoding, newline='')
else:
self.f = open(self.filename, 'wb')
self.writer = csv.writer(self.f, dialect=self.dialect, **self.kw)
return self

def close_writer(self):
self.f.close()

def __enter__(self):
return self.open_writer()

def __exit__(self, type, value, traceback):
self.close_writer()

def writerow(self, row):
if not PY3:
row = [s.encode(self.encoding) for s in row]
self.writer.writerow(row)

def writerows(self, rows):
for row in rows:
self.writerow(row)
11 changes: 5 additions & 6 deletions bigml/model.py
Expand Up @@ -54,7 +54,6 @@
import sys
import locale
import json
import csv

from functools import partial

Expand All @@ -69,6 +68,7 @@
from bigml.basemodel import BaseModel, retrieve_resource, print_importance
from bigml.basemodel import ONLY_MODEL
from bigml.multivote import ws_confidence
from bigml.io import UnicodeWriter

# we use the atof conversion for integers to include integers written as
# 10.0
Expand Down Expand Up @@ -286,7 +286,7 @@ def predict(self, input_data, by_name=True,
(maximum number of categories to be returned), or the
literal 'all', that will cause the entire distribution
in the node to be returned.
"""
# Checks if this is a regression model, using PROPORTIONAL
# missing_strategy
Expand Down Expand Up @@ -321,8 +321,8 @@ def predict(self, input_data, by_name=True,
total_instances = float(prediction.count)
distribution = enumerate(prediction.distribution)
for index, [category, instances] in distribution:
if ((isinstance(multiple, basestring) and multiple == 'all') or
(isinstance(multiple, int) and index < multiple)):
if ((isinstance(multiple, basestring) and multiple == 'all') or
(isinstance(multiple, int) and index < multiple)):
prediction_dict = {
'prediction': category,
'confidence': ws_confidence(category,
Expand Down Expand Up @@ -878,8 +878,7 @@ def tree_CSV(self, file_name=None, leaves_only=False):
nodes_generator = self.get_nodes_info(headers_names,
leaves_only=leaves_only)
if file_name is not None:
with open(file_name, "w") as file_handler:
writer = csv.writer(file_handler)
with UnicodeWriter(file_name) as writer:
writer.writerow([header.encode("utf-8")
for header in headers_names])
for row in nodes_generator:
Expand Down
49 changes: 26 additions & 23 deletions bigml/multimodel.py
Expand Up @@ -40,14 +40,13 @@
LOGGER = logging.getLogger('BigML')


import csv
import ast
from bigml.model import Model
from bigml.model import LAST_PREDICTION
from bigml.util import get_predictions_file_name
from bigml.multivote import MultiVote
from bigml.multivote import PLURALITY_CODE

from bigml.io import UnicodeWriter, UnicodeReader

def read_votes(votes_files, to_prediction, data_locale=None):
"""Reads the votes found in the votes' files.
Expand All @@ -70,23 +69,24 @@ def read_votes(votes_files, to_prediction, data_locale=None):
for order in range(0, len(votes_files)):
votes_file = votes_files[order]
index = 0
for row in csv.reader(open(votes_file, "U"), lineterminator="\n"):
prediction = to_prediction(row[0], data_locale=data_locale)
if index > (len(votes) - 1):
votes.append(MultiVote([]))
distribution = None
instances = None
if len(row) > 2:
distribution = ast.literal_eval(row[2])
instances = int(row[3])
try:
confidence = float(row[1])
except ValueError:
confidence = 0
prediction_row = [prediction, confidence, order,
distribution, instances]
votes[index].append_row(prediction_row)
index += 1
with UnicodeReader(votes_file) as rdr:
for row in rdr:
prediction = to_prediction(row[0], data_locale=data_locale)
if index > (len(votes) - 1):
votes.append(MultiVote([]))
distribution = None
instances = None
if len(row) > 2:
distribution = ast.literal_eval(row[2])
instances = int(row[3])
try:
confidence = float(row[1])
except ValueError:
confidence = 0
prediction_row = [prediction, confidence, order,
distribution, instances]
votes[index].append_row(prediction_row)
index += 1
return votes


Expand Down Expand Up @@ -192,6 +192,7 @@ def batch_predict(self, input_data_list, output_file_path=None,

for model in self.models:
order += 1
out = None
if to_file:
output_file = get_predictions_file_name(model.resource_id,
output_file_path)
Expand All @@ -203,12 +204,13 @@ def batch_predict(self, input_data_list, output_file_path=None,
except IOError:
pass
try:
predictions_file = csv.writer(open(output_file, 'w', 0),
lineterminator="\n")
out = UnicodeWriter(output_file)
except IOError:
raise Exception("Cannot find %s directory." %
output_file_path)

if out:
out.open_writer()
for index, input_data in enumerate(input_data_list):
if add_headers:
input_data = dict(zip(headers, input_data))
Expand All @@ -219,15 +221,16 @@ def batch_predict(self, input_data_list, output_file_path=None,
if to_file:
if isinstance(prediction[0], basestring):
prediction[0] = prediction[0].encode("utf-8")
predictions_file.writerow(prediction)
out.writerow(prediction)
else:
prediction, confidence, distribution, instances = prediction
prediction_row = [prediction, confidence, order,
distribution, instances]
if len(votes) <= index:
votes.append(MultiVote([]))
votes[index].append_row(prediction_row)

if out:
out.close_writer()
if not to_file:
return votes

Expand Down
File renamed without changes.
File renamed without changes.
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
#!/usr/bin/env python
#
# Copyright 2012 BigML
# Copyright 2012, 2015 BigML
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
Expand All @@ -19,7 +19,7 @@
import json
import os
from datetime import datetime, timedelta
from world import world
from world import world, res_filename

from bigml.api import HTTP_CREATED
from bigml.api import HTTP_ACCEPTED
Expand All @@ -30,6 +30,7 @@

#@step(r'I create a MultiVote for the set of predictions in file (.*)$')
def i_create_a_multivote(step, predictions_file):
predictions_file = res_filename(predictions_file)
try:
with open(predictions_file, 'r') as predictions_file:
world.multivote = MultiVote(json.load(predictions_file))
Expand Down
File renamed without changes.
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
#!/usr/bin/env python
#
# Copyright 2012 BigML
# Copyright 2012, 2015 BigML
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
Expand All @@ -19,13 +19,15 @@
import json
import requests
import csv
import traceback
from datetime import datetime, timedelta
from world import world
from world import world, res_filename

from bigml.api import HTTP_CREATED
from bigml.api import FINISHED
from bigml.api import FAULTY
from bigml.api import get_status
from bigml.io import UnicodeReader

from read_batch_prediction_steps import (i_get_the_batch_prediction,
i_get_the_batch_centroid, i_get_the_batch_anomaly_score)
Expand Down Expand Up @@ -113,49 +115,44 @@ def the_batch_anomaly_score_is_finished_in_less_than(step, secs):
#@step(r'I download the created predictions file to "(.*)"')
def i_download_predictions_file(step, filename):
file_object = world.api.download_batch_prediction(
world.batch_prediction, filename=filename)
world.batch_prediction, filename=res_filename(filename))
assert file_object is not None
world.output = file_object

#@step(r'I download the created centroid file to "(.*)"')
def i_download_centroid_file(step, filename):
file_object = world.api.download_batch_centroid(
world.batch_centroid, filename=filename)
world.batch_centroid, filename=res_filename(filename))
assert file_object is not None
world.output = file_object

#@step(r'I download the created anomaly score file to "(.*)"')
def i_download_anomaly_score_file(step, filename):
file_object = world.api.download_batch_anomaly_score(
world.batch_anomaly_score, filename=filename)
world.batch_anomaly_score, filename=res_filename(filename))
assert file_object is not None
world.output = file_object

def check_rows(prediction_rows, test_rows):
for row in prediction_rows:
check_row = next(test_rows)
assert len(check_row) == len (row)
for index in range(len(row)):
dot = row[index].find(".")
if dot > 0:
try:
decs = min(len(row[index]), len(check_row[index])) - dot - 1
row[index] = round(float(row[index]), decs)
check_row[index] = round(float(check_row[index]), decs)
except ValueError:
pass
assert check_row[index] == row[index], ("%s/%s" % (row, check_row))

#@step(r'the batch prediction file is like "(.*)"')
def i_check_predictions(step, check_file):
predictions_file = world.output
try:
predictions_file = csv.reader(open(predictions_file, "U"), lineterminator="\n")
check_file = csv.reader(open(check_file, "U"), lineterminator="\n")
for row in predictions_file:
check_row = check_file.next()
if len(check_row) != len(row):
assert False
for index in range(len(row)):
dot = row[index].find(".")
if dot > 0:
try:
decimal_places = min(len(row[index]), len(check_row[index])) - dot - 1
row[index] = round(float(row[index]), decimal_places)
check_row[index] = round(float(check_row[index]), decimal_places)
except ValueError:
pass
if check_row[index] != row[index]:
print row, check_row
assert False
assert True
except Exception, exc:
assert False, str(exc)
with UnicodeReader(world.output) as prediction_rows:
with UnicodeReader(res_filename(check_file)) as test_rows:
check_rows(prediction_rows, test_rows)

#@step(r'the batch centroid file is like "(.*)"')
def i_check_batch_centroid(step, check_file):
Expand Down
File renamed without changes.

0 comments on commit 018fba8

Please sign in to comment.