# D3 integration in Jupyter notebook

### Goal: 

The main goal is to integrate D3 in jupyter notebook and achieve two way communication between front end javascript and backend python notebook. This will allow us to get the data from user interaction.


### Libraries:

The main component we will be using is ipywidgets. You can learn more about this at http://ipywidgets.readthedocs.io/en/latest/.
Using widgets we can build interactive visualizations in notebooks.
 
Apart from ipywidgets, othe libraries we will be using are pandas for data manupulation, Ipython.display to display the widget.

In [1]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from traitlets import Unicode, validate, List
import pandas as pd

We will be following the example on building a custom widget from http://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Custom.html

#### Import d3 using require

In [2]:
%%javascript
require.config({
    paths: {
        d3: 'https://d3js.org/d3.v4.min'
    }
});

<IPython.core.display.Javascript object>

## Example 1: Bar Chart - Selection
We will be using the soccer player data set for all of our examples.

The data contains top 50 players and attributes like `overall_score`, `crossing`, `finishing` etc. over the years 2007-16. We will mostly use the overall_score attribute.

In the first chart, we will show top 5 players with maximum overall score averaged over the years as a bar chart. If user clicks on any one bar chart, then we will show the progress of the player over the years.

In [3]:
players_data = pd.read_csv(open("player_data.csv"))
data_average = players_data.groupby('player_name')['overall_rating'].mean().reset_index(name='average_score')
data_average_top5 = data_average.head(5)
data_average_top5

Unnamed: 0,player_name,average_score
0,Alexis Sanchez,80.9625
1,Andres Iniesta,87.638889
2,Angel Di Maria,81.95
3,Arjen Robben,87.52
4,Bastian Schweinsteiger,84.633333


First, create a back-end python class for our widget. 
In the following class, `value` variable is used for storing the data we derived in the above cell. We can access this variable data in front-end javascript.
`player_name` variable is used for storing the player name that is selected by the user in front-end. User selects the player by clicking on the bar belonging to the player.

In [4]:
class BarWidget(widgets.DOMWidget):
    _view_name = Unicode('BarView').tag(sync=True)
    _view_module = Unicode('barChart').tag(sync=True)
    _view_module_version = Unicode('0.1.0').tag(sync=True)
    value = List([]).tag(sync=True)
    player_name = Unicode('').tag(sync=True)

Next, we will create the front-end view. To accomplish this, the widget framework uses Backbone.js. Any single WidgetView is bound to a single cell. Multiple WidgetViews can be linked to a single WidgetModel. The view name should be same as that of back-end class. This helps the widget framework to link the view with corresponding model.

We can access the above created model with `this.model` in javascript. To get the variable value, use `this.model.get('value')`

The render method is where our javascript code lies. We will use this section to render our d3 chart.

In [5]:
%%javascript
require.undef('barChart');

define('barChart', ["@jupyter-widgets/base", "d3"], function(widgets, d3) {

    var BarView = widgets.DOMWidgetView.extend({

        render: function() {
            this.value_changed();
            // set change event handler for variable 'value' of bar widget
            this.model.on('change:value', this.value_changed, this);
        },
         value_changed: function() {
            // get the data from back-end model 
            var data = this.model.get('value');
            var self = this;
            var margin = {top: 20, right: 20, bottom: 30, left: 40},
            width = 400 - margin.left - margin.right,
            height = 300 - margin.top - margin.bottom;

            // set the ranges
            var x = d3.scaleBand()
                      .range([0, width])
                      .padding(0.1);
            var y = d3.scaleLinear()
                      .range([height, 0]);
            $("#barChart").remove();
            this.$el.append("<div id='barChart'></div>");
            $("#barChart").width("400px");
            $("#barChart").height("300px");
            var svg = d3.select("#barChart").append("svg")
                        .attr("width", width + margin.left + margin.right)
                        .attr("height", height + margin.top + margin.bottom)
                        .append("g")
                        .attr("transform", 
                              "translate(" + margin.left + "," + margin.top + ")");
            
            x.domain(data.map(function(d) { return d.player_name; }));
            y.domain([0, d3.max(data, function(d) { return d.average_score; })]);
            // append the rectangles for the bar chart
            svg.selectAll(".bar")
                .data(data)
                .enter().append("rect")
                .attr("class", "bar")
                .attr("x", function(d) { return x(d.player_name); })
                .attr("width", x.bandwidth())
                .attr("y", function(d) { return y(d.average_score); })
                .attr("height", function(d) { return height - y(d.average_score); })
                .style("fill","steelblue")
                .on("click", function(d){
                    d3.selectAll('.bar').style("fill","steelblue");
                    d3.select(this).style("fill","green");
                    self.model.set('player_name', d.player_name);
                    self.model.save_changes();
                    self.touch();
                });

            // add the x Axis
            svg.append("g")
            .attr("transform", "translate(0," + height + ")")
            .call(d3.axisBottom(x));

            // add the y Axis
            svg.append("g")
            .call(d3.axisLeft(y));
            
        },
    });

    return {
        BarView : BarView
    };
});

<IPython.core.display.Javascript object>

Lets create the widget for Line chart to show the progress of a player whose bar is selected.

For this again, we need to create the back-end python model. Here `value` variable contains the data of a particular player (selected from bar chart) over the years. `selected_year` variable will be used in the `range` selection below (Example 3).

In [6]:
class LineWidget(widgets.DOMWidget):
    _view_name = Unicode('LineView').tag(sync=True)
    _view_module = Unicode('lineChart').tag(sync=True)
    _view_module_version = Unicode('0.1.0').tag(sync=True)
    value = List([]).tag(sync=True)
    selected_years = List([]).tag(sync=True)

Now, the front-end view for the line chart

In [7]:
%%javascript
require.undef('lineChart');
define('lineChart', ["@jupyter-widgets/base", "d3"], function(widgets, d3) {

    var LineView = widgets.DOMWidgetView.extend({

        render: function() {
            this.value_changed();
            // set change event handler for variable 'value' of line widget
            this.listenTo(this.model, 'change:value', this.value_changed, this);
        },

        value_changed: function() {
            // get data from back-end variable 'value'
            var player = this.model.get('value');
            var that = this;
            // get data from back-end variable 'count'
            var test = this.model.get('count');
            var yearValues = [];
            var attribValues = [];
            var playerYearDataList = [];
            player.sort(function(x, y){
                return d3.ascending(x[2], y[2]);
            })
            player.forEach(function(d){
                yearValues.push(d["year"]);
                attribValues.push(d["overall_rating"]);
            })
            var margin = {top: 10, right: 30, bottom: 30, left: 50};
            var svgHeight = 400;
            var svgWidth = 1000;
          //create canvas
          $("#chart1").remove();
          this.$el.append("<div id='chart1'></div>");
          $("#chart1").width("960px");
          $("#chart1").height("400px");        
          var margin = {top: 20, right: 20, bottom: 30, left: 40};
          var width = 880 - margin.left - margin.right;
          var height = 500 - margin.top - margin.bottom;
          var svg = d3.select("#chart1").append("svg")
            .style("position", "relative")
            .style("max-width", "960px")
            .attr("width", width + "px")
            .attr("height", (height + 50) + "px");
          svg.append('g')
             .attr("id", "xAxis");
          svg.append('g')
             .attr("id", "yAxis");
          let yScale = d3.scaleLinear()
                    .domain([d3.min(attribValues, d => d), d3.max(attribValues, d => d)])
                    .range([svgHeight - margin.top - margin.bottom, 0]);

            let yAxis = d3.axisLeft();
            // assign the scale to the axis
            yAxis.scale(yScale);
            var yAxisG = d3.select("#yAxis")
                .attr("transform", "translate("+margin.left+"," + margin.top +")");
        
            yAxisG.transition(3000).call(yAxis);
            let xScale = d3.scaleLinear()
                .domain([d3.min(yearValues), d3.max(yearValues)])
                .range([0, 600]);

            let xAxis = d3.axisBottom();
            // assign the scale to the axis
            xAxis.scale(xScale);
            var xAxisG = d3.select("#xAxis")
                .attr("transform", "translate("+(margin.left+ 10)+"," + (svgHeight - margin.bottom) +")");

            xAxisG.transition(3000).call(xAxis);
            svg.selectAll(".playerPath").remove();
            svg.selectAll(".playerNode").remove();

            var lineCoords = []
            for(var k=0; k<yearValues.length; k++){
                lineCoords.push([xScale(yearValues[k]), yScale(attribValues[k])]);
            }
            var lineGenerator = d3.line();
            var pathString = lineGenerator(lineCoords);
            svg.append('path')
                .attr('d', pathString)
                .attr("transform", "translate("+(margin.left+ 10)+"," + (margin.top) +")")
                .attr("style", "fill : none;")
                .attr("class", "playerPath")
                .style("stroke", "steelblue")
                .style("stroke-width", 3)
                .style('opacity', 0.5);
            
            lineCoords.forEach(function(point){
                svg.append('circle').attr('cx', point[0])
                    .attr("cy", point[1])
                    .attr("r", 5)
                    .attr("transform", "translate("+(margin.left+ 10)+"," + (margin.top) +")")
                    .attr("class", "playerNode");
            });
            d3.selectAll(".brush").remove();
            var brush = d3.brushX().extent([[margin.left,svgHeight-margin.bottom-20],[svgWidth,svgHeight-10]]).on("end", brushed);
            svg.append("g").attr("class", "brush").call(brush);
            function brushed() {

                var sel = d3.event.selection;

                if(sel === null){
                    return;
                }

            var yearValuesBrushed = yearValues.filter((d) => xScale(d)+margin.left+ 10 >= sel["0"] &&  xScale(d)+margin.left+ 10  <= sel["1"]);
            // set data to back-end variable 'selected_years'
            that.model.set('selected_years', yearValuesBrushed);
            that.model.save_changes();
            that.touch();
        }
        },
    });

    return {
        LineView : LineView
    };
});

<IPython.core.display.Javascript object>

We have successfully created the widgets for our bar chart and line chart. In order to use it, first we have to initialize them.
Let us create an object for the bar chart class and display it. updateBar function is used to update the 'value' variable.

In [8]:
barWidget = BarWidget()
def updateBar():
    barWidget.value = []
    barWidget.value = data_average_top5.to_dict(orient='records')
updateBar()

After initializing the barwidget, we call `updateBar` method first time to send the initial state to the widget. Then by running the below widget we can see the bar chart.

In [9]:
display(barWidget)
updateBar()

BarWidget(value=[{'player_name': 'Alexis Sanchez', 'average_score': 80.9625}, {'player_name': 'Andres Iniesta'…

To make sure our selection works, click on any bar in the above chart and run the following cell to display the player name.

In [13]:
barWidget.player_name

'Andres Iniesta'

We need to initialize the line widget before using it. Once the user clicks on any bar in the above chart, `updateLineChart` function in the below cell filters the data based on the player name. Also this function updates the value that is being sent to the line chart. The data here is the score of the selected player over the years.

In [14]:
lineWidget = LineWidget()
def updateLineChart(name):
    filterByName = players_data[players_data["player_name"]==name]
    jsonValue = filterByName[["player_name", "overall_rating", "year"]]
    lineWidget.value = []
    lineWidget.value = jsonValue.to_dict(orient='records')
updateLineChart(barWidget.player_name)

After selecting any player from above bar chart, we will run the following cell to show the progress of him over the years.
If the user wants to see the line chart of a different player, user needs to select a player from the bar chart and re-run the below cell.

In [15]:
display(lineWidget)
updateLineChart(barWidget.player_name)

LineWidget(value=[{'player_name': 'Andres Iniesta', 'overall_rating': 83.5, 'year': 2007}, {'player_name': 'An…

## Example 2: Bar chart - Filter

This example shows the implementation of filtering.

The below bar chart shows the average score of all 50 players. We will have a text box to enter a number, then the plot will be updated to show the players having average score of more than or equal to the entered number. 

In [16]:
barChart2 = BarWidget()
def updateBarChart2(score):
    player_with_score = data_average[data_average['average_score']>score['new']]
    barChart2.value = []
    barChart2.value = player_with_score.to_dict(orient='records')
updateBarChart2({'new':0})
def createScoreTextBox():
    score_text_box = widgets.widgets.FloatText(
    value=0,
    description='Show player with score more than or equal to:',
    disabled=False
    )
    score_text_box.observe(updateBarChart2, names='value')
    display(score_text_box)
createScoreTextBox()
display(barChart2)
updateBarChart2({'new':0})

FloatText(value=0.0, description='Show player with score more than or equal to:')

BarWidget(value=[{'player_name': 'Alexis Sanchez', 'average_score': 80.9625}, {'player_name': 'Andres Iniesta'…

### Example 3 - Range selection using brush

The goal is to get the brush selection in d3 plot to python side. It is achieved in the same way as above. The brush selection data is populated in the `selected_years` variable of the LineWidget class.

To set the model variable value, use `this.model.set('selected_years')` followed by `this.model.save_changes()`.

In [17]:
lineWidget2 = LineWidget()
def showProgressOfPlayer(name):
    filterByName = players_data[players_data["player_name"]==name]
    returnData = filterByName[["player_name", "overall_rating", "year"]]
    lineWidget2.value = []
    lineWidget2.value = returnData.to_dict(orient='records')
showProgressOfPlayer("Lionel Messi")

We can see the line chart by running the below cell. The following line chart shows the progress of player Messi. Use the brush tool over the x-axis to select the years. Then we will print the selected years to make sure that we recieve the range

In [None]:
display(lineWidget2)
showProgressOfPlayer("Lionel Messi")

In [None]:
lineWidget2.selected_years

## Conclusion

Ipywidgets provides a clean way to establish comminication between front-end and back-end python kernel. The framework uses 'comm' API for this purpose. It allows the programmer to send JSON-able blobs between the front-end and the back-end. The comm API hides the complexity of the webserver, ZMQ, and websockets.

I hope this simple project provides a starting point to achieve bi-direction communication in jupyter notebook.