In [171]:
from pymongo import MongoClient
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.llms import openai
from langchain.chains import retrieval_qa
import os
from dotenv import load_dotenv, find_dotenv
import utils
from utils import logger
from pathlib import Path
import json
import pickle
load_dotenv(find_dotenv())

ATLAS_DB_URI = ''.join(["mongodb+srv://",
                f"{os.environ.get('ATLAS_DB_USERNAME')}:",
                f"{os.environ.get('ATLAS_DB_PASSWORD')}",
                "@ai-auditor-dev.2mfs29p.mongodb.net/?retryWrites=true&w=majority",
                "&appName=ai-auditor-dev"])

def ping_mongodb():
    client = MongoClient(ATLAS_DB_URI)
    try:
        client.admin.command('ping')
        print("Pinged your deployment. You successfully connected to MongoDB!")
    except Exception as e:
        print(e)


In [24]:
def load_data(
        save_to_pickle=True, 
        force_reload=False, 
        suffix="_vulnerabilities_formatted.txt"
):
    if os.path.exists(Path(utils.DATADIR) / 'data.pickle') and not force_reload:
        logger.info("Loading data from pickle file")
        with open(Path(utils.DATADIR) / 'data.pickle', 'rb') as f:
            data = pickle.load(f)
        return data

    datafiles = [file for file in os.listdir(utils.DATADIR)
                if os.path.isfile(Path(utils.DATADIR) / file)
                and file.endswith('.txt')]
    data = {}
    logger.info("Starting data load, loading data from %s files", len(datafiles))
    for file in datafiles:
        logger.info("Reading file: %s", file)
        section_data = []
        with open(Path(utils.DATADIR) / file, 'r', encoding='utf-8') as f:
            read_data = False
            json_data = ""
            for line in f:
                if "----Start JSON----" in line:
                    read_data = True
                    continue

                if "----End JSON----" in line:
                    read_data = False
                    try:
                        section_data.append(json.loads(json_data))
                    except json.JSONDecodeError as e:
                        pass
                    logger.debug("Read JSON Data: %s", section_data[-1].keys())
                    json_data = ""
                    continue

                if read_data:
                    json_data += line
            
            data[file.removesuffix(suffix)] = section_data
            
    logger.info("Loaded data from files")
    if save_to_pickle:
        logger.info("Saving data to pickle file")
        with open(Path(utils.DATADIR) / 'data.pickle', 'wb') as f:
            pickle.dump(data, f)
    
    return data

In [25]:
vuln_data = load_data(force_reload=True)

2024-05-26 13:31:02,393 - INFO - 517238322 - load_data - Starting data load, loading data from 5 files
2024-05-26 13:31:02,395 - INFO - 517238322 - load_data - Reading file: ConsenSys_vulnerabilities_formatted.txt
2024-05-26 13:31:02,398 - DEBUG - 517238322 - load_data - Read JSON Data: dict_keys(['code', 'Resolution', 'Description', 'Example', 'Recommendation'])
2024-05-26 13:31:02,399 - DEBUG - 517238322 - load_data - Read JSON Data: dict_keys(['code', 'Resolution', 'Description', 'Example', 'Recommendation'])
2024-05-26 13:31:02,401 - DEBUG - 517238322 - load_data - Read JSON Data: dict_keys(['code', 'Resolution', 'Description', 'Examples', 'Recommendation'])
2024-05-26 13:31:02,402 - DEBUG - 517238322 - load_data - Read JSON Data: dict_keys(['code', 'Resolution', 'Description', 'Examples', 'Recommendation'])
2024-05-26 13:31:02,405 - DEBUG - 517238322 - load_data - Read JSON Data: dict_keys(['code', 'Resolution', 'Description', 'Examples', 'Recommendation'])
2024-05-26 13:31:02,407

In [26]:
vuln_data.keys()

dict_keys(['ConsenSys', 'Cyfrin', 'Pashov_Audit_Group', 'Sherlock', 'Trust_Security'])

In [36]:
import itertools
no_sherlock_data = [vuln_data[k] for k in vuln_data.keys() 
                    if k.lower() != "sherlock"]

no_sherlock_data = list(itertools.chain(*no_sherlock_data))
sherlock_data = vuln_data["Sherlock"]

In [60]:
def get_cleaned_explanations(entry, exclude_keys=[]):
    text = ''
    EXCLUDE_KEYS = [
        'code',
        'Recommendation',
        "Mitigation Review",
        "Recommendations",
        "Cyfrin",
        "Client",
        "Recommended Mitigation",
        "Resolution",
        "Example",
    ] + exclude_keys
    
    for k in entry:
        if k not in EXCLUDE_KEYS:
            text += k + ': ' + ''.join(entry[k]) + ' '
    
    text.replace('\n', ' ')

    return text

In [159]:
def split_code_by_function(code):
    snippets = []
    parentheses = []
    opened = False
    curr_function = ''

    opposite = {
        '}': '{',
        ']': '[',
        ')': '('
    }
    for char in code:
        if char in '{([':
            if char == '{':
                
                opened = True
            parentheses.append(char)
        elif char in '}])':
            if parentheses:
                if opposite[char] == parentheses[-1]:  
                    parentheses.pop()
        
        curr_function += char

        if opened:
            if not parentheses:
                opened = False
                snippets.append(curr_function)
                curr_function = ''
    
    snippets.append(curr_function)

    return snippets

In [156]:
def split_and_combine_code(code, min_snippet_len=10, max_snippet_len=1000):
    return_code = []
    for snippet in code:
        return_code += [s for s in split_code_by_function(snippet) if s
                        and len(s) > min_snippet_len
                        and len(s) < max_snippet_len]
    
    return return_code


In [149]:
split_and_combine_code(no_sherlock_data[1]['code'])[3]

';\nconst json = await response.json();\n\n'

In [152]:
print(no_sherlock_data[1]['code'][0])

if (assetName.startsWith('W')) {
 // Assume this is a wrapped token
 assetName = assetName.slice(1); // remove W
}
try {




In [157]:
split_and_combine_code(no_sherlock_data[1]['code'])

["if (assetName.startsWith('W')) {\n // Assume this is a wrapped token\n assetName = assetName.slice(1); // remove W\n}",
 'const response = await fetch(\n `https://api.binance.com/api/v3/ticker/price?symbol=${assetName.toUpperCase()}USDT`,\n)',
 ';\nconst json = await response.json();\n\n']

In [155]:
no_sherlock_data[1]['code']

["if (assetName.startsWith('W')) {\n // Assume this is a wrapped token\n assetName = assetName.slice(1); // remove W\n}\ntry {\n\n",
 'const response = await fetch(\n `https://api.binance.com/api/v3/ticker/price?symbol=${assetName.toUpperCase()}USDT`,\n);\nconst json = await response.json();\n\n']

In [166]:
cleaned_no_sherlock_data = []

for entry in no_sherlock_data:
    if not entry['code']:
        logger.debug(entry)
        continue

    cleaned_no_sherlock_data.append(
        (
            split_and_combine_code(entry['code']),
            get_cleaned_explanations(entry)
        )
    )
    


2024-05-26 17:34:39,039 - DEBUG - 1109072594 - <module> - {'code': [], 'Resolution': ['Addressed with the following changesets: fort-major/msq@7f9cde2 and fort-major/msq@0b9f8d1 (removing whitelisted method names, only allowing  icrc1_transfer)', 'The client provided the following statement:'], 'Description': ['Identities are bound to their origin (URL). Third-party origins are outside the scope of this Snap and are therefore in a lower trust zone where it is unsure what security measures are in place to protect the dApp from impersonating the users’ wallet identity. dApps may be hosted on integrity protecting endpoints (ipfs/IC), however, this is not enforced.', 'Protected RPC functions can only be invoked by the MSQ administrative origin. User consent may not consistently be enforced on the administrative origin.', 'The administrative origin is identified by the origin URL. According to the client the dApp is hosted on an integrity protecting endpoint (IC). This already protects from

In [145]:
code = "    function getFees(...)\n        internal\n        view\n        returns (uint256 feeInToken, uint256 nativeFees)\n    {\n@>      nativeFees = controller.getMinFees(connector, gasLimit, payloadSize);\n        feeInToken = Configuration.getStaticWithdrawFee(token, connector);\n    }\n    function executeBridging(...)\n        internal\n    {\n        ISocketControllerWithPayload socketController =\n            ISocketControllerWithPayload(Configuration.getController(withdrawToken));\n\n        (uint256 tokenFees, uint256 nativeFees) =\n            getFees(withdrawToken, socketController, socketConnector, socketMsgGasLimit, socketPayloadSize);\n        if (tokenAmount > tokenFees) {\n            uint256 tokensToWithdraw = tokenAmount - tokenFees;\n@>          socketController.bridge{ value: nativeFees }({\n                receiver_: receiver,\n                amount_: tokensToWithdraw,\n                msgGasLimit_: socketMsgGasLimit,\n                connector_: socketConnector,\n                execPayload_: abi.encode(),\n                options_: abi.encode()\n            });\n            withdrawToken.safeTransfer(OwnableStorage.getOwner(), tokenFees);\n        } else {\n            revert Errors.NotEnoughFees(tokenAmount, tokenFees);\n        }\n    }\n"
print(split_code_by_function(code)[0])

    function getFees(...)
        internal
        view
        returns (uint256 feeInToken, uint256 nativeFees)
    {
@>      nativeFees = controller.getMinFees(connector, gasLimit, payloadSize);
        feeInToken = Configuration.getStaticWithdrawFee(token, connector);
    }


In [169]:
cleaned_no_sherlock_data[850]


(['    function depositETHV2(address _recipient, uint256 _proxyFeeInWei) external payable nonETHReuse {\n        address _cEther = address(cEther);\n\n        ICEther(_cEther).mint{value: msg.value - _proxyFeeInWei}();\n    ....\n    ....\n'],
 "Severity:  Impact: \nHigh, as the protocol can lose potential yield in the form of fees Likelihood: \nHigh, as users can craft such transactions in a permissionless way Description: The codebase is using a fee mechanism where the users pay a fee for using some functionality. An example where this is done is the Compound::depositETHV2 method, as we can see here:The problem with this approach is that the value of the fee is controlled by the user through the _proxyFeeInWei argument, meaning he can always send 0 value to it so he doesn't pay any fees. ")

## Data embedding

Now that the data is cleaned and split into functions and put together with all the explanations of the data, we will try to create some vector embeddings

In [172]:
from openai import OpenAI

client = OpenAI(api_key=os.environ.get('OPENAI_API_KEY'))

def get_embedding(text, model="text-embedding-3-small"):
   text = text.replace("\n", " ")
   return client.embeddings.create(input = [text], model=model).data[0].embedding

In [175]:
get_embedding(text=''.join(cleaned_no_sherlock_data[0][0]))

[-0.0048248483799397945,
 0.019132444635033607,
 -0.011129984632134438,
 0.03833166882395744,
 0.07804345339536667,
 -0.005976801738142967,
 -0.05159860849380493,
 0.047591812908649445,
 0.04000116512179375,
 -0.00766299432143569,
 0.03428035229444504,
 -0.0006831028149463236,
 0.032610855996608734,
 -0.026956822723150253,
 -0.006026886869221926,
 0.017006617039442062,
 -0.01942182332277298,
 0.03405775502324104,
 -0.016372207552194595,
 0.016316557303071022,
 0.023595567792654037,
 0.0062383562326431274,
 0.04870481416583061,
 -0.04567745700478554,
 -0.0063385264948010445,
 0.008675822988152504,
 0.009660826995968819,
 -0.010673655197024345,
 -0.014123951084911823,
 -0.05021849274635315,
 0.01973346248269081,
 -0.02918281964957714,
 0.0415593646466732,
 -0.022860988974571228,
 0.05836563929915428,
 0.04231620207428932,
 0.003558812662959099,
 -0.00011469101445982233,
 -0.026133203878998756,
 0.034035492688417435,
 -0.018275434151291847,
 -0.09349187463521957,
 -0.004017924424260855,
 

The above simple model allows us to get all the embeddings, we can now simply store these into a vector DB

In [None]:
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
    openai_api_key=os.environ.get('OPENAI_API_KEY')
)

dbName = "ai_aiditor_dev"
collectionName = "code_snippets"
collection = client[dbName][collectionName]

vectorStore = MongoDBAtlasVectorSearch.from_documents(
    cleaned_no_sherlock_data,
    embedding=embeddings,
    collection=collection,
)