# TODO
---
1. Convert codes from list to string
---
2. Contract line numbers


In [79]:
import json
import os, sys
from openai import OpenAI
from pydantic import BaseModel, Field

from tqdm.notebook import tqdm
import time
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..','..'))
sys.path.append(project_root)
from config.keys import OPENAI_API_KEY

In [80]:
def get_numbered_contract(contract):
        lines = contract.split("\n")
        numbered_lines = [f"{i+1}: {line}" for i, line in enumerate(lines)]
        return "\n".join(numbered_lines)
    
class Generator:
    def __init__(self, vulnerability, schema):
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        self.vulnerability = vulnerability
        self.schema = schema
        self.message = []
        if "Timestamp Dependency" in vulnerability:
            documentation_file = "TD_instruction.txt"
        elif "Integer overflow/underflow" in vulnerability:
            documentation_file = "IoU_instruction.txt"
        elif "Reentrancy" in vulnerability:
            documentation_file = "RE_instruction.txt"
        with open(os.path.join("documentation", documentation_file), "r") as f:
            self.documentation = f.read()
        self.system_message = {"role":"system","content":"You are a cyber-security programmer that can detect vulnerable lines of the contract based on the instruction."}
        self.user_prefix = f"""In the code below, detect {vulnerability} vulnerabilities and provide the risk, fixed code and the risk the vulnerability may cause regarding the vulnerable code snippet"""
        self.output_formatter = """[
   {"vulnerableLines": "l1-l2",
    "vulnerableCode": "<a list containing lines of vulnerable piece of code>"
    "vulnerabilityReason": "<Reasons of the ulnerability of lines l1 to l2>",
    "potentialSecurityRisk": "<Potential risks the vulneraility causes>",
    "fixedCode": "<the healthy code snippet>"
  }
  ...
  ]"""
        self.formatter = f"Return the response in RFC8259 compliant JSON according to the ResponseFormat schema with no other text. Follow the example format:\n{self.output_formatter}"

    
    def set_target_vulnerability(self, vulnerability):
        self.vulnerability = vulnerability
        
    def update_message(self, new_message):
        self.message.append(new_message)
        
    def get_user_message(self, code, instruction, helper):
        self.user_content = f"""
{self.user_prefix}


Here is the vulnerable lines and corresponding vulnerable code:
{helper}
-----------------
Smart Contract Code:
{code}

-----------------
Return the result in the below JSON format.
{self.output_formatter}
###
        """
# -----------------
# {self.formatter}
        user_message = {"role": "user", "content":self.user_content}
        return user_message

    def get_example_message(self, example_data):
        train_code, instruction, train_response, helper = map_example(example_data)
        numbered_train_code = get_numbered_contract(train_code)

        train_user_message = self.get_user_message(numbered_train_code, instruction, helper)
        train_assistant_message = {"role": "assistant", "content": str(train_response)}
        return [train_user_message, train_assistant_message]
        

    def create_prompt(self, train_data, code, instruction, helper):
        self.message = []
        self.message.append(self.system_message)
        self.message.append({"role": "user", "content":self.documentation})
        for example_data in train_data:
            self.message.extend(self.get_example_message(example_data))        # 
        self.message.append(self.get_user_message(get_numbered_contract(code), instruction, helper))
        
    def generate(self):
        done = False
        i=0
        while not done:
            try:
                if i==5:
                    done=True
                completion = self.client.beta.chat.completions.parse(
                              model="gpt-4o-mini",
                              messages = self.message,
                              response_format=self.schema
                            )
                answer = json.loads(completion.choices[0].message.content)
                done = True
                i+=1
            except Exception as e:
                print(e)
                print("Rate limit exceeded. Paused for 120 seconds!")
        return answer, completion


def read_json_files(fewshot_dir, loc_dir):
    all_data = []
    for filename in os.listdir(fewshot_dir):
        if filename.endswith(".json"):
            file_path = os.path.join(fewshot_dir, filename)
            with open(file_path, 'r') as f:
                data = json.load(f)
            with open(os.path.join(loc_dir, filename), 'r') as f:
                fewshot_helper = json.load(f)
                # print("FEWSHOT HELPER")
                # print(fewshot_helper)
            data.append({"helper": fewshot_helper})
            # print("FS DATA")
            # print(data)
            all_data.append(data)
    return all_data

def map_example(example_data):

    if not example_data or len(example_data) < 1:
        raise ValueError("example_data must contain at least one element.")

    train_code = example_data[0]["input"]
    instruction = example_data[0]["output"]
    train_response = example_data[1:-1]
    start_line = example_data[1]["vulnerableLines"]
    helper = example_data[-1]
    return train_code, instruction, train_response, helper


In [81]:
# dataset = [
#     {"dataset_name":"ESC_timestamp", 
#      "vulnerability": "Timestamp Dependency"}
#     ]
# dataset = [
#     {"dataset_name": "source3_integeroverflow",
#            "vulnerability": "Integer overflow/underflow"}
# ]
# dataset = [
#     {"dataset_name":"source3_reentrancy", 
#      "vulnerability": "Reentrancy"}
#     ]

dataset = [
    {"dataset_name":"RE_FTSmartAudit_datasets", 
     "vulnerability": "Reentrancy"}
    ]

end_sample_num = 2

In [82]:
from pydantic import BaseModel, Field
from typing import List

class SingleVulnerability(BaseModel):
    vulnerableLines: str
    vulnerableCode: List[str]
    vulnerabilityReason: str=Field(description='Why the code snippet is vulnerable') 
    potentialSecurityRisk: str=Field(description='The potential security risk the vulnerability may cause.') 
    fixedCode: str=Field(description='The healthy code snippet, not the way to fix ') 

class FullVulnerability(BaseModel):
    vulnerabilities: List[SingleVulnerability]



dataset_name = dataset[0]["dataset_name"]
vulnerability =dataset[0]["vulnerability"]
raw_fname = os.path.join("..", "..", "..","data", "dataset", "raw", dataset_name+".json")
fewshot_dir = os.path.join("..", "..", "..","data", "dataset", "few_shots", dataset_name)
processed_dir = os.path.join("..", "..", "..","data", "dataset", "processed_data", dataset_name)
loc_dir = os.path.join("..", "..", "..","data", "processed_data", dataset_name, "LOCs")
os.makedirs(loc_dir, exist_ok=True)
loc_helper_dir = os.path.join("..", "..", "..","data", "processed_data", dataset_name, "LOCs_old")

fewshot_data = read_json_files(fewshot_dir, loc_helper_dir)
# print(fewshot_data[0])
with open(raw_fname, 'r') as f:
    raw_data = json.load(f)

schema = FullVulnerability
for i, raw_record in enumerate(raw_data[:end_sample_num]):
    if f"{i}.json" in os.listdir(loc_dir):
        print(f"Contract '{raw_record['File Name']}' is already processed - skipping")
        continue
    with open(os.path.join(loc_helper_dir, f"{i}.json"), 'r') as f:
        helper = json.load(f)
    print(f"Contract '{raw_record['File Name']}' is Being  processed :)")
    
    generator = Generator(vulnerability, schema)
    prompt = generator.create_prompt(fewshot_data, code=raw_record["input"], instruction="", helper=helper)
    response, completions = generator.generate()
    answer = json.loads(completions.choices[0].message.content)
    with open(os.path.join(loc_dir, f"{i}.json"), "w", encoding="utf-8") as file:
        json.dump(answer, file, ensure_ascii=False, indent=4)
    break
print("Done!")
# os.makedirs(output_dir, exist_ok=True)
# locs_dir = os.path.join(output_dir, "LOCs")
# os.makedirs(locs_dir, exist_ok=True)
# contracts_dir = os.path.join(output_dir, "contracts")
# os.makedirs(contracts_dir, exist_ok=True)
# raw_dir = "../../data/dataset/raw"

Contract 'reentrancy_bonus.sol' is Being  processed :)
Done!


In [90]:
print(generator.message[2]['content'])


In the code below, detect Reentrancy vulnerabilities and provide the risk, fixed code and the risk the vulnerability may cause regarding the vulnerable code snippet


Here is the vulnerable lines and corresponding vulnerable code:
{'helper': [{'start_line': 23, 'end_line': 27, 'code': ['            if(msg.sender.call.value(_am)())', '            {', '                acc.balance-=_am;', '                LogFile.AddMessage(msg.sender,_am,"Collect");', '            }']}]}
-----------------
Smart Contract Code:
1: pragma solidity ^0.4.25;
2: 
3: contract MY_BANK
4: {
5:     function Put(uint _unlockTime)
6:     public
7:     payable
8:     {
9:         var acc = Acc[msg.sender];
10:         acc.balance += msg.value;
11:         acc.unlockTime = _unlockTime>now?_unlockTime:now;
12:         LogFile.AddMessage(msg.sender,msg.value,"Put");
13:     }
14: 
15:     function Collect(uint _am)
16:     public
17:     payable
18:     {
19:         var acc = Acc[msg.sender];
20:         if( acc.balan

In [84]:
raw_record

{'Vulnerability Type': 'reentrancy',
 'File Name': 'reentrancy_bonus.sol',
 'Source Code': '/*\n * @source: https://consensys.github.io/smart-contract-best-practices/known_attacks/\n * @author: consensys\n * @vulnerable_at_lines: 28\n */\n\npragma solidity ^0.4.24;\n\ncontract Reentrancy_bonus{\n\n    // INSECURE\n    mapping (address => uint) private userBalances;\n    mapping (address => bool) private claimedBonus;\n    mapping (address => uint) private rewardsForA;\n\n    function withdrawReward(address recipient) public {\n        uint amountToWithdraw = rewardsForA[recipient];\n        rewardsForA[recipient] = 0;\n        (bool success, ) = recipient.call.value(amountToWithdraw)("");\n        require(success);\n    }\n\n    function getFirstWithdrawalBonus(address recipient) public {\n        require(!claimedBonus[recipient]); // Each recipient should only be able to claim the bonus once\n\n        rewardsForA[recipient] += 100;\n        // <yes> <report> REENTRANCY\n        withd

In [71]:
print(generator.message[6]['content'])


In the code below, detect Timestamp Dependency vulnerabilities and provide the risk, fixed code and the risk the vulnerability may cause regarding the vulnerable code snippet


Here is the vulnerable lines and corresponding vulnerable code:
[{'start_line': 36, 'end_line': 36, 'code': ['        var random = uint(sha3(block.timestamp)) % 2;']}, {'start_line': 42, 'end_line': 45, 'code': ['            bank.transfer(FEE_AMOUNT);', '', '            ', '            msg.sender.transfer(pot - FEE_AMOUNT);']}]
-----------------
Smart Contract Code:
1: pragma solidity ^0.4.15;
2: 
3: 
4: 
5: contract EtherLotto {
6: 
7:     
8:     uint constant TICKET_AMOUNT = 10;
9: 
10:     
11:     uint constant FEE_AMOUNT = 1;
12: 
13:     
14:     address public bank;
15: 
16:     
17:     uint public pot;
18: 
19:     
20:     function EtherLotto() {
21:         bank = msg.sender;
22:     }
23: 
24:     
25:     
26:     function play() payable {
27: 
28:         
29:         assert(msg.value == TICKET_A