Skip to content

Commit

Permalink
Completed DB integration with the filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Casey Barnette committed Mar 28, 2016
1 parent 073aa69 commit 8faca0b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 26 deletions.
15 changes: 10 additions & 5 deletions classrank/filters/collabfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
from classrank.filters.datawrapper import DataWrapper
class CollaborativeFilter:
#This takes in a matrix
def __init__(self, data, numRecommendations=2):
self.dataset = DataWrapper(data)
def __init__(self, data=dict(), numRecommendations=1, db=None, metric="rating", school="gatech"):
self.dataset = DataWrapper(instances=data, db=db, school=school, metric=metric)
self.updated = False
self.sparsedata = None
self.sparseifyData()
self.svd = TruncatedSVD()
self.model = self.svd.inverse_transform(self.svd.fit_transform(self.sparsedata))

try:
self.svd = TruncatedSVD(n_components=numRecommendations)
self.model = self.svd.inverse_transform(self.svd.fit_transform(self.sparsedata))
except ValueError:
self.svd = None
self.model = None
raise ValueError("Not enough ratings for predictions")

def getRecommendation(self, instances):
if(self.updated):
self.sparseifyData()
Expand Down
42 changes: 22 additions & 20 deletions classrank/filters/datawrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import classrank.database.wrapper as db
from classrank.database.wrapper import Query
class DataWrapper:
def __init__(self, instances=dict(), db=None, school="gatech", metric="rating"):
self.db = db
Expand Down Expand Up @@ -63,22 +63,24 @@ def getColumn(self, feature):
return self.featureLookup[feature]

def queryDB(self):
query = wrapper.Query(self.db)
for student in query.query(self.db.Student).filter(self.db.Student==self.school).all():
results = query.query(self.db.Rating, self.db.Section, self.db.Course).filter(self.db.Rating.student_id == student.uid).\
filter(self.db.Rating.section_id==self.db.Course.section_id).all() #a tuple of lists
results = zip(*results) #a list of tuples
instance = {}
for result in results:
courseName = query.query(self.db.Course).filter(self.db.Course.uid==result[1].course_id).first()
courseName = courseName.name
if metric == "rating":
rating = result[0][0].rating
elif metric == "grade":
rating = result[0][0].grade
elif metric == "workload":
rating = result[0][0].workload
elif metric == "difficulty":
rating = result[0][0].difficulty
instance[courseName] = rating
self.instances[student.uid] = instance
with Query(self.db) as query:
for student in query.query(self.db.student).filter(self.db.school.abbreviation==self.school).all():
results = query.query(self.db.rating, self.db.section).filter(self.db.rating.student_id == student.uid).\
filter(self.db.rating.section_id==self.db.section.uid).all() #a tuple of lists
#results = list(zip(*results)) #a list of tuples
#pprint.pprint(results)
instance = {}
for result in results:
courseName = query.query(self.db.course).filter(self.db.course.uid==result[1].course_id).first()
courseName = courseName.name
rating = result[0].__getattribute__(self.metric)
#if self.metric == "rating":
# rating = result[0][0].rating
#elif self.metric == "grade":
# rating = result[0][0].grade
#elif self.metric == "workload":
# rating = result[0][0].workload
#elif self.metric == "difficulty":
# rating = result[0][0].difficulty
instance[courseName] = rating
self.dataDict[student.uid] = instance
2 changes: 1 addition & 1 deletion test/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest

from unittest.mock import Mock, MagicMock, patch
from classrank.filters.collabfilter import CollaborativeFilter
import numpy as np
from scipy import sparse
Expand Down
39 changes: 39 additions & 0 deletions test/test_filter_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
from unittest.mock import Mock, MagicMock, patch
from classrank.filters.collabfilter import CollaborativeFilter
import numpy as np
import os
from classrank.database.wrapper import Database, Query

class TestDatabaseFilter(unittest.TestCase):
def setUp(self):
self.conn = Database(engine=os.environ.get("CONNECTION", "sqlite:///:memory:"))
school = self.conn.school(name="Georgia Tech", abbreviation="gatech")
course = self.conn.course(school=school, name="Intro Java", number="1331", subject="CS")
course2 = self.conn.course(school=school, name="Stuff", number="1332", subject="CS")
section1 = self.conn.section(course=course, semester="fall", year=2016, name="A1")
section2 = self.conn.section(course=course, semester="fall", year=2016, name="A2")
self.section3 = self.conn.section(course=course2, semester="spring",year=2015, name="A")
account = self.conn.account(username="test", email_address="test@test.com", password_hash=b"t", password_salt=b"t")
student = self.conn.student(account=account, school=school)
account2 = self.conn.account(username="test2", email_address="test2@test.com", password_hash=b"t", password_salt=b"t")
self.student2 = self.conn.student(account=account2, school=school)
with Query(self.conn) as q:
q.add(school)
q.add(course)
q.add(section1)
q.add(section2)
q.add(course2)
q.add(self.section3)
q.add(account)
q.add(student)
q.add(self.student2)
q.add(self.conn.rating(student=student, section=section1, rating=5))
q.add(self.conn.rating(student=self.student2, section=section2, rating=3))
def test_filter_query(self):
with self.assertRaises(ValueError):
cf = CollaborativeFilter(db=self.conn)
with Query(self.conn) as q:
q.add(self.conn.rating(student=self.student2, section=self.section3, rating=4))
cf = CollaborativeFilter(db=self.conn)
self.assertIsInstance(cf.getData(), type([]))

0 comments on commit 8faca0b

Please sign in to comment.