-
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #216 from janezd/correlogram-pyqtgraph
Correlogram, Periodogram: Reimplementation in pyqtgraph
- Loading branch information
Showing
6 changed files
with
304 additions
and
227 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,138 +1,77 @@ | ||
from Orange.data import Table, ContinuousVariable | ||
from Orange.widgets import widget, gui, settings | ||
from Orange.widgets.utils.colorpalette import ColorPaletteGenerator | ||
from Orange.widgets.widget import Input | ||
import numpy as np | ||
|
||
from AnyQt.QtCore import Qt | ||
import pyqtgraph as pg | ||
|
||
from orangecontrib.timeseries.widgets.owperiodbase import OWPeriodBase | ||
from orangewidget.settings import Setting | ||
from orangewidget.utils.widgetpreview import WidgetPreview | ||
|
||
from Orange.widgets import gui | ||
|
||
from orangecontrib.timeseries import ( | ||
Timeseries, autocorrelation, partial_autocorrelation) | ||
from orangecontrib.timeseries.widgets.highcharts import Highchart | ||
|
||
from AnyQt.QtWidgets import QListWidget | ||
|
||
|
||
class OWCorrelogram(widget.OWWidget): | ||
class OWCorrelogram(OWPeriodBase): | ||
# TODO: allow computing cross-correlation of two distinct series | ||
name = 'Correlogram' | ||
description = "Visualize variables' auto-correlation." | ||
icon = 'icons/Correlogram.svg' | ||
priority = 110 | ||
|
||
class Inputs: | ||
time_series = Input("Time series", Table) | ||
|
||
attrs = settings.Setting([]) | ||
use_pacf = settings.Setting(False) | ||
use_confint = settings.Setting(True) | ||
|
||
graph_name = 'plot' | ||
use_pacf = Setting(False) | ||
use_confint = Setting(True) | ||
|
||
class Error(widget.OWWidget.Error): | ||
no_instances = widget.Msg("At least 2 data instances are required") | ||
yrange = (-1, 1) | ||
|
||
def __init__(self): | ||
self.all_attrs = [] | ||
opts = gui.widgetBox(self.controlArea, 'Options') | ||
gui.checkBox(opts, self, 'use_pacf', | ||
label='Compute partial auto-correlation (PACF)', | ||
callback=self.on_changed) | ||
gui.checkBox(opts, self, 'use_confint', | ||
super().__init__() | ||
gui.separator(self.controlArea) | ||
gui.checkBox(self.controlArea, self, 'use_pacf', | ||
label='Compute partial auto-correlation', | ||
callback=self.replot) | ||
gui.checkBox(self.controlArea, self, 'use_confint', | ||
label='Plot 95% significance interval', | ||
callback=self.on_changed) | ||
gui.listBox(self.controlArea, self, 'attrs', | ||
labels='all_attrs', | ||
box='Auto-correlated attribute(s)', | ||
selectionMode=QListWidget.ExtendedSelection, | ||
callback=self.on_changed) | ||
plot = self.plot = Highchart( | ||
self, | ||
chart_zoomType='x', | ||
plotOptions_line_marker_enabled=False, | ||
plotOptions_column_borderWidth=0, | ||
plotOptions_column_groupPadding=0, | ||
plotOptions_series_pointWidth=3, | ||
yAxis_min=-1.0, | ||
yAxis_max=1.0, | ||
xAxis_min=0, | ||
xAxis_gridLineWidth=1, | ||
yAxis_plotLines=[dict(value=0, color='#000', width=1, zIndex=2)], | ||
yAxis_title_text='', | ||
xAxis_title_text='period', | ||
tooltip_headerFormat='Correlation at period: {point.key:.2f}<br/>', | ||
tooltip_pointFormat='<span style="color:{point.color}">\u25CF</span> {point.y:.2f}<br/>', | ||
) | ||
self.mainArea.layout().addWidget(plot) | ||
callback=self.replot) | ||
|
||
def acf(self, attr, pacf, confint): | ||
x = self.data.interp(attr).ravel() | ||
func = partial_autocorrelation if pacf else autocorrelation | ||
return func(x, alpha=.05 if confint else None) | ||
|
||
@Inputs.time_series | ||
def set_data(self, data): | ||
self.Error.no_instances.clear() | ||
self.data = data = None if data is None else \ | ||
Timeseries.from_data_table(data) | ||
self.all_attrs = [] | ||
if data is None: | ||
self.plot.clear() | ||
return | ||
if len(data) < 2: | ||
self.Error.no_instances() | ||
self.plot.clear() | ||
return | ||
self.all_attrs = [(var.name, gui.attributeIconDict[var]) | ||
for var in data.domain.variables | ||
if (var is not data.time_variable and | ||
isinstance(var, ContinuousVariable))] | ||
self.attrs = [0] | ||
self.on_changed() | ||
|
||
def on_changed(self): | ||
if not self.attrs or not self.all_attrs: | ||
if attr not in self._cached: | ||
x = self.data.interp(attr).ravel() | ||
func = partial_autocorrelation if pacf else autocorrelation | ||
self._cached[attr] = func(x, alpha=.05 if confint else None) | ||
return self._cached[attr] | ||
|
||
def replot(self): | ||
self.plot.clear() | ||
if not self.selection: | ||
return | ||
|
||
series = [] | ||
options = dict(series=series) | ||
plotlines = [] | ||
for i, (attr, color) in enumerate(zip(self.attrs, | ||
ColorPaletteGenerator(len(self.all_attrs))[self.attrs])): | ||
attr_name = self.all_attrs[attr][0] | ||
pac = self.acf(attr_name, self.use_pacf, False) | ||
self.plot_widget.addItem(pg.InfiniteLine(0, 0, pen=pg.mkPen(0., width=2))) | ||
|
||
palette = self.get_palette() | ||
for i, attr in enumerate(self.selection): | ||
color = palette.value_to_qcolor(i) | ||
x, acf = np.array(self.acf(attr, self.use_pacf, False)).T | ||
x = np.repeat(x, 2) | ||
y = np.vstack((np.zeros(len(acf)), acf)).T.flatten() | ||
item = pg.PlotCurveItem( | ||
x=x, y=y, connect="pairs", antialias=True, | ||
pen=pg.mkPen(color, width=5)) | ||
self.plot_widget.addItem(item) | ||
|
||
if self.use_confint: | ||
# Confidence intervals, from: | ||
# https://www.mathworks.com/help/econ/autocorrelation-and-partial-autocorrelation.html | ||
# https://www.mathworks.com/help/signal/ug/confidence-intervals-for-sample-autocorrelation.html | ||
std = 1.96 * ((1 + 2 * (pac[:, 1]**2).sum()) / len(self.data))**.5 # = more precise than 1.96/sqrt(N) | ||
color = '/**/ Highcharts.getOptions().colors[{}] /**/'.format(i) | ||
line = dict(color=color, width=1.5, dashStyle='dash') | ||
plotlines.append(dict(line, value=std)) | ||
plotlines.append(dict(line, value=-std)) | ||
|
||
series.append(dict( | ||
# TODO: set units to something more readable than #periods (e.g. days) | ||
data=pac, | ||
type='column', | ||
name=attr_name, | ||
zIndex=2, | ||
)) | ||
|
||
# TODO: give periods meaning (datetime names) | ||
plotlines.append(dict(value=0, color='black', width=2, zIndex=3)) | ||
if series: | ||
self.plot.chart(options, yAxis_plotLines=plotlines, xAxis_type='linear') | ||
else: | ||
self.plot.clear() | ||
se = np.sqrt((1 + 2 * (acf ** 2).sum()) / len(self.data)) | ||
std = 1.96 * se | ||
pen = pg.mkPen(color, width=2, style=Qt.DashLine) | ||
self.plot_widget.addItem(pg.InfiniteLine(std, 0, pen=pen)) | ||
self.plot_widget.addItem(pg.InfiniteLine(-std, 0, pen=pen)) | ||
|
||
|
||
if __name__ == "__main__": | ||
from AnyQt.QtWidgets import QApplication | ||
|
||
a = QApplication([]) | ||
ow = OWCorrelogram() | ||
|
||
data = Timeseries.from_file('airpassengers') | ||
ow.set_data(data) | ||
|
||
ow.show() | ||
a.exec() | ||
WidgetPreview(OWCorrelogram).run( | ||
Timeseries.from_file("airpassengers") | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from html import escape | ||
from typing import List | ||
|
||
from AnyQt.QtCore import QItemSelectionModel, QTimer, QItemSelection | ||
from AnyQt.QtWidgets import QListView | ||
|
||
import pyqtgraph as pg | ||
|
||
from orangecontrib.timeseries import Timeseries | ||
from orangewidget.settings import Setting | ||
|
||
from Orange.data import Table, ContinuousVariable | ||
from Orange.widgets.utils.itemmodels import VariableListModel | ||
from Orange.widgets.widget import OWWidget, Input, Msg | ||
from Orange.widgets.utils.colorpalettes import DefaultDiscretePalette, Glasbey | ||
from Orange.widgets.visualize.owdistributions import LegendItem | ||
|
||
|
||
class OWPeriodBase(OWWidget, openclass=True): | ||
class Inputs: | ||
time_series = Input("Time series", Table) | ||
|
||
# Selected attributes are stored as strings. They are always continuous, | ||
# so there's nothing to match, and storing as Variable would require | ||
# context handler. | ||
selection: List[str] = Setting([], schema_only=True) | ||
|
||
graph_name = 'plot' | ||
|
||
class Error(OWWidget.Error): | ||
no_instances = Msg("Data contains just a single instance") | ||
no_variables = Msg("Data doesn't contain any numeric variables") | ||
|
||
def __init__(self): | ||
self.data = None | ||
self.model = VariableListModel() | ||
self._cached = {} | ||
self.persistent_selection = self.selection | ||
|
||
listbox = QListView(self) | ||
listbox.setModel(self.model) | ||
self.controlArea.layout().addWidget(listbox) | ||
|
||
self.selectionModel = listbox.selectionModel() | ||
self.selectionModel.selectionChanged.connect(self._selection_changed) | ||
listbox.setSelectionModel(self.selectionModel) | ||
listbox.setSelectionMode(QListView.ExtendedSelection) | ||
|
||
self.plot_widget = pg.PlotWidget(background="w") | ||
self.plot = self.plot_widget.getPlotItem() | ||
self.plot.showGrid(x=False, y=True) | ||
self.plot.setYRange(*self.yrange) | ||
self.plot.buttonsHidden = False | ||
self.plot.vb.setMouseEnabled(x=True, y=False) | ||
self.mainArea.layout().addWidget(self.plot_widget) | ||
self.plot.sigYRangeChanged.connect(self._rescale_y) | ||
|
||
self.legend = self._create_legend(((1, 0), (1, 0))) | ||
|
||
def _create_legend(self, anchor): | ||
legend = LegendItem() | ||
legend.setLabelTextSize("12pt") | ||
legend.setParentItem(self.plot.vb) | ||
legend.restoreAnchor(anchor) | ||
legend.hide() | ||
return legend | ||
|
||
def update_legend(self): | ||
self.legend.clear() | ||
if not self.selection: | ||
self.legend.hide() | ||
return | ||
|
||
for name, color in zip(self.selection, self.get_palette()): | ||
dot = pg.ScatterPlotItem(pen=color, brush=color, size=10, shape="s") | ||
self.legend.addItem(dot, escape(name)) | ||
self.legend.show() | ||
|
||
|
||
def _rescale_y(self): | ||
QTimer.singleShot(1, lambda: self.plot.setYRange(*self.yrange)) | ||
|
||
@Inputs.time_series | ||
def set_data(self, data): | ||
self.plot.clear() | ||
self._cached.clear() | ||
self.Error.clear() | ||
|
||
if self.selection: | ||
self.persistent_selection = self.selection[:] | ||
|
||
if not data or len(data) < 2: | ||
self.Error.no_instances(shown=bool(data)) | ||
self.data = None | ||
self.model.clear() | ||
return | ||
|
||
self.data = Timeseries.from_data_table(data) | ||
self.model[:] = [ | ||
var for var in self.data.domain.variables | ||
if isinstance(var, ContinuousVariable) | ||
and var is not self.data.time_variable] | ||
if not self.model: | ||
self.Error.no_variables() | ||
self.data = None | ||
return | ||
|
||
item_selection = QItemSelection() | ||
names = [attr.name for attr in self.model] | ||
selection = [ | ||
names.index(name) | ||
for name in self.persistent_selection | ||
if name in names] | ||
for idx in selection or [0]: | ||
index = self.model.index(idx) | ||
item_selection.select(index, index) | ||
self.selectionModel.select(item_selection, | ||
QItemSelectionModel.ClearAndSelect) | ||
|
||
def _selection_changed(self): | ||
self.selection = [ | ||
self.model.data(index) | ||
for index in self.selectionModel.selectedIndexes()] | ||
self.update_legend() | ||
self.replot() | ||
|
||
def get_palette(self): | ||
if len(self.selection) > len(DefaultDiscretePalette): | ||
return Glasbey | ||
return DefaultDiscretePalette |
Oops, something went wrong.