In [32]:
%%writefile ../src/salary_predict/data_loader.py

import pandas as pd
import os

def get_project_root():
    current_dir = os.path.dirname(os.path.abspath(__file__))
    return os.path.dirname(os.path.dirname(current_dir))

def load_data(inflated=False):
    root_dir = get_project_root()
    file_name = 'final_salary_data_with_yos_and_inflated_cap_2000_on.csv'
    file_path = os.path.join(root_dir, 'data', 'processed', file_name)
    
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"The file {file_path} does not exist. Please check the file path.")
    
    df = pd.read_csv(file_path)
    if 'Salary' not in df.columns:
        raise KeyError("The 'Salary' column is missing in the dataset.")
    
    # Convert 'Season' to the correct format if necessary
    if df['Season'].dtype == 'object':
        df['Season'] = df['Season'].str[:4].astype(int)
    
    # Ensure both 'Salary Cap' and 'Salary_Cap_Inflated' columns are present
    if 'Salary Cap' not in df.columns:
        raise KeyError("The 'Salary Cap' column is missing in the dataset.")
    if 'Salary_Cap_Inflated' not in df.columns:
        df['Salary_Cap_Inflated'] = df['Salary Cap']  # Use non-inflated as fallback
    
    # Use the appropriate salary cap column based on the 'inflated' parameter
    if inflated:
        df['Salary Cap'] = df['Salary_Cap_Inflated']
    else:
        df['Salary_Cap_Inflated'] = df['Salary Cap']
    
    return df

def load_predictions(inflated=False, team=None):
    # Load the actual data
    df_actual = load_data(inflated)
    
    # Load predictions
    root_dir = get_project_root()
    predictions_file = 'salary_predictions_inflated.csv' if inflated else 'salary_predictions.csv'
    predictions_path = os.path.join(root_dir, 'data', 'predictions', predictions_file)
    
    if not os.path.exists(predictions_path):
        raise FileNotFoundError(f"The predictions file {predictions_path} does not exist.")
    
    df_predictions = pd.read_csv(predictions_path)
    
    # Check for required columns
    required_columns_predictions = ['Player', 'Predicted_Season']
    required_columns_actual = ['Player', 'Season']
    
    missing_columns_predictions = [col for col in required_columns_predictions if col not in df_predictions.columns]
    missing_columns_actual = [col for col in required_columns_actual if col not in df_actual.columns]
    
    if missing_columns_predictions:
        raise KeyError(f"The following required columns are missing in the predictions dataframe: {', '.join(missing_columns_predictions)}")
    if missing_columns_actual:
        raise KeyError(f"The following required columns are missing in the actual data dataframe: {', '.join(missing_columns_actual)}")
    
    # Rename 'Predicted_Season' to 'Season' in predictions dataframe for merging
    df_predictions = df_predictions.rename(columns={'Predicted_Season': 'Season'})
    
    # Merge predictions with actual data
    df_merged = pd.merge(df_predictions, df_actual, on=['Player', 'Season'], suffixes=('_pred', ''), how='left')
    
    # Rename columns to match expected names
    df_merged = df_merged.rename(columns={
        'Season': 'Predicted_Season',  # Change back to 'Predicted_Season'
        'Salary': 'Previous_Season_Salary',
        'Predicted_Salary_Pct': 'SalaryPct',
        'Age_pred': 'Age'  # Use predicted age
    })
    
    # Select relevant columns
    relevant_columns = ['Player', 'Predicted_Season', 'Team', 'Age', 'Position', 'Previous_Season_Salary', 
                        'Predicted_Salary', 'Salary_Change', 'SalaryPct', 'GP', 'MP', 'PTS', 'TRB', 'AST', 
                        'FG%', '3P%', 'FT%', 'PER', 'WS', 'VORP']
    
    # Check if all relevant columns are present
    missing_columns = [col for col in relevant_columns if col not in df_merged.columns]
    if missing_columns:
        print(f"Warning: The following columns are missing and will be excluded: {', '.join(missing_columns)}")
        relevant_columns = [col for col in relevant_columns if col in df_merged.columns]
    
    df_merged = df_merged[relevant_columns]
    
    # Filter by team if specified
    if team:
        if 'Team' not in df_merged.columns:
            raise KeyError("The 'Team' column is missing in the merged dataframe.")
        df_merged = df_merged[df_merged['Team'] == team]
    
    return df_merged

def merge_predictions_with_original(predictions, original_data):
    merged = predictions.merge(original_data[['Player', 'Position']], on='Player', how='left')
    merged['Position'] = merged['Position'].fillna('Unknown')
    merged.rename(columns={'Previous_Season_Salary': 'Salary'}, inplace=True)
    return merged

if __name__ == "__main__":
    get_project_root()
    # Example usage
    print("Loading data...")
    df = load_data(inflated=False)
    print("\nDataframe shape:", df.shape)
    print("\nColumns:", df.columns)
    print("\nFirst few rows:")
    print(df.head())
    print("\nNaN values:")
    print(df.isna().sum())
    
    print("\nLoading predictions...")
    predictions = load_predictions(inflated=False)
    print("\nPredictions shape:", predictions.shape)
    print("\nPredictions columns:", predictions.columns)
    print("\nFirst few rows of predictions:")
    print(predictions.head())
    print("\nNaN values in predictions:")
    print(predictions.isna().sum())

Overwriting ../src/salary_predict/data_loader.py


In [33]:
%%writefile ../src/salary_predict/data_preprocessor.py

# data_preprocessor.py
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.cluster import KMeans

def handle_missing_values(df):
    df = df.copy()
    numeric_columns = df.select_dtypes(include=[np.number]).columns
    print(f"Number of numeric columns: {len(numeric_columns)}")
    print(f"Numeric columns: {numeric_columns}")
    
    # Remove columns with all NaN values
    df = df.dropna(axis=1, how='all')
    numeric_columns = df.select_dtypes(include=[np.number]).columns
    print(f"Number of numeric columns after dropping all-NaN columns: {len(numeric_columns)}")
    
    imputer = SimpleImputer(strategy='mean')
    imputed_data = imputer.fit_transform(df[numeric_columns])
    print(f"Shape of imputed data: {imputed_data.shape}")
    print(f"Shape of original numeric data: {df[numeric_columns].shape}")
    
    df[numeric_columns] = imputed_data
    return df

def feature_engineering(df, use_inflated_data=False):
    df = df.copy()
    # Calculate per-game stats if not already present
    if 'PPG' not in df.columns:
        df['PPG'] = df['PTS'] / df['GP']
    if 'APG' not in df.columns:
        df['APG'] = df['AST'] / df['GP']
    if 'RPG' not in df.columns:
        df['RPG'] = df['TRB'] / df['GP']
    if 'SPG' not in df.columns:
        df['SPG'] = df['STL'] / df['GP']
    if 'BPG' not in df.columns:
        df['BPG'] = df['BLK'] / df['GP']
    if 'TOPG' not in df.columns:
        df['TOPG'] = df['TOV'] / df['GP']
    
    # Calculate win percentage if not already present
    if 'WinPct' not in df.columns:
        df['WinPct'] = df['Wins'] / (df['Wins'] + df['Losses'])
    
    # Calculate availability if not already present
    if 'Availability' not in df.columns:
        df['Availability'] = df['GP'] / 82
    
    # Calculate SalaryPct using the correct Salary Cap column
    salary_cap_column = 'Salary_Cap_Inflated' if use_inflated_data else 'Salary Cap'
    if salary_cap_column not in df.columns:
        raise KeyError(f"The '{salary_cap_column}' column is missing in the dataset.")
    df['SalaryPct'] = df['Salary'] / df[salary_cap_column]
    
    return df

def calculate_vorp_salary_ratio(df):
    df['Salary_M'] = df['Salary'] / 1e6
    if 'VORP' in df.columns:
        df['VORP_Salary_Ratio'] = df['VORP'] / df['Salary_M']
    else:
        print("Warning: 'VORP' column not found. VORP/Salary ratio cannot be calculated.")
    return df

def cluster_career_trajectories(df):
    features = ['Age', 'Years of Service', 'PTS', 'TRB', 'AST', 'PER', 'WS', 'VORP']
    X = df[features]
    imputer = SimpleImputer(strategy='mean')
    X_imputed = imputer.fit_transform(X)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_imputed)
    kmeans = KMeans(n_clusters=5, random_state=42)
    df['Cluster'] = kmeans.fit_predict(X_scaled)
    
    # Adding cluster definitions
    cluster_definitions = {
        0: "Young Bench Players",
        1: "Rising Role Players",
        2: "Star Players",
        3: "Superstars",
        4: "Veteran Players"
    }
    
    df['Cluster_Definition'] = df['Cluster'].map(cluster_definitions)
    return df

if __name__ == "__main__":
    # from data_loader import load_data
    
    print("Loading data...")
    df = load_data(inflated=False)
    
    print("\nHandling missing values...")
    df = handle_missing_values(df)
    print("NaN values after handling:")
    print(df.isna().sum())
    
    print("\nPerforming feature engineering...")
    df = feature_engineering(df)
    print("New columns after feature engineering:", df.columns)
    
    print("\nCalculating VORP salary ratio...")
    df = calculate_vorp_salary_ratio(df)
    print("VORP salary ratio stats:")
    print(df['VORP_Salary_Ratio'].describe())
    
    print("\nClustering career trajectories...")
    df = cluster_career_trajectories(df)
    print("Cluster distribution:")
    print(df['Cluster_Definition'].value_counts())

Overwriting ../src/salary_predict/data_preprocessor.py


In [34]:
%%writefile ../src/salary_predict/model_trainer.py

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import Ridge, ElasticNet
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.feature_selection import RFE
from sklearn.impute import SimpleImputer
import joblib
from sklearn.inspection import permutation_importance
from data_loader import get_project_root, load_data
import os

def retrain_models(X, y, model_params):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    scaler = StandardScaler()
    X_train_scaled = pd.DataFrame(scaler.fit_transform(X_train), columns=X_train.columns)
    X_test_scaled = pd.DataFrame(scaler.transform(X_test), columns=X_test.columns)
    
    models = {}
    for name, (model, params) in model_params.items():
        grid_search = GridSearchCV(estimator=model, param_grid=params, cv=5, n_jobs=-1, scoring='neg_mean_squared_error')
        grid_search.fit(X_train_scaled, y_train)
        models[name] = grid_search.best_estimator_
    
    return models, X_train_scaled, X_test_scaled, y_train, y_test, scaler

def save_models(models, scaler, selected_features, inflated=False):
    root_dir = get_project_root()
    suffix = '_inflated' if inflated else ''
    model_name_mapping = {
        'Random Forest': 'Random_Forest',
        'Gradient Boosting': 'Gradient_Boosting',
        'Ridge Regression': 'Ridge_Regression',
        'ElasticNet': 'ElasticNet',
        'SVR': 'SVR',
        'Decision Tree': 'Decision_Tree'
    }
    for name, model in models.items():
        formatted_name = model_name_mapping[name]
        joblib.dump(model, os.path.join(root_dir, 'data', 'models', f'{formatted_name}_salary_prediction_model{suffix}.joblib'))
    joblib.dump(scaler, os.path.join(root_dir, 'data', 'models', f'scaler{suffix}.joblib'))
    joblib.dump(selected_features, os.path.join(root_dir, 'data', 'models', f'selected_features{suffix}.joblib'))

def evaluate_models(models, X_test, y_test):
    evaluations = {}
    for name, model in models.items():
        y_pred = model.predict(X_test)
        mse = mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        evaluations[name] = {"MSE": mse, "R²": r2}
    return evaluations

def retrain_and_save_models(use_inflated_data):
    # Load the appropriate data
    data = load_data(use_inflated_data)
    
    # Drop unnecessary columns
    columns_to_drop = ['2022 Dollars', 'Luxury Tax', '1st Apron', '2nd Apron', 'BAE', 'Standard /Non-Taxpayer', 'Taxpayer', 'Team Room /Under Cap']
    data = data.drop(columns=[col for col in columns_to_drop if col in data.columns])

    # Convert 'Season' to an integer if necessary
    if data['Season'].dtype == 'object':
        data['Season'] = data['Season'].str[:4].astype(int)

    # Handle missing values for numerical columns
    numerical_cols = data.select_dtypes(include=['float64', 'int64']).columns
    imputer = SimpleImputer(strategy='mean')
    data[numerical_cols] = imputer.fit_transform(data[numerical_cols])

    # Feature engineering
    data = feature_engineering(data, use_inflated_data)

    # Identify categorical and numerical columns
    categorical_cols = ['Player', 'Season', 'Position', 'Team']
    numerical_cols = data.columns.difference(categorical_cols + ['Salary', 'SalaryPct', 'Salary Cap', 'Salary_Cap_Inflated'])

    # One-hot encode categorical variables
    encoder = OneHotEncoder(drop='first', sparse=False)
    encoded_cats = pd.DataFrame(encoder.fit_transform(data[categorical_cols]), columns=encoder.get_feature_names_out(categorical_cols))

    # Combine the numerical and encoded categorical data
    data = pd.concat([data[numerical_cols], encoded_cats, data[['Player', 'Season', 'Salary', 'SalaryPct', 'Salary Cap', 'Salary_Cap_Inflated']]], axis=1)

    # Select initial features
    initial_features = ['Age', 'Years of Service', 'GP', 'PPG', 'APG', 'RPG', 'SPG', 'BPG', 'TOPG', 'FG%', '3P%', 'FT%', 'PER', 'WS', 'VORP', 'Availability'] + list(encoded_cats.columns)

    # Create a new DataFrame with only the features we're interested in and the target variable
    data_subset = data[initial_features + ['SalaryPct']].copy()

    # Drop rows with any missing values
    data_cleaned = data_subset.dropna()

    # Separate features and target variable
    X = data_cleaned[initial_features]
    y = data_cleaned['SalaryPct']

    # Perform feature selection
    rfe = RFE(estimator=RandomForestRegressor(n_estimators=100, random_state=42), n_features_to_select=10)
    rfe = rfe.fit(X, y)
    selected_features = [feature for feature, selected in zip(initial_features, rfe.support_) if selected]

    print("Selected features by RFE:", selected_features)

    X = data_cleaned[selected_features]
    y = data_cleaned['SalaryPct']

    # Split the data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Scale the features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Define models with updated parameters
    models = {
        'Random_Forest': RandomForestRegressor(random_state=42),
        'Gradient_Boosting': GradientBoostingRegressor(random_state=42),
        'Ridge_Regression': Ridge(),
        'ElasticNet': ElasticNet(max_iter=10000),
        'SVR': SVR(),
        'Decision_Tree': DecisionTreeRegressor(random_state=42)
    }

    # Define parameter grids
    param_grids = {
        'Random_Forest': {
            'n_estimators': [50, 100, 200],
            'max_features': ['sqrt', 'log2'],
            'max_depth': [8, 10, 12],
            'min_samples_split': [5, 10, 15],
            'min_samples_leaf': [1, 2, 4]
        },
        'Gradient_Boosting': {
            'n_estimators': [100, 200, 300],
            'learning_rate': [0.01, 0.05, 0.1],
            'max_depth': [3, 4, 5],
            'min_samples_split': [2, 5, 10],
            'min_samples_leaf': [1, 2, 4],
            'subsample': [0.8, 0.9, 1.0]
        },
        'Ridge_Regression': {'alpha': [0.1, 1.0, 10.0, 100.0]},
        'ElasticNet': {'alpha': [0.1, 1.0, 10.0], 'l1_ratio': [0.1, 0.5, 0.9]},
        'SVR': {'C': [0.1, 1, 10], 'epsilon': [0.1, 0.2, 0.5]},
        'Decision_Tree': {'max_depth': [6, 8, 10], 'min_samples_split': [2, 5, 10], 'min_samples_leaf': [1, 2, 4]}
    }

    # Train and evaluate models
    best_models = {}
    evaluations = {}
    for name, model in models.items():
        print(f"Training {name}...")
        grid_search = GridSearchCV(estimator=model, param_grid=param_grids[name], cv=5, n_jobs=-1, scoring='neg_mean_squared_error')
        grid_search.fit(X_train_scaled, y_train)
        best_models[name] = grid_search.best_estimator_
        
        # Cross-validation
        cv_scores = cross_val_score(best_models[name], X_train_scaled, y_train, cv=5, scoring='neg_mean_squared_error')
        print(f"{name} - Best params: {grid_search.best_params_}")
        print(f"{name} - Cross-validation MSE: {-cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
        
        # Test set performance
        y_pred = best_models[name].predict(X_test_scaled)
        mse = mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        print(f"{name} - Test MSE: {mse:.4f}, R²: {r2:.4f}")
        
        evaluations[name] = {"MSE": mse, "R²": r2}
        
        # Feature importance
        if name in ['Random_Forest', 'Gradient_Boosting', 'Decision_Tree']:
            importances = best_models[name].feature_importances_
            feature_importance = pd.DataFrame({'feature': selected_features, 'importance': importances})
            feature_importance = feature_importance.sort_values('importance', ascending=False)
            print(f"\n{name} - Top 5 important features:")
            print(feature_importance.head())
        else:
            perm_importance = permutation_importance(best_models[name], X_test_scaled, y_test, n_repeats=10, random_state=42)
            feature_importance = pd.DataFrame({'feature': selected_features, 'importance': perm_importance.importances_mean})
            feature_importance = feature_importance.sort_values('importance', ascending=False)
            print(f"\n{name} - Top 5 important features (Permutation Importance):")
            print(feature_importance.head())
        
        # Save the model
        root_dir = get_project_root()
        suffix = '_inflated' if use_inflated_data else ''
        model_filename = os.path.join(root_dir, 'data', 'models', f'{name}_salary_prediction_model{suffix}.joblib')
        joblib.dump(best_models[name], model_filename)
        print(f"{name} model saved to '{model_filename}'")

    # Identify the best overall model
    best_model_name = min(evaluations, key=lambda x: evaluations[x]['MSE'])
    best_model = best_models[best_model_name]

    print(f"Best overall model: {best_model_name}")

    # Save the scaler, selected features, and best model name
    root_dir = get_project_root()
    suffix = '_inflated' if use_inflated_data else ''
    
    scaler_filename = os.path.join(root_dir, 'data', 'models', f'scaler{suffix}.joblib')
    joblib.dump(scaler, scaler_filename)
    print(f"Scaler saved to '{scaler_filename}'")
    
    selected_features_filename = os.path.join(root_dir, 'data', 'models', f'selected_features{suffix}.joblib')
    joblib.dump(selected_features, selected_features_filename)
    print(f"Selected features saved to '{selected_features_filename}'")
    
    with open(os.path.join(root_dir, 'data', 'models', f'best_model_name{suffix}.txt'), 'w') as f:
        f.write(best_model_name)

    return best_model_name, best_model, evaluations, selected_features, scaler, data[salary_cap_column].max()

if __name__ == "__main__":
    print("Retraining models...")
    best_model_name, best_model, evaluations, selected_features, scaler, max_salary_cap = retrain_and_save_models(use_inflated_data=False)
    
    print(f"\nBest model: {best_model_name}")
    print("\nModel evaluations:")
    for model, metrics in evaluations.items():
        print(f"{model}:")
        print(f"  MSE: {metrics['MSE']:.4f}")
        print(f"  R²: {metrics['R²']:.4f}")
    
    print("\nSelected features:")
    print(selected_features)
    
    print(f"\nMax salary cap: ${max_salary_cap:,.2f}")

Overwriting ../src/salary_predict/model_trainer.py


In [35]:
%%writefile ../src/salary_predict/predictor.py

import joblib
from data_loader import get_project_root
from data_preprocessor import feature_engineering
from sklearn.impute import SimpleImputer
import os

def load_model_and_scaler(model_name, inflated=False):
    root_dir = get_project_root()
    suffix = '_inflated' if inflated else ''
    
    # Convert model name to a consistent format
    model_name = model_name.replace(' ', '_')
    
    if 'Best' in model_name:
        model_file_name = f'{model_name}_salary_prediction_model{suffix}.joblib'
    else:
        model_file_name = f'{model_name}_salary_prediction_model{suffix}.joblib'
    
    model_path = os.path.join(root_dir, 'data', 'models', model_file_name)
    
    if not os.path.exists(model_path):
        # Try alternative naming conventions
        alternative_names = [
            f'{model_name.lower()}_salary_prediction_model{suffix}.joblib',
            f'{model_name.upper()}_salary_prediction_model{suffix}.joblib',
            f'{model_name.capitalize()}_salary_prediction_model{suffix}.joblib'
        ]
        
        for alt_name in alternative_names:
            alt_path = os.path.join(root_dir, 'data', 'models', alt_name)
            if os.path.exists(alt_path):
                model_path = alt_path
                break
        else:
            raise FileNotFoundError(f"The model file for '{model_name}' does not exist. Tried the following paths:\n"
                                    f"- {model_path}\n" + "\n".join(f"- {os.path.join(root_dir, 'data', 'models', name)}" for name in alternative_names))

    model = joblib.load(model_path)
    scaler = joblib.load(os.path.join(root_dir, 'data', 'models', f'scaler{suffix}.joblib'))
    selected_features = joblib.load(os.path.join(root_dir, 'data', 'models', f'selected_features{suffix}.joblib'))
    return model, scaler, selected_features

def make_predictions(df, model, scaler, selected_features, season, use_inflated_data, max_salary_cap):
    df = df[df['Season'] == season].copy()
    df = feature_engineering(df, use_inflated_data)
    df['Age'] += 1
    df['Season'] += 1
    
    if not all(feature in df.columns for feature in selected_features):
        missing_features = [f for f in selected_features if f not in df.columns]
        raise ValueError(f"Missing features in dataframe: {missing_features}")
    
    X = df[selected_features]
    
    imputer = SimpleImputer(strategy='mean')
    X_imputed = imputer.fit_transform(X)
    X_scaled = scaler.transform(X_imputed)
    
    df.loc[:, 'Predicted_Salary_Pct'] = model.predict(X_scaled)
    
    salary_cap_column = 'Salary_Cap_Inflated' if use_inflated_data else 'Salary Cap'
    
    if salary_cap_column not in df.columns:
        raise ValueError(f"Salary cap column '{salary_cap_column}' not found in dataframe")
    
    df.loc[:, 'Predicted_Salary'] = df['Predicted_Salary_Pct'] * df[salary_cap_column]
    df.loc[:, 'Salary_Change'] = df['Predicted_Salary'] - df['Salary']
    
    return df


if __name__ == "__main__":
    from data_loader import load_data
    
    print("Loading model...")
    model, scaler, selected_features = load_model_and_scaler('Random_Forest', inflated=False)
    
    print("\nLoading data...")
    df = load_data(inflated=False)
    
    print("\nMaking predictions...")
    season = df['Season'].max()
    predictions = make_predictions(df, model, scaler, selected_features, season, use_inflated_data=False, max_salary_cap=df['Salary Cap'].max())
    
    print("\nPredictions shape:", predictions.shape)
    print("\nFirst few rows of predictions:")
    print(predictions[['Player', 'Salary', 'Predicted_Salary', 'Salary_Change']].head())
    
    print("\nNaN values in predictions:")
    print(predictions.isna().sum())

Overwriting ../src/salary_predict/predictor.py


In [36]:
%%writefile ../src/salary_predict/app.py

import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from data_loader import load_data, load_predictions, get_project_root
from data_preprocessor import handle_missing_values, feature_engineering, calculate_vorp_salary_ratio, cluster_career_trajectories
from predictor import load_model_and_scaler, make_predictions
from model_trainer import retrain_and_save_models
from sklearn.metrics import mean_squared_error, r2_score
import os
import joblib
from champ_percentile_ranks import calculate_percentiles, analyze_team_percentiles, get_champions
from data_loader import load_predictions, get_project_root

def filter_by_position(df, selected_positions):
    if not selected_positions:
        return df
    return df[df['Position'].apply(lambda x: any(pos in x.split('-') for pos in selected_positions))]

def format_salary_df(df):
    formatted_df = df.copy()
    salary_columns = ['Salary', 'Predicted_Salary', 'Salary_Change']
    
    for col in salary_columns:
        if col in formatted_df.columns:
            formatted_df[col] = formatted_df[col].apply(lambda x: f"${x/1e6:.2f}M")
    
    return formatted_df[['Player', 'Position', 'Age', 'Salary', 'Predicted_Salary', 'Salary_Change']]

def load_selected_model(model_name, use_inflated_data):
    try:
        model, scaler, selected_features = load_model_and_scaler(model_name, use_inflated_data)
        df = load_data(use_inflated_data)
        df = feature_engineering(df, use_inflated_data)
        df = handle_missing_values(df)
        
        X = df[selected_features]
        y = df['SalaryPct']
        X_scaled = scaler.transform(X)
        
        y_pred = model.predict(X_scaled)
        mse = mean_squared_error(y, y_pred)
        r2 = r2_score(y, y_pred)
        
        salary_cap_column = 'Salary_Cap_Inflated' if use_inflated_data else 'Salary Cap'
        max_salary_cap = df[salary_cap_column].max()
        
        return model_name, model, mse, r2, selected_features, scaler, max_salary_cap
    except Exception as e:
        st.error(f"Error in load_selected_model: {str(e)}")
        raise

def find_best_model(use_inflated_data):
    root_dir = get_project_root()
    suffix = '_inflated' if use_inflated_data else ''
    
    with open(os.path.join(root_dir, 'data', 'models', f'best_model_name{suffix}.txt'), 'r') as f:
        best_model_name = f.read().strip()
    
    return load_selected_model(best_model_name, use_inflated_data)



def load_champions_data():
    root_dir = get_project_root()
    champions_file = os.path.join(root_dir, 'data', 'processed', 'nba_champions.csv')
    return pd.read_csv(champions_file)

RELEVANT_STATS = ['PTS', 'TRB', 'AST', 'FG%', '3P%', 'FT%', 'PER', 'WS', 'VORP']

def calculate_team_percentiles(team_players):
    team_percentiles = {}
    for stat in RELEVANT_STATS:
        if stat in team_players.columns:
            values = team_players[stat].values
            team_percentiles[stat] = {
                'min': np.min(values),
                'max': np.max(values),
                'mean': np.mean(values),
                'std': np.std(values),
                'above_average': np.sum(values > np.mean(values)),
                'total_players': len(values)
            }
    return team_percentiles

def analyze_trade(players1, players2, predictions_df):
    group1_data = predictions_df[predictions_df['Player'].isin(players1)]
    group2_data = predictions_df[predictions_df['Player'].isin(players2)]
    
    group1_percentiles = calculate_team_percentiles(group1_data)
    group2_percentiles = calculate_team_percentiles(group2_data)
    
    return {
        'group1': {
            'players': group1_data,
            'percentiles': group1_percentiles,
            'salary_before': group1_data['Previous_Season_Salary'].sum(),
            'salary_after': group1_data['Predicted_Salary'].sum(),
        },
        'group2': {
            'players': group2_data,
            'percentiles': group2_percentiles,
            'salary_before': group2_data['Previous_Season_Salary'].sum(),
            'salary_after': group2_data['Predicted_Salary'].sum(),
        }
    }


def plot_salary_distribution(df):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    sns.histplot(df['Salary_M'], bins=30, kde=True, ax=ax1)
    ax1.set_title('Distribution of NBA Player Salaries (in Millions)')
    ax1.set_xlabel('Salary (in Millions)')
    sns.boxplot(y='Salary_M', x='Position', data=df, ax=ax2)
    ax2.set_title('NBA Player Salaries by Position (in Millions)')
    ax2.set_xlabel('Position')
    ax2.set_ylabel('Salary (in Millions)')
    plt.xticks(rotation=45)
    return fig

def plot_age_vs_salary(df):
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.scatterplot(x='Age', y='Salary_M', hue='Position', data=df, ax=ax)
    ax.set_title('Age vs Salary (in Millions)')
    ax.set_xlabel('Age')
    ax.set_ylabel('Salary (in Millions)')
    return fig

def plot_vorp_vs_salary(df):
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.scatterplot(x='VORP', y='Salary_M', hue='Position', size='Age', data=df, ax=ax)
    ax.set_title('VORP vs Salary')
    ax.set_xlabel('VORP')
    ax.set_ylabel('Salary (in Millions)')
    return fig

def plot_career_clusters(df):
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.scatterplot(x='Age', y='Salary_M', hue='Cluster_Definition', style='Position', data=df, ax=ax)
    ax.set_title('Career Clusters: Age vs Salary')
    ax.set_xlabel('Age')
    ax.set_ylabel('Salary (in Millions)')
    return fig

def plot_salary_change_distribution(filtered_df):
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.histplot(filtered_df['Salary_Change'] / 1e6, bins=30, kde=True, ax=ax)
    ax.set_title('Distribution of Predicted Salary Changes')
    ax.set_xlabel('Salary Change (in Millions)')
    ax.set_ylabel('Count')
    return fig

def plot_player_comparison(comparison_df):
    fig, ax = plt.subplots(figsize=(12, 6))
    comparison_df['Salary_M'] = comparison_df['Predicted_Salary'] / 1e6
    sns.barplot(x='Player', y='Salary_M', data=comparison_df, ax=ax)
    ax.set_title('Predicted Salaries for Selected Players')
    ax.set_xlabel('Player')
    ax.set_ylabel('Predicted Salary (in Millions)')
    plt.xticks(rotation=45, ha='right')
    return fig

def plot_performance_metrics_comparison(df, selected_players):
    metrics = ['PTS', 'TRB', 'AST', 'PER', 'WS', 'VORP']
    metrics_df = df[df['Player'].isin(selected_players)][['Player'] + metrics]
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    for i, metric in enumerate(metrics):
        sns.barplot(x='Player', y=metric, data=metrics_df, ax=axes[i//3, i%3])
        axes[i//3, i%3].set_title(f'{metric} Comparison')
        axes[i//3, i%3].set_xticklabels(axes[i//3, i%3].get_xticklabels(), rotation=45, ha='right')
    plt.tight_layout()
    return fig

def plot_salary_difference_distribution(filtered_df):
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.histplot(filtered_df['Salary_Difference'] / 1e6, bins=30, kde=True, ax=ax)
    ax.set_title('Distribution of Salary Differences')
    ax.set_xlabel('Salary Difference (in Millions)')
    ax.set_ylabel('Count')
    return fig

def plot_category_analysis(avg_predictions, category):
    fig, ax = plt.subplots(figsize=(12, 6))
    avg_predictions[['Salary', 'Predicted_Salary']].plot(kind='bar', ax=ax)
    ax.set_title(f'Average Actual vs Predicted Salary by {category}')
    ax.set_ylabel('Salary')
    plt.xticks(rotation=45)
    return fig

def plot_model_evaluation(df, y_pred, model_choice):
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.scatter(df['SalaryPct'], y_pred, alpha=0.5)
    ax.plot([df['SalaryPct'].min(), df['SalaryPct'].max()], [df['SalaryPct'].min(), df['SalaryPct'].max()], 'r--', lw=2)
    ax.set_xlabel("Actual Salary Percentage")
    ax.set_ylabel("Predicted Salary Percentage")
    ax.set_title(f"Actual vs Predicted Salary Percentage - {model_choice}")
    return fig

def plot_feature_importance(feature_importance, model_choice):
    fig, ax = plt.subplots(figsize=(10, 6))
    feature_importance.plot(x='feature', y='importance', kind='bar', ax=ax)
    ax.set_title(f"Feature Importances - {model_choice}")
    ax.set_xlabel("Features")
    ax.set_ylabel("Importance")
    plt.xticks(rotation=45, ha='right')
    return fig

def plot_trade_impact(trade_analysis, team1, team2):
    fig, ax = plt.subplots(figsize=(12, 6))
    x = range(len(RELEVANT_STATS))
    width = 0.35
    
    group1_stats = [trade_analysis['group1']['percentiles'].get(stat, {}).get('mean', 0) for stat in RELEVANT_STATS]
    group2_stats = [trade_analysis['group2']['percentiles'].get(stat, {}).get('mean', 0) for stat in RELEVANT_STATS]
    
    ax.bar([i - width/2 for i in x], group1_stats, width, label=team1)
    ax.bar([i + width/2 for i in x], group2_stats, width, label=team2)
    
    ax.set_ylabel('Value')
    ax.set_title('Trade Impact on Team Stats')
    ax.set_xticks(x)
    ax.set_xticklabels(RELEVANT_STATS, rotation=45, ha='right')
    ax.legend()
    
    return fig

def main():
    st.sidebar.title("Navigation")
    sections = ["Introduction", "Data Overview", "Exploratory Data Analysis", 
                "Advanced Analytics", "Salary Predictions", "Player Comparisons", 
                "Salary Comparison", "Analysis by Categories", "Model Selection and Evaluation",
                "Model Retraining", "Trade Analysis"]
    choice = st.sidebar.radio("Go to", sections)
    
    # Update model selection dropdown
    model_options = ['Random_Forest', 'Gradient_Boosting', 'Ridge_Regression', 'ElasticNet', 'SVR', 'Decision_Tree']
    selected_model = st.sidebar.selectbox("Select Model", model_options)

    use_inflated_data = st.sidebar.checkbox("Use Inflation Adjusted Salary Cap Data")
    st.sidebar.markdown("### All Salaries in Millions")

    # Load the selected model
    model_name, model, mse, r2, selected_features, scaler, max_salary_cap = load_selected_model(selected_model, use_inflated_data)

    # Display model info in sidebar
    st.sidebar.markdown(f"### Selected Model: {model_name}")
    st.sidebar.write(f"MSE: {mse:.4f}")
    st.sidebar.write(f"R²: {r2:.4f}")

    df = load_data(use_inflated_data)
    df = feature_engineering(df)
    df = handle_missing_values(df)

    seasons = df['Season'].unique()
    selected_season = st.sidebar.selectbox("Select Season", seasons)
    
    df = calculate_vorp_salary_ratio(df)
    df = cluster_career_trajectories(df)

    if model and selected_features and scaler:
        predictions = make_predictions(df, model, scaler, selected_features, selected_season, use_inflated_data, max_salary_cap)
    else:
        predictions = None
    
    
    if choice == "Introduction":
        st.title("Enhanced NBA Player Salary Analysis")
        st.write("Welcome to the NBA Salary Analysis and Prediction App! This project aims to provide comprehensive insights into NBA player salaries, advanced metrics, and future salary predictions based on historical data. Here's a detailed breakdown of the steps involved in creating this app:")

        st.subheader("Data Collection")
        
        st.write("### Salary Data")
        st.write("- **Sources**:")
        st.write("  - [Basketball Reference Salary Cap History](https://www.basketball-reference.com/contracts/salary-cap-history.html)")
        st.write("- **Description**: Data on the NBA salary cap from various seasons, along with maximum salary details for players based on years of service.")

        st.write("### Advanced Metrics")
        st.write("- **Source**: [Basketball Reference](https://www.basketball-reference.com)")
        st.write("- **Description**: Advanced player metrics such as Player Efficiency Rating (PER), True Shooting Percentage (TS%), and Value Over Replacement Player (VORP) were scraped using BeautifulSoup.")

        st.write("### Player Salaries and Team Data")
        st.write("- **Source**: [Hoopshype](https://hoopshype.com)")
        st.write("- **Description**: Player salary data was scraped for multiple seasons, with detailed information on individual player earnings and team salaries.")

        st.subheader("Data Processing")

        st.write("### Inflation Adjustment")
        st.write("- **Source**: [Adjusting for Inflation in Python](https://medium.com/analytics-vidhya/adjusting-for-inflation-when-analysing-historical-data-with-python-9d69a8dcbc27)")
        st.write("- **Description**: Adjusted historical salary data for inflation to provide a consistent basis for comparison.")

        st.write("### Data Aggregation")
        st.write("- Steps:")
        st.write("  1. Loaded salary data and combined it with team standings and advanced metrics.")
        st.write("  2. Merged multiple data sources to create a comprehensive dataset containing player performance, salaries, and advanced metrics.")

        st.subheader("Model Training and Prediction")

        st.write("### Data Preprocessing")
        st.write("- Implemented functions to handle missing values, perform feature engineering, and calculate key metrics such as points per game (PPG), assists per game (APG), and salary growth.")

        st.write("### Model Selection")
        st.write("- Utilized various machine learning models including Random Forest, Gradient Boosting, Ridge Regression, and others to predict future player salaries.")
        st.write("- Employed grid search for hyperparameter tuning and selected the best-performing models based on evaluation metrics like Mean Squared Error (MSE) and R² score.")

        st.write("### Feature Importance and Clustering")
        st.write("- Analyzed feature importance to understand the key factors influencing player salaries.")
        st.write("- Clustered players into categories based on career trajectories, providing insights into player development and value.")

        st.subheader("App Development")

        st.write("### Streamlit App")
        st.write("- Built an interactive app using Streamlit to visualize data, perform exploratory data analysis, and make salary predictions.")
        st.write("- **Features**:")
        st.write("  - **Data Overview**: Display raw and processed data.")
        st.write("  - **Exploratory Data Analysis**: Visualize salary distributions, age vs. salary, and other key metrics.")
        st.write("  - **Advanced Analytics**: Analyze VORP to salary ratio, career trajectory clusters, and other advanced metrics.")
        st.write("  - **Salary Predictions**: Predict future salaries and compare actual vs. predicted values.")
        st.write("  - **Player Comparisons**: Compare selected players based on predicted salaries and performance metrics.")
        st.write("  - **Model Evaluation**: Evaluate different models and display their performance metrics and feature importance.")

        st.write("### Data Files")
        st.write("- Stored processed data and model files in a structured format to facilitate easy loading and analysis within the app.")

        st.subheader("Improvements:")
        
        st.write("### Add Injury Data:")
        st.write("- **Source**: [Kaggle NBA Injury Stats 1951-2023](https://www.kaggle.com/datasets/loganlauton/nba-injury-stats-1951-2023/data)")
        st.write("- **Description**: This dataset provides detailed statistics on NBA injuries from 1951 to 2023, allowing for analysis of player availability and its impact on performance and salaries.")

        st.subheader("Conclusion")

        st.write("This app provides a robust platform for analyzing NBA player salaries, understanding the factors influencing earnings, and predicting future salaries based on historical data and advanced metrics. Explore the app to gain insights into player performance, salary trends, and much more.")


    elif choice == "Data Overview":
        st.header("Data Overview")
        st.write("First few rows of the current season's dataset:")
        st.write(df[['Player', 'Season', 'Salary', 'GP', 'PTS', 'TRB', 'AST', 'Injured', 'Injury_Periods', 'Position', 'Age', 'Team', 'Years of Service', 'PER', 'WS', 'VORP', 'Salary Cap', 'Salary_Cap_Inflated']].head())
        st.write("\nFirst few rows of the predictions dataset:")
        st.write(predictions.head())
        
        if use_inflated_data:
            st.write("\nNote: This data uses inflated salary cap projections.")
        else:
            st.write("\nNote: This data uses the standard salary cap.")

    elif choice == "Exploratory Data Analysis":
        st.header("Exploratory Data Analysis")
        
        st.subheader("Salary Distribution")
        df['Salary_M'] = df['Salary'] / 1e6
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        sns.histplot(df['Salary_M'], bins=30, kde=True, ax=ax1)
        ax1.set_title('Distribution of NBA Player Salaries (in Millions)')
        ax1.set_xlabel('Salary (in Millions)')
        sns.boxplot(y='Salary_M', x='Position', data=df, ax=ax2)
        ax2.set_title('NBA Player Salaries by Position (in Millions)')
        ax2.set_xlabel('Position')
        ax2.set_ylabel('Salary (in Millions)')
        plt.xticks(rotation=45)
        st.pyplot(fig)

        st.subheader("Age vs Salary")
        fig, ax = plt.subplots(figsize=(10, 6))
        sns.scatterplot(x='Age', y='Salary_M', hue='Position', data=df, ax=ax)
        ax.set_title('Age vs Salary (in Millions)')
        ax.set_xlabel('Age')
        ax.set_ylabel('Salary (in Millions)')
        st.pyplot(fig)

    elif choice == "Advanced Analytics":
        st.header("Advanced Analytics")

        st.subheader("VORP to Salary Ratio")
        fig, ax = plt.subplots(figsize=(12, 6))
        sns.scatterplot(x='VORP', y='Salary_M', hue='Position', size='Age', data=df, ax=ax)
        ax.set_title('VORP vs Salary')
        ax.set_xlabel('VORP')
        ax.set_ylabel('Salary (in Millions)')
        st.pyplot(fig)

        top_value_players = df.nlargest(10, 'VORP_Salary_Ratio')
        st.write("Top 10 Value Players (Highest VORP to Salary Ratio):")
        st.write(top_value_players[['Player', 'Position', 'Age', 'Salary_M', 'VORP', 'VORP_Salary_Ratio']])

        st.subheader("Career Trajectory Clusters")
        fig, ax = plt.subplots(figsize=(12, 6))
        sns.scatterplot(x='Age', y='Salary_M', hue='Cluster_Definition', style='Position', data=df, ax=ax)
        ax.set_title('Career Clusters: Age vs Salary')
        ax.set_xlabel('Age')
        ax.set_ylabel('Salary (in Millions)')
        st.pyplot(fig)

        st.write("Average Metrics by Cluster:")
        cluster_averages = df.groupby('Cluster_Definition')[['Age', 'Salary_M', 'PTS', 'TRB', 'AST', 'PER', 'WS', 'VORP']].mean()
        st.write(cluster_averages)


    elif choice == "Salary Predictions":
        st.header("Salary Predictions")
        
        if model:
            predictions = make_predictions(df, model, scaler, selected_features, selected_season, use_inflated_data, max_salary_cap)
            
            st.sidebar.subheader("Filter by Position")
            unique_positions = sorted(set([pos for sublist in predictions['Position'].str.split('-') for pos in sublist]))
            selected_positions = st.sidebar.multiselect("Select positions", unique_positions, default=unique_positions)
            filtered_df = filter_by_position(predictions, selected_positions)
            
            st.write("### Top 10 Highest Predicted Salaries")
            st.write(format_salary_df(filtered_df.nlargest(10, 'Predicted_Salary')))
            
            st.subheader("Salary Change Distribution")
            fig, ax = plt.subplots(figsize=(12, 6))
            sns.histplot(filtered_df['Salary_Change'] / 1e6, bins=30, kde=True, ax=ax)
            ax.set_title('Distribution of Predicted Salary Changes')
            ax.set_xlabel('Salary Change (in Millions)')
            ax.set_ylabel('Count')
            st.pyplot(fig)

            if use_inflated_data:
                st.write("\nNote: These predictions are based on inflated salary cap projections.")
            else:
                st.write("\nNote: These predictions are based on the standard salary cap.")
        else:
            st.warning("No model found. Please select a valid model or retrain the models.")


            
    elif choice == "Player Comparisons":
        st.header("Player Comparisons")
        
        players = sorted(predictions['Player'].unique())
        selected_players = st.multiselect("Select players to compare", players)
        
        if selected_players:
            comparison_df = predictions[predictions['Player'].isin(selected_players)]
            st.write(format_salary_df(comparison_df))
            
            fig, ax = plt.subplots(figsize=(12, 6))
            comparison_df['Salary_M'] = comparison_df['Predicted_Salary'] / 1e6
            sns.barplot(x='Player', y='Salary_M', data=comparison_df, ax=ax)
            ax.set_title('Predicted Salaries for Selected Players')
            ax.set_xlabel('Player')
            ax.set_ylabel('Predicted Salary (in Millions)')
            plt.xticks(rotation=45, ha='right')
            st.pyplot(fig)

            st.subheader("Performance Metrics Comparison")
            metrics = ['PTS', 'TRB', 'AST', 'PER', 'WS', 'VORP']
            metrics_df = df[df['Player'].isin(selected_players)][['Player'] + metrics]
            
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            for i, metric in enumerate(metrics):
                sns.barplot(x='Player', y=metric, data=metrics_df, ax=axes[i//3, i%3])
                axes[i//3, i%3].set_title(f'{metric} Comparison')
                axes[i//3, i%3].set_xticklabels(axes[i//3, i%3].get_xticklabels(), rotation=45, ha='right')
            plt.tight_layout()
            st.pyplot(fig)

    elif choice == "Salary Comparison":
        st.header("Salary Comparison")

        st.sidebar.subheader("Filter by Position")
        unique_positions = sorted(set([pos for sublist in predictions['Position'].str.split('-') for pos in sublist]))
        selected_positions = st.sidebar.multiselect("Select positions", unique_positions, default=unique_positions)
        filtered_df = filter_by_position(predictions, selected_positions)

        filtered_df['Salary_Difference'] = filtered_df['Salary'] - filtered_df['Predicted_Salary']
        
        top_overpaid_count = st.sidebar.slider("Number of Top Overpaid Players to Display", min_value=1, max_value=50, value=10)
        top_underpaid_count = st.sidebar.slider("Number of Top Underpaid Players to Display", min_value=1, max_value=50, value=10)
        
        st.subheader("Overpaid vs Underpaid Players")
        st.write("### Top Overpaid Players")
        st.write(format_salary_df(filtered_df.nlargest(top_overpaid_count, 'Salary_Difference')))
        
        st.write("### Top Underpaid Players")
        st.write(format_salary_df(filtered_df.nsmallest(top_underpaid_count, 'Salary_Difference')))
        
        st.subheader("Salary Difference Distribution")
        fig, ax = plt.subplots(figsize=(12, 6))
        sns.histplot(filtered_df['Salary_Difference'] / 1e6, bins=30, kde=True, ax=ax)
        ax.set_title('Distribution of Salary Differences')
        ax.set_xlabel('Salary Difference (in Millions)')
        ax.set_ylabel('Count')
        st.pyplot(fig)
        
    elif choice == "Analysis by Categories":
        st.header("Analysis by Categories")
        
        category = st.selectbox("Select Category", ['Position', 'Age', 'Team'])
        
        if category == 'Age':
            predictions['Age_Group'] = pd.cut(predictions['Age'], bins=[0, 25, 30, 35, 100], labels=['Under 25', '25-30', '30-35', 'Over 35'])
            category = 'Age_Group'
        
        avg_predictions = predictions.groupby(category)[['Salary', 'Predicted_Salary', 'Salary_Change']].mean()
        
        st.write(f"Average Salaries by {category}")
        st.write(avg_predictions)
        
        fig, ax = plt.subplots(figsize=(12, 6))
        avg_predictions[['Salary', 'Predicted_Salary']].plot(kind='bar', ax=ax)
        ax.set_title(f'Average Actual vs Predicted Salary by {category}')
        ax.set_ylabel('Salary')
        plt.xticks(rotation=45)
        st.pyplot(fig)

    elif choice == "Model Selection and Evaluation":
        st.header("Model Selection and Evaluation")
        
        models = ['Random_Forest', 'Gradient_Boosting', 'Ridge_Regression', 'ElasticNet', 'SVR', 'Decision_Tree']
        model_choice = st.selectbox("Select Model to Evaluate", models)
        
        if model_choice:
            try:
                # Load the model, scaler, and selected features
                model, scaler, selected_features = load_model_and_scaler(model_choice, use_inflated_data)
                
                # Ensure we have the correct features in our dataframe
                df_features = df[selected_features]
                
                # Scale the features
                X_scaled = scaler.transform(df_features)
                
                # Make predictions
                y_pred = model.predict(X_scaled)
                
                # Calculate evaluation metrics
                mse = mean_squared_error(df['SalaryPct'], y_pred)
                r2 = r2_score(df['SalaryPct'], y_pred)
                
                st.write(f"### Evaluation Metrics for {model_choice}")
                st.write(f"Mean Squared Error (MSE): {mse:.4f}")
                st.write(f"R² Score: {r2:.4f}")
                
                # Create a scatter plot of actual vs predicted values
                fig, ax = plt.subplots(figsize=(10, 6))
                ax.scatter(df['SalaryPct'], y_pred, alpha=0.5)
                ax.plot([df['SalaryPct'].min(), df['SalaryPct'].max()], [df['SalaryPct'].min(), df['SalaryPct'].max()], 'r--', lw=2)
                ax.set_xlabel("Actual Salary Percentage")
                ax.set_ylabel("Predicted Salary Percentage")
                ax.set_title(f"Actual vs Predicted Salary Percentage - {model_choice}")
                st.pyplot(fig)
                
                # Display feature importances for tree-based models
                if model_choice in ['Random_Forest', 'Gradient_Boosting', 'Decision_Tree']:
                    feature_importance = pd.DataFrame({
                        'feature': selected_features,
                        'importance': model.feature_importances_
                    }).sort_values('importance', ascending=False)
                    
                    st.write("### Feature Importances")
                    st.write(feature_importance)
                    
                    fig, ax = plt.subplots(figsize=(10, 6))
                    feature_importance.plot(x='feature', y='importance', kind='bar', ax=ax)
                    ax.set_title(f"Feature Importances - {model_choice}")
                    ax.set_xlabel("Features")
                    ax.set_ylabel("Importance")
                    plt.xticks(rotation=45, ha='right')
                    st.pyplot(fig)
                
            except FileNotFoundError as e:
                st.error("Error: Model file not found")
                st.error(str(e))
                st.error("Please make sure the model file exists and the name is correct.")
            except Exception as e:
                st.error(f"An unexpected error occurred: {str(e)}")
                st.error("Please check the logs for more details and ensure all required files are present.")

    # Update the Model Retraining section in your main() function
    elif choice == "Model Retraining":
        st.header("Model Retraining")
        
        if st.button("Retrain Models"):
            try:
                with st.spinner("Retraining models... This may take a while."):
                    best_model_name, best_model, evaluations, selected_features, scaler, max_salary_cap = retrain_and_save_models(use_inflated_data)
                
                st.success("Retraining completed successfully!")
                st.write(f"Best model: {best_model_name}")
                st.write("Model performance:")
                for model, metrics in evaluations.items():
                    st.write(f"{model}:")
                    st.write(f"  MSE: {metrics['MSE']:.4f}")
                    st.write(f"  R²: {metrics['R²']:.4f}")
                
                st.write("All models have been retrained and saved. The best model will be used for future predictions.")
                
                # Refresh the app to use the new models
                st.rerun()
            except Exception as e:
                st.error(f"An error occurred during model retraining: {str(e)}")
                st.error("Please check the logs for more details.")

    elif choice == "Trade Analysis":
        st.header("Trade Analysis")
        
        try:
            # Load the necessary data
            use_inflated_data_trade = st.checkbox("Use Inflation Adjusted Salary Cap Data", key="trade_analysis_inflated_data")
            
            predictions = load_predictions(use_inflated_data_trade)
            
            if 'Team' not in predictions.columns:
                st.error("The 'Team' column is missing from the predictions data. Please check your data loading process.")
            else:
                # Team filter
                all_teams = sorted(predictions['Team'].unique())
                team1 = st.selectbox("Select Team 1", all_teams, key="trade_analysis_team1")
                team2 = st.selectbox("Select Team 2", all_teams, index=1, key="trade_analysis_team2")
                
                predictions1 = predictions[predictions['Team'] == team1]
                predictions2 = predictions[predictions['Team'] == team2]
                
                st.subheader(f"Available Players for {team1}")
                st.write(predictions1[['Player', 'Age', 'Position', 'Previous_Season_Salary', 'Predicted_Salary', 'PTS', 'TRB', 'AST']])
                
                st.subheader(f"Available Players for {team2}")
                st.write(predictions2[['Player', 'Age', 'Position', 'Previous_Season_Salary', 'Predicted_Salary', 'PTS', 'TRB', 'AST']])
                
                # Player selection
                players1 = st.multiselect(f"Select players from {team1}", predictions1['Player'].unique(), key="trade_analysis_players1")
                players2 = st.multiselect(f"Select players from {team2}", predictions2['Player'].unique(), key="trade_analysis_players2")
                
                if st.button("Analyze Trade", key="trade_analysis_button"):
                    if not players1 or not players2:
                        st.warning("Please select players from both teams.")
                    else:
                        combined_predictions = pd.concat([predictions1, predictions2])
                        trade_analysis = analyze_trade(players1, players2, combined_predictions)
                        
                        st.subheader("Trade Impact")
                        
                        for group, data in trade_analysis.items():
                            st.write(f"\n{group.upper()} Analysis:")
                            st.write(f"Total Salary Before: ${data['salary_before']/1e6:.2f}M")
                            st.write(f"Total Salary After: ${data['salary_after']/1e6:.2f}M")
                            st.write(f"Salary Change: ${(data['salary_after'] - data['salary_before'])/1e6:.2f}M")
                            
                            st.write("\nPlayer Details:")
                            st.write(data['players'][['Player', 'Age', 'Position', 'Previous_Season_Salary', 'Predicted_Salary', 'Salary_Change', 'PTS', 'TRB', 'AST', 'PER', 'WS', 'VORP']])
                            
                            st.write("\nTeam Percentiles:")
                            for stat in RELEVANT_STATS:
                                if stat in data['percentiles']:
                                    st.write(f"{stat}: {data['percentiles'][stat]['mean']:.2f}")
                        
                        st.subheader("Salary Comparison")
                        group1_trade_salary = trade_analysis['group1']['salary_after']
                        group2_trade_salary = trade_analysis['group2']['salary_after']
                        salary_difference = abs(group1_trade_salary - group2_trade_salary)
                        
                        st.write(f"{team1} is trading ${group1_trade_salary/1e6:.2f}M in salary")
                        st.write(f"{team2} is trading ${group2_trade_salary/1e6:.2f}M in salary")
                        st.write(f"Salary difference: ${salary_difference/1e6:.2f}M")
                        
                        if salary_difference > 5e6:  # Assuming a 5 million threshold for salary matching
                            st.warning("The salaries in this trade are not well-matched. This may not be a valid trade under NBA rules.")
                        else:
                            st.success("The salaries in this trade are well-matched.")
                        
                        # Visualize the trade impact
                        fig = plot_trade_impact(trade_analysis, team1, team2)
                        st.pyplot(fig)

        except FileNotFoundError as e:
            st.error(f"Error: {str(e)}")
            st.error("Please make sure the predictions file exists in the correct location.")
        except KeyError as e:
            st.error(f"Error: {str(e)}")
            st.error("Please check your data files and ensure they contain all required columns.")
        except Exception as e:
            st.error(f"An unexpected error occurred: {str(e)}")
            st.error("Please check the data and try again.")

if __name__ == "__main__":
    main()

Overwriting ../src/salary_predict/app.py
