In [None]:
# ==========================================
# Installing all the necessary dependencies
# ==========================================
print("Installing dependencies...")
!pip install torch torchvision torchaudio --quiet
!pip install transformers --quiet
!pip install shap --quiet
!pip install numpy --quiet
!pip install matplotlib --quiet
!pip install ipython --quiet
# Use bitsandbytes, not auto-gptq
!pip install --upgrade transformers bitsandbytes --quiet
!pip install indic-nlp-library jieba --quiet
!pip install huggingface_hub --quiet
!pip install gradio deep-translator langdetect --quiet
!pip install rich --quiet # For console formatting

In [None]:
# =========================
# Imports
# =========================
import torch
import shap
import numpy as np
import re
import jieba
from indicnlp.tokenize import indic_tokenize
# Import BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, BitsAndBytesConfig
from deep_translator import GoogleTranslator
from langdetect import detect
import warnings
from rich.console import Console
from rich.table import Table
import gradio as gr
import os
import pandas as pd

# Suppress warnings
warnings.filterwarnings("ignore")
jieba.setLogLevel(jieba.logging.WARN)

In [None]:
# =========================
# All the necessary models being declared
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

classification_model_name = "Yashaswini21/multilingual_new1_continued" # Using the model from previous script
# --- USING BITSANDBYTES ---
vicuna_model_name = "TheBloke/Wizard-Vicuna-7B-Uncensored-HF" # Use the HF model
vicuna_tokenizer_name = "lmsys/vicuna-7b-v1.5" # Use the correct base tokenizer
# --- END MODIFICATION ---

output_filter_repo = "khrshtt/XJailGuard"
output_filter_subfolder = "output_filter"
multi_turn_subfolder = "multi_turn"
use_auth = {}

In [None]:
# =======================================================
# Load models (Only classifiers here, LLM is lazy-loaded)
# =======================================================
print("Loading classifier and filter models...")
cls_tokenizer = AutoTokenizer.from_pretrained(classification_model_name, **use_auth)
cls_model = AutoModelForSequenceClassification.from_pretrained(classification_model_name, **use_auth).to(device).eval()

output_filter_tokenizer = AutoTokenizer.from_pretrained(output_filter_repo, subfolder=output_filter_subfolder, **use_auth)
output_filter_model = AutoModelForSequenceClassification.from_pretrained(
    output_filter_repo, subfolder=output_filter_subfolder, **use_auth
).to(device).eval()

multiturn_tokenizer = AutoTokenizer.from_pretrained(output_filter_repo, subfolder=multi_turn_subfolder, **use_auth)
multiturn_model = AutoModelForSequenceClassification.from_pretrained(
    output_filter_repo, subfolder=multi_turn_subfolder, **use_auth
).to(device).eval()
print("Classifier models loaded.")

vicuna_tokenizer = None
vicuna_model = None

In [None]:
# =========================
# Helper functions
# =========================
# word_tokenizer, predict_proba, multiturn_predict_proba, is_output_safe remain the same

#This function takes a string (or a list of strings) and splits it into individual words or tokens based on the language detected.
def word_tokenizer(text):
    if not text or len(text.strip()) == 0:
        return [""]
    if isinstance(text, str):
        if any(0x0900 <= ord(c) <= 0x097F for c in text):
            return indic_tokenize.trivial_tokenize(text, lang="hi")
        elif any(0x4E00 <= ord(c) <= 0x9FFF for c in text):
            tokens = list(jieba.cut(text))
            return tokens or list(text)
        else:
            return re.findall(r'\b\w+\b', text.lower())
    elif isinstance(text, list):
        return [word_tokenizer(t) for t in text]
    else:
        raise ValueError("Invalid input type")

#This function takes text input and uses the multilingual classifier model (cls_model) to predict the probability of the input being "benign" or "jailbreak".
def predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    elif isinstance(texts, (list, tuple, np.ndarray)) and all(isinstance(item, list) for item in texts):
        texts = [" ".join(str(tok) for tok in tokens) for tokens in texts]
    texts = [str(t) for t in texts if t and t.strip() != ""]
    if not texts:
        return np.zeros((1, cls_model.config.num_labels))
    inputs = cls_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        probs = torch.nn.functional.softmax(cls_model(**inputs).logits, dim=-1)
    return probs.cpu().numpy()

#This function takes text input and uses the multi-turn classifier model (multiturn_model) to predict the probability of the input being safe or unsafe.
def multiturn_predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    texts = [str(t) for t in texts if t and t.strip() != ""]
    if not texts:
        return np.zeros((0, multiturn_model.config.num_labels)) # Return shape (0, N)
    inputs = multiturn_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        probs = torch.nn.functional.softmax(multiturn_model(**inputs).logits, dim=-1)
    return probs.cpu().numpy()

#This function checks if the output text is safe using the output filter model (output_filter_model).
def is_output_safe(text):
    texts = [text] if isinstance(text, str) else text
    texts = [t for t in texts if t and t.strip() != ""]
    if not texts:
        return True, "benign", np.array([])
    inputs = output_filter_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        probs = torch.nn.functional.softmax(output_filter_model(**inputs).logits, dim=-1)
    pred = torch.argmax(probs, dim=-1).item()
    label = output_filter_model.config.id2label[pred]
    return pred == 0, label, probs.cpu().numpy()[0] # pred==0 is 'benign'

# This loads the LLM for the chat assistant
def generate_llm_response(prompt):
    global vicuna_model, vicuna_tokenizer
    if vicuna_model is None or vicuna_tokenizer is None:
        print("\nLoading 4-bit Vicuna model with bitsandbytes... (This may take a moment)")
        gr.Info("Loading 4-bit LLM with bitsandbytes... UI might be slow.") # Gradio info
        # 1. Define the 4-bit quantization config
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True
        )
        # 2. Load the -HF model with this config
        vicuna_model = AutoModelForCausalLM.from_pretrained(
            vicuna_model_name,      # The -HF repo
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )
        # 3. Load the correct tokenizer
        vicuna_tokenizer = AutoTokenizer.from_pretrained(
            vicuna_tokenizer_name,
            trust_remote_code=True
        )
        vicuna_model.eval()
        print("✅ Vicuna model loaded successfully.\n")
        gr.Info("LLM loaded successfully.")

    chat_prompt = f"### Human: {prompt.strip()}\n### Assistant:"
    inputs = vicuna_tokenizer(chat_prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = vicuna_model.generate(
            **inputs, max_new_tokens=300, temperature=0.8, top_p=0.95, do_sample=True,
            repetition_penalty=1.15, eos_token_id=vicuna_tokenizer.eos_token_id,
            pad_token_id=vicuna_tokenizer.pad_token_id
        )
    decoded = vicuna_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return decoded.split("### Assistant:")[-1].strip() if "### Assistant:" in decoded else decoded.strip()

# the filters used around the LLM
def input_filter_func(texts):
    probs = predict_proba(texts)[0]
    pred_idx = np.argmax(probs)
    return {"label": "jailbreak" if pred_idx == 1 else "benign", "probs": probs, "pred_idx": pred_idx}

def output_filter_func(texts):
    safe, label, probs = is_output_safe(texts[0])
    return {"label": "benign" if safe else "jailbreak", "probs": probs, "pred_idx": np.argmax(probs)}

# =========================
# SHAP explainers
# =========================
print("Initializing SHAP explainers...")
single_turn_masker = shap.maskers.Text(tokenizer=cls_tokenizer, mask_token="...")
single_turn_explainer = shap.Explainer(predict_proba, masker=single_turn_masker, algorithm="partition")

multiturn_masker = shap.maskers.Text(tokenizer=multiturn_tokenizer, mask_token="...")
multiturn_explainer = shap.Explainer(multiturn_predict_proba, masker=multiturn_masker, algorithm="partition")
print("SHAP explainers ready.")


class ConversationSafeAgent:
    def _init_(self, input_filter, output_filter, single_turn_explainer, multiturn_explainer, model_func, max_context_turns=2):
        self.memory = []  # Internal memory
        self.max_context_turns = max_context_turns  # How many past user/assistant pairs for multi-turn check
        self.input_filter = input_filter
        self.output_filter = output_filter
        self.single_turn_explainer = single_turn_explainer
        self.multiturn_explainer = multiturn_explainer
        self.model_func = model_func

    def clear_memory(self):
        """Clears the agent's internal conversation memory."""
        self.memory = []

    # Returns list of lists for Gradio DataFrame
    def explain_shap_values(self, explainer, text, pred_index=None, top_k=7, original_text=None):
        """
        Computes SHAP values for the given text and returns top_k impactful tokens.
        original_text: the text in the original language (for multi-turn display)
        """
        try:
            shap_vals_list = explainer([text])
            if not shap_vals_list:
                return [["Error", "No SHAP output"]]

            shap_vals = shap_vals_list[0]
            tokens = shap_vals.data
            values = shap_vals.values

            # If multi-class, select the column
            if values.ndim != 1 and pred_index is not None and values.shape[1] > pred_index:
                values = values[:, pred_index]
            elif values.ndim != 1:
                values = values.mean(1)

            # Use original tokens if provided (for multi-turn)
            if original_text is not None:
                try:
                    orig_tokens = word_tokenizer(original_text)
                    # Flatten if nested list
                    if any(isinstance(t, list) for t in orig_tokens):
                        orig_tokens = [t for sub in orig_tokens for t in sub]
                except:
                    orig_tokens = tokens
            else:
                orig_tokens = tokens

            # Pair tokens with SHAP values
            valid_pairs = [(str(t), float(v)) for t, v in zip(orig_tokens, values) if str(t).strip()]
            # Sort by absolute impact descending
            pairs = sorted(valid_pairs, key=lambda x: abs(x[1]), reverse=True)[:top_k]

            return [[token, f"{impact:+.5f}"] for token, impact in pairs]

        except Exception as e:
            print(f"SHAP explanation error: {e}")
            return [["Error", str(e)]]

    # Takes chat_history, returns (response, label, explanation_data)
    def classify_and_generate(self, user_prompt, chat_history):
        if not user_prompt or not user_prompt.strip():
            return None, 'Invalid Input', None

        # 1. Single-turn check
        input_pred = self.input_filter([user_prompt])
        if input_pred['label'] == 'jailbreak':
            exp_data = self.explain_shap_values(
                self.single_turn_explainer, 
                user_prompt, 
                input_pred.get('pred_idx', 1), 
                original_text=user_prompt
            )
            return None, 'Unsafe Input', exp_data

        # 2. Language detection + translation
        try:
            user_lang = detect(user_prompt)
        except Exception:
            user_lang = "en"
        try:
            translated_prompt = GoogleTranslator(source='auto', target='en').translate(user_prompt) if user_lang != "en" else user_prompt
        except Exception as e:
            print(f"Input translation failed: {e}")
            translated_prompt = user_prompt
            user_lang = "en"
        if not translated_prompt:
            translated_prompt = user_prompt

        # 3. Multi-turn check using internal memory
        context_prompts = self.memory[-(self.max_context_turns * 2):] + [translated_prompt]
        full_conversation_text = "\n".join(context_prompts)

        if full_conversation_text.strip():
            multi_turn_probs_list = multiturn_predict_proba(full_conversation_text)
            if multi_turn_probs_list.shape[0] > 0:
                multi_turn_pred_idx = np.argmax(multi_turn_probs_list[0])
                if multi_turn_pred_idx == 1:  # 1 is 'jailbreak'
                    exp_data = self.explain_shap_values(
                        self.multiturn_explainer,
                        full_conversation_text,
                        1,
                        original_text="\n".join(self.memory[-(self.max_context_turns * 2):] + [user_prompt])
                    )
                    return None, 'Unsafe Conversation', exp_data

        # 4. Generate response
        response_en = self.model_func(translated_prompt)
        if not response_en or not response_en.strip():
            print("Warning: LLM generated empty response.")
            final_response = "I couldn't generate a response for that prompt."
            self.memory.extend([translated_prompt, "Failed to generate response."])
            return final_response, 'Safe', None

        # 5. Output filter
        safe, _, _ = is_output_safe(response_en)
        if not safe:
            exp_data = self.explain_shap_values(
                self.single_turn_explainer,
                response_en,
                1,
                original_text=user_prompt
            )
            return None, 'Unsafe Output', exp_data

        # 6. Translate back to original language
        try:
            final_response = GoogleTranslator(source='en', target=user_lang).translate(response_en) if user_lang != "en" else response_en
        except Exception as e:
            print(f"Return translation failed: {e}")
            final_response = response_en
        if not final_response:
            final_response = response_en

        # --- Update agent memory on success ---
        self.memory.extend([translated_prompt, response_en])

        return final_response, 'Safe', None

# =========================
# Initialize agent
# =========================
print("Initializing ConversationSafeAgent...")
agent = ConversationSafeAgent(
    input_filter=input_filter_func,
    output_filter=output_filter_func,
    single_turn_explainer=single_turn_explainer,
    multiturn_explainer=multiturn_explainer,
    model_func=generate_llm_response, # Pass the function
    max_context_turns=2 # Number of user/bot pairs
)
print("Agent is ready.")

# =========================
# Gradio App
# =========================
def handle_chat(user_input, history):
    """
    Gradio chat function.
    Takes user input and current history (for display).
    Calls the agent (which now uses its internal memory).
    Returns updated history for display and explanation data.
    """
    history = history or [] # Ensure history is a list for display

    # Pass Gradio history for context IF NEEDED by agent, but agent uses self.memory now
    # For this version, agent uses self.memory, so we pass an empty list or None if not needed
    response, label, explanation_data = agent.classify_and_generate(user_input, []) # Pass empty list as classify_and_generate uses self.memory now

    explanation_df = pd.DataFrame(columns=["Token", "Impact"]) # Default empty DataFrame
    block_message = None

    if response: # Success
        history.append((user_input, response))
    else: # Blocked
        block_message = f"⚠ Response Blocked: {label}"

        # Clear history AND agent memory 
        unsafe_labels = ['Unsafe Input', 'Unsafe Conversation', 'Unsafe Output']
        if label in unsafe_labels:
            print(f"Block detected ({label}). Clearing Gradio history AND agent memory.") # Log update
            history = [] # Reset the Gradio display history list
            agent.clear_memory() # Call the agent's internal memory clear method
            block_message += " (Chat history and agent memory cleared)" # Update message

        # Append the user message and the (potentially updated) block message AFTER clearing
        history.append((user_input, block_message))

        if explanation_data:
            explanation_df = pd.DataFrame(explanation_data, columns=["Token", "Impact"])

    return history, explanation_df

# =========================
# Gradio UI Layout
# =========================
print("\nBuilding Gradio Interface...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Secure Multilingual AI Assistant")
    with gr.Row():
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(label="Conversation", height=500)
            msg = gr.Textbox(label="Enter prompt", placeholder="Type your message here...", lines=3)
            submit_btn = gr.Button("Send")
        with gr.Column(scale=1):
            gr.Markdown("## Analysis of Unsafe Content")
            gr.Markdown("The following tokens were identified as contributing to a potential jailbreak or policy violation:")
            explanation_df_output = gr.DataFrame(headers=["Token", "Impact"], datatype=["str", "str"], label="")

    # --- Event Handlers ---
    # 1. Handle message submission (Enter key)
    msg.submit(handle_chat, [msg, chatbot], [chatbot, explanation_df_output])
    msg.submit(lambda: "", None, msg) # Clear input textbox after submit

    # 2. Handle button click
    submit_btn.click(handle_chat, [msg, chatbot], [chatbot, explanation_df_output])
    submit_btn.click(lambda: "", None, msg) # Clear input textbox after click

print("Launching Gradio Interface...")
# share=True is needed for Colab public link
demo.launch(debug=True, share=True)