TODOs:
- Generate data as a set of intents
- Run through the pipeline to get a set of (initial) risk prioritizations
- Create LLM-as-a-Judge to reprioritize each risk
- For each risk run the policy distillation -> obtain a policy for when this risk should be reprioritized

In [1]:
from risk_policy_distillation.llms.rits_component import RITSComponent
# from risk_policy_distillation.llms.ollama_component import OllamaComponent
from dotenv import load_dotenv
from langextract.resolver import ResolverParsingError

import json
import requests
import ast
import math
import logging
import datetime

import numpy as np
import os
import pandas as pd
import langextract as lx
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama.llms import OllamaLLM
# load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# # setting up logging
# logger = logging.getLogger('logger')
# logger.setLevel(logging.INFO)
# fh = logging.FileHandler(f'logs/{datetime.datetime.now().strftime("%m_%d__%H_%M")}.log')
# fh.setLevel(logging.INFO)
# logger.addHandler(fh)

In [3]:
class model_type_timeout(lx.inference.OllamaLanguageModel):
    def _ollama_query(self,*args,**kwargs):
        kwargs.setdefault('timeout',120)
        return super()._ollama_query(*args,**kwargs)

  return inference.__getattr__(name)


In [4]:
# Creating a synthetic dataset of intents
# Use an LLM to generate a set of intents as scenarios where AI can be used
def generate_intent_data(llm_component, n_samples=10):
    context = """
                You are a helpful AI assistant that can generate diverse AI usecase scenarios.
              """

    prompt = f"""
                Generate a list of {n_samples} scenarios in which AI could be used to automate or replaces specific tasks. 
                Include information about how this tool will be used or monitored.

                Here are some examples of scenarios:
                    - AI medical chatbot for internal use. Ouptuts monitored by medical professionals. Not used for diagnostics or patient care.
                    - AI-enhanced interior design tool that helps users design their homes. Used by people decorating their homes. Verified for hallucinatory behavior.
                    - Storytelling app for kids where you can choose characters and different plot features and AI generates a story. Guardrails for toxic and harmful ouptut are implemented.

                Return ONLY the list of scenarios formatted as a Python list
                Do not include any additional information or explanations.
            """ 

    intents = llm_component.send_request(context, prompt)
    intents = ast.literal_eval(intents)

    print(len(intents))
    
    return intents

In [5]:
# Extracts structured information about the intent such as AI tasks, domain, stakeholders
def extract_intent_features(input_text):
    # 1. Define the prompt and extraction rules
    prompt = """
        Get the answers only to the following questions. Make sure to provide the mentioned attributes for all questions.
        Classify the confidence attribute as "Direct answer from the input text" or "Likely answer from the  input text" or "Unable to answer from input text". 
    
        1. "What domain does your use request fall under amongst the following options?
        Customer service/support, Technical, Information retrieval, Strategy, Code/software engineering, Communications, IT/business automation, Writing assistant, Financial,
        Talent and Organization including HR, Product, Marketing, Cybersecurity, Healthcare, User Research, Sales, Risk and Compliance, Design, Other", attribute: domain. 
        There must be a direct or likely answer for this question. 
        2. In which environment is the system used?, attribute: environment
        3. Is there private data for this use-case?, attribute: private data
        4. What techniques are utilised in the system? Multi-modal: {Document Question/Answering, Image and text to text, Video and text to text, visual question answering}, 
                                                       Natural language processing: {feature extraction, fill mask, question answering, sentence similarity, summarization, table question answering, text classification, text generation, token classification, translation, zero shot classification},
                                                       computer vision: {image classification, image segmentation, text to image, object detection}, audio:{audio classification, audio to audio, text to speech}, tabular: {tabular classification, tabular regression}, 
                                                       reinforcement learning                     
            attribute: ai_techniques. There must be a direct or likely answer for this question. 
        5. Who is the subject as per the intent?
        """
    
    # 2. Provide a high-quality example to guide the model
    examples = [
        lx.data.ExampleData(
            text="Generate personalized, relevant responses, recommendations, and summaries of claims for customers to support agents to enhance their interactions with customers. LLM has been robustly tested and shows high accuracy with no bias or fairness issues. There are guardrais for privacy and security. The main subjects are customers/claimants (whose data and preferences the LLM uses), support agents (who receive the personalized content), and the organization’s stakeholders (compliance, security, and tech teams) who oversee the system’s integrity",
            extractions=[
                lx.data.Extraction(
                    extraction_class="What techniques are utilised in the system from the given options? Multi-modal: {Document Question/Answering, Image and text to text, Video and text to text, visual question answering}, Natural language processing: {feature extraction, fill mask, question answering, sentence similarity, summarization, table question answering, text classification, text generation, token classification, translation, zero shot classification}, computer vision: {image classification, image segmentation, text to image, object detection}, audio:{audio classification, audio to audio, text to speech}, tabular: {tabular classification, tabular regression}, reinforcement learning",
                    extraction_text="1. Generate personalized, relevant responses 2. summaries of claims ",
                    attributes={"confidence": "Likely answer from the  input text", "ai_techniques": "Natural language processing: {Text classification, sentence similarity, question answering, summarization, text generation} "}
                ),
                lx.data.Extraction(
                    extraction_class="What domain does your use request fall under? Customer service/support, Technical, Information retrieval, Strategy, Code/software engineering, Communications, IT/business automation, Writing assistant, Financial, Talent and Organization including HR, Product, Marketing, Cybersecurity, Healthcare, User Research, Sales, Risk and Compliance, Design, Other",
                    extraction_text="summaries of claims for customers to support agents to enhance their interactions with customers",
                    attributes={"confidence": "Directly from the  input text", "domain": "Customer service/support"}
                ),
                lx.data.Extraction(
                    extraction_class="In which environment is the system used?",
                    extraction_text=" ",
                    attributes={"confidence": "Likely answer from the  input text", "environment": "digital environment, specifically designed for online interactions between customers and support agents"}
                ),
                lx.data.Extraction(
                    extraction_class="Is there private data for this use-case?",
                    extraction_text=" ",
                    attributes={"confidence": "Unable to answer from  input text", "private_data": "Unable to answer from the text" }
                ),
                lx.data.Extraction(
                    extraction_class="Who is the subject as per the intent?",
                    extraction_text="The main subjects are customers/claimants (whose data and preferences the LLM uses), support agents (who receive the personalized content), and the organization’s stakeholders (compliance, security, and tech teams) who oversee the system’s integrity",
                    attributes={"confidence": "Likely answer from the  input text", "subject_data": "customers/claimants, support agents, compliance, security, and tech teams" }
                )
            ]
        )
    ]

    
    template = """Use the following pieces of context to answer the question at the end. Quantify the statement as most likely if it is not in the context. If no answer is possible from the context, state that as it is not possible to give that from the context. Succinct answer and appended to the context \n\n{context}\n\nQuestion: {question}\nHelpful Answer:"""
    prompt = ChatPromptTemplate.from_template(template)
    
    model = OllamaLLM(model="llama3:8b")
    chain = prompt | model

    techniques = chain.invoke({"question":"What techniques are utilised in the system from the given options? Multi-modal: {Document Question/Answering, Image and text to text, Video and text to text, visual question answering}, Natural language processing: {feature extraction, fill mask, question answering, sentence similarity, summarization, table question answering, text classification, text generation, token classification, translation, zero shot classification}, computer vision: {image classification, image segmentation, text to image, object detection}, audio:{audio classification, audio to audio, text to speech}, tabular: {tabular classification, tabular regression}, reinforcement learning","context": input_text})
    input_text = input_text + techniques
    n_tries = 0
    while n_tries < 10:
        try:
            result = lx.extract(
                        text_or_documents=input_text,
                        prompt_description=prompt,
                        examples=examples,
                        language_model_type=model_type_timeout, #lx.inference.OllamaLanguageModel,
                        model_id="llama3:8b",  # or any Ollama model
                        model_url="http://localhost:11435",
                        fence_output=False,
                        use_schema_constraints=False)

            return result
        except ResolverParsingError:
            n_tries += 1
                    
    return None

In [6]:
def get_ai_task_names():
    with open("risk_matrix/ai_tasks_risk_matrix.json") as f:
        tasks = json.load(f)
    
    ai_tasks = []
    
    for task in tasks["tasks"]:
        ai_tasks.append(task["name"])
    
    return ai_tasks

In [7]:
def map_extraction_to_ai_risks(intent_ai_tasks, ai_tasks):
    # Extracts AI tasks from the intent text
    template = """You are a ai task analysis assistant. Given the following freeform text, extract which ai tasks are PRESENT from this list: {ai_tasks} - If an ai task is explicitly mentioned as NOT present (e.g., 'no regression'), exclude it. - If close enough words are used (e.g., 'doc qnA' → 'Document Question Answering'), map them to the correct ai task. - Only return the ai tasks that are clearly PRESENT. Answer only with the ai tasks in list format. DO NOT INCLUDE any preamble or explanations \n\n{intent_ai_tasks}\n:"""
    
    prompt = ChatPromptTemplate.from_template(template)
    
    model = OllamaLLM(model="llama3:8b")
    
    chain = prompt | model
    
    intent_ai_techniques = chain.invoke({"ai_tasks": ai_tasks, "intent_ai_tasks": intent_ai_tasks})

    intent_ai_techniques = ast.literal_eval(intent_ai_techniques)

    ai_task_indices = [ai_tasks.index(x.strip())  for x in intent_ai_techniques if x.strip() in ai_tasks]
    ai_task_indices

    return ai_task_indices

In [8]:
def consolidate_risks(dict_list):
    priority = {"Low": 1, "Medium": 2, "High": 3, "Critical": 4}
    
    result = {key: "Low" for key in dict_list[0].keys()}
    
    for d in dict_list:
        for key, value in d.items():
            if priority[value] > priority[result[key]]:
                result[key] = value
    return result

In [9]:
def get_baseline_risks(ai_task_indices):
    # Loads baseline risks from a json 
    with open("risk_matrix/ai_tasks_risk_matrix.json") as f:
        tasks = json.load(f)
    
    all_baseline_risks = [tasks["tasks"][index]["baseline_risks"] for index in ai_task_indices]
    return all_baseline_risks

In [10]:
def generate_dataset(n_samples=100):
    'Generates an intent-risk-baseline priority dataset we can use to reprioritize risks'
    try:
        df = pd.read_csv('datasets/intents.csv', header=0)
    except FileNotFoundError:
        data = []
        
        llm_component = RITSComponent('llama-3-3-70b-instruct', 'meta-llama/llama-3-3-70b-instruct')    
        intent_data = generate_intent_data(llm_component, n_samples=n_samples)
        ai_tasks_names = get_ai_task_names()
    
        for intent_text in intent_data:
            # extract features of the intent like stakeholders, AI tasks, domains etc
            extraction = extract_intent_features(intent_text).extractions[0]

            if extraction.attributes is None or "ai_techniques" not in extraction.attributes.keys():
                continue
    
            # extract AI tasks present in the scenario
            ai_task_indices = map_extraction_to_ai_risks(extraction.attributes["ai_techniques"], ai_tasks_names)
            baseline_risks = get_baseline_risks(ai_task_indices)
    
            final_risk_matrix = consolidate_risks(baseline_risks)
    
            for risk, priority in final_risk_matrix.items():
                data.append((intent_text, risk, priority))

        df = pd.DataFrame(data, columns=['intent', 'risk', 'priority'])
        df.to_csv('datasets/intents.csv', index=False)

    return df

In [11]:
# def generate_dataset(n_samples=100):
#     'Generates an intent-risk-baseline priority dataset we can use to reprioritize risks'
#     try:
#         df = pd.read_csv('datasets/intents.csv', header=0)
#     except FileNotFoundError:
#         data = []
        
#         llm_component = OllamaComponent('llama-3-3-70b-instruct', 'meta-llama/llama-3-3-70b-instruct')    
#         intent_data = generate_intent_data(llm_component, n_samples=n_samples)
#         ai_tasks_names = get_ai_task_names()
    
#         for intent_text in intent_data:
#             # extract features of the intent like stakeholders, AI tasks, domains etc
#             extraction = extract_intent_features(intent_text).extractions[0]

#             if extraction.attributes is None or "ai_techniques" not in extraction.attributes.keys():
#                 continue
    
#             # extract AI tasks present in the scenario
#             ai_task_indices = map_extraction_to_ai_risks(extraction.attributes["ai_techniques"], ai_tasks_names)
#             baseline_risks = get_baseline_risks(ai_task_indices)
    
#             final_risk_matrix = consolidate_risks(baseline_risks)
    
#             for risk, priority in final_risk_matrix.items():
#                 data.append((intent_text, risk, priority))

#         df = pd.DataFrame(data, columns=['intent', 'risk', 'priority'])
#         df.to_csv('datasets/intents.csv', index=False)

#     return df

In [None]:
from risk_policy_distillation.models.guardians.judge import Judge
from risk_policy_distillation.utils.rits_util import post_rits_req

class RiskReprioritizationJudge(Judge):
    ''' Simple LLM-as-a-Judge model that decides whether a risk should be reprioritized'''
    def __init__(self, model_name, model_served_name, config):
        super().__init__(config)
        self.model_name = model_name
        self.model_served_name = model_served_name
        self.context = """
                        Given a task, AI risk, and a baseline estimated impact of this risk on the task decide if the impact should be changed.
                        Answer with one of the possible risk impacts: Low, Medium, High or Critical. 

                        Here are some examples:
                        
                        AI Task: Generate personalized, relevant responses, recommendations, and summaries of claims for customers to support agents to enhance their interactions with customers. 
                                 LLM has been thoroughly tested and shows high accuracy with no bias or fairness issues. 
                                 There are guardrais for privacy and security. 
                                 The main subjects are customers/claimants (whose data and preferences the LLM uses), support agents (who receive the personalized content), 
                                 and the organization’s stakeholders (compliance, security, and tech teams) who oversee the system’s integrity. 
                        Risk: Fairness
                        Risk impact: Medium
                        New impact: Low
                        Supporting text: high accuracy with no bias or fairness issues

                        Answer ONLY with one word indicating new estimated impact for this risk (i.e. "Low"/"Medium"/"High"/"Critical").
                        Do not include any other explanation or information.
                        """

        self.rits_url = f'https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{self.model_name}/v1/chat/completions'
        
    def ask_guardian(self, message):
        prompt, risk, priority = message
        message = f''' Message:
                        AI task: {prompt}
                        Risk: {risk}
                        Risk impact: {priority}'''

    
        i = 0
        while i < 100:
            response = requests.post(
                self.rits_url,
                headers={
                    "RITS_API_KEY": os.getenv('RITS_API_KEY'),
                    "Content-Type": "application/json"},
                json={
                    "model": self.model_served_name,
                    "messages": [{"role": "system", "content": self.context, "name":"test"}, {"role": "user", "content": message, "name":"test"}],
                    "temperature": 0.7, 
                    "logprobs": True, 
                    "top_logprobs": 10
                }
            )
            if response.status_code == 200:
                break
            else:
                i += 1
                
        answer = response.json()['choices'][0]['message']['content']
        
        return answer

    def predict_proba(self, inputs):
        probabilities = []

        for input_text in inputs:
            i = 0
            while i < 100:
                response = requests.post(
                    self.rits_url,
                    headers={
                        "RITS_API_KEY": os.getenv('RITS_API_KEY'),
                        "Content-Type": "application/json"},
                    json={
                        "model": self.model_served_name,
                        "messages": [{"role": "system", "content": self.context}, {"role": "user", "content": input_text}],
                        "temperature": 0.7, 
                        "logprobs": True, 
                        "top_logprobs": 10
                    }
                )
                if response.status_code == 200:
                    break
                else:
                    i += 1
                                
            answer = response.json()['choices'][0]['message']['content']
            
            log_probs = response.json()['choices'][0]['logprobs']['content'][0]['top_logprobs']

            answer = response.json()['choices'][0]
    
            tokens = [i['token'] for i in log_probs]
            token_probs = [i['logprob'] for i in log_probs]
            
            probs = [math.e ** token_probs[tokens.index(label)] if label in tokens else 0 for label in self.label_names ]
            probabilities.append(probs)

        return np.array(probabilities)

In [13]:
import os
# print(os.getenv('RITS_API_KEY'))

In [14]:
judge_config = {
                "task": "risk reprioritization",
                 "criterion": "risk impact reevaluation",
                 "criterion_definition": "Given a task, AI risk, and a baseline estimated impact of this risk on the task decide if the impact should be changed.",
                 "labels": [0, 1, 2, 3],
                 "label_names": ["Low", "Medium", "High", "Critical"],
                 "output_labels": ["Low", "Medium", "High", "Critical"]
              }

judge = RiskReprioritizationJudge('llama-3-3-70b-instruct', 'meta-llama/llama-3-3-70b-instruct', judge_config)

In [15]:
from risk_policy_distillation.datasets.prompt_dataset import PromptDataset

data_config = {
            "general": {
              "location": "",
              "dataset_name": "intent_dataset"
            },
            "data": {
              "type": "prompt",
              "index_col": "Index",
              "prompt_col": "intent",
              "label_col": "priority",
              "flip_labels": False
            },
            "split": {
              "split": False,
              "subset": "test",
              "sample_ratio": 1
            }
          }

In [16]:
from risk_policy_distillation.datasets.abs_dataset import AbstractDataset

class IntentDataset(AbstractDataset):

    def __init__(self, config, dataframe=None):
        super().__init__(config, dataframe)

        self.expl_input = self.prompt_col
        self.message_labels = [self.prompt_col, 'risk', 'priority']

    def extract_message(self, row):
        prompt = row[self.prompt_col]
        risk = row['risk']
        baseline_priority = row['priority']
        local_expl_input = prompt

        # get true label
        true_label = row[self.label_col]

        return (prompt, risk, baseline_priority), local_expl_input, true_label

    def build_message_format(self, message, decision):
        prompt, risk, _ = message
        return f''' Message:
                        AI Task: {prompt}
                        Risk: {risk}
                        Risk impact: {decision}'''

In [17]:
# Load intent dataset
df = pd.read_csv('datasets/intents.csv', header=0)
dataset = IntentDataset(data_config, df[0:10])

In [18]:
from risk_policy_distillation.models.explainers.local_explainers.lime import LIME
from risk_policy_distillation.pipeline.pipeline import Pipeline
from risk_policy_distillation.pipeline.clusterer import Clusterer
from risk_policy_distillation.pipeline.concept_extractor import Extractor

llm_component = RITSComponent('llama-3-3-70b-instruct', 'meta-llama/llama-3-3-70b-instruct')
local_explainer = LIME('intent', judge_config['label_names'], n_samples=500, n_words=6)

pipeline = Pipeline(extractor=Extractor(judge, llm_component, judge_config['criterion'],
                                        judge_config['criterion_definition'], local_explainer),
                    clusterer=Clusterer(llm_component, judge_config['criterion_definition'],
                                        judge_config['label_names'], n_iter=20),
                    lime=True,
                    fr=True,
                    verbose=1)

# evaluate the global explanation
path = f'results/intent_dataset'


INFO 10-01 11:00:41 [__init__.py:216] Automatically detected platform cpu.


In [19]:
# from risk_policy_distillation.models.explainers.local_explainers.lime import LIME
# from risk_policy_distillation.pipeline.pipeline import Pipeline
# from risk_policy_distillation.pipeline.clusterer import Clusterer
# from risk_policy_distillation.pipeline.concept_extractor import Extractor

# llm_component = OllamaComponent('llama3.2:3b-instruct-fp16', 'llama3.2:latest')
# local_explainer = LIME('intent', judge_config['label_names'], n_samples=500, n_words=6)

# pipeline = Pipeline(extractor=Extractor(judge, llm_component, judge_config['criterion'],
#                                         judge_config['criterion_definition'], local_explainer),
#                     clusterer=Clusterer(llm_component, judge_config['criterion_definition'],
#                                         judge_config['label_names'], n_iter=20),
#                     lime=True,
#                     fr=True,
#                     verbose=1)

# # evaluate the global explanation
# path = f'/Users/seshu/Documents/2025/policy-distillation-ran-extension/examples/notebooks/results/intent_dataset'


In [21]:
expl = pipeline.run(dataset, path=path)

Path:  results/intent_dataset
results/intent_dataset/intent_dataset/local/expl.csv results/intent_dataset/intent_dataset/global/global_expl.pkl
results/intent_dataset/intent_dataset/local/expl.csv


0it [00:00, ?it/s]'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 56ed98e0-b6af-43ba-978f-bcc1c0bead85)')' thrown while requesting HEAD https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/./modules.json
Retrying in 1s [Retry 1/5].
0it [07:50, ?it/s]


AttributeError: 'IntentDataset' object has no attribute 'response_col'

In [None]:
for i, r in enumerate(expl.rules):
    print(f'{expl.predictions[i]} IF {r} DESPITE {expl.despites[i]}')