# Generative AI Unlearning: Korean Legal Knowledge Editing

**Modification and Deletion of Knowledge in Korean Legal Domain | KT-Korea University Joint Research**


### Create a conda environment

```bash
conda create -n unlearning python==3.10.0 -y
conda activate unlearning
pip install -r requirements.txt
```


### Environment Configuration

Make sure to have an `.env` file in the project root directory containing your OpenAI API key:

```plaintext
OPENAI_API_KEY=your_api_key_here
```


## Setup


In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import json
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from typing import List, Dict
import time
import openai
from openai.error import APIError
import tqdm
from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [3]:
input_dir: str = "./법령지식"
output_dir: str = "./results"
final_dir: str = "./final_results"

In [4]:
model_id: str = "meta-llama/Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

# Initialize model with automatic device mapping
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="auto"
)

# Get the device of the first model parameter for input tensors
model_device = next(model.parameters()).device

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.35s/it]


## 1. Dataset Extraction


In [5]:
def extract_data(input_file: str) -> dict:
    """Extracts label and full_text pairs from JSON files in the specified directory."""
    result = {}
    if input_file.endswith(".json"):
        with open(input_file, "r") as f:
            data = json.load(f)
            count = 0
            for i, key in enumerate(data):
                try:
                    # Skip entries with 'comment'
                    if 'comment' in data[key]:
                        continue
                        
                    label = data[key]["label"]
                    full_text = data[key]["fullText"]
                    if len(full_text) < 3 * len(label) and len(full_text) < 50:
                        continue
                    result[label] = full_text
                    count += 1
                except Exception as e:
                    continue
            print(f"{count} out of {len(data)} were successfully extracted from {input_file}")
    else:
        raise ValueError("Invalid input directory.")
    return result

## 2. Prompt Creation


The main two types of queries we can ask are

1. Give the law name and ask for explanation
2. Give the explanation and ask for law name

For each type, we can try giving system prompts in different languages and different complexity:

- Simple Korean
- Detailed Korean
- Simple English
- Detailed English

Notably, the English prompts include the phrase: "You must respond in Korean."

This will result in 8 total system prompts


In [6]:
# Create a dictionary of system prompts
system_prompts = {
    "type1": {
        "simple": {
            "korean": ["다음 법령의 조항을 말해주세요."],
            "english": ["Please state the provisions of the following law. You must respond in Korean."]
        },
        "detailed": {
            "korean": ["다음은 대한민국의 법령입니다. 법령의 조항을 말해주세요."],
            "english": ["The following is a law of the Republic of Korea. Please state the provisions of the law. You must respond in Korean."]
        }
    },
    "type2": {
        "simple": {
            "korean": ["다음 법령 조항을 읽고 법률의 이름을 알려주세요."],
            "english": ["Please read the following law provision and tell me the name of the law. You must respond in Korean."]
        },
        "detailed": {
            "korean": ["다음은 대한민국의 법령입니다. 법령 조항을 읽고 법률의 이름을 알려주세요."],
            "english": ["The following is a law of the Republic of Korea. Please read the law provision and tell me the name of the law. You must respond in Korean."]
        }
    }
}

We will also give a single shot


In [7]:
shot = {
    "name": "119긴급신고법 제 18조의 제1항",
    "provision": "① 소방청장은 「전파법」 제9조제1항제1호에 따라 소방업무용으로 할당된 무선통신 주파수를 효율적으로 운영하여야 한다. ② 제1항에 따른 소방업무용 주파수의 운영에 필요한 사항은 행정안전부령으로 정한다."
}

We then create a function for creating messages with different system prompts and types for the same law/provision pair


In [8]:
def create_messages(system_prompt: dict, shot: dict, label: str, full_text: str) -> dict:
    messages_dict = {}
    
    create_type1 = lambda x: [
        {"role": "system", "content": x[0]},
        {"role": "user", "content": shot["name"]},
        {"role": "assistant", "content": shot["provision"]},
        {"role": "user", "content": label}
    ]
    
    create_type2 = lambda x: [
        {"role": "system", "content": x[0]},
        {"role": "user", "content": shot["provision"]},
        {"role": "assistant", "content": shot["name"]},
        {"role": "user", "content": full_text}
    ]
    
    # Create message variations
    for type_key in system_prompt:
        messages_dict[type_key] = {}
        for complexity in system_prompt[type_key]:
            messages_dict[type_key][complexity] = {}
            for lang in system_prompt[type_key][complexity]:
                messages_dict[type_key][complexity][lang] = {}
                creator = create_type1 if type_key == "type1" else create_type2
                messages_dict[type_key][complexity][lang] = creator(
                    system_prompt[type_key][complexity][lang]
                )
    
    return messages_dict

Let's test create_message() with a sample law/provision pair


In [9]:
sample_name = "자동차손해배상 보장법 제45조의2 제1항"
sample_provision = "제45조의2 (정보의 제공 및 관리)  ① 제45조제3항에 따라 업무를 위탁받은 보험요율산출기관은 같은 조 제1항에 따라 업무를 위탁받은 자의 요청이 있는 경우 제공할 정보의 내용 등 대통령령으로 정하는 범위에서 가입관리전산망에서 관리되는 정보를 제공할 수 있다."

In [10]:
sample_messages_dict = create_messages(system_prompts, shot, sample_name, sample_provision)

In [11]:
sample_messages_dict["type1"]["simple"]["english"]

[{'role': 'system',
  'content': 'Please state the provisions of the following law. You must respond in Korean.'},
 {'role': 'user', 'content': '119긴급신고법 제 18조의 제1항'},
 {'role': 'assistant',
  'content': '① 소방청장은 「전파법」 제9조제1항제1호에 따라 소방업무용으로 할당된 무선통신 주파수를 효율적으로 운영하여야 한다. ② 제1항에 따른 소방업무용 주파수의 운영에 필요한 사항은 행정안전부령으로 정한다.'},
 {'role': 'user', 'content': '자동차손해배상 보장법 제45조의2 제1항'}]

In [12]:
sample_messages_dict["type2"]["detailed"]["korean"]

[{'role': 'system', 'content': '다음은 대한민국의 법령입니다. 법령 조항을 읽고 법률의 이름을 알려주세요.'},
 {'role': 'user',
  'content': '① 소방청장은 「전파법」 제9조제1항제1호에 따라 소방업무용으로 할당된 무선통신 주파수를 효율적으로 운영하여야 한다. ② 제1항에 따른 소방업무용 주파수의 운영에 필요한 사항은 행정안전부령으로 정한다.'},
 {'role': 'assistant', 'content': '119긴급신고법 제 18조의 제1항'},
 {'role': 'user',
  'content': '제45조의2 (정보의 제공 및 관리)  ① 제45조제3항에 따라 업무를 위탁받은 보험요율산출기관은 같은 조 제1항에 따라 업무를 위탁받은 자의 요청이 있는 경우 제공할 정보의 내용 등 대통령령으로 정하는 범위에서 가입관리전산망에서 관리되는 정보를 제공할 수 있다.'}]

## 3.Inference


In [13]:
def generate_response(messages):
    """Generate a response using the model."""

    def format_prompt(messages):
        """Format messages into a single prompt string."""
        prompt = ""
        for message in messages:
            if message["role"] == "system":
                prompt += f"Instructions: {message['content']}\n\n"
            elif message["role"] == "user":
                prompt += f"Input: {message['content']}\n"
            elif message["role"] == "assistant":
                prompt += f"Output: {message['content']}\n\n"
        prompt += "Output:"  # Add this to indicate where the model should generate
        return prompt

    # Format the input messages into a single string
    prompt = format_prompt(messages)
    
    # Tokenize the input
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(model_device)

    # Generate output with refined parameters
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id
    )

    # Extract only the generated response
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Remove the input part from the response
    generated_response = full_response[len(prompt):].strip()
    
    if '\n\nInput:' in generated_response:
        generated_response = generated_response.split('\n\nInput:')[0]
    
    return prompt, generated_response


Let's try generating the response for the first items of the 층간소음법령.json


In [14]:
sample_data = extract_data(f"{input_dir}/층간소음법령.json")

for label, full_text in list(sample_data.items())[:1]:
    messages_dict = create_messages(system_prompts, shot, label, full_text)
    
    for type_key in messages_dict:
        for complexity_key in messages_dict[type_key]:
            for lang_key in messages_dict[type_key][complexity_key]:
                messages = messages_dict[type_key][complexity_key][lang_key]
                
                print(f"\nType: {type_key}, Complexity: {complexity_key}, Language: {lang_key}")
                response = generate_response(messages)
                # print(f"messages: {messages}")
                print(f"Generated Response: {response}")

## 4. Evaluation


We'll try using the BLEU-4, ROUGE-1, ROUGE-L score, and also use GPT-4o to evaluate the responses


In [15]:
def calculate_bleu(reference, hypothesis):
    chencherry = SmoothingFunction()
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()

    bleu1 = sentence_bleu([ref_tokens], hyp_tokens, weights=(1, 0, 0, 0), smoothing_function=chencherry.method1)
    bleu4 = sentence_bleu([ref_tokens], hyp_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=chencherry.method1)
    return bleu1, bleu4

In [16]:
def calculate_rouge(reference, hypothesis):
    """Calculates ROUGE-1 (unigram overlap) and ROUGE-L (longest common subsequence)."""

    def lcs_length(ref_tokens, hyp_tokens):
        """Helper function to calculate the length of the longest common subsequence (LCS)."""
        ref_len = len(ref_tokens)
        hyp_len = len(hyp_tokens)
        
        # Create a 2D table to store lengths of longest common subsequence
        lcs_table = [[0] * (hyp_len + 1) for _ in range(ref_len + 1)]
        
        for i in range(1, ref_len + 1):
            for j in range(1, hyp_len + 1):
                if ref_tokens[i - 1] == hyp_tokens[j - 1]:
                    lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1
                else:
                    lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1])
        
        return lcs_table[ref_len][hyp_len]
    
    # Tokenize the reference and hypothesis
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()

    # --- ROUGE-1 ---
    # Calculate precision and recall for unigrams
    precision_1 = len(set(ref_tokens) & set(hyp_tokens)) / len(hyp_tokens) if len(hyp_tokens) > 0 else 0
    recall_1 = len(set(ref_tokens) & set(hyp_tokens)) / len(ref_tokens) if len(ref_tokens) > 0 else 0

    # Calculate F1 score for ROUGE-1
    f1_1 = 2 * (precision_1 * recall_1) / (precision_1 + recall_1) if (precision_1 + recall_1) > 0 else 0

    # --- ROUGE-L ---
    # Calculate LCS length
    lcs_len = lcs_length(ref_tokens, hyp_tokens)

    # Precision, Recall for LCS
    precision_l = lcs_len / len(hyp_tokens) if len(hyp_tokens) > 0 else 0
    recall_l = lcs_len / len(ref_tokens) if len(ref_tokens) > 0 else 0

    # Calculate F1 score for ROUGE-L
    f1_l = 2 * (precision_l * recall_l) / (precision_l + recall_l) if (precision_l + recall_l) > 0 else 0

    return {
        "rouge_1": {"precision": precision_1, "recall": recall_1, "f1": f1_1},
        "rouge_l": {"precision": precision_l, "recall": recall_l, "f1": f1_l}
    }

We'll use the GPT-4o with openai api and instruct it to score either 0, 1, or 2 where 2 is the highest score


In [17]:
openai.api_key = os.getenv("OPENAI_API_KEY")

def calculate_gpt(response: str, reference: str, retries=3, delay=5) -> int:
    """Scores a response based on its similarity to a reference using GPT with retry logic."""
    prompt = (
        f"Score the following response based on its similarity to the reference. "
        f"The score should be either 0, 1, or 2 where 2 is the highest score.\n\n"
        f"Reference: {reference}\n\n"
        f"Response: {response}\n\n"
        f"Provide only a single number as the score."
    )

    for attempt in range(retries):
        try:
            # Call OpenAI API with refined prompt
            api_response = openai.ChatCompletion.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": "You are an evaluator that scores responses based on their similarity to a reference."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=1,
                temperature=0,
            )

            # Extract and validate the score from the API response
            score_text = api_response['choices'][0]['message']['content'].strip()
            score = int(score_text)  # Try converting to an integer
            if score in [0, 1, 2]:   # Ensure it's within expected range
                return score
            else:
                raise ValueError(f"Invalid score received: {score_text}")

        except APIError as e:
            if e.http_status == 500:
                print(f"Server error (500), retrying in {delay} seconds... (Attempt {attempt + 1}/{retries})")
                time.sleep(delay)
            else:
                print(f"API Error: {e}")
                break
        except (ValueError, IndexError) as e:
            print(f"Error parsing GPT score: {e}")
            return -1  # Return a default or error value in case of failure

    print("Max retries reached. Returning default score -1.")
    return -1

Let's try getting the scores for the first item of the sample_data


In [18]:
sample_data = extract_data(f"{input_dir}/층간소음법령.json")

for label, full_text in list(sample_data.items())[:1]:
    messages_dict = create_messages(system_prompts, shot, label, full_text)
    
    for type_key in messages_dict:
        for complexity_key in messages_dict[type_key]:
            for lang_key in messages_dict[type_key][complexity_key]:
                messages = messages_dict[type_key][complexity_key][lang_key]
                
                print(f"\nType: {type_key}, Complexity: {complexity_key}, Language: {lang_key}")
                prompt, response = generate_response(messages)
                print(f"Prompt: {prompt}")
                print(f"Generated Response: {response}")

                # Calculate BLEU score
                bleu_score = calculate_bleu(label, response)
                print(f"BLEU-1 score: {bleu_score[0]:.4f}")
                print(f"BLEU-4 score: {bleu_score[1]:.4f}")

                # Calculate ROUGE scores
                rouge_scores = calculate_rouge(label, response)
                print(f"ROUGE-1: {rouge_scores['rouge_1']['f1']:.4f}")
                print(f"ROUGE-L: {rouge_scores['rouge_l']['f1']:.4f}")

                # Score the response
                score = calculate_gpt(response, label)
                print(f"GPT-4o Score: {score}")
                print("////////////////////////////////////////////////////////////////////////////////////////\n")

Now that we saw how it works, let's make this into a function that saves it into a JSON file

In [19]:
def eval_response(dataset: dict, output_file: str):
    """Generates responses for a dataset and saves them to a JSON file."""
    responses = []

    # Initialize progress bar
    bar = tqdm.tqdm(total=len(dataset)*8, desc="Generating Responses", unit="entry")

    for id, (label, full_text) in enumerate(dataset.items()):
        messages_dict = create_messages(system_prompts, shot, label, full_text)
        
        for type_key in messages_dict:
            for complexity_key in messages_dict[type_key]:
                for lang_key in messages_dict[type_key][complexity_key]:
                    # Get the message list directly
                    message_list = messages_dict[type_key][complexity_key][lang_key]
                    
                    # Generate response using the improved method
                    prompt, response = generate_response(message_list)
                    
                    # Calculate BLEU and ROUGE scores
                    bleu_score = calculate_bleu(label, response)
                    rouge_scores = calculate_rouge(label, response)
                    
                    # Use GPT-based scoring with error handling
                    gpt_score = calculate_gpt(response, label)

                    responses.append({
                        "id": id,
                        "label": label,
                        "full_text": full_text,
                        "type": type_key,
                        "complexity": complexity_key,
                        "language": lang_key,
                        "prompt": prompt,
                        "response": response,
                        "bleu_1": bleu_score[0],
                        "bleu_4": bleu_score[1],
                        "rouge_1": rouge_scores["rouge_1"]["f1"],
                        "rouge_l": rouge_scores["rouge_l"]["f1"],
                        "gpt_score": gpt_score
                    })

                    # Save responses to JSON file every time a response is generated to prevent data loss
                    with open(output_file, "w") as f:
                        json.dump(responses, f, indent=2, ensure_ascii=False)
                    # Update progress bar
                    bar.update(1)

This may take a while...


In [20]:
for file in os.listdir(input_dir):
    if file.endswith(".json"):
        data = extract_data(os.path.join(input_dir, file))
        output_file = os.path.join(output_dir, file)
        eval_response(data, output_file)
        print(f"Responses saved to {output_file}")

We can then average the scores for the entries with the same id


In [4]:
def avg_scores(responses: list, output_file: str):
    """Combines scores for entries with the same label and calculates average scores."""
    # Initialize list to store results
    result = []
    temp_scores = {}
    
    # Group responses by label
    for response in responses:
        try:
            label = response["label"]
            if label not in temp_scores:
                temp_scores[label] = {
                    "count": 0,
                    "label": label,
                    "full_text": response["full_text"],
                    "avg_bleu_1": 0,
                    "avg_bleu_4": 0,
                    "avg_rouge_1": 0,
                    "avg_rouge_l": 0,
                    "avg_gpt_score": 0
                }
            
            # Add scores
            temp_scores[label]["count"] += 1
            temp_scores[label]["avg_bleu_1"] += response["bleu_1"]
            temp_scores[label]["avg_bleu_4"] += response["bleu_4"]
            temp_scores[label]["avg_rouge_1"] += response["rouge_1"]
            temp_scores[label]["avg_rouge_l"] += response["rouge_l"]
            temp_scores[label]["avg_gpt_score"] += response["gpt_score"]
            
        except KeyError as e:
            print(f"Missing key in response: {e}")
            continue
    
    # Calculate averages and format output
    for scores in temp_scores.values():
        count = scores["count"]
        if count > 0:
            entry = {
                "label": scores["label"],
                "full_text": scores["full_text"],
                "avg_bleu_1": scores["avg_bleu_1"] / count,
                "avg_bleu_4": scores["avg_bleu_4"] / count,
                "avg_rouge_1": scores["avg_rouge_1"] / count,
                "avg_rouge_l": scores["avg_rouge_l"] / count,
                "avg_gpt_score": scores["avg_gpt_score"] / count
            }
            result.append(entry)
    
    # Save as JSON array
    with open(output_file, "w", encoding='utf-8') as f:
        json.dump(result, f, indent=2, ensure_ascii=False)
    
    return result

In [5]:
for file in os.listdir(output_dir):
    if file.endswith(".json"):
        input_path = os.path.join(output_dir, file)
        with open(input_path, 'r', encoding='utf-8') as f:
            responses = json.load(f)
        output_file = os.path.join(output_dir, f"avg_{file}")
        avg_scores(responses, output_file)

## 5. Ranking Results


Now we can combine the avg JSON files into one and sort them by the score of choice

In [12]:
def merge_and_sort_scores(input_file, output_file, metric):
    """
    Combines multiple JSON files from input directory, sorts them by metric, and saves to output file
    
    Args:
        input_file (str): Path to the input JSON file
        output_file (str): Path to save the sorted combined JSON
        metric (str): Metric to sort by (e.g. 'avg_bleu_4')
    """
    # List to store combined data from all files
    combined_data = []

    # Read each JSON file
    with open(input_file, 'r', encoding='utf-8') as f:
        try:
            data = json.load(f)
            # Handle both single dict and list of dicts
            if isinstance(data, dict):
                combined_data.append(data)
            elif isinstance(data, list):
                combined_data.extend(data)
        except json.JSONDecodeError:
            print(f"Error reading {input_file} - invalid JSON")
    
    # Sort combined data based on metric
    sorted_data = sorted(combined_data, key=lambda x: x[metric])
    
    # Save sorted data to output file
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(sorted_data, f, ensure_ascii=False, indent=4)
        
    return sorted_data

Let's sort the files using all metrics!


In [14]:
metrics = ["bleu_1", "bleu_4", "rouge_1", "rouge_l", "gpt_score"]
for file in os.listdir(output_dir):
    if file.startswith("avg_") and file.endswith(".json"):
        input_file = os.path.join(output_dir, file)
        for metric in metrics:
            output_file = os.path.join(final_dir, f"avg_{metric}.json")
            merge_and_sort_scores(input_file, output_file, metric=f"avg_{metric}")

Let's see the top five entries for each metric


In [16]:
for metric in metrics:
    print(f"Top five entries by {metric}:\n")
    with open(os.path.join(final_dir, f"avg_{metric}.json"), "r") as f:
        data = json.load(f)
        for i in range(5):
            print(data[i])
        print("\n")

Top five entries by bleu_1:

{'label': '도로교통법 시행규칙 제19조 제3항', 'full_text': '③ 경찰청장 또는 지방경찰청장이 법 제17조제2항에 따라 구역 또는 구간을 지정하여 자동차등의 속도를 제한하려는 경우에는 「도로의 구조ㆍ시설기준에 관한 규칙」 제8조에 따른 설계속도, 실제 주행속도, 교통사고 발생 위험성, 도로주변 여건 등을 고려하여야 한다.<신설 2010.7.9.>', 'avg_bleu_1': 0.0, 'avg_bleu_4': 0.0, 'avg_rouge_1': 0.0, 'avg_rouge_l': 0.0, 'avg_gpt_score': 0.0}
{'label': '자동차손해배상 보장법 시행령 제12조의3 제2항', 'full_text': '② 보험회사등이 법 제14조제4항 전단에 따라 교통사고 관련 조사기록의 열람을 청구하는 경우에는 열람예정일 7일 전까지 열람청구서에 열람사유서를 첨부하여 경찰관서에 제출하여야 한다. 다만, 긴급하거나 부득이한 사유가 있음을 소명하는 경우에는 그러하지 아니하다.', 'avg_bleu_1': 0.0, 'avg_bleu_4': 0.0, 'avg_rouge_1': 0.0, 'avg_rouge_l': 0.0, 'avg_gpt_score': 0.0}
{'label': '도로교통법 시행규칙 제37조의3 제2항', 'full_text': '② 제1항에 따라 정보의 제공을 요청받은 경찰서장은 법 제53조의4에 따라 어린이 교육시설을 감독하는 주무기관의 장이 요청한 정보를 해당 주무기관의 장에게 제공할 수 있다. [본조신설 2014.12.31.]', 'avg_bleu_1': 0.0, 'avg_bleu_4': 0.0, 'avg_rouge_1': 0.0, 'avg_rouge_l': 0.0, 'avg_gpt_score': 0.0}
{'label': '도로교통법 제74조 제2항', 'full_text': '② 지방경찰청장은 교통안전교육을 하기 위하여 다음 각 호의 어느 하나에 해당하는 기관이나 시

Now we can save the top 10% and bottom 10% for each metric into a new JSON file

In [17]:
percent = float(0.1)
for metric in metrics:
    with open(os.path.join(final_dir, f"avg_{metric}.json"), "r") as avg_file:
        data = json.load(avg_file)
        with open(os.path.join(final_dir, f"top_{str(int(percent*100))}_{metric}.json"), "w") as f:
            json.dump(data[:int(len(data)*percent)], f, indent=2, ensure_ascii=False)
        with open(os.path.join(final_dir, f"bottom_{str(int(percent*100))}_{metric}.json"), "w") as f:
            json.dump(data[-int(len(data)*percent):], f, indent=2, ensure_ascii=False)