In [6]:
import json
import os, sys
from openai import OpenAI

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)
from config.keys import OPENAI_API_KEY

In [7]:
import os
import requests

class Generator:
    def __init__(self):
        self.client = OpenAI(api_key=OPENAI_API_KEY)


        self.json_formatter = "Return the response in RFC8259 compliant JSON according to the ResponseFormat schema with no other text."
        self.message = [{
            "role": "system",
            "content": 
                "You are a cyber-security programmer that can detect line numbers from the contract based on the instruction."
        }]
        self.output_formatter = """
Response Schema:
[
  {"start_line": <los1>, 
   "end_line": <loe1>, 
   "code": [
        "vulnerable line 1",
        "vulnerable line 2",
        "... (and so on)"
      ]},
  {"start_line": <los2>, 
   "end_line": <loe2>, 
   "code": [
        "vulnerable line 1",
        "vulnerable line 2",
        "... (and so on)"
      ]},
...
]

** Do not use ```json or any other extra texts in the output. Include only the list of detected lines as the schema.
"""
        self.user_prefix = """You are given a smart contract code snippet and an explanation document on how to detect vulnerabilities. Your task is to identify and extract all of the exact lines of code where a vulnerability occurs—only the specific lines that are vulnerable, not any extra context or surrounding code. You should follow the instructions and report all vulnerable line as output. 
What I want is to have the exact lines of vulnerable code, the start and the end lines of vulnerable pieces of code for the given code.
Instructions:
1. Input Data:
    - Explanation: A detailed document containing guidelines for detecting vulnerabilities under the <This is the helping document to find the lines of vulnerable codes.> tag.
    - Smart Contract Code: The smart contract code is provided under the <Smart contract code> tag.

2. Task Requirements:
    - Use the explanation guidelines to precisely locate all of the vulnerabilities in the code.
    - Extract if and only if the exact lines of code that are vulnerable.
    - Do not provide a broad range of line numbers that include additional non-vulnerable lines. Be precise and be limited to return only vulnerable lines.
    - If multiple vulnerabilities exist within a single function (or code block), list each vulnerability as a separate entry with its own start and end line numbers. Do not merge them into one broad range.

3. Output Requirements:
    - Return your output as RFC8259 compliant JSON with no additional text.
    - The output should include:
        -- The exact start line number of the vulnerable code segment.
        -- The exact end line number of the vulnerable code segment.
        -- An array containing each exact line of vulnerable code.
"""

    def get_user_message(self, dataset_output, contract):
        self.user_content = f"""
{self.user_prefix}

Additional Note:

Be precise: if the vulnerability is only on a few lines (for example, lines 215 to 218), only output those lines. Avoid outputting large ranges that include non-vulnerable lines.
Do not include any commentary or extraneous information outside of the JSON output.

This is the helping document to find the lines of vulnerable codes.
{dataset_output}

Smart contract code:
{contract}

---
{self.output_formatter}

###
"""
        self.user_message = {"role": "user", "content": self.user_content}

    def create_prompt(self, dataset_output, contract):
        self.get_user_message(dataset_output, contract)
        self.message.append(self.user_message)

    def generate(self):
        
        completion = self.client.chat.completions.create(
          model="gpt-4o-mini",
          messages = self.message,
          temperature=0.5,
          max_tokens=3200,
          top_p=1.,
          frequency_penalty=0,
          presence_penalty=0,
          stop=None
        )
        answer = completion.choices[0].message.content
        return answer, completion


In [14]:
dataset_name = "smartbugs_reentrancy"
output_dir = f"../../data/processed_data/{dataset_name}/"
os.makedirs(output_dir, exist_ok=True)
locs_dir = os.path.join(output_dir, "loc")
os.makedirs(locs_dir, exist_ok=True)
contracts_dir = os.path.join(output_dir, "contracts")
os.makedirs(contracts_dir, exist_ok=True)
raw_dir = "../../data/dataset/raw"

In [17]:
start_contract = 0
end_contract =  1

with open(f"{raw_dir}/{dataset_name}.json", "r",  encoding="utf-8") as file:
    data = json.load(file)
len(data)

1635

In [9]:

for contract_index, record in enumerate(data[start_contract: end_contarct]):
    if contract_index < start_contract:
        continue
        
    contract = record['input']
    with open(f"{contract_index}.sol", 'w') as f:
        f.write(contract)
    # lines = contract.split("\n")
    # numbered_lines = [f"{i+1}: {line}" for i, line in enumerate(lines)]
    # record['input'] = "\n".join(numbered_lines)
    # print(f"Checking contract {contract_index}...")

    # if record["output"][0] == "1":
    #     generator = Generator()
    #     generator.create_prompt( record['output'], record['input'])

    #     print(f"Generating vulnerability lines for Contract {contract_index}...")
    #     answer, completion = generator.generate()
    #     print(answer)
    #     json_answer = json.loads(answer)
    #     json_filename = os.path.join(output_dir, f"{contract_index}.json")  # Use vulnerable contract count for naming
    #     with open(json_filename, "w", encoding="utf-8") as json_file:
    #         json.dump(json_answer, json_file, indent=4)

    #     print(f"Saved: {json_filename}")

    # if contract_index >= end_contract:
    #     break

Checking contract 1...
Generating vulnerability lines for Contract 1...
[
  {
    "start_line": 268,
    "end_line": 272,
    "code": [
      "if(!playerTempAddress[myid].send(playerTempBetValue[myid])){",
      "LogResult(serialNumberOfResult, playerBetId[myid], playerTempAddress[myid], playerNumber[myid], playerDieResult[myid], playerTempBetValue[myid], 4, proof);",
      "playerPendingWithdrawals[playerTempAddress[myid]] = safeAdd(playerPendingWithdrawals[playerTempAddress[myid]], playerTempBetValue[myid]);"
    ]
  },
  {
    "start_line": 304,
    "end_line": 308,
    "code": [
      "if(!playerTempAddress[myid].send(playerTempReward[myid])){",
      "LogResult(serialNumberOfResult, playerBetId[myid], playerTempAddress[myid], playerNumber[myid], playerDieResult[myid], playerTempReward[myid], 2, proof);",
      "playerPendingWithdrawals[playerTempAddress[myid]] = safeAdd(playerPendingWithdrawals[playerTempAddress[myid]], playerTempReward[myid]);"
    ]
  },
  {
    "start_line": 33

In [10]:
print(answer)

[
  {
    "start_line": 268,
    "end_line": 272,
    "code": [
      "if(!playerTempAddress[myid].send(playerTempBetValue[myid])){",
      "LogResult(serialNumberOfResult, playerBetId[myid], playerTempAddress[myid], playerNumber[myid], playerDieResult[myid], playerTempBetValue[myid], 4, proof);",
      "playerPendingWithdrawals[playerTempAddress[myid]] = safeAdd(playerPendingWithdrawals[playerTempAddress[myid]], playerTempBetValue[myid]);"
    ]
  },
  {
    "start_line": 304,
    "end_line": 308,
    "code": [
      "if(!playerTempAddress[myid].send(playerTempReward[myid])){",
      "LogResult(serialNumberOfResult, playerBetId[myid], playerTempAddress[myid], playerNumber[myid], playerDieResult[myid], playerTempReward[myid], 2, proof);",
      "playerPendingWithdrawals[playerTempAddress[myid]] = safeAdd(playerPendingWithdrawals[playerTempAddress[myid]], playerTempReward[myid]);"
    ]
  },
  {
    "start_line": 336,
    "end_line": 339,
    "code": [
      "if(!playerTempAddress[myid]

In [11]:
len(json_answer)

6