# Step by step demonstration

### 1. Load packages & define model

In [21]:
import nltk
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

# define model & tokenizer
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda().half()
tokenizer = AutoTokenizer.from_pretrained(model_name)

### 2. Define article and question
- 問題是 ChatGPT 生的

In [22]:
# define article & question, target sentence: "[April 3] Researchers ..."
with open("long_knowledge.txt", "r") as file:
    article = file.read().replace('\\n', '').replace('. ', '.')
question = "What new attack methods used by the APT41 subgroup Earth Freybug were discovered, and how do they target endpoint protection systems?"

### 3. Split Sentence

In [23]:
# split article into sentences using nltk
nltk.download('punkt_tab')
sentences = nltk.tokenize.sent_tokenize(article)

for i, sentence in enumerate(sentences):
    print(f"------ sentence {i + 1}, length: {len(sentence)} ------")
    print(sentence[:200], "...")

------ sentence 1, length: 3543 ------
This week, the most significant news in the IT and cybersecurity world is the accidental discovery of a backdoor planted in XZ/liblzma.Red Hat also announced a related vulnerability, CVE-2024--3094, w ...
------ sentence 2, length: 783 ------
[April 1] A supply chain attack that had been lurking for three years was discovered, with the XZ Utils library recently implanted with a backdoor.The revelation of this supply chain attack over the w ...
------ sentence 3, length: 837 ------
[April 2] CISA reported to the US government about an Ivanti system intrusion incident.At the beginning of this year, Ivanti announced a series of Connect Secure and Policy Secure vulnerabilities, onc ...
------ sentence 4, length: 688 ------
[April 3] Researchers discovered that a group under the Chinese hacker organization APT41 is using more covert methods to evade detection.In recent years, attack activities by the Chinese hacker organ ...


[nltk_data] Downloading package punkt_tab to /home/nycu/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


### 4. Get prompt and apply template

In [24]:
# get prompt and apply template
prompt = f"""\
Below is an article, read the article and answer my question after the article.
Now the article begins:
{article}
Now the article ends.
Select several sentences from the article to answer my question.
Question: {question}
"""
messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
input_ids = tokenizer(text, return_tensors="pt").input_ids.cuda()
sentence_ids_list = [tokenizer(sentence, return_tensors="pt").input_ids.cuda() for sentence in sentences]

print("input_ids shape:", input_ids.shape)
print("sentence_ids shape:", [sentence_ids.shape for sentence_ids in sentence_ids_list])

input_ids shape: torch.Size([1, 1186])
sentence_ids shape: [torch.Size([1, 667]), torch.Size([1, 154]), torch.Size([1, 144]), torch.Size([1, 128])]


### 5. Constrained Sentence Prefix Decoding
- 找到可以區分不同 sentence 為止，所以第一個找到 "This" 就停，其他的需要再找到 "[April 1/2/3"

In [25]:
def calculate_probs_list(input_ids, sentence_ids_list):
    prefix_map = {}
    for i, sentence_ids in enumerate(sentence_ids_list):
        prefix = sentence_ids[0][0].item()
        if prefix not in prefix_map:
            prefix_map[prefix] = []
        prefix_map[prefix].append(i)
    prefixes = list(prefix_map.keys())
    
    if len(prefix_map) == 1:
        probs_list = [[1] for _ in sentence_ids_list]    
    else:
        probs_list = [None for _ in sentence_ids_list]  
        with torch.no_grad():
            outputs = model(input_ids, use_cache=True)
        
        probs = torch.softmax(outputs.logits[0, -1, prefixes], dim=-1)
        for prefix, prob in zip(prefixes, probs):
            for idx in prefix_map[prefix]:
                probs_list[idx] = [prob.item()]
    
    for prefix, indices in prefix_map.items():
        if len(indices) == 1:
            continue
        
        next_input_ids = torch.cat([input_ids, torch.tensor(prefix).reshape(1, 1).cuda()], dim=-1)
        next_sentence_ids_list = [sentence_ids_list[idx][:, 1:] for idx in indices]
        next_probs_list = calculate_probs_list(next_input_ids, next_sentence_ids_list)
        
        for idx, probs in zip(indices, next_probs_list):
            probs_list[idx].extend(probs)
    
    return probs_list

In [26]:
def constrained_sentence_prefix_decoding(input_ids, sentence_ids_list, top_k=2):
    probs_list = calculate_probs_list(input_ids, sentence_ids_list)
    scores = [np.mean(np.log(probs)) for probs in probs_list]
    indices = np.argsort(scores)[::-1][:top_k]
    
    for i in range(len(sentence_ids_list)):
        print(f"------ sentence {i + 1} ------")
        print("prefix:", tokenizer.decode(sentence_ids_list[i][0, :len(probs_list[i])]))
        print("probs:", probs_list[i])
        print("score:", scores[i])
    
    return [sentence_ids_list[idx] for idx in indices]

sentence_ids_list = constrained_sentence_prefix_decoding(input_ids, sentence_ids_list, top_k=2)

------ sentence 1 ------
prefix: This
probs: [0.09808349609375]
score: -2.321936162101722
------ sentence 2 ------
prefix: [April 1
probs: [0.90185546875, 1, 1, 0.016510009765625]
score: -1.0517723588782142
------ sentence 3 ------
prefix: [April 2
probs: [0.90185546875, 1, 1, 0.00856781005859375]
score: -1.2157610301851964
------ sentence 4 ------
prefix: [April 3
probs: [0.90185546875, 1, 1, 0.97509765625]
score: -0.03212966467593541


### 6. Skip Decoding

In [32]:
def find_subarray_index(a, b):
    n, m = len(a), len(b)
    for i in range(n - m + 1):
        if a[i : i + m] == b:
            return i
    return -1

def skip_decoding(input_ids, sentence_ids_list, max_length=256):
    intervals = []
    for i, sentence_ids in enumerate(sentence_ids_list):
        sentence = tokenizer.decode(sentence_ids[0], skip_special_token=False)
        idx = find_subarray_index(article, sentence)
        sub_article_ids = tokenizer(article[idx:], return_tensors="pt", max_length=max_length, truncation=True).input_ids.cuda()
        
        sentence_input_ids = torch.cat([input_ids, sub_article_ids], dim=-1)
        with torch.no_grad():
            outputs = model(sentence_input_ids, use_cache=True)
        
        probs = torch.softmax(outputs.logits[0, -len(sub_article_ids[0]):, :], dim=-1)
        eos_idx = torch.argmax(probs[:, tokenizer.eos_token_id])
        evidence_sentence = tokenizer.decode(sub_article_ids[0, :eos_idx], skip_special_token=False)        
        intervals.append((idx, idx + len(evidence_sentence)))
        
        print(f"------ evidence sentence {i + 1} ------")
        print("content:", sentence[:100], "...")
        print("idx (in article):", idx)
        print("length of sub_article (# tokens):", len(sub_article_ids[0]))
        print("eos_idx (maximum eos prob):", eos_idx.item())
        
    return intervals

intervals = skip_decoding(input_ids, sentence_ids_list, max_length=256)
print('-' * 30)
print("intervals:", intervals)


------ evidence sentence 1 ------
content: [April 3] Researchers discovered that a group under the Chinese hacker organization APT41 is using m ...
idx (in article): 5163
length of sub_article (# tokens): 128
eos_idx (maximum eos prob): 127
------ evidence sentence 2 ------
content: [April 1] A supply chain attack that had been lurking for three years was discovered, with the XZ Ut ...
idx (in article): 3543
length of sub_article (# tokens): 256
eos_idx (maximum eos prob): 81
------------------------------
intervals: [(5163, 5850), (3543, 3975)]


### 7. Merge evidence sentences
- 這個 case 剛好沒有 intersect 所以結果不變
- 下面放一些會 merge 的 case

In [33]:
def merge_interval(intervals):
    if not intervals or len(intervals) == 1:
        return intervals
    
    intervals.sort(key=lambda x: x[0])
    merged = [intervals[0]]

    for current in intervals[1:]:
        prev_start, prev_end = merged[-1]
        curr_start, curr_end = current

        if curr_start <= prev_end:
            merged[-1] = (prev_start, max(prev_end, curr_end))
        else:
            merged.append(current)

    return merged

intervals = merge_interval(intervals)
print("intervals:", intervals)

intervals: [(3543, 3975), (5163, 5850)]


In [34]:
test_intervals = [(1, 3), (2, 5), (7, 11)]
print(merge_interval(test_intervals))

test_intervals = [(1, 4), (5, 6), (6, 7)]
print(merge_interval(test_intervals))

test_intervals = [(1, 2), (3, 4), (5, 6)]
print(merge_interval(test_intervals))

[(1, 5), (7, 11)]
[(1, 4), (5, 7)]
[(1, 2), (3, 4), (5, 6)]


### 8. Show result

In [35]:
# show result
for i, (start, end) in enumerate(intervals):
    print(f"------- high-quality knowledge {i + 1} -------")
    print(article[start : end])

------- high-quality knowledge 1 -------
[April 1] A supply chain attack that had been lurking for three years was discovered, with the XZ Utils library recently implanted with a backdoor.The revelation of this supply chain attack over the weekend shocked the entire cybersecurity community and kept the IT world busy with patching.Attackers targeted the XZ Utils data compression library to implant a backdoor, which would allow attackers to bypass the SSHD authentication
------- high-quality knowledge 2 -------
[April 3] Researchers discovered that a group under the Chinese hacker organization APT41 is using more covert methods to evade detection.In recent years, attack activities by the Chinese hacker organization APT41 have been reported from time to time, but few related incidents were disclosed at the beginning of this year.Recently, researchers discovered the attack methods of a group under this organization.Cybersecurity firm Trend Micro discovered that the group, Earth Freybug, sp

In [36]:
# 一些測試用問題，都是 ChatGPT 生的，但我記得有些抓出來的 high-quality knowlege 會有錯就是了，分別對應 target sentence 第一二三四個
question_1 = "What is the significance of the backdoor discovered in XZ/liblzma, and what implications does it have for open-source software security?"
question_2 = "What specific vulnerabilities did the backdoor in the XZ Utils library exploit, and how did researchers identify its presence?"
question_3 = "How did the Ivanti system vulnerabilities lead to the CISA intrusion, and what were the critical impacts on its systems and operations?"
question_4 = "What new attack methods used by the APT41 subgroup Earth Freybug were discovered, and how do they target endpoint protection systems?"