In [0]:
from abc import ABC, abstractmethod
import time
import mlflow
from contextlib import contextmanager
import json
import pandas as pd
import datetime
import re
import openai
from openai import OpenAI
import gspread
import random
import logging
import os
from unidecode import unidecode
from pathlib import Path
import pyspark
from pyspark.sql.functions import *
from functools import reduce
from typing import *
import tiktoken

#from rouge_score import rouge_scorer
#from bert_score import score as bert_score



from general_config import *
EXPERIMENT_NAME = "/Users/krista@jamcity.com/centralized_loc_translation_run"

#from config import EXPERIMENT_NAME

request_example =  {"RowFingerprint":"",
                    "Timestamp":"",
                    "SubmitterEmail":"",
                    "DueDate":"",
                    'LocType':"", 
                    'Game':"", 
                    'TargetLanguages':"", 
                    "URL":"",
                    "QAFlag":""}

cfg_example = {
    "input": {"required_tabs": ["ios","android"], "ios_header_rows": 3, "android_header_rows": 3},
    "char_limit_policy": "strict"
}
aso_cfg_example = {
    "input": {"required_tabs": ["ios","android"], "ios_header_rows": 3, "android_header_rows": 3},
    "char_limit_policy": "strict",
    "output_sheets":["formatted_ios", "formatted_android", "long_results"]}

class MLTracker:
    def __init__(self, request, language=None, experiment_name=EXPERIMENT_NAME):
        self.request = request
        self.language = language
        self.experiment_name = experiment_name
        self.run = None
        self._events = []

    @staticmethod
    def _parse_langs(raw_langs: str | None):
        if not raw_langs:
            return []
        parts = [x.strip() for x in raw_langs.replace(";", ",").split(",")]
        return [x for x in parts if x]

    def start(self, nested: bool = False):
        mlflow.set_experiment(self.experiment_name)

        # Run name: request row id + optional language suffix
        base_name = self.request.get("Game")
        run_id = self.request.get("RowFingerprint") or self.request.get("RequestID", "run")
        base_name = f"{base_name}:{run_id}"
        run_name = f"{base_name}:{self.language}" if self.language else base_name

        self.run = mlflow.start_run(run_name=run_name, nested=nested)

        # Tags (request-level + language-level)
        tags = {
            "RowFingerprint": self.request.get("RowFingerprint",""),
            "LocType": self.request.get("LocType",""),
            "Game": self.request.get("Game",""),
            "InputSheetURL": self.request.get("URL",""),
            "status": "RUNNING",
        }
        if self.language:
            tags["Language"] = self.language
        mlflow.set_tags(tags)

        # Params
        if not self.language:
            # parent: log the list of languages once
            langs = self._parse_langs(self.request.get("TargetLanguages"))
            mlflow.log_params({
                "TargetLanguages": ",".join(langs),
                "QAFlag": str(self.request.get("QAFlag", False)),
            })
        else:
            # child run: log just this language
            mlflow.log_param("Language", self.language)

        return self.run.info.run_id

    def end(self, succeeded: bool, err_text: str | None = None):
        mlflow.set_tag("status", "SUCCEEDED" if succeeded else "FAILED")
        if self._events:
            mlflow.log_text("\n".join(self._events), "logs/events.txt")
        if err_text:
            mlflow.log_text(err_text, "logs/error.txt")
        mlflow.end_run()

    def event(self, msg: str):
        ts = time.strftime("%Y-%m-%d %H:%M:%S")
        mlflow.set_tag("last_event", msg)
        self._events.append(f"[{ts}] {msg}")

    def dict(self, obj, path: str):
        mlflow.log_dict(obj, path)

    def text(self, txt: str, path: str):
        mlflow.log_text(txt, path)

    def metrics(self, d: dict, step: int | None = None):
        mlflow.log_metrics({k: float(v) for k, v in d.items()}, step=step)

    def params(self, d: dict):
        mlflow.log_params({k: str(v) for k, v in d.items()})

    @contextmanager
    def step(self, name: str):
        start = time.time()
        self.event(f"Step start: {name}")
        try:
            yield
        finally:
            dur = time.time() - start
            self.metrics({f"duration_sec.{name}": dur})

    @contextmanager
    def nested(self, name: str, tags: dict | None = None, params: dict | None = None):
        # You can still do sub-steps inside a language if you want (e.g., per-batch)
        with mlflow.start_run(run_name=name, nested=True):
            if tags: mlflow.set_tags(tags)
            if params: self.params(params)
            yield


class LocalizationRun(ABC):
    def __init__(self, 
                 request, 
                 gsheet_client=None, 
                 gpt_client=None, 
                 cfg=None,
                 tracker: MLTracker | None = None):
        """
        request: dict with fields like RequestID, LocType, Game, TargetLanguages, etc.
        gsheet_client: injected dependency for Google Sheets
        gpt_client: injected dependency for GPT API
        cfg: loaded YAML/JSON config for this LocType
        """
        self.request = request
        self.gc = gsheet_client
        self.gpt = gpt_client
        self.cfg = cfg or {}
        self.artifacts = {}
        self.tracker =  MLTracker(request, language=None)
        self.lang_trackers = {}
        
    def run(self):
        parent_run_id = self.tracker.start()  # parent request-level run
        try:
            with self.tracker.step("validate_inputs"):
                self.validate_inputs()

            with self.tracker.step("load_inputs"):
                data = self.load_inputs()
                self.tracker.dict({"preview": str(data)[:2000]}, "snapshots/input_preview.json")

            with self.tracker.step("preprocess"):
                prepped = self.preprocess(data)
                self.tracker.metrics({"rows.prepped": len(prepped)})

            with self.tracker.step("build_prompts"):
                prompts = self.build_prompts(prepped)
                self.tracker.metrics({"prompts.count": len(prompts)})

            # Build per-language trackers (child runs) using request languages
            langs = MLTracker._parse_langs(self.request.get("TargetLanguages"))
            # TODO:
            # This could be remedied by just using self.languages instad of self.request.get("TargetLanguages") and parsing
            for lang in langs:
                self.lang_trackers[lang] = MLTracker(
                    request=self.request,
                    language=lang,
                    experiment_name=self.tracker.experiment_name  # same experiment
                )

            outputs = self.translate(prompts)  # tracked per language below

            with self.tracker.step("postprocess"):
                final_rows = self.postprocess(outputs)
                self.tracker.metrics({"rows.final": len(final_rows)})

            with self.tracker.step("write_outputs"):
                self.write_outputs(final_rows)

            self.tracker.end(succeeded=True)
            return {"status": "SUCCEEDED", "run_id": parent_run_id}
        except Exception as e:
            self.tracker.end(succeeded=False, err_text=str(e))
            raise

    # Shared, tracked translate that spins one child run per language
    def translate(self, groups):
        parent = self.tracker
        #batch_size = int(batch_size or self.cfg.get("batch_size", 50))

        total_prompt_tokens = 0
        total_completion_tokens = 0
        results = []

        with parent.step("translate"):
            # TODO: we actually already have them batched in a preprocessing step
            #groups = self._group_prompts_for_translation(prompt_batch) # remove!!!
            #groups = self.groups 
            parent.metrics({"translate.groups": len(groups)})

            for group_name, prompts in groups.items():
                # pick the language tracker matching group_name; fallback to a generic child tracker
                lang_tracker = self.lang_trackers.get(group_name) or MLTracker(
                    request=self.request,
                    language=group_name,
                    experiment_name=parent.experiment_name
                )
                # Start the child run as nested=True
                lang_run_id = lang_tracker.start(nested=True)
                try:
                    lang_total_p = 0
                    lang_total_c = 0

                    with lang_tracker.step("api_batch"):
                        out, usage = self._call_model_batch(prompts)  # subclass hook
                        results.append(out)

                        lang_total_p = (usage or {}).get("prompt_tokens", 0)
                        lang_total_c = (usage or {}).get("completion_tokens", 0)


                    # per-language metric
                    lang_tracker.metrics({
                        "items.total": len(prompts),
                        "tokens.prompt.total": lang_total_p,
                        "tokens.completion.total": lang_total_c,
                    })

                    # accumulate into parent
                    total_prompt_tokens += lang_total_p
                    total_completion_tokens += lang_total_c

                    lang_tracker.end(succeeded=True)

                    """
                    ##TODO: We dont need to loop by batch, only by language
                    for i in range(0, len(prompts), batch_size):
                        batch = prompts[i:i+batch_size]
                        with lang_tracker.step("api_batch"):
                            out, usage = self._call_model_batch(batch)  # subclass hook
                            results.extend(out)

                            p = (usage or {}).get("prompt_tokens", 0)
                            c = (usage or {}).get("completion_tokens", 0)
                            lang_total_p += p
                            lang_total_c += c

                            lang_tracker.metrics({
                                "items.batch": len(batch),
                                "tokens.prompt.batch": p,
                                "tokens.completion.batch": c,
                            })

                    # per-language rollup
                    lang_tracker.metrics({
                        "items.total": len(prompts),
                        "tokens.prompt.total": lang_total_p,
                        "tokens.completion.total": lang_total_c,
                    })
                    # accumulate into parent
                    total_prompt_tokens += lang_total_p
                    total_completion_tokens += lang_total_c

                    lang_tracker.end(succeeded=True)"""

                except Exception as e:
                    lang_tracker.end(succeeded=False, err_text=str(e))
                    raise

            # parent rollup
            parent.metrics({
                "tokens.prompt.total": total_prompt_tokens,
                "tokens.completion.total": total_completion_tokens,
            })

        return results
    
    """
    # Actually TODO: Lets alter this
    # Default grouping: by 'lang' key if present
    #def _group_prompts_for_translation(self, prompts):

        ###REMOVE THIS! It's already prepared before now!!
    #    groups = {}
    #    for p in prompts:
    #        key = p.get("lang", "default")
    #        groups.setdefault(key, []).append(p)
    """

    @abstractmethod
    def _call_model_batch(self, prompt):
        pass

    @abstractmethod
    def validate_inputs(self): 
        pass
    
    @abstractmethod
    def load_inputs(self):
        pass
    
    @abstractmethod
    def preprocess(self, data): 
        pass

    @abstractmethod
    def build_prompts(self, prepped): 
        pass

    @abstractmethod
    def postprocess(self, outputs): 
        pass

    @abstractmethod
    def write_outputs(self, post): 
        pass


class ASOLocalizer(LocalizationRun):

    def __init__(self, 
                 request, 
                 gsheet_client=None, 
                 gpt_client=None, 
                 cfg=None, 
                 tracker: MLTracker | None = None):
        super().__init__(request, gsheet_client, gpt_client, cfg, tracker)

        # Now subclass-specific initialization
        self.required_tabs = self.cfg.get("input", {}).get("required_tabs", [])
        self.char_limit_policy = self.cfg.get("char_limit_policy", "strict")
        # e.g., store header row counts for ios/android
        self.ios_header_rows = self.cfg.get("input", {}).get("ios_header_rows", 3)
        self.android_header_rows = self.cfg.get("input", {}).get("android_header_rows", 3)
        self.sh = None
        self.ios_wksht = None
        self.android_wksht = None
        self.ios_long_df = None
        self.ios_wide_df = None
        self.android_long_df = None
        self.android_wide_df = None


    def validate_inputs(self): 

        # Open url
        try:
            sh = self.gsheet_client.open_by_url(self.request.get("url"))
        except Exception as e:
            raise Exception(f"Invalid spreadsheet URL: {e}")

        wkshts = sh.worksheets()
        for i in self.required_tabs:
            if i not in [w.title for w in wkshts]:
                raise Exception(f"Required tab '{i}' not found in spreadsheet '{self.request.get('url')}'")
        self.sh = sh

        # Validate IOS formatting - if no data, leave the worksheet object as None
        ios_data =  self.sh.worksheet('ios').get_all_records()
        ios_rows = len(ios_data)
        if len(ios_data) >=4:
            self.ios_wksht = sh.worksheet('ios')
           

        #Validate Android Formatting
        android_data =  self.sh.worksheet('android').get_all_records()
        android_rows = len(android_data)
        if len(android_data)>=4:
            self.android_wksht = sh.worksheet('android')

        return
    
    #TODO: convert to wide
    def _convert_wide_to_long_inputs(self, df, type):

        df = spark.createDataFrame(df)
        if type == "ios":
            return
        
        if type == "android":

            return
        
        return

    def load_inputs(self):
        if self.ios_wksht:
            ios_data = self.ios_wksht.get_all_records()
            ios_headers, ios_vals = ios_data[0:2], ios_data[3:]
            self.ios_wide_df = pd.DataFrame(ios_data)
            #TODO: 
            self.ios_long_df = self._convert_wide_to_long_inputs(self.ios_wide_df,'ios')
        if self.android_wksht:
            self.android_wide_df = pd.DataFrame(self.ios_wksht.get_all_records())
            self.android_long_df = self._convert_wide_to_long_inputs(self.android_wide_df,'android')

        # Join the two dataframes 
        self.joined_long = pd.concat([self.ios_long_df, self.android_long_df], axis=0)

        return
    def _group_prompts_for_translation(self, prompts):
        #[{'lang':"","prompts":[]}]
        return groups
 
    def preprocess(self, data): 
        #prepped = [] # List of data objects to be translated, by language
        #for i in range(len(self.ios_long_df))
        return prepped
    
    def build_prompts(self, prepped): 
        # ASO Guidelines
        # Language Specific
        # Game Specific 
        # Add dump of prepped for each language
        
        prompts = []

        return

    """
    # ---- the helper you asked about ----
    def _translate_with_tracking(self, prompt_batch, batch_size: int = 50):
        tr = self.tracker

        total_prompt_tokens = 0
        total_completion_tokens = 0
        total_batches = math.ceil(len(prompt_batch)/batch_size)
        results = []

        with tr.step("translate"):  # overall translate timing
            # Example: group by language (if prompt objects carry a 'lang' key)
            groups = self._group_prompts_by_lang(prompt_batch)  # dict[lang] = list[prompts]

            for lang, prompts in groups.items():
                with tr.nested(f"translate:{lang}", tags={"lang": lang}, params={"count": len(prompts)}):
                    for b_i in range(0, len(prompts), batch_size):
                        batch = prompts[b_i:b_i+batch_size]
                        with tr.step(f"api_batch"):
                            # --- call your GPT client (returns outputs + usage) ---
                            out, usage = self.gpt.translate_batch(batch)
                            results.extend(out)

                            # usage is expected like: {"prompt_tokens": int, "completion_tokens": int}
                            p = usage.get("prompt_tokens", 0)
                            c = usage.get("completion_tokens", 0)
                            total_prompt_tokens += p
                            total_completion_tokens += c

                            # log batch-level metrics
                            tr.metrics({
                                "tokens.prompt.batch": p,
                                "tokens.completion.batch": c,
                                "items.batch": len(batch),
                            })

                    # helpful per-lang rollup
                    tr.metrics({
                        f"items.per_lang.{lang}": len(prompts),
                    })

        # global counters
        tr.metrics({
            "tokens.prompt.total": total_prompt_tokens,
            "tokens.completion.total": total_completion_tokens,
            "batches.total": total_batches,
            "items.total": len(prompt_batch),
        })

        # (Optional) save raw model responses as an artifact
        # tr.dict(results, "snapshots/model_outputs.json")

        return results

    def _group_prompts_by_lang(self, prompts):
        # e.g., each prompt is {"lang": "es_LA", "payload": {...}}
        groups = {}
        for p in prompts:
            lang = p.get("lang", "unknown")
            groups.setdefault(lang, []).append(p)
        return groups
    """
    def postprocess(self, outputs): return

    def write_outputs(self, post): return

class MarketingLocalizer(LocalizationRun):

    def __init__(self, 
                 request, 
                 gsheet_client=None, 
                 gpt_client=None, 
                 cfg=None, 
                 tracker: MLTracker | None = None):
         
        super().__init__(request, gsheet_client, gpt_client, cfg, tracker)

        # Now subclass-specific initialization
        self.required_tabs = self.cfg.get("input", {}).get("required_tabs", [])
        #self.input_headers = []
        #self.char_limit_policy = self.cfg.get("char_limit_policy", "")


    def validate_inputs(self): 
        return
    
    def load_inputs(self):
        return
 
    def preprocess(self, data): 
        return
    
    def build_prompts(self, prepped): return


    def translate(self, prompt_batch): return


    def postprocess(self, outputs): return

    def write_outputs(self, post): return

class CSLocalizer(LocalizationRun):

    def __init__(self, 
                 request, 
                 gsheet_client=None, 
                 gpt_client=None, 
                 cfg=None, 
                 tracker: MLTracker | None = None):
        super().__init__(request, gsheet_client, gpt_client, cfg, tracker)
    def validate_inputs(self): 
        return
    
    def load_inputs(self):
        return
 
    def preprocess(self, data): 
        return
    
    def build_prompts(self, prepped): return


    def translate(self, prompt_batch): return


    def postprocess(self, outputs): return

    def write_outputs(self, post): return


##############################################

from InGame_Config import * 

class InGameLocalizer(LocalizationRun):

    def __init__(self, 
                 request, 
                 gsheet_client=None, 
                 gpt_client=None, 
                 cfg=None, 
                 tracker: MLTracker | None = None):
         
        super().__init__(request, gsheet_client, gpt_client, cfg, tracker)

        self.required_tabs = self.cfg.get("input", {}).get("required_tabs", [])
        self.char_limit_policy = self.cfg.get("char_limit_policy", "")

    def validate_inputs(self):
        
        # Open Sheet
        try:
            self.sh = self.gsheet_client.open_by_url(self.request.get)
        except Exception as e:
            raise Exception(f"Error opening google sheet: {e}")
        
        #Open Tab
        try:
            self.wksht = self.sh.worksheet("input")
        except Exception as e:
            raise Exception(f"Error opening input tab: {e}")
        
        # Check if all required tabs are present
        wkshts = self.sh.worksheets()
        for tab in self.required_tabs:
            if tab not in [wksht.title for wksht in wkshts]:
                self.sh.add_worksheet(tab)
                # with expected header row
    
        return
    
    def load_inputs(self):

        data = self.wksht.get_all_values()
        self.input_headers = data.pop(0)
        self.data = data
        self.df = pd.DataFrame(data, columns=self.input_headers) #pandas DF

        return self.data
 
    def preprocess(self, data:List[str])->str: 

        ## Convert data to slug....
        prepped = json.dumps(self.df.to_dict(orient='records'))

        return prepped
    
    def _get_game_context(self):

        """ Helper function to get relevant context for in game localization for particular game """
        game = self.request.get('Game')
        self.game = game
        self.lang_specific_guidelines = GENRAL_LANG_SPECIFIC_GUIDELINES
        self.general_game_specific_guidelines = GENERAL_LANG_SPECIFIC_GUIDELINES

        # Specifics for games
        if game == "Panda Pop":
            self.game_description = self.general_game_specific_guidelines[game]
            self.lang_map = PP_LANG_MAP
        
            # game specific prompt inputs 
            self.ex_input = PP_EX_INPUT
            self.context_infer = PP_CONTEXT_INFER
            self.token_infer = PP_TOKEN_INFER
        
        if game == "Cookie Jam Blast":
            self.game_description = self.general_game_specific_guidelines[game]
            self.lang_map = CJB_LANG_MAP

            # game specific prompt inputs 
            self.ex_input = CJB_EX_INPUT
            self.context_infer = CJB_CONTEXT_INFER
            self.token_infer = CJB_TOKEN_INFER
            
        if game == "Genies & Gems":
            self.game_description = self.general_game_specific_guidelines[game]
            self.lang_map = GG_LANG_MAP
    
            # game specific prompt inputs 
            self.ex_input = GG_EX_INPUT
            self.context_infer = GG_CONTEXT_INFER
            self.token_infer = GG_TOKEN_INFER

        self.languages = list(self.lang_map.keys())
        self.lang_cds = list(self.lang_map.vaues())

    def _generate_prompt_helper(self, 
                                language:str, 
                                game:str, 
                                prepped:str)->List[Dict[str, Any]]:

        base = f""" 
            You are a professional game localizer translating for a popular mobile puzzle game called {self.game} by Jam City which is described as:
            {self.game_description}
            Please translate the in-game phrases provided below from English into {language}.
                •   Keep the translations natural, playful, and appropriate for a casual mobile gaming tone.
                •   Avoid overly formal or mechanical language.
                •   There is no strict character limit, but translations should not be egregiously longer than the original English text.
                •   {self.token_infer}
                •   {self.context_infer}
            If present, use the context to guide tone, word choice, or brevity — especially when the English phrase is vague or could be interpreted multiple ways.
            Example Inputs as a json string:
            json
                {self.ex_input}
            
            """
        base += f"""
            You MUST follow these language specific guidelines:
            {self.lang_specfic_guidelines[language]} 
            """

        lang_cd = self.lang_map[language]
    
        base += f"""
            Respond in **JSON format**, one object per row:
            json
            [
            {{ "token": "token_name_1", "{lang_cd}": "translated phrase 1" }},
            {{ "token": "token_name_2", "{lang_cd}": "translated phrase 2" }},
            ...
            ]\n\n
            """
    
        return [
            {"role": "system", "content": base},
            {"role": "user",   "content": prepped}
        ]

    def build_prompts(self, prepped:str)->Dict[str,List[Dict[str, Any]]: 
        self._get_game_contex()
        prompts = []

        # TODO: Make sure "langauges" is appropriately passed here
        for lang in self.languages:
           prompt = self._generate_prompt_helper(lang, self.game, self.prepped)
           prompts.append(prompt)
        
        self.prompts = prompts
        self.groups = dict(zip(self.languages, prompts))

        return groups

    #return ->Tuple[str, Dict[str, int]]
    def _call_model_batch(self, prompt:str):
        """
        Must return (outputs, usage_dict) where usage_dict includes:
          {'prompt_tokens': int, 'completion_tokens': int}
        """
        MODEL = "gpt-4o"
        temperature = 0.05

        response = self.gpt_client.chat.completions.create(
                model=MODEL, 
                messages=prompt,
                temperature=0.05  # adjust for creativity vs. stability
        )
    
        ### call GPT model for translation
        #GPT chat completions prompt
        #raw_results = self.gpt_model.

        output = response.choices[0].message.content
        usage = response.usage
        return (output, usage)

    
    def _parse_model_json_block(self, output:Dict[str,Any]):
        """
        Cleans and parses a JSON-like string from a model output wrapped in markdown code block.
        
        Args:
            raw_output (str): The raw output string, e.g., from GPT, wrapped with ```json ... ```
        
        Returns:
            list[dict]: Parsed JSON content as Python list of dictionaries.
            
        Raises:
            ValueError: If the cleaned string cannot be parsed as valid JSON.
        """
        try:
            # Strip markdown-style code block markers and leading/trailing whitespace
            cleaned = re.sub(r"^```json|```$", "", raw_output.strip(), flags=re.IGNORECASE).strip()

            # Replace escaped newlines (if necessary) and extra leading/trailing junk
            cleaned = cleaned.replace("\\n", "").replace("\n", "").strip()

            # Now parse
            loaded = json.loads(cleaned)
        except json.JSONDecodeError as e:
            raise ValueError(f"Could not parse JSON: {e}")

        if isinstance(loaded, str):
            try:
                return json.loads(loaded)
            except:
                raise ValueError(f"Could not parse JSON: {e}")
        else:
            return loaded


    def postprocess(self, 
                    outputs): 
        
        postprocessed_dfs = []
        for output in outputs:
            parsed = self._parse_model_json_block(output)
            returned_df = pd.DataFrame(parsed)
            postprocessed_dfs.append(returned_df)

        self.postprocessed_dfs = postprocessed_dfs 

        return self.postprocessed_dfs

    def _merge_outputs_by_language(self, 
                                   post: List[pd.DataFrame]):
        
        for i in post:
            self.df = self.df.merge(i, on=['token'],how='left')
        
        return self.df 

    def write_outputs(self, post:List[pd.DataFrame]): 

        results = self._merge_outputs_by_language(post)

        #TODO: May want to update this later so I'm not removing superfluous context - more about the formatting of the template
        results.drop(columns=['context'])
        
        wksht = self.sh.worksheet("output")

        out_data = results.values.tolist()
        data_range = f"A2:Q{len(out_data)+1}"

        wksht.batch_update([{'range':data_range, 'values':out_data}])

        return "Done!"

