In [None]:
import google.generativeai as genai
from dotenv import load_dotenv
import os
import torch
from captum.attr import IntegratedGradients
import random

# Load API key
load_dotenv(dotenv_path="../../.env")
key = os.getenv("GENAI_API_KEY")

genai.configure(api_key=key)
llm_model = genai.GenerativeModel("gemini-1.5-flash")

def format_prompt(results, players_id_list):
    """
    Format the LLM prompt for the concatenated explanations for all rows in results.
    """
    separator = "\n---\n"  # Custom separator between player explanations
    combined_prompts = []
    for i, result in enumerate(results):
        # Build the features section for the current player
        features_str = f"{players_id_list[i]}\n"
        for i, (feature, score) in enumerate(result.items()):
            feature_name = feature if isinstance(feature, str) else ", ".join(feature)  # Handle interaction terms
            features_str += f"{i+1}. **Feature:** {feature_name}\n" \
                            f"   - **Type:** {'Interaction' if isinstance(feature, tuple) else 'Main'}\n" \
                            f"   - **Importance Score:** {score}\n"

        # Add the explanation prompt for the current player
        player_prompt = f"**Prediction Explanation for Fantasy Score**\n\n" \
                        f"**Top contributing features:**\n\n" \
                        f"{features_str}"
        combined_prompts.append(player_prompt)

    # Join all player prompts with the separator
    return f"{separator}".join(combined_prompts) + f"Give priority to the features with lower index since they have a bigger score, and output the individual text explanation for each player, separated by a separator \"{separator}\", in just 2-3 lines, without mentioning the technical details related to the model (for example, don't directly mention the actual feature names, you might give the english understanding of it), focus only on a genuine cricket explanation. Also, there is no need for telling which row's collective response you give, if you output per row one by one, even with repetitions. Note that you have to compare the features with the other players and give positive explanation for first 11 players and negative explanation for the last 11 players."

def generate_explainability_text(player_name, feature1, feature2, feature3, positivity):
    positive_templates = [
        f"{player_name}'s high predicted fantasy score is attributed to their {feature1}, {feature2}, and {feature3}.",
        f"Key factors like {feature1}, {feature2}, and {feature3} significantly contribute to {player_name}'s performance prediction.",
        f"The model identifies {feature1}, {feature2}, and {feature3} as crucial metrics for {player_name}'s projected success.",
        f"{player_name}'s standout performance metrics in {feature1}, {feature2}, and {feature3} drive their fantasy score prediction.",
        f"With strengths in {feature1}, {feature2}, and {feature3}, {player_name} emerges as a top contender for high fantasy points."
    ]
    negative_templates = [
        f"{player_name}'s strong performance in {feature1}, {feature2}, and {feature3} was commendable, but competition was tough, leading to their exclusion.",
        f"With high marks in {feature1}, {feature2}, and {feature3}, {player_name} was close to making the team but was ultimately not selected.",
        f"{player_name} had great metrics in {feature1}, {feature2}, and {feature3}, but others outperformed in critical areas.",
        f"Despite excelling in {feature1}, {feature2}, and {feature3}, {player_name} could not secure a spot on the team due to stiff competition.",
        f"While {player_name} stood out for {feature1}, {feature2}, and {feature3}, their overall profile didn't meet the team's needs this time."
    ] 
    if positivity == 1:
        return random.choice(positive_templates)
    else:
        return random.choice(negative_templates)


def backup_explanations(results, players_id_list):
    feature_descriptions = {
        'Total Innings Played': "number of innings played across all matches",
        'last_10_matches_Fours_sum': "fours hit in the last 10 matches",
        'last_10_matches_Sixes_sum': "sixes scored in the last 10 matches",
        'last_10_matches_Outs_sum': "times dismissed in the last 10 matches",
        'last_10_matches_fantasy_points_sum': "fantasy points accumulated in the last 10 matches",
        'last_10_matches_Dot Balls_sum': "dot balls faced in the last 10 matches",
        'last_10_matches_Balls Faced_sum': "total balls faced in the last 10 matches",
        'last_10_matches_Innings Bowled_sum': "innings bowled in the last 10 matches",
        'last_10_matches_Balls Bowled_sum': "balls delivered in the last 10 matches",
        'last_10_matches_derived_Dot Ball%': "percentage of dot balls in the last 10 matches",
        'last_10_matches_derived_Batting Strike Rate': "strike rate in the last 10 matches",
        'last_10_matches_derived_Batting Avg': "batting average in the last 10 matches",
        'last_10_matches_derived_Mean Score': "average score across the last 10 matches",
        'last_10_matches_derived_Boundary%': "percentage of runs from boundaries in the last 10 matches",
        'last_10_matches_derived_Mean Balls Faced': "average balls faced per match in the last 10 matches",
        'last_10_matches_derived_Dismissal Rate': "rate of dismissals in the last 10 matches",
        'last_10_matches_derived_Bowling Dot Ball%': "percentage of dot balls bowled in the last 10 matches",
        'last_10_matches_derived_Boundary Given%': "percentage of runs conceded from boundaries in the last 10 matches",
        'last_10_matches_derived_Bowling Avg': "average runs conceded per wicket in the last 10 matches",
        'last_10_matches_derived_Bowling Strike Rate': "deliveries per wicket in the last 10 matches",
        'Opponent_total_matches_sum': "matches played against the opponent",
        'Venue_total_matches_sum': "matches played at the venue",
        'last_10_matches_Runsgiven_sum': "runs conceded in the last 10 matches",
        'last_10_matches_Dot Balls Bowled_sum': "dot balls bowled in the last 10 matches",
        'last_10_matches_Foursgiven_sum': "fours conceded in the last 10 matches",
        'last_10_matches_Sixesgiven_sum': "sixes conceded in the last 10 matches",
        'venue_avg_runs': "average runs scored at the venue",
        'venue_avg_wickets': "average wickets taken at the venue",
        'last_10_matches_Extras_sum': "extras conceded in the last 10 matches",
        'last_10_matches_centuries_sum': "centuries scored in the last 10 matches",
        'last_10_matches_half_centuries_sum': "half-centuries scored in the last 10 matches",
        'last_10_matches_opponent_Runs_sum': "runs scored against the opponent in the last 10 matches",
        'last_10_matches_venue_Runs_sum': "runs scored at the venue in the last 10 matches",
        'last_10_matches_Wickets_sum': "wickets taken in the last 10 matches",
        'last_10_matches_LBWs_sum': "LBWs in the last 10 matches",
        'last_10_matches_Maiden Overs_sum': "maiden overs bowled in the last 10 matches",
        'last_10_matches_Stumpings_sum': "stumpings in the last 10 matches",
        'last_10_matches_Catches_sum': "catches taken in the last 10 matches",
        'last_10_matches_direct run_outs_sum': "direct run-outs in the last 10 matches",
        'last_10_matches_indirect run_outs_sum': "indirect run-outs in the last 10 matches",
        'last_10_matches_match_type_Innings Batted_sum': "innings batted in the specific match type during the last 10 matches",
        'last_10_matches_match_type_Innings Bowled_sum': "innings bowled in the specific match type during the last 10 matches",
        'match_type_total_matches': "total matches played in the specific match type",
        'batting_fantasy_points': "batting fantasy points",
        'bowling_fantasy_points': "bowling fantasy points",
        'fielding_fantasy_points': "fielding fantasy points",
        'last_10_matches_venue_Wickets_sum': "wickets taken at the venue in the last 10 matches",
        'last_10_matches_Opposition_Innings Bowled_sum': "innings bowled against the opposition in the last 10 matches",
        'last_10_matches_match_type_Wickets_sum': "wickets taken in the specific match type during the last 10 matches",
        'last_10_matches_Opposition_Wickets_sum': "wickets taken against the opposition in the last 10 matches",
        'last_10_matches_derived_Economy Rate': "economy rate in the last 10 matches",
        'last_10_matches_Venue_Innings Bowled_sum': "innings bowled at the venue in the last 10 matches",
        'last_10_matches_lbw_bowled_sum': "LBWs and bowled dismissals in the last 10 matches",
        'last_10_matches_Bowleds_sum': "bowled dismissals in the last 10 matches"
    }
    
    explanations = []

    for i, result in enumerate(results):
        player_name = players_id_list[i]
        top_features = result[:3] 
        feature1, feature2, feature3 = [f[0] for f in top_features]
        if(feature1 in feature_descriptions and feature2 in feature_descriptions and feature3 in feature_descriptions):
            feature1 = feature_descriptions[feature1]
            feature2 = feature_descriptions[feature2]
            feature3 = feature_descriptions[feature3]
            if i < len(results)/2:
                positivity = 1
            else:
                positivity = 0
            
            explanation_text = generate_explainability_text(
                player_name, 
                feature1, 
                feature2, 
                feature3,
                positivity
            )
            explanations.append(explanation_text)
        else:
            print("Feature not found in feature_descriptions!")
    return explanations
 
# Example inputs
results = [
    [('last_10_matches_Fours_sum', 0.9), ('last_10_matches_Sixes_sum', 0.8), ('last_10_matches_Outs_sum', 0.7)],
    [('last_10_matches_Dot Balls_sum', 0.85), ('last_10_matches_Balls Faced_sum', 0.75), ('last_10_matches_derived_Bowling Avg', 0.65)]
]
players_id_list = ['Player 1', 'Player 2']

# Generate explanations
explanations = backup_explanations(results, players_id_list)

# Output explanations
for explanation in explanations:
    print(explanation)



def generate_explanations(results, players_id_list):
    """
    Generate explanations for all rows in the result array using a single LLM call.
    """
    system_prompt = """
    You are a cricket fantasy prediction explainability assistant. Your task is to explain the fantasy score prediction based on the most important features. 
    These features may include main terms (individual player statistics) and interaction terms (combinations of player statistics that jointly affect the prediction). 
    For each feature, provide an explanation that is relevant to the cricket match context, and prioritize features based on their importance scores. Note that there are **22** players. 
    Your task includes giving positive explainable texts for the first 11, since they will be selected in the team, and negative explainable texts for the next 11, as they are left out from the team. 
    Example of positive and negative explainable texts can be: **Player's high predicted fantasy score is attributed to their feature1, feature2, and feature3.** and **Player's strong performance in feature1, feature2, and feature3 was commendable, but competition was tough, leading to their exclusion.**  
    Just give the result no starting text.
    """
    prompt = format_prompt(results, players_id_list)
    full_prompt = system_prompt + "\n" + prompt

    print(full_prompt)
    response = llm_model.generate_content(full_prompt)
    return response
# Function to get the feature names from the indices

def get_feature_names_from_indices(indices, main_features):

    feature_names = []
    for feature in indices:
        if isinstance(feature, tuple):  # Interaction term
            feature_names.append(tuple([main_features[i] for i in feature]))
        else:  # Main term
            feature_names.append(main_features[feature])
    return feature_names

def get_top_features(explainability_scores, columns, k=5):
    topk_values, topk_indices = torch.topk(explainability_scores, k=k, dim=1)

    players_data = []
    for i in range(len(topk_values)):
        topk_columns = [columns[idx] for idx in topk_indices[i].tolist()]
        players_data.append(dict(zip(topk_columns, topk_values[i].tolist())))
    
    return players_data

    # Map indices to features for each row
    # topk_features = [[additive_features[idx] for idx in indices] for indices in topk_indices]

    # # Combine scores and features for each row
    # results = []
    # for row_idx in range(len(topk_values)):
    #     row_results = list(zip(topk_values[row_idx].tolist(), get_feature_names_from_indices(topk_features[row_idx], all_features)))
    #     results.append(row_results)
    
    # return results

def explain_outputs(model, X, columns, players_id_list):
    integrated_gradients = IntegratedGradients(model)
    attributions = integrated_gradients.attribute(X, target=0)
    
    results = get_top_features(attributions, columns, 5)
    response = generate_explanations(results, players_id_list)
    explainations = response.text
    # Split explanations by the separator and print them individually
    separator = "\n---\n"
    explaination_list = explainations.split(separator)
    return explaination_list

The model identifies fours hit in the last 10 matches, sixes scored in the last 10 matches, and times dismissed in the last 10 matches as crucial metrics for Player 1's projected success.
Player 2's strong performance in dot balls faced in the last 10 matches, total balls faced in the last 10 matches, and average runs conceded per wicket in the last 10 matches was commendable, but competition was tough, leading to their exclusion.
