In [None]:
import sys
import os
import pickle
import random
from collections import Counter, defaultdict
import pandas as pd

# Add Tokenizer path to sys.path
sys.path.append("../Tokenizer")
from bpe_tokenizer import BPETokenizer

# ==========================================================
# CONSTANTS
# ==========================================================
START = "<START>"
EOT = "\uFFF2"   # Must match preprocessing special token

# ==========================================================
# LOAD TRAINED TOKENIZER
# ==========================================================
# Using the fixed BPETokenizer class
tokenizer = BPETokenizer.load("../Tokenizer/bpe_tokenizer.pkl")

print("Tokenizer loaded successfully.")
print("Vocab size:", len(tokenizer.vocab))

# Ensure START and EOT tokens exist in vocab
if START not in tokenizer.vocab:
    new_id = max(tokenizer.vocab.values() or [-1]) + 1
    tokenizer.vocab[START] = new_id
    tokenizer.id_to_token[new_id] = START

if EOT not in tokenizer.vocab:
    new_id = max(tokenizer.vocab.values() or [-1]) + 1
    tokenizer.vocab[EOT] = new_id
    tokenizer.id_to_token[new_id] = EOT

# ==========================================================
# TRIGRAM LANGUAGE MODEL
# ==========================================================
class TrigramLanguageModel:
    def __init__(self, lambda1=0.1, lambda2=0.3, lambda3=0.6):
        # Check if lambdas sum to 1
        total = lambda1 + lambda2 + lambda3
        self.lambda1 = lambda1 / total
        self.lambda2 = lambda2 / total
        self.lambda3 = lambda3 / total

        self.unigram = Counter()
        self.bigram = Counter()
        self.trigram = Counter()
        
        # For better performance and zero-handling
        self.bigram_totals = Counter()  # N(w1, w2, *)
        self.unigram_totals = Counter() # N(w1, *)
        
        self.total_tokens = 0
        self.vocab = set()

    def train(self, corpus, start_id):
        print("Training model...")
        for tokens in corpus:
            # Pad with START tokens
            # For trigram, we need two START tokens to predict the first real token
            full_tokens = [start_id, start_id] + tokens
            self.total_tokens += len(full_tokens)

            for i in range(len(full_tokens)):
                w3 = full_tokens[i]
                self.unigram[w3] += 1
                self.vocab.add(w3)

                if i >= 1:
                    w2 = full_tokens[i-1]
                    self.bigram[(w2, w3)] += 1
                    self.unigram_totals[w2] += 1

                if i >= 2:
                    w1, w2 = full_tokens[i-2], full_tokens[i-1]
                    self.trigram[(w1, w2, w3)] += 1
                    self.bigram_totals[(w1, w2)] += 1

        print("Training completed.")
        print("Model vocabulary size:", len(self.vocab))
        print("Total tokens processed:", self.total_tokens)

    def unigram_prob(self, w):
        return self.unigram[w] / self.total_tokens if self.total_tokens else 0

    def bigram_prob(self, w1, w2):
        denom = self.unigram_totals[w1]
        return self.bigram[(w1, w2)] / denom if denom else 0

    def trigram_prob(self, w1, w2, w3):
        denom = self.bigram_totals[(w1, w2)]
        return self.trigram[(w1, w2, w3)] / denom if denom else 0

    def interpolated_prob(self, w1, w2, w3):
        return (
            self.lambda1 * self.unigram_prob(w3) +
            self.lambda2 * self.bigram_prob(w2, w3) +
            self.lambda3 * self.trigram_prob(w1, w2, w3)
        )

    def generate(self, tokenizer, max_length=150):
        start_id = tokenizer.vocab[START]
        eot_id = tokenizer.vocab[EOT]

        # Initial state
        result_tokens = [start_id, start_id]
        
        # Precompute possible next tokens for efficiency
        vocab_list = list(self.vocab)

        for _ in range(max_length):
            w1, w2 = result_tokens[-2], result_tokens[-1]
            
            # Calculate probabilities for all tokens in vocab
            probs = []
            for w3 in vocab_list:
                probs.append(self.interpolated_prob(w1, w2, w3))
            
            # Normalize (should already be close to 1 but for safety)
            total_p = sum(probs)
            if total_p == 0:
                break
            probs = [p / total_p for p in probs]
            
            # Sample next token
            next_token = random.choices(vocab_list, weights=probs, k=1)[0]
            result_tokens.append(next_token)

            if next_token == eot_id:
                break

        # Decode excluding START tokens
        text = tokenizer.decode(result_tokens[2:])
        return text.replace(EOT, "").strip()

# ==========================================================
# LOAD PREPROCESSED DATA
# ==========================================================
try:
    df = pd.read_csv("../PreProcessing/urdu_stories_processed.csv", encoding="utf-8")
    stories = df["content"].dropna().astype(str).tolist()
    print("Stories loaded:", len(stories))
except FileNotFoundError:
    print("Error: Preprocessed data file not found.")
    stories = []

# ==========================================================
# ENCODE CORPUS
# ==========================================================
print("Encoding corpus...")
tokenized_corpus = []
for story in stories:
    encoded = tokenizer.encode(story)
    # Append EOT if not present
    if encoded and encoded[-1] != tokenizer.vocab[EOT]:
        encoded.append(tokenizer.vocab[EOT])
    if encoded:
        tokenized_corpus.append(encoded)

print("Stories encoded:", len(tokenized_corpus))

# ==========================================================
# TRAIN TRIGRAM MODEL
# ==========================================================
model = TrigramLanguageModel(lambda1=0.05, lambda2=0.15, lambda3=0.8)
if tokenized_corpus:
    model.train(tokenized_corpus, tokenizer.vocab[START])

# ==========================================================
# SAVE MODEL
# ==========================================================
with open("trigram_model.pkl", "wb") as f:
    pickle.dump(model, f)

print("Trigram model saved successfully âœ…")

# ==========================================================
# GENERATE SAMPLE STORY
# ==========================================================
if tokenized_corpus:
    print("\nGenerated Story:\n")
    sample_story = model.generate(tokenizer)
    print(sample_story)
