In [31]:
import json, openai
import os, sys
from openai import OpenAI
from tqdm.notebook import tqdm
import time
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)
from config.keys import OPENAI_API_KEY

In [32]:
class Generator:
    def __init__(self):
        self.client = openai
        self.json_formatter = "Return the response in RFC8259 compliant JSON according to the ResponseFormat schema with no other text."
        self.message = [{
            "role": "system",
            "content": "You are a cyber-security programmer that can detect line numbers from the contract based on the instruction."
        }]
        self.output_formatter = """
This should instruct the model to output exactly the vulnerability lines, ensuring it doesn't output extra lines or large ranges that contain unrelated code.

Response Schema:
 [
    {
      "start_line": <exact_start_line_number>,
      "end_line": <exact_end_line_number>,
      "code": [
        "vulnerable line 1",
        "vulnerable line 2",
        "... (and so on)"
      ]
    }
 ]
 
** Do not use ```json or any other extra texts in the output. Include only the list of detected lines as the schema.
"""
        self.user_prefix = """You are given a record from a dataset containing smart contract vulnerability analyses. Each record contains an "id", "prompt", and "completion". The "prompt" includes both the instructions and the smart contract code snippet, while the "completion" provides the vulnerability analysis.

Your task is to extract the exact lines of code that are vulnerable from the smart contract code. If the contract is vulnerable, output the minimal range of lines (with exact start and end line numbers) and the code lines themselves in the JSON Response Schema format provided below. If the contract is safe, do not output any JSON (i.e. return an empty result).

Instructions:
- Extract the smart contract code from the prompt.
- If the completion analysis indicates that the contract is vulnerable (i.e. it does not state that the contract is safe), locate the vulnerable lines of code precisely.
- Do not include any extra commentary or unrelated code; only include the vulnerable lines.
"""
    def get_user_message(self, dataset_output, contract):
        self.user_content = f"""
{self.user_prefix}

This is the helping document to find the lines of vulnerable codes.
Guideline:
{dataset_output}

Smart contract code:
{contract}

Additional Note:
- Only output the minimal range of code lines that are directly vulnerable.
- Do not include any commentary or unrelated code.
- Follow the JSON Response Schema exactly:
[
    {{
      "start_line": <exact_start_line_number>,
      "end_line": <exact_end_line_number>,
      "code": [
        "vulnerable line 1",
        "vulnerable line 2",
        "... (and so on)"
      ]
    }}
]
"""
        self.user_message = {"role": "user", "content": self.user_content}
    def create_prompt(self, dataset_output, contract):
        self.get_user_message(dataset_output, contract)
        self.message.append(self.user_message)
        
        def generate(self):
        
        completion = self.client.chat.completions.create(
          model="gpt-4o-mini",
          messages = self.message,
          temperature=0.1,
          max_tokens=3200,
          top_p=1.,
          frequency_penalty=0,
          presence_penalty=0,
          stop=None
        )
        answer = completion.choices[0].message.content
        return answer, completion

In [33]:
dataset_name = "train_TrustLLM"
output_dir = os.path.join("data", "processed_data", dataset_name)
os.makedirs(output_dir, exist_ok=True)

contracts_dir = os.path.join(output_dir, "contracts")
os.makedirs(contracts_dir, exist_ok=True)

locs_dir = os.path.join(output_dir, "LOCs")
os.makedirs(locs_dir, exist_ok=True)

raw_dir = os.path.join("data", "dataset", "raw")
dataset_path = os.path.join(raw_dir, f"{dataset_name}.json")

if not os.path.exists(dataset_path):
    raise FileNotFoundError(f"Dataset file not found at: {dataset_path}")

In [34]:
with open(dataset_path, "r", encoding="utf-8") as file:
    data = json.load(file)
print("Total records:", len(data))
#print(data[1]["completion"])

Total records: 20885


In [35]:
ids = [record.get("id") for record in data]
if len(ids) == len(set(ids)):
    print("All record ids are unique.")
else:
    print("There are duplicate record ids.")

There are duplicate record ids.


In [36]:
end_contract = 20
vul_idx = 0  # Counter for vulnerable smart contracts

for idx, record in enumerate(tqdm(data[:end_contract], desc="Processing Records")):
    prompt_text = record.get("prompt", "")
    completion_text = record.get("completion", "")
    
    # Extract smart contract code from the prompt.
    # This regex now matches any code block enclosed in triple backticks regardless of the language tag.
    code_match = re.search(r"```[^\n]*\n(.*?)\n```", prompt_text, re.DOTALL)
    if code_match:
        contract_code = code_match.group(1).strip()
    else:
        print(f"Record {idx}: No smart contract code found in prompt.")
        continue
    
    # Skip safe contracts
    if "appears to be safe" in completion_text.lower():
        print(f"Record {idx} is marked as safe. Skipping vulnerability extraction.")
        continue
    # Otherwise, if it does not mention "the issue", assume there is no clear vulnerability description.
    if "the issue" not in completion_text.lower():
        print(f"Record {idx} does not clearly indicate a vulnerability. Skipping vulnerability extraction.")
        continue

    # Save the vulnerable smart contract code to a .sol file using vul_idx for numbering
    sol_filepath = os.path.join(contracts_dir, f"{vul_idx}.sol")
    with open(sol_filepath, 'w', encoding='utf-8') as f:
        f.write(contract_code)
    
    # Add line numbers to the contract code for vulnerability mapping
    lines = contract_code.split("\n")
    numbered_lines = [f"{i+1}: {line}" for i, line in enumerate(lines)]
    numbered_contract = "\n".join(numbered_lines)
    
    print(f"Extracting vulnerability for Record {idx} (vulnerable contract index: {vul_idx})...")
    generator = Generator()
    # Use the 'completion' as the guideline for vulnerability extraction
    generator.create_prompt(completion_text, numbered_contract)
    start_time = time.time()
    answer, _ = generator.generate()
    print(f"Vulnerability extraction for Record {idx} completed in {round(time.time() - start_time, 2)} seconds")
    
    # Parse the answer as JSON and save the vulnerability details if extraction was successful.
    try:
        vulnerability_data = json.loads(answer)
    except Exception as e:
        print(f"Error parsing vulnerability JSON for Record {idx}: {e}")
        continue
    
    json_filepath = os.path.join(locs_dir, f"{vul_idx}.json")
    with open(json_filepath, "w", encoding="utf-8") as json_file:
        json.dump(vulnerability_data, json_file, indent=4)
    print(f"Saved vulnerability data for Record {idx} to {json_filepath}")
    
    vul_idx += 1

print("Process completed!")

Processing Records:   0%|          | 0/20 [00:00<?, ?it/s]

Record 0 is marked as safe. Skipping vulnerability extraction.
Record 1 is marked as safe. Skipping vulnerability extraction.
Record 2 is marked as safe. Skipping vulnerability extraction.
Record 3 is marked as safe. Skipping vulnerability extraction.
Record 4 is marked as safe. Skipping vulnerability extraction.
Extracting vulnerability for Record 5 (vulnerable contract index: 0)...


APIRemovedInV1: 

You tried to access openai.ChatCompletion, but this is no longer supported in openai>=1.0.0 - see the README at https://github.com/openai/openai-python for the API.

You can run `openai migrate` to automatically upgrade your codebase to use the 1.0.0 interface. 

Alternatively, you can pin your installation to the old version, e.g. `pip install openai==0.28`

A detailed migration guide is available here: https://github.com/openai/openai-python/discussions/742
