Skip to content

Commit

Permalink
Merge pull request #19 from PrimozGodec/gradient-descent
Browse files Browse the repository at this point in the history
Gradient descent small updates and fixes
  • Loading branch information
ajdapretnar committed Sep 6, 2016
2 parents 3c51358 + 32e53d0 commit 3f82e7e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
61 changes: 46 additions & 15 deletions orangecontrib/educational/widgets/owgradientdescent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import operator

from Orange.canvas import report
from os import path
import time
Expand All @@ -17,6 +19,8 @@
from Orange.preprocess.preprocess import Normalize
from scipy.interpolate import splprep, splev

from orangecontrib.educational.widgets.utils.color_transform import (
rgb_to_hex, hex_to_rgb)
from orangecontrib.educational.widgets.utils.linear_regression import \
LinearRegression
from orangecontrib.educational.widgets.utils.logistic_regression \
Expand Down Expand Up @@ -155,7 +159,7 @@ class OWGradientDescent(OWWidget):
("Coefficients", Table),
("Data", Table)]

graph_name = "Gradient descent graph"
graph_name = "scatter"

# selected attributes in chart
attr_x = settings.Setting('')
Expand All @@ -176,10 +180,13 @@ class OWGradientDescent(OWWidget):
cost_grid = None
grid_size = 15
contour_color = "#aaaaaa"
default_background_color = "#00BFFF"
line_colors = ["#00BFFF", "#ff0000", "#33cc33"]
min_x = None
max_x = None
min_y = None
max_y = None
current_gradient_color = None

# data
data = None
Expand Down Expand Up @@ -212,7 +219,8 @@ def __init__(self):

# info box
self.info_box = gui.widgetBox(self.controlArea, "Info")
self.learner_label = gui.label(widget=self.info_box, master=self, label="")
self.learner_label = gui.label(
widget=self.info_box, master=self, label="")

# options box
policy = QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
Expand Down Expand Up @@ -284,10 +292,10 @@ def __init__(self):
yAxis_gridLineWidth=0,
title_text='',
tooltip_shared=False,
debug=True,
debug=False,
legend_symbolWidth=0,
legend_symbolHeight=0)
# TODO: set false when end of development

gui.rubber(self.controlArea)

# Just render an empty chart so it shows a nice 'No data to display'
Expand Down Expand Up @@ -481,7 +489,7 @@ def change_theta(self, x, y):
self.learner.j(np.array([x, y]))))],
showInLegend=False,
type="scatter", lineWidth=1,
color="#ff0000",
color=self.line_color(),
marker=dict(
enabled=True, radius=2),
tooltip=dict(
Expand Down Expand Up @@ -545,7 +553,8 @@ def plot_last_point(self, x, y):
dict(
x=x, y=y, dataLabels=dict(
enabled=True,
format=' {0:.2f} '.format(self.learner.j(np.array([x, y]))),
format=' {0:.2f} '.format(
self.learner.j(np.array([x, y]))),
useHTML=True,
verticalAlign='middle',
align="left" if self.label_right() else "right",
Expand All @@ -559,6 +568,20 @@ def label_right(self):
l = self.learner
return l.step_no == 0 or l.history[l.step_no - 1][0][0] < l.theta[0]

def gradient_color(self):
if not self.is_logistic:
return self.default_background_color
else:
target_class_idx = self.data.domain.class_var.values.\
index(self.target_class)
color = self.data.domain.class_var.colors[target_class_idx]
return rgb_to_hex(tuple(color))

def line_color(self):
rgb_tuple = hex_to_rgb(self.current_gradient_color)
max_index, _ = max(enumerate(rgb_tuple), key=operator.itemgetter(1))
return self.line_colors[max_index]

def replot(self):
"""
This function performs complete replot of the graph
Expand All @@ -579,13 +602,16 @@ def replot(self):
options['series'] += self.plot_gradient_and_contour(
self.min_x, self.max_x, self.min_y, self.max_y)

# select gradient color
self.current_gradient_color = self.gradient_color()

# highcharts parameters
kwargs = dict(
xAxis_title_text="<p>&theta;<sub>{attr}</sub></p>"
.format(attr=self.attr_x if self.is_logistic else 0),
xAxis_title_text="<p>&theta;<sub>{attr}</sub></p>".format(
attr=self.attr_x if self.is_logistic else 0),
xAxis_title_useHTML=True,
yAxis_title_text="&theta;<sub>{attr}</sub>".
format(attr=self.attr_y if self.is_logistic else self.attr_x),
yAxis_title_text="&theta;<sub>{attr}</sub>".format(
attr=self.attr_y if self.is_logistic else self.attr_x),
yAxis_title_useHTML=True,
xAxis_min=self.min_x,
xAxis_max=self.max_x,
Expand All @@ -596,7 +622,7 @@ def replot(self):
yAxis_startOnTick=False,
yAxis_endOnTick=False,
colorAxis=dict(
minColor="#ffffff", maxColor="#00BFFF",
minColor="#ffffff", maxColor=self.current_gradient_color,
endOnTick=False, startOnTick=False),
plotOptions_contour_colsize=(self.max_y - self.min_y) / 1000,
plotOptions_contour_rowsize=(self.max_x - self.min_x) / 1000,
Expand Down Expand Up @@ -837,8 +863,13 @@ def is_logistic(self):
def send_report(self):
if self.data is None:
return
caption = report.render_items_vert((
("Stochastic", str(self.stochastic)),
))
caption_items = (
("Target class", self.target_class),
("Learning rate", self.alpha),
("Stochastic", str(self.stochastic))
)
if self.stochastic:
caption_items += (("Stochastic step size", self.step_size),)
caption = report.render_items_vert(caption_items)
self.report_plot(self.scatter)
self.report_caption(caption)
self.report_caption(caption)
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,11 @@ def test_send_report(self):
w.report_button.click()

# when no data
# when everything fine
self.send_signal("Data", None)

w.report_button.click()

# for stochastic
self.send_signal("Data", self.iris)
w.stochastic_checkbox.click()
w.report_button.click()

0 comments on commit 3f82e7e

Please sign in to comment.