In [1]:
import json
import os, sys
from openai import OpenAI
from tqdm.notebook import tqdm
import time
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)
from config.keys import OPENAI_API_KEY

In [2]:
import os
import requests

class Generator:
    def __init__(self):
        self.client = OpenAI(api_key=OPENAI_API_KEY)


        self.json_formatter = "Return the response in RFC8259 compliant JSON according to the ResponseFormat schema with no other text."
        self.message = [{
            "role": "system",
            "content": 
                "You are a cyber-security programmer that can detect line numbers from the contract based on the instruction."
        }]
        self.output_formatter = """

This should instruct the model to output exactly the vulnerability lines, ensuring it doesn't output extra lines or large ranges that contain unrelated code.

Response Schema:
 [
    {
      "start_line": <exact_start_line_number>,
      "end_line": <exact_end_line_number>,
      "code": [
        "vulnerable line 1",
        "vulnerable line 2",
        "... (and so on)"
      ]
    }
  ]


** Do not use ```json or any other extra texts in the output. Include only the list of detected lines as the schema.
"""
        self.user_prefix = """You are given a smart contract code snippet and an explanation document on how to detect vulnerabilities. Your task is to identify and extract the exact lines of code where a vulnerability occurs—only the specific lines that are vulnerable, not any extra context or surrounding code.
Instructions:
1. Input Data:
    - Explanation: A detailed document containing guidelines for detecting vulnerabilities.
    - Smart Contract Code: The smart contract code is provided under the <Smart contract> tag.

2. Task Requirements:
    - Use the explanation guidelines to precisely locate all of the vulnerabilities in the code.
    - Extract only the exact lines of code that are vulnerable.
    - Do not provide a broad range of line numbers that include additional non-vulnerable lines. Instead, Be precise for the vulnerability lines and pinpoint the start and end lines where the vulnerability occurs, ensuring the extraction is minimal and exact.

3. Output Requirements:
    - Return your output as RFC8259 compliant JSON with no additional text.
    - The output should include:
        -- The exact start line number of the vulnerable code segment.
        -- The exact end line number of the vulnerable code segment.
        -- An array containing each exact line of vulnerable code.
    ** Do not include any lines of code that are not directly related to the vulnerability. **

"""

    def get_user_message(self, dataset_output, contract):
        self.user_content = f"""
{self.user_prefix}


This is the helping document to find the lines of vulnerable codes.
Guideline:
{dataset_output}

Smart contract code:
{contract}

Additional Note:

Be precise: Be noticed that most of the vulnerabilities occur in few lines. if the vulnerability is only on a few lines (for example, lines 215 to 218), only output those lines. Avoid outputting large ranges that include non-vulnerable lines.
Do not include any commentary or extraneous information outside of the JSON output.
Do not return the entire function or big code snippet. Specifically return small snippets with vulnerability.
---
{self.output_formatter}

###
"""
        self.user_message = {"role": "user", "content": self.user_content}

    def create_prompt(self, dataset_output, contract):
        self.get_user_message(dataset_output, contract)
        self.message.append(self.user_message)

    def generate(self):
        
        completion = self.client.chat.completions.create(
          model="gpt-4o-mini",
          messages = self.message,
          temperature=0.1,
          max_tokens=3200,
          top_p=1.,
          frequency_penalty=0,
          presence_penalty=0,
          stop=None
        )
        answer = completion.choices[0].message.content
        return answer, completion


In [3]:
dataset_name = "ESC_timestamp"
source_dir = f"../../data/processed_data/{dataset_name}/"
locs_dir = os.path.join(source_dir, "LOCs")
contracts_dir = os.path.join(source_dir, "contracts")


In [4]:
loc_files = os.listdir(locs_dir)

In [21]:
for loc_fname in loc_files[:4]:
    with open(os.path.join(locs_dir, loc_fname), 'r') as f:
        loc_data = json.load(f)
    sol_fname = loc_fname[:-5]+".sol"
    print(loc_fname)
    print(sol_fname)
    # with open 

98.json
98.sol
41.json
41.sol
16.json
16.sol
57.json
57.sol


In [23]:
with open(os.path.join(contracts_dir, sol_fname), 'r', encoding="utf-8") as file:
    solidity_code = file.read()

In [24]:
contracts_dir

'../../data/processed_data/ESC_timestamp/contracts'

In [25]:
sol_fname

'57.sol'

In [27]:
print(solidity_code)

pragma solidity ^0.4.19;


 
contract Ownable {
  address public owner;


  event OwnershipTransferred(address indexed previousOwner, address indexed newOwner);


   
  function Ownable() public {
    owner = msg.sender;
  }

   
  modifier onlyOwner() {
    require(msg.sender == owner);
    _;
  }

   
  function transferOwnership(address newOwner) public onlyOwner {
    require(newOwner != address(0));
    OwnershipTransferred(owner, newOwner);
    owner = newOwner;
  }

}











contract ERC20Basic {
    function totalSupply() public view returns (uint256);
    function balanceOf(address who) public view returns (uint256);
    function transfer(address to, uint256 value) public returns (bool);
    event Transfer(address indexed from, address indexed to, uint256 value);
}










 
contract BasicToken is ERC20Basic {
  using SafeMath for uint256;

  mapping(address => uint256) balances;

  uint256 totalSupply_;

   
  function totalSupply() public view returns (uint256) {
    r