In [1]:
import os
from typing import List
from pprint import pprint

import pandas as pd
import numpy as np
import sqlite3
from pyod.models.iforest import IForest, check_array
from pyod.models.ocsvm import OCSVM
from pyod.models.lof import LOF
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.ensemble import IsolationForest
import shap


pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


# Utils


In [2]:
SEED = 42

def cast_to_0_1(preds):
    """
    from  -1 for outlies and 1 for inliers
    to 0 for inliers and 1 for outliers)
    """
    return (preds == -1).astype(int)

def explain_outlier(shap_value, columns, top_k=5):
    """
    Get TOP abnormal features
    """
    _vals = shap_value.values
    top_5 = np.argsort(_vals)[:top_k]
    return ({columns[idx]: _vals[idx] for idx in top_5})

# Get  Data

In [5]:
class  AnomalyСalculation:
    '''
    Рассчёт аномалий по прошедшим играм
    '''
    
    def __init__(self):
        # для рассчёта аномалий
        self.calculus_col = ['min_sec', 'FGM', 'FGA', 'FG3M', 'FG3A', 'FTM', 'FTA', 
                             'REB', 'AST', 'STL', 'BLK', 'TO', 'PF', 'PTS', 'PLUS_MINUS']
        
        # для доп контекста
        self.labels_col =  ['GAME_ID', 'TEAM_ID',	'TEAM_ABBREVIATION',	
                            'TEAM_CITY',	'PLAYER_ID',	'PLAYER_NAME',	'NICKNAME']
    
    
    def get_data(self, limit=10000000000, db_path: str=r'../../data/basnya.db'):
        """
        Получить данные из локаьной SQLite
        
        Аргументы:
            db_path - путь до локальной БД
            limit - указывает, солько последних записей взять
        Возвращает:
            df - данные из бд в формате DataFrame
        """

        conn = sqlite3.connect(db_path)
        
        game2date = (
            pd.read_sql_query("SELECT * FROM GAMES", conn)
            .set_index('GAME_ID')['GAME_DATE_EST']
            ).to_dict()
        
        df = (
            pd.read_sql_query(f"SELECT * FROM boxscoretraditionalv2_0 LIMIT {limit}", conn)
            .drop('index', axis=1)
        )
        df[['_min', '_sec']] = df['MIN'].str.split(':', expand=True).fillna(0)
        df['min_sec'] = df._min.astype(float) + df._sec.astype(int) / 60 
        df['GAME_DATE'] = pd.to_datetime(df['GAME_ID'].map(game2date)).fillna(pd.to_datetime('1900-01-01'))
        return df.fillna(0)
    
    def get_anomalous_records(self, date: str) -> pd.DataFrame:
        """
        Получить аномальные значения
        
        Аргументы:
            date - дата (пример: 2021-11-10), ДО которой будет обучаться IF для того, чтоб определить аномалию для игр ПОСЛЕ date
        
        Возвращает: DataFrame с аномальными значениями
        """
        
        df = self.get_data()
        
        if date == None:
            df_train, df_test, _, _= train_test_split(df, df, test_size=0.2, random_state=SEED)


        df_train = df.loc[df['GAME_DATE'] <= pd.to_datetime(date)]
        df_test = df.loc[df['GAME_DATE'] > pd.to_datetime(date)]
        
        scaler = StandardScaler()
        
        X_train = scaler.fit_transform(df_train[self.calculus_col])
        X_test = scaler.transform(df_test[self.calculus_col])
        
        # create and fit IsolationForest model
        CONTAMINATION = 0.01
        clf =  IsolationForest(contamination=CONTAMINATION, random_state=SEED)
        clf.fit(X_train)
        _preds = cast_to_0_1(clf.predict(X_test)) # if anomaly:1 else:0
        
        # explanation of anomalies for isolation forest
        explainer = shap.TreeExplainer(clf, feature_names=self.calculus_col)
        # get shape_values for each example X_test 
        shap_values = explainer(X_test) 
        
        anomaly_with_shape = []
        
        for example, (_, context) in zip(shap_values[_preds==1][:10], df_test[_preds==1][self.labels_col][:10].iterrows()):
            anomaly_with_shape.append({'shape_value': explain_outlier(example, columns=self.calculus_col),
                                       'context': context.to_dict()})
           
        return anomaly_with_shape
        

        

In [4]:
# ------------------- TESTING -------------------------------
anomaly = AnomalyСalculation()
anomaly_records = anomaly.get_anomalous_records('2021-10-10')
anomaly_records 

[{'shape_value': {'FG3M': -1.2603159403805317,
   'FG3A': -1.2392756779069771,
   'FGA': -0.8698809703459265,
   'FGM': -0.7540991175091406,
   'PTS': -0.7273360235039464},
  'context': {'GAME_ID': 12100057,
   'TEAM_ID': 1610612743,
   'TEAM_ABBREVIATION': 'DEN',
   'TEAM_CITY': 'Denver',
   'PLAYER_ID': 1630210,
   'PLAYER_NAME': 'Markus Howard',
   'NICKNAME': 'Markus'}},
 {'shape_value': {'FG3M': -1.0183068931302433,
   'FG3A': -0.8316612084503043,
   'PTS': -0.7513910825956549,
   'FGM': -0.6424492650283425,
   'FGA': -0.560458106046341},
  'context': {'GAME_ID': 12100066,
   'TEAM_ID': 1610612744,
   'TEAM_ABBREVIATION': 'GSW',
   'TEAM_CITY': 'Golden State',
   'PLAYER_ID': 201939,
   'PLAYER_NAME': 'Stephen Curry',
   'NICKNAME': 'Stephen'}},
 {'shape_value': {'BLK': -1.6297115780368925,
   'FTA': -0.9235368574167693,
   'STL': -0.8826940032221455,
   'FTM': -0.739388425227038,
   'min_sec': -0.56178943550655},
  'context': {'GAME_ID': 22100005,
   'TEAM_ID': 1610612738,
   'TE

In [9]:
pprint(anomaly_records )

[{'context': {'GAME_ID': 12100057,
              'NICKNAME': 'Markus',
              'PLAYER_ID': 1630210,
              'PLAYER_NAME': 'Markus Howard',
              'TEAM_ABBREVIATION': 'DEN',
              'TEAM_CITY': 'Denver',
              'TEAM_ID': 1610612743},
  'shape_value': {'FG3A': -1.2392756779069771,
                  'FG3M': -1.2603159403805317,
                  'FGA': -0.8698809703459265,
                  'FGM': -0.7540991175091406,
                  'PTS': -0.7273360235039464}},
 {'context': {'GAME_ID': 12100066,
              'NICKNAME': 'Stephen',
              'PLAYER_ID': 201939,
              'PLAYER_NAME': 'Stephen Curry',
              'TEAM_ABBREVIATION': 'GSW',
              'TEAM_CITY': 'Golden State',
              'TEAM_ID': 1610612744},
  'shape_value': {'FG3A': -0.8316612084503043,
                  'FG3M': -1.0183068931302433,
                  'FGA': -0.560458106046341,
                  'FGM': -0.6424492650283425,
                  'PTS': -0.7513910