In [45]:
from py2neo import Graph
import pandas as pd

pd.set_option('display.max_colwidth', -1)

In [72]:
graph = Graph("bolt://localhost", auth=("neo4j", "neo"))

For this section we need to install the [ml-models](https://github.com/neo4j-graph-analytics/ml-models) procedures library. You can find installation instructions on the [releases pages](https://github.com/neo4j-graph-analytics/ml-models/releases/tag/1.0.1). 

Once you've done that, run the following code to check that the library is installed:

In [47]:
query = """
CALL dbms.procedures() 
YIELD name WHERE name STARTS WITH 'regression' 
RETURN *
"""

graph.run(query).to_data_frame()

Unnamed: 0,name
0,regression.linear.add
1,regression.linear.addM
2,regression.linear.clear
3,regression.linear.copy
4,regression.linear.create
5,regression.linear.delete
6,regression.linear.info
7,regression.linear.load
8,regression.linear.remove
9,regression.linear.removeM


Let's split training and test data:

In [97]:
split_data_train_query = """
MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood) 
WHERE exists(list.bedrooms) AND exists(list.bathrooms)
AND exists(list.price) 
AND (:Review)-[:REVIEWS]->(list) 
WITH regression.linear.split(collect(id(list)), 0.75) AS trainingIDs
MATCH (list:Listing) WHERE id(list) in trainingIDs 
SET list:Train
"""

split_data_test_query = """
MATCH (list:Listing)-[n:IN_NEIGHBORHOOD]->(:Neighborhood)
WHERE exists(list.bedrooms) AND exists(list.bathrooms)
AND exists(list.price) 
AND (:Review)-[:REVIEWS]->(list) 
AND NOT list:Train 
SET list:Test
"""

graph.run(split_data_train_query).summary().counters
graph.run(split_data_test_query).summary().counters

{'labels_added': 9815}

In [98]:
correlation_query = """
MATCH (list) 
WHERE list:Test OR list:Train
WITH collect(size((list)<-[:REVIEWS]-()) * 1.0) AS reviews,
     collect(list.bedrooms + list.bathrooms) as rooms
RETURN regression.linear.correlation(reviews, rooms)
"""

graph.run(correlation_query).to_data_frame()

Unnamed: 0,"regression.linear.correlation(reviews, rooms)"
0,0.026142


In [51]:
model_name = "rental-prices-4"

In [52]:
init_query = """
CALL regression.linear.create($modelName, 'Multiple', true, 2)
"""

graph.run(init_query, {"modelName": model_name}).summary().counters

{}

In [53]:
add_training_data_query = """
MATCH (list:Train)
WHERE NOT list:Seen 
CALL regression.linear.add($modelName, 
  [list.bedrooms + list.bathrooms, size((list)<-[:REVIEWS]-()) * 1.0], 
  list.price
) 
SET list:Seen 
RETURN count(list)
"""

graph.run(add_training_data_query, {"modelName": model_name}).summary().counters

{'labels_added': 36790}

In [54]:
train_model_query = """
CALL regression.linear.train($modelName)
"""

graph.run(train_model_query, {"modelName": model_name})

<py2neo.database.Cursor at 0x1156d6a90>

In [55]:
add_test_data_query = """
MATCH (list:Test) 
WHERE NOT list:Seen
CALL regression.linear.add($modelName, 
  [list.bedrooms + list.bathrooms, size((list)<-[:REVIEWS]-()) * 1.0],  
  list.price, 
  'test'
) 
SET list:Seen 
RETURN count(list)
"""

graph.run(add_test_data_query, {"modelName": model_name})

In [56]:
test_model_query = """
CALL regression.linear.test($modelName)
"""

graph.run(test_model_query, {"modelName": model_name}).to_data_frame()

Unnamed: 0,framework,hasConstant,model,nTest,nTrain,numVars,state,testInfo,trainInfo
0,Multiple,True,rental-prices-4,0,36790,2,ready,"{'adjRSquared': nan, 'RSquared': nan, 'SSE': 0.0, 'SST': 0.0, 'MSE': -0.0}","{'RSquared': 0.22357460363988557, 'SSR': 88897254.99853414, 'SSE': 308720602.9837708, 'SST': 397617857.98230493, 'adjRSquared': 0.22353239169564654, 'parameters std error': [1.2824171688063868, 0.5109515890197222, 0.011533589063208016], 'parameters': [16.223217245130986, 52.5881409596996, -0.034282521063706], 'MSE': 8392.111424790572}"


In [57]:
info_query = """
CALL regression.linear.info($modelName) 
"""

graph.run(info_query, {"modelName": model_name}).to_data_frame()

Unnamed: 0,framework,hasConstant,model,nTest,nTrain,numVars,state,testInfo,trainInfo
0,Multiple,True,rental-prices-4,0,36790,2,ready,"{'adjRSquared': nan, 'RSquared': nan, 'SSE': 0.0, 'SST': 0.0, 'MSE': -0.0}","{'RSquared': 0.22357460363988557, 'SSR': 88897254.99853414, 'SSE': 308720602.9837708, 'SST': 397617857.98230493, 'adjRSquared': 0.22353239169564654, 'parameters std error': [1.2824171688063868, 0.5109515890197222, 0.011533589063208016], 'parameters': [16.223217245130986, 52.5881409596996, -0.034282521063706], 'MSE': 8392.111424790572}"


Now let's add the neighborhood to the list of independent variables that we feed to our regression model. Neighborhoods are categorical variables so we'll need to create a [one hot encoding](https://hackernoon.com/what-is-one-hot-encoding-why-and-when-do-you-have-to-use-it-e3c6186d008f).

We can use the `algo.ml.oneHotEncoding` function to help us out. 

In [101]:
model_name = "rental-prices-8"

In [102]:
nh_count_query = """
MATCH (:Neighborhood)
RETURN count(*) AS count
"""

nh_count =  graph.run(nh_count_query).to_table()[0][0]


init_query = """
CALL regression.linear.create($modelName, 'Multiple', true, $numberOfVariables)
"""

graph.run(init_query, {"modelName": model_name, "numberOfVariables": 2 + nh_count}).summary().counters

{}

Before we create our new model let's remove the `Seen` label from our nodes so that we can process them again:

In [103]:
clear_seen_query = """
MATCH (s:Seen)
REMOVE s:Seen
"""

graph.run(clear_seen_query)

<py2neo.database.Cursor at 0x1156f6780>

In [104]:
add_training_data_query = """
MATCH (nh:Neighborhood)
WITH collect(nh) AS neighborhoods
MATCH (list:Train)-[:IN_NEIGHBORHOOD]->(nh)
WHERE NOT list:Seen 
CALL regression.linear.add($modelName, 
  apoc.coll.flatten([
    [list.bedrooms + list.bathrooms, size((list)<-[:REVIEWS]-()) * 1.0],
    algo.ml.oneHotEncoding(neighborhoods, [nh])
  ]), 
  list.price
) 
SET list:Seen 
RETURN count(list)
"""

graph.run(add_training_data_query, {"modelName": model_name}).summary().counters

{'labels_added': 29443}

In [105]:
train_model_query = """
CALL regression.linear.train($modelName)
"""

graph.run(train_model_query, {"modelName": model_name}).to_data_frame()

Unnamed: 0,framework,hasConstant,model,nTest,nTrain,numVars,state,testInfo,trainInfo
0,Multiple,True,rental-prices-8,0,29443,226,testing,{},"{'RSquared': 0.43841853768199124, 'SSR': 139502942.45113206, 'SSE': 178692869.22396898, 'SST': 318195811.67510104, 'adjRSquared': 0.4340744313538193, 'parameters std error': [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, ...], 'parameters': [172932567497899.8, 55.28662085533142, -0.008346721529960632, -172932567497909.66, -172932567497947.7, -172932567497792.88, -172932567497890.0, -172932567497886.94, -172932567497850.25, -172932567497832.72, -172932567497824.16, -172932567497854.88, -172932567497774.34, -172932567497865.72, -172932567497899.06, -172932567497802.28, -172932567497911.25, -172932567497880.22, -172932567497806.88, -172932567497929.94, -172932567497885.0, -172932567497934.6, -172932567497922.84, -172932567497929.6, -172932567497898.38, -172932567497949.16, -172932567497938.47, -172932567497861.88, -172932567497926.78, -172932567497914.22, -172932567497892.5, -172932567497948.56, -172932567497823.5, -172932567497784.9, -172932567497913.06, -172932567497935.3, -172932567497854.16, -172932567497936.78, -172932567497839.44, -172932567497813.34, -172932567497874.22, -172932567497936.06, -172932567497872.88, -172932567497863.03, -172932567497923.75, -172932567497868.34, -172932567497944.8, -172932567497933.94, -172932567497949.1, -172932567497928.97, -172932567497855.94, -172932567497949.84, -172932567497953.06, -172932567497843.2, -172932567497958.38, -172932567497913.7, -172932567497929.2, -172932567497925.5, -172932567497771.3, -172932567497929.0, -172932567497924.9, -172932567497869.12, -172932567497940.84, -172932567497962.22, -172932567497828.47, -172932567497925.06, -172932567497940.3, -172932567497926.97, -172932567497828.06, -172932567497931.2, -172932567497755.5, -172932567497959.3, -172932567497907.6, -172932567497916.9, -172932567497819.84, -172932567497959.1, -172932567497945.56, -172932567497942.53, -172932567497959.28, -172932567497789.38, -172932567497944.84, -172932567497947.84, -172932567497921.94, -172932567497918.62, -172932567497948.25, -172932567497943.78, -172932567497947.2, -172932567497938.47, -172932567497953.47, -172932567497934.22, -172932567497948.12, -172932567497865.62, -172932567497949.53, -172932567497947.9, -172932567497932.78, -172932567497960.88, -172932567497936.53, -172932567497973.8, -172932567497926.72, -172932567497934.12, ...], 'MSE': 6116.2674296265395}"


In [110]:
add_test_data_query = """
MATCH (nh:Neighborhood)
WITH collect(nh) AS neighborhoods
MATCH (list:Test)-[:IN_NEIGHBORHOOD]->(nh)
WHERE NOT list:Seen 
CALL regression.linear.add($modelName, 
  apoc.coll.flatten([
    [list.bedrooms + list.bathrooms, size((list)<-[:REVIEWS]-()) * 1.0],
    algo.ml.oneHotEncoding(neighborhoods, [nh])
  ]), 
  list.price,
    'test'
) 
SET list:Seen 
RETURN count(list)
"""

graph.run(add_test_data_query, {"modelName": model_name})

<py2neo.database.Cursor at 0x115502b00>

In [111]:
test_model_query = """
CALL regression.linear.test($modelName)
"""

graph.run(test_model_query, {"modelName": model_name}).to_data_frame()

Unnamed: 0,framework,hasConstant,model,nTest,nTrain,numVars,state,testInfo,trainInfo
0,Multiple,True,rental-prices-8,9815,39258,226,ready,"{'adjRSquared': 0.4256676041949198, 'RSquared': 0.43883499659925473, 'SSE': 58728138.86140442, 'SST': 104653958.29302071, 'MSE': 6124.532157827137}","{'RSquared': 0.43768862075928705, 'SSR': 185076677.3689768, 'SSE': 237773423.34397343, 'SST': 422850100.7129502, 'adjRSquared': 0.4344326864581315, 'parameters std error': [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, ...], 'parameters': [21313743924091.84, 54.929301381111145, -0.010722295381128788, -21313743924102.37, -21313743924133.79, -21313743923982.83, -21313743924081.098, -21313743924080.22, -21313743924040.02, -21313743924023.973, -21313743924017.617, -21313743924044.508, -21313743923959.984, -21313743924057.574, -21313743924091.406, -21313743923986.9, -21313743924100.23, -21313743924071.67, -21313743923996.27, -21313743924120.305, -21313743924072.508, -21313743924125.117, -21313743924114.094, -21313743924120.953, -21313743924089.555, -21313743924139.277, -21313743924131.28, -21313743924052.04, -21313743924118.688, -21313743924102.992, -21313743924086.918, -21313743924134.836, -21313743924019.715, -21313743923980.855, -21313743924086.17, -21313743924127.004, -21313743924043.42, -21313743924125.215, -21313743924030.867, -21313743924003.15, -21313743924068.508, -21313743924118.945, -21313743924065.47, -21313743924055.97, -21313743924125.664, -21313743924062.816, -21313743924135.957, -21313743924128.58, -21313743924141.62, -21313743924121.844, -21313743924045.152, -21313743924146.242, -21313743924142.094, -21313743924028.52, -21313743924149.426, -21313743924103.883, -21313743924122.76, -21313743924104.8, -21313743923961.254, -21313743924118.152, -21313743924113.867, -21313743924050.145, -21313743924134.473, -21313743924148.816, -21313743924025.4, -21313743924114.16, -21313743924132.883, -21313743924120.24, -21313743924009.633, -21313743924124.68, -21313743923949.258, -21313743924149.03, -21313743924095.957, -21313743924110.68, -21313743924011.645, -21313743924145.668, -21313743924127.902, -21313743924134.008, -21313743924147.562, -21313743923989.414, -21313743924135.43, -21313743924135.168, -21313743924113.49, -21313743924109.793, -21313743924134.8, -21313743924113.363, -21313743924139.535, -21313743924138.137, -21313743924143.684, -21313743924127.867, -21313743924141.5, -21313743924068.84, -21313743924140.754, -21313743924141.508, -21313743924123.227, -21313743924151.29, -21313743924127.203, -21313743924161.805, -21313743924119.777, -21313743924125.17, ...], 'MSE': 6091.912155567969}"
