# TODO
---
1. Convert codes from list to string
---
2. Contract line numbers


In [25]:
import json
import os, sys
from openai import OpenAI
from pydantic import BaseModel, Field

from tqdm.notebook import tqdm
import time
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..','..'))
sys.path.append(project_root)
from config.keys import OPENAI_API_KEY

In [26]:
def get_numbered_contract(contract):
        lines = contract.split("\n")
        numbered_lines = [f"{i+1}: {line}" for i, line in enumerate(lines)]
        return "\n".join(numbered_lines)
    
class Generator:
    def __init__(self, vulnerability, schema):
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        self.vulnerability = vulnerability
        self.schema = schema
        self.message = []
        if "Timestamp Dependency" in vulnerability:
            documentation_file = "TD_instruction.txt"
        elif "Integer overflow/underflow" in vulnerability:
            documentation_file = "IoU_instruction.txt"
        elif "Reentrancy" in vulnerability:
            documentation_file = "RE_instruction.txt"
        with open(os.path.join("documentation", documentation_file), "r") as f:
            self.documentation = f.read()
        self.system_message = {"role":"system","content":"You are a cyber-security programmer that can detect vulnerable lines of the contract based on the instruction."}
        self.user_prefix = f"""In the code below, detect {vulnerability} vulnerabilities and provide extra information regarding the vulnerable code snippet based on the given instruction"""
        self.output_formatter = """[
   {"vulnerableLines": "l1-l2",
    "vulnerableCode": "<a list containing lines of vulnerable piece of code>"
    "vulnerabilityReason": "<Reasons of the ulnerability of lines l1 to l2>",
    "potentialSecurityRisk": "<Potential risks the vulneraility causes>",
    "fixedCode": "<the healthy code snippet>"
  }
  ...
  ]"""
        self.formatter = f"Return the response in RFC8259 compliant JSON according to the ResponseFormat schema with no other text. Follow the example format:\n{self.output_formatter}"

    
    def set_target_vulnerability(self, vulnerability):
        self.vulnerability = vulnerability
        
    def update_message(self, new_message):
        self.message.append(new_message)
        
    def get_user_message(self, code, instruction, helper):
        self.user_content = f"""
{self.user_prefix}

Instruction:
{instruction}

Here is the vulnerable lines:
{helper}
-----------------
Smart Contract Code:
{code}

-----------------
{self.output_formatter}
###
        """
# -----------------
# {self.formatter}
        user_message = {"role": "user", "content":self.user_content}
        return user_message

    def get_example_message(self, example_data):
        train_code, instruction, train_response, helper = map_example(example_data)
        numbered_train_code = get_numbered_contract(train_code)

        train_user_message = self.get_user_message(numbered_train_code, instruction, helper)
        train_assistant_message = {"role": "assistant", "content": str(train_response)}
        return [train_user_message, train_assistant_message]
        

    def create_prompt(self, train_data, code, instruction, helper):
        self.message = []
        self.message.append(self.system_message)
        self.message.append({"role": "user", "content":self.documentation})
        for example_data in train_data:
            self.message.extend(self.get_example_message(example_data))        # 
        self.message.append(self.get_user_message(get_numbered_contract(code), instruction, helper))

    # def generate(self):
    #     done = False
    #     while not done:
    #         try:
    #             completion = self.completion_with_backoff(model="gpt-4o", 
    #                                       messages=self.message,
    #                                       temperature=1,
    #                                       max_tokens=4096,
    #                                       top_p=1.,
    #                                       frequency_penalty=0,
    #                                       presence_penalty=0)
        
        
    #             answer = completion.choices[0].message.content
    #             done = True
    #         except RateLimitError:
    #             time.sleep(60)
    #             print("Rate limit exceeded. Paused for 60 seconds!")
                
    #     return answer, completion
        
    def generate(self):
        done = False
        i=0
        while not done:
            try:
                if i==5:
                    done=True
                completion = self.client.beta.chat.completions.parse(
                              model="gpt-4o-mini",
                              messages = self.message,
                              response_format=self.schema
                            )
                answer = json.loads(completion.choices[0].message.content)
                done = True
                i+=1
            except Exception as e:
                print(e)
                print("Rate limit exceeded. Paused for 120 seconds!")
        return answer, completion


def read_json_files(fewshot_dir, loc_dir):
    all_data = []
    for filename in os.listdir(fewshot_dir):
        if filename.endswith(".json"):
            file_path = os.path.join(fewshot_dir, filename)
            with open(file_path, 'r') as f:
                data = json.load(f)
            with open(os.path.join(loc_dir, filename), 'r') as f:
                fewshot_helper = json.load(f)
                # print("FEWSHOT HELPER")
                # print(fewshot_helper)
            data.append({"helper": fewshot_helper})
            # print("FS DATA")
            # print(data)
            all_data.append(data)
    return all_data

def map_example(example_data):

    if not example_data or len(example_data) < 1:
        raise ValueError("example_data must contain at least one element.")

    train_code = example_data[0]["input"]
    instruction = example_data[0]["output"]
    train_response = example_data[1:-1]
    start_line = example_data[1]["vulnerableLines"]
    helper = example_data[-1]
    return train_code, instruction, train_response, helper


In [31]:
# dataset = [
#     {"dataset_name":"ESC_timestamp", 
#      "vulnerability": "Timestamp Dependency"}
#     ]
# dataset = [
#     {"dataset_name": "source3_integeroverflow",
#            "vulnerability": "Integer overflow/underflow"}
#  ]
dataset = [
    {"dataset_name":"source3_reentrancy", 
     "vulnerability": "Reentrancy"}
    ]
end_sample_num = 630 # for smartbugs, we only get 629 vulnerable code, not 1078.

In [28]:
# for full_source2_reentrancy_218: instead of input--> contract, instead of output--> target

In [None]:
from pydantic import BaseModel, Field
from typing import List

class SingleVulnerability(BaseModel):
    vulnerableLines: str
    vulnerableCode: List[str]
    vulnerabilityReason: str=Field(description='Why the code snippet is vulnerable') 
    potentialSecurityRisk: str=Field(description='The potential security risk the vulnerability may cause.') 
    fixedCode: str=Field(description='The healthy code snippet, not the way to fix ') 

class FullVulnerability(BaseModel):
    vulnerabilities: List[SingleVulnerability]



dataset_name = dataset[0]["dataset_name"]
vulnerability =dataset[0]["vulnerability"]
raw_fname = os.path.join("..", "..", "..","data", "dataset", "raw", dataset_name+".json")
fewshot_dir = os.path.join("..", "..", "..","data", "dataset", "few_shots", dataset_name)
processed_dir = os.path.join("..", "..", "..","data", "dataset", "processed_data", dataset_name)
loc_dir = os.path.join("..", "..", "..","data", "processed_data", dataset_name, "LOCs")
os.makedirs(loc_dir, exist_ok=True)
loc_helper_dir = os.path.join("..", "..", "..","data", "processed_data", dataset_name, "LOCs_old")

fewshot_data = read_json_files(fewshot_dir, loc_helper_dir)
# print(fewshot_data[0])
print(raw_fname)
with open(raw_fname, 'r', encoding="utf-8") as f:
    raw_data = json.load(f)

schema = FullVulnerability
for i, raw_record in enumerate(raw_data[:end_sample_num]): # [:end_sample_num]
    if raw_record["target"][0] == "0":
        print(f"Contract {i} is marked as healthy - skipping")
        continue
    if f"{i}.json" in os.listdir(loc_dir):
        print(f"Contract {i} is already processed - skipping")
        continue
    with open(os.path.join(loc_helper_dir, f"{i}.json"), 'r') as f:
        helper = json.load(f)
    print(f"Contract {i} is Being  processed :)")
    
    generator = Generator(vulnerability, schema)
    prompt = generator.create_prompt(fewshot_data, code=raw_record["input"], instruction=raw_record["output"], helper=helper)
    response, completions = generator.generate()
    answer = json.loads(completions.choices[0].message.content)
    with open(os.path.join(loc_dir, f"{i}.json"), "w", encoding="utf-8") as file:
        json.dump(answer, file, ensure_ascii=False, indent=4)
    #break
print("Done!")
# os.makedirs(output_dir, exist_ok=True)
# locs_dir = os.path.join(output_dir, "LOCs")
# os.makedirs(locs_dir, exist_ok=True)
# contracts_dir = os.path.join(output_dir, "contracts")
# os.makedirs(contracts_dir, exist_ok=True)
# raw_dir = "../../data/dataset/raw"

In [None]:
# answer = json.loads(completions.choices[0].message.content)
# answer

In [None]:
print(generator.message[6]['content'])

In [None]:
#generator.message

In [None]:
#### **adapt to the SmartAudit datasets : IoU,RE,TD**

In [None]:
import json
import os, sys
from openai import OpenAI
from pydantic import BaseModel, Field
from tqdm.notebook import tqdm
import time

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..', '..'))
sys.path.append(project_root)
from config.keys import OPENAI_API_KEY

def get_numbered_contract(contract):
    """Number each line of the smart contract for easier reference."""
    lines = contract.split("\n")
    numbered_lines = [f"{i+1}: {line}" for i, line in enumerate(lines)]
    return "\n".join(numbered_lines)

class SingleVulnerability(BaseModel):
    vulnerableLines: str
    vulnerableCode: list
    vulnerabilityReason: str = Field(description="Why the code snippet is vulnerable") 
    potentialSecurityRisk: str = Field(description="Potential security risks caused by this vulnerability") 
    fixedCode: str = Field(description="A secure version of the vulnerable code") 

class FullVulnerability(BaseModel):
    vulnerabilities: list[SingleVulnerability]

class Generator:
    def __init__(self, vulnerability, schema):
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        self.vulnerability = vulnerability
        self.schema = schema
        self.message = []
        
        if "Timestamp Dependency" in vulnerability:
            documentation_file = "TD_instruction.txt"
        elif "Integer overflow/underflow" in vulnerability:
            documentation_file = "IoU_instruction.txt"
        elif "Reentrancy" in vulnerability:
            documentation_file = "RE_instruction.txt"
        
        with open(os.path.join("documentation", documentation_file), "r") as f:
            self.documentation = f.read()
        
        self.system_message = {"role": "system", "content": "You are a cyber-security programmer that can detect vulnerable lines of the contract based on the instruction."}
        self.user_prefix = f"""In the code below, detect {vulnerability} vulnerabilities and provide extra information regarding the vulnerable code snippet based on the given instruction."""
        
        self.output_formatter = """[
            {"vulnerableLines": "l1-l2",
            "vulnerableCode": "<a list containing lines of vulnerable piece of code>",
            "vulnerabilityReason": "<Reasons for the vulnerability of lines l1 to l2>",
            "potentialSecurityRisk": "<Potential risks the vulnerability causes>",
            "fixedCode": "<The corrected secure code snippet>"
            }
        ]"""
        
        self.formatter = f"Return the response in RFC8259 compliant JSON according to the ResponseFormat schema with no other text. Follow the example format:\n{self.output_formatter}"

    def update_message(self, new_message):
        self.message.append(new_message)

    def get_user_message(self, code, helper):
        """Generate the user prompt without an `output` field."""
        self.user_content = f"""
{self.user_prefix}

Here is the vulnerable lines:
{helper}
-----------------
Smart Contract Code:
{code}

-----------------
{self.output_formatter}
###
        """
        user_message = {"role": "user", "content": self.user_content}
        return user_message

    def create_prompt(self, train_data, code, helper):
        """Create the prompt based on training data and current contract."""
        self.message = []
        self.message.append(self.system_message)
        self.message.append({"role": "user", "content": self.documentation})
        
        for example_data in train_data:
            self.message.extend(self.get_example_message(example_data)) 
        
        self.message.append(self.get_user_message(get_numbered_contract(code), helper))

    def get_example_message(self, example_data):
        """Prepare few-shot examples."""
        train_code, train_response, helper = map_example(example_data)
        numbered_train_code = get_numbered_contract(train_code)
        
        train_user_message = self.get_user_message(numbered_train_code, helper)
        train_assistant_message = {"role": "assistant", "content": str(train_response)}
        return [train_user_message, train_assistant_message]

    def generate(self):
        """Generate vulnerability analysis using GPT."""
        done = False
        i = 0
        while not done:
            try:
                if i == 5:
                    done = True
                completion = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=self.message,
                    response_format=self.schema  # ✅ Using schema as in original code
                )
                answer = json.loads(completion.choices[0].message.content)
                done = True
                i += 1
            except Exception as e:
                print(e)
                print("Rate limit exceeded. Paused for 120 seconds!")
                time.sleep(120)
        return answer, completion

def read_json_files(fewshot_dir, loc_dir):
    """Read and load few-shot examples."""
    all_data = []
    for filename in os.listdir(fewshot_dir):
        if filename.endswith(".json"):
            file_path = os.path.join(fewshot_dir, filename)
            with open(file_path, 'r') as f:
                data = json.load(f)
            with open(os.path.join(loc_dir, filename), 'r') as f:
                fewshot_helper = json.load(f)
            data.append({"helper": fewshot_helper})
            all_data.append(data)
    return all_data

def map_example(example_data):
    """Extract training data from few-shot examples."""
    if not example_data or len(example_data) < 1:
        raise ValueError("example_data must contain at least one element.")

    train_code = example_data[0]["input"]
    train_response = example_data[1:-1]
    helper = example_data[-1]
    return train_code, train_response, helper

# Define dataset structure
dataset = [
    {"dataset_name": "IoU_FTSmartAudit_datasets", 
     "vulnerability": "Reentrancy"}
]

end_sample_num = 2  # Limit processing to the first two contracts

# Define directory paths
dataset_name = dataset[0]["dataset_name"]
vulnerability = dataset[0]["vulnerability"]

raw_fname = os.path.join("..", "..", "..", "data", "dataset", "raw", dataset_name + ".json")
fewshot_dir = os.path.join("..", "..", "..", "data", "dataset", "few_shots", dataset_name)
processed_dir = os.path.join("..", "..", "..", "data", "dataset", "processed_data", dataset_name)
loc_dir = os.path.join("..", "..", "..", "data", "processed_data", dataset_name, "LOCs")
os.makedirs(loc_dir, exist_ok=True)
loc_helper_dir = os.path.join("..", "..", "..", "data", "processed_data", dataset_name, "LOCs_old")

# Load few-shot examples
fewshot_data = read_json_files(fewshot_dir, loc_helper_dir)

print(raw_fname)
with open(raw_fname, 'r', encoding="utf-8") as f:
    raw_data = json.load(f)

schema = FullVulnerability  # ✅ Using schema directly

# Process each contract
for i, raw_record in enumerate(raw_data[:end_sample_num]):  
    output_filename = os.path.join(loc_dir, f"{i}.json")
    
    if os.path.exists(output_filename):
        print(f"Contract {i} is already processed - skipping")
        continue

    with open(os.path.join(loc_helper_dir, f"{i}.json"), 'r') as f:
        helper = json.load(f)

    print(f"Processing contract {i}...")

    generator = Generator(vulnerability, schema)
    generator.create_prompt(fewshot_data, code=raw_record["input"], helper=helper)
    response, completions = generator.generate()

    if response:
        with open(output_filename, "w", encoding="utf-8") as file:
            json.dump(response, file, ensure_ascii=False, indent=4)

print("Done!")
