Skip to content

Commit

Permalink
Merge pull request #216 from janezd/correlogram-pyqtgraph
Browse files Browse the repository at this point in the history
Correlogram, Periodogram: Reimplementation in pyqtgraph
  • Loading branch information
ajdapretnar committed Jul 22, 2022
2 parents 27c3b74 + 28ff7f2 commit adb33ef
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 227 deletions.
2 changes: 2 additions & 0 deletions orangecontrib/timeseries/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def periodogram_nonequispaced(times, x, *, freqs=None,
else:
periods = 2 * np.pi / freqs

if times.base is not None:
times = times.copy() # lombscargle is Pythonized and doesn't like views
pgram = lombscargle(times, x, freqs)
# Normalize -- I have no idea what I am doing; took this from
# https://jakevdp.github.io/blog/2015/06/13/lomb-scargle-in-python/#lomb-scargle-algorithms-in-python
Expand Down
161 changes: 50 additions & 111 deletions orangecontrib/timeseries/widgets/owcorrelogram.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,77 @@
from Orange.data import Table, ContinuousVariable
from Orange.widgets import widget, gui, settings
from Orange.widgets.utils.colorpalette import ColorPaletteGenerator
from Orange.widgets.widget import Input
import numpy as np

from AnyQt.QtCore import Qt
import pyqtgraph as pg

from orangecontrib.timeseries.widgets.owperiodbase import OWPeriodBase
from orangewidget.settings import Setting
from orangewidget.utils.widgetpreview import WidgetPreview

from Orange.widgets import gui

from orangecontrib.timeseries import (
Timeseries, autocorrelation, partial_autocorrelation)
from orangecontrib.timeseries.widgets.highcharts import Highchart

from AnyQt.QtWidgets import QListWidget


class OWCorrelogram(widget.OWWidget):
class OWCorrelogram(OWPeriodBase):
# TODO: allow computing cross-correlation of two distinct series
name = 'Correlogram'
description = "Visualize variables' auto-correlation."
icon = 'icons/Correlogram.svg'
priority = 110

class Inputs:
time_series = Input("Time series", Table)

attrs = settings.Setting([])
use_pacf = settings.Setting(False)
use_confint = settings.Setting(True)

graph_name = 'plot'
use_pacf = Setting(False)
use_confint = Setting(True)

class Error(widget.OWWidget.Error):
no_instances = widget.Msg("At least 2 data instances are required")
yrange = (-1, 1)

def __init__(self):
self.all_attrs = []
opts = gui.widgetBox(self.controlArea, 'Options')
gui.checkBox(opts, self, 'use_pacf',
label='Compute partial auto-correlation (PACF)',
callback=self.on_changed)
gui.checkBox(opts, self, 'use_confint',
super().__init__()
gui.separator(self.controlArea)
gui.checkBox(self.controlArea, self, 'use_pacf',
label='Compute partial auto-correlation',
callback=self.replot)
gui.checkBox(self.controlArea, self, 'use_confint',
label='Plot 95% significance interval',
callback=self.on_changed)
gui.listBox(self.controlArea, self, 'attrs',
labels='all_attrs',
box='Auto-correlated attribute(s)',
selectionMode=QListWidget.ExtendedSelection,
callback=self.on_changed)
plot = self.plot = Highchart(
self,
chart_zoomType='x',
plotOptions_line_marker_enabled=False,
plotOptions_column_borderWidth=0,
plotOptions_column_groupPadding=0,
plotOptions_series_pointWidth=3,
yAxis_min=-1.0,
yAxis_max=1.0,
xAxis_min=0,
xAxis_gridLineWidth=1,
yAxis_plotLines=[dict(value=0, color='#000', width=1, zIndex=2)],
yAxis_title_text='',
xAxis_title_text='period',
tooltip_headerFormat='Correlation at period: {point.key:.2f}<br/>',
tooltip_pointFormat='<span style="color:{point.color}">\u25CF</span> {point.y:.2f}<br/>',
)
self.mainArea.layout().addWidget(plot)
callback=self.replot)

def acf(self, attr, pacf, confint):
x = self.data.interp(attr).ravel()
func = partial_autocorrelation if pacf else autocorrelation
return func(x, alpha=.05 if confint else None)

@Inputs.time_series
def set_data(self, data):
self.Error.no_instances.clear()
self.data = data = None if data is None else \
Timeseries.from_data_table(data)
self.all_attrs = []
if data is None:
self.plot.clear()
return
if len(data) < 2:
self.Error.no_instances()
self.plot.clear()
return
self.all_attrs = [(var.name, gui.attributeIconDict[var])
for var in data.domain.variables
if (var is not data.time_variable and
isinstance(var, ContinuousVariable))]
self.attrs = [0]
self.on_changed()

def on_changed(self):
if not self.attrs or not self.all_attrs:
if attr not in self._cached:
x = self.data.interp(attr).ravel()
func = partial_autocorrelation if pacf else autocorrelation
self._cached[attr] = func(x, alpha=.05 if confint else None)
return self._cached[attr]

def replot(self):
self.plot.clear()
if not self.selection:
return

series = []
options = dict(series=series)
plotlines = []
for i, (attr, color) in enumerate(zip(self.attrs,
ColorPaletteGenerator(len(self.all_attrs))[self.attrs])):
attr_name = self.all_attrs[attr][0]
pac = self.acf(attr_name, self.use_pacf, False)
self.plot_widget.addItem(pg.InfiniteLine(0, 0, pen=pg.mkPen(0., width=2)))

palette = self.get_palette()
for i, attr in enumerate(self.selection):
color = palette.value_to_qcolor(i)
x, acf = np.array(self.acf(attr, self.use_pacf, False)).T
x = np.repeat(x, 2)
y = np.vstack((np.zeros(len(acf)), acf)).T.flatten()
item = pg.PlotCurveItem(
x=x, y=y, connect="pairs", antialias=True,
pen=pg.mkPen(color, width=5))
self.plot_widget.addItem(item)

if self.use_confint:
# Confidence intervals, from:
# https://www.mathworks.com/help/econ/autocorrelation-and-partial-autocorrelation.html
# https://www.mathworks.com/help/signal/ug/confidence-intervals-for-sample-autocorrelation.html
std = 1.96 * ((1 + 2 * (pac[:, 1]**2).sum()) / len(self.data))**.5 # = more precise than 1.96/sqrt(N)
color = '/**/ Highcharts.getOptions().colors[{}] /**/'.format(i)
line = dict(color=color, width=1.5, dashStyle='dash')
plotlines.append(dict(line, value=std))
plotlines.append(dict(line, value=-std))

series.append(dict(
# TODO: set units to something more readable than #periods (e.g. days)
data=pac,
type='column',
name=attr_name,
zIndex=2,
))

# TODO: give periods meaning (datetime names)
plotlines.append(dict(value=0, color='black', width=2, zIndex=3))
if series:
self.plot.chart(options, yAxis_plotLines=plotlines, xAxis_type='linear')
else:
self.plot.clear()
se = np.sqrt((1 + 2 * (acf ** 2).sum()) / len(self.data))
std = 1.96 * se
pen = pg.mkPen(color, width=2, style=Qt.DashLine)
self.plot_widget.addItem(pg.InfiniteLine(std, 0, pen=pen))
self.plot_widget.addItem(pg.InfiniteLine(-std, 0, pen=pen))


if __name__ == "__main__":
from AnyQt.QtWidgets import QApplication

a = QApplication([])
ow = OWCorrelogram()

data = Timeseries.from_file('airpassengers')
ow.set_data(data)

ow.show()
a.exec()
WidgetPreview(OWCorrelogram).run(
Timeseries.from_file("airpassengers")
)
130 changes: 130 additions & 0 deletions orangecontrib/timeseries/widgets/owperiodbase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from html import escape
from typing import List

from AnyQt.QtCore import QItemSelectionModel, QTimer, QItemSelection
from AnyQt.QtWidgets import QListView

import pyqtgraph as pg

from orangecontrib.timeseries import Timeseries
from orangewidget.settings import Setting

from Orange.data import Table, ContinuousVariable
from Orange.widgets.utils.itemmodels import VariableListModel
from Orange.widgets.widget import OWWidget, Input, Msg
from Orange.widgets.utils.colorpalettes import DefaultDiscretePalette, Glasbey
from Orange.widgets.visualize.owdistributions import LegendItem


class OWPeriodBase(OWWidget, openclass=True):
class Inputs:
time_series = Input("Time series", Table)

# Selected attributes are stored as strings. They are always continuous,
# so there's nothing to match, and storing as Variable would require
# context handler.
selection: List[str] = Setting([], schema_only=True)

graph_name = 'plot'

class Error(OWWidget.Error):
no_instances = Msg("Data contains just a single instance")
no_variables = Msg("Data doesn't contain any numeric variables")

def __init__(self):
self.data = None
self.model = VariableListModel()
self._cached = {}
self.persistent_selection = self.selection

listbox = QListView(self)
listbox.setModel(self.model)
self.controlArea.layout().addWidget(listbox)

self.selectionModel = listbox.selectionModel()
self.selectionModel.selectionChanged.connect(self._selection_changed)
listbox.setSelectionModel(self.selectionModel)
listbox.setSelectionMode(QListView.ExtendedSelection)

self.plot_widget = pg.PlotWidget(background="w")
self.plot = self.plot_widget.getPlotItem()
self.plot.showGrid(x=False, y=True)
self.plot.setYRange(*self.yrange)
self.plot.buttonsHidden = False
self.plot.vb.setMouseEnabled(x=True, y=False)
self.mainArea.layout().addWidget(self.plot_widget)
self.plot.sigYRangeChanged.connect(self._rescale_y)

self.legend = self._create_legend(((1, 0), (1, 0)))

def _create_legend(self, anchor):
legend = LegendItem()
legend.setLabelTextSize("12pt")
legend.setParentItem(self.plot.vb)
legend.restoreAnchor(anchor)
legend.hide()
return legend

def update_legend(self):
self.legend.clear()
if not self.selection:
self.legend.hide()
return

for name, color in zip(self.selection, self.get_palette()):
dot = pg.ScatterPlotItem(pen=color, brush=color, size=10, shape="s")
self.legend.addItem(dot, escape(name))
self.legend.show()


def _rescale_y(self):
QTimer.singleShot(1, lambda: self.plot.setYRange(*self.yrange))

@Inputs.time_series
def set_data(self, data):
self.plot.clear()
self._cached.clear()
self.Error.clear()

if self.selection:
self.persistent_selection = self.selection[:]

if not data or len(data) < 2:
self.Error.no_instances(shown=bool(data))
self.data = None
self.model.clear()
return

self.data = Timeseries.from_data_table(data)
self.model[:] = [
var for var in self.data.domain.variables
if isinstance(var, ContinuousVariable)
and var is not self.data.time_variable]
if not self.model:
self.Error.no_variables()
self.data = None
return

item_selection = QItemSelection()
names = [attr.name for attr in self.model]
selection = [
names.index(name)
for name in self.persistent_selection
if name in names]
for idx in selection or [0]:
index = self.model.index(idx)
item_selection.select(index, index)
self.selectionModel.select(item_selection,
QItemSelectionModel.ClearAndSelect)

def _selection_changed(self):
self.selection = [
self.model.data(index)
for index in self.selectionModel.selectedIndexes()]
self.update_legend()
self.replot()

def get_palette(self):
if len(self.selection) > len(DefaultDiscretePalette):
return Glasbey
return DefaultDiscretePalette

0 comments on commit adb33ef

Please sign in to comment.