In [245]:
from rec_room.recommenders import BaseRS
from rec_room.datasets import load_oms_courses_dataset
from rec_room import ROOT_DIR

In [246]:
from os.path import join
from typing import Dict

import joblib

In [3]:
RS_PATH = join(ROOT_DIR, 'recommenders')

In [244]:
class OMSCoursesRS(BaseRS):
    
    META = dict(
        dataset = 'oms_courses.csv',
        model = 'oms_courses_model.sav',
        labels = 'oms_courses_labels.sav',
        path = 'oms_courses_rs'
    )
    
    def __init__(self, rs_args:Dict[str, str]=None, rs_dir:str=None) -> None:        
        """
        Instantiate the OMS Courses Recommender System.

        Parameters
        ----------
        rs_args : Dict
            User specified arguments for preferences or constraints
            
        rs_dir : Str
            Directory to where the RS is located.
        """
        if rs_dir is None:
            rs_dir = join(RS_PATH, self.META['path'])
        
        super(OMSCoursesRS, self).__init__(
            args=rs_args,
            path=rs_dir
        )
        
        self._model = joblib.load(join(self.path, self.META['model']))
        self._labels = joblib.load(join(self.path, self.META['labels']))
        
        self.stats = self.dataset.groupby(['course_id','course_name']).mean().reset_index()
        
        
    @property
    def dataset(self) -> 'pd.DataFrame':
        return load_oms_courses_dataset()
    
    @property
    def model(self) -> 'sklearn.linear_model.LogisticRegression':
        """
        LogisticRegression(C=1.0, class_weight=None, dual=False, 
                           fit_intercept=True, intercept_scaling=1, 
                           l1_ratio=None, max_iter=5000, multi_class='auto', 
                           n_jobs=None, penalty='l2', random_state=0, 
                           solver='saga', tol=0.0001, verbose=0, warm_start=False)
        """
        return self._model
        
    @property
    def labels(self) -> 'sklearn.preprocessing.LabelEncoder':
        return self._labels
    
    def recommend(self, args:Dict[str, str]=None) -> Dict[str, str]:
        """
        Make a Recommendation for OMS Courses.

        This function calls the pre-trained `self.model` and `self.labels`
        to make the best prediction, or recommendation, 
        for the provided instance arguments.

        Parameters
        ----------
        args : Dict
            User specified arguments for preferences or constraints

        Returns
        -------
        Dict
            The recommendation results
        """
        if args is None:
            args = self.args
            
        assert args is not None
            
        metrics = args.get('metrics', None)
        choices = args.get('choices', [])
        top = args.get('top', len(choices))
        
        assert metrics is not None
        
        inst = [metrics['rating'], metrics['workload'], metrics['difficulty']]
        
        preds = self.model.predict_proba([inst])        
        preds = preds.argsort()[0][::-1] # sort pred indices in ascending order        

        courses = self.labels.inverse_transform(preds)
        
        # drop=True?
        recs = self.stats.set_index('course_name').loc[courses].reset_index()[:top]
        
        return recs.to_dict('records')

    
    def render(self) -> 'HTML':
        pass    
        
    
    def _train(self) -> None:
        """
        Train, or re-train, the US College RS.
        """
        df = self.dataset.copy()
        y, _ = df.pop('course_name'), df.pop('course_id')
        labels = self.labels.fit_transform(y)
        self.model.fit(df.values, labels)
    
    
args = dict(
    choices = ['CS-6210', 'CS-6515', 'CS-6601', 'CS-7642'],
    metrics = {'rating': 5, 'workload': 20, 'difficulty': 5},
    top = 5
)
rs = OMSCoursesRS(args)
rs.recommend()

[{'course_name': 'Artificial Intelligence',
  'course_id': 'CS-6601',
  'rating': 4.235294117647059,
  'workload': 22.852941176470587,
  'difficulty': 4.279411764705882},
 {'course_name': 'Intro to Graduate Algorithms',
  'course_id': 'CS-6515',
  'rating': 3.8098591549295775,
  'workload': 21.070422535211268,
  'difficulty': 4.288732394366197},
 {'course_name': 'Machine Learning',
  'course_id': 'CS-7641',
  'rating': 3.711864406779661,
  'workload': 21.376949152542373,
  'difficulty': 4.146892655367232},
 {'course_name': 'Computer Vision',
  'course_id': 'CS-6476',
  'rating': 4.46,
  'workload': 21.0,
  'difficulty': 4.02},
 {'course_name': 'Reinforcement Learning',
  'course_id': 'CS-7642',
  'rating': 3.9022556390977443,
  'workload': 22.360902255639097,
  'difficulty': 4.120300751879699}]

In [247]:
rs.model

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=5000,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=0, solver='saga', tol=0.0001, verbose=0,
                   warm_start=False)

In [163]:
for b in a:
    print(b)

[['CS-6601', 'Artificial Intelligence', 4.235294117647059, 22.852941176470587, 4.279411764705882]]
[['CS-6515', 'Intro to Graduate Algorithms', 3.8098591549295775, 21.070422535211268, 4.288732394366197]]
[['CS-7642', 'Reinforcement Learning', 3.9022556390977443, 22.360902255639097, 4.120300751879699]]
[['CS-6210', 'Advanced Operating Systems', 4.4423076923076925, 17.326923076923077, 4.211538461538462]]


In [134]:
a[a.loc[:, 'course_name'].isin(['Intro to Information Security', 'Reinforcement Learning'])]#'Intro to Information Security']

Unnamed: 0,course_id,course_name,rating,workload,difficulty
0,CS-6035,Intro to Information Security,3.655039,9.106395,2.410853
26,CS-7642,Reinforcement Learning,3.902256,22.360902,4.120301


In [136]:
a[a.loc[:, 'course_id'].isin(args['choices'])]

Unnamed: 0,course_id,course_name,rating,workload,difficulty
2,CS-6210,Advanced Operating Systems,4.442308,17.326923,4.211538
19,CS-6515,Intro to Graduate Algorithms,3.809859,21.070423,4.288732
20,CS-6601,Artificial Intelligence,4.235294,22.852941,4.279412
26,CS-7642,Reinforcement Learning,3.902256,22.360902,4.120301
