In [None]:
import csv
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import json

In [None]:
base_path = "code"

In [None]:
def load_plm(model_name):
  # AutoModelForMaskedLM
  tokenizer = AutoTokenizer.from_pretrained(f'huggingface/hub/{model_name}', trust_remote_code=True)
  model = AutoModelForCausalLM.from_pretrained(f'huggingface/hub/{model_name}', trust_remote_code=True)

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model.to(device)
  return tokenizer, model

In [None]:
tokenizer, model = load_plm('CodeLlama-13b-Instruct-hf')

# tokenizer, model = load_plm('starchat-alpha')

In [None]:
from transformers import pipeline

# 
chat_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# 
def chat_with_model(prompt):
    input_ids = chat_pipe.tokenizer.encode(prompt, return_tensors='pt').to('cuda')
    response = chat_pipe.model.generate(input_ids, max_new_tokens=128, no_repeat_ngram_size=2)
    return chat_pipe.tokenizer.decode(response[0], skip_special_tokens=True)

# # 
# prompt = " Python ？"  # 
# print(chat_with_model(prompt))


In [None]:
from transformers import pipeline


def get_answer(question):
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
   # 
  prompt_template = "<|system|>\n<|end|>\n<|user|>\n{query}<|end|>\n<|assistant|>"
  prompt = prompt_template.format(query=question)

  outputs = pipe(prompt, max_new_tokens=256, no_repeat_ngram_size=2)

  # 
  return outputs[0]["generated_text"]

In [None]:
from transformers import pipeline

# 
chat_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# 
def chat_with_model(question):
    prompt_template = """ 
# <s>[INST] 
# {query}
# [/INST]
#  
"""
    
    prompt = prompt_template.format(query=question)
    input_ids = chat_pipe.tokenizer.encode(prompt, return_tensors='pt').to('cuda')
    response = chat_pipe.model.generate(input_ids, max_new_tokens=512, no_repeat_ngram_size=2)
    return chat_pipe.tokenizer.decode(response[0], skip_special_tokens=True)

In [None]:
question = """
You are an AI trained to detect similar code expressions. Given a Smart Contract code and a specific target code expression, your task is to find and list the most similar expressions within the provided Smart Contract code. I will show you the answer format and then please analyze the new input following code file and search for expressions that closely resemble the target code piece provided.

```
{"Answer":"Yes" or "No", "similar_expressions": [
    {
      "function_name": the matched funciton name,
      "line_number": line_number,
      "expression": the similar code
    }
  ]
  "Reason": your reason
  }



Input Smart Contract Code:
```Solidity
function getUserLimitIntraAsset ( address payable userAddr ) external view override returns (uint256, uint256) {
		uint256 userTotalBorrowLimitAsset;
		uint256 userTotalMarginCallLimitAsset;
		for (uint256 handlerID;
 handlerID < tokenHandlerLength;
 handlerID++)		{
			if (dataStorageInstance.getTokenHandlerSupport(handlerID))			{
				uint256 depositHandlerAsset;
				uint256 borrowHandlerAsset;
				(depositHandlerAsset, borrowHandlerAsset) = _getUserIntraHandlerAssetWithInterest(userAddr, handlerID);
				uint256 borrowLimit = _getTokenHandlerBorrowLimit(handlerID);
				uint256 marginCallLimit = _getTokenHandlerMarginCallLimit(handlerID);
				uint256 userBorrowLimitAsset = depositHandlerAsset.unifiedMul(borrowLimit);
				uint256 userMarginCallLimitAsset = depositHandlerAsset.unifiedMul(marginCallLimit);
				userTotalBorrowLimitAsset = userTotalBorrowLimitAsset.add(userBorrowLimitAsset);
				userTotalMarginCallLimitAsset = userTotalMarginCallLimitAsset.add(userMarginCallLimitAsset);
			}
			else			{
				continue;
			}
		}
		return (userTotalBorrowLimitAsset, userTotalMarginCallLimitAsset);
	}

function _getUserIntraHandlerAssetWithInterest ( address payable userAddr , uint256 handlerID ) internal view returns (uint256, uint256) {
		uint256 price = _getTokenHandlerPrice(handlerID);
		proxyContractInterface tokenHandler = proxyContractInterface(dataStorageInstance.getTokenHandlerAddr(handlerID));
		uint256 depositAmount;
		uint256 borrowAmount;
		bytes memory data;
		(, data) = tokenHandler.handlerViewProxy(			abi.encodeWithSelector(				marketHandlerInterface.getUserAmountWithInterest.selector,				userAddr			)		);
		(depositAmount, borrowAmount) = abi.decode(data, (uint256, uint256));
		uint256 depositAsset = depositAmount.unifiedMul(price);
		uint256 borrowAsset = borrowAmount.unifiedMul(price);
		return (depositAsset, borrowAsset);
	}
function _checkLiquidation ( address payable userAddr ) internal view returns (bool) {
		uint256 userBorrowAssetSum;
		uint256 liquidationLimitAssetSum;
		uint256 tokenListLength = marketManager.getTokenHandlersLength();
		for (uint256 handlerID = 0;
 handlerID < tokenListLength;
 handlerID++)		{
			if (marketManager.getTokenHandlerSupport(handlerID))			{
				uint256 depositAsset;
				uint256 borrowAsset;
				(depositAsset, borrowAsset) = marketManager.getUserIntraHandlerAssetWithInterest(userAddr, handlerID);
				uint256 marginCallLimit = marketManager.getTokenHandlerMarginCallLimit(handlerID);
				liquidationLimitAssetSum = add(liquidationLimitAssetSum, unifiedMul(depositAsset, marginCallLimit));
				userBorrowAssetSum = add(userBorrowAssetSum, borrowAsset);
			}
		}
		if (liquidationLimitAssetSum <= userBorrowAssetSum)		{
			return true;
		}
		return false;
	}

function partialLiquidation ( address payable delinquentBorrower , uint256 targetHandler , uint256 liquidateAmount , uint256 receiveHandler ) circuitBreaker external override returns (uint256) {
		address payable liquidator = msg.sender;
		LiquidationModel memory vars;
		if (_checkLiquidation(delinquentBorrower) == false)		{
			revert(NO_DELINQUENT);
		}
		(vars.liquidateAmount, vars.delinquentDepositAsset, vars.delinquentBorrowAsset) = marketManager.partialLiquidationUser(delinquentBorrower, liquidateAmount, liquidator, targetHandler, receiveHandler);
		vars.liquidatePrice = marketManager.getTokenHandlerPrice(targetHandler);
		vars.liquidateAsset = unifiedMul(vars.liquidateAmount, vars.liquidatePrice);
		vars.rewardAsset = unifiedDiv(unifiedMul(vars.liquidateAsset, vars.delinquentDepositAsset), vars.delinquentBorrowAsset);
		vars.receivePrice = marketManager.getTokenHandlerPrice(receiveHandler);
		vars.rewardAmount = unifiedDiv(vars.rewardAsset, vars.receivePrice);
		marketManager.partialLiquidationUserReward(delinquentBorrower, vars.rewardAmount, liquidator, receiveHandler);
    emit Liquidate(liquidator, delinquentBorrower, targetHandler, vars.liquidateAmount, receiveHandler, vars.rewardAmount);
    return vars.rewardAmount;
	}

function partialLiquidation ( address payable delinquentBorrower , uint256 targetHandler , uint256 liquidateAmount , uint256 receiveHandler ) circuitBreaker external override returns (uint256) {
		address payable liquidator = msg.sender;
		LiquidationModel memory vars;
		if (_checkLiquidation(delinquentBorrower) == false)		{
			revert(NO_DELINQUENT);
		}
		(vars.liquidateAmount, vars.delinquentDepositAsset, vars.delinquentBorrowAsset) = marketManager.partialLiquidationUser(delinquentBorrower, liquidateAmount, liquidator, targetHandler, receiveHandler);
		vars.liquidatePrice = marketManager.getTokenHandlerPrice(targetHandler);
		vars.liquidateAsset = unifiedMul(vars.liquidateAmount, vars.liquidatePrice);
		vars.rewardAsset = unifiedDiv(unifiedMul(vars.liquidateAsset, vars.delinquentDepositAsset), vars.delinquentBorrowAsset);
		vars.receivePrice = marketManager.getTokenHandlerPrice(receiveHandler);
		vars.rewardAmount = unifiedDiv(vars.rewardAsset, vars.receivePrice);
		marketManager.partialLiquidationUserReward(delinquentBorrower, vars.rewardAmount, liquidator, receiveHandler);
    emit Liquidate(liquidator, delinquentBorrower, targetHandler, vars.liquidateAmount, receiveHandler, vars.rewardAmount);
    return vars.rewardAmount;
	}

function partialLiquidation ( address payable delinquentBorrower , uint256 targetHandler , uint256 liquidateAmount , uint256 receiveHandler ) circuitBreaker external override returns (uint256) {
		address payable liquidator = msg.sender;
		LiquidationModel memory vars;
		if (_checkLiquidation(delinquentBorrower) == false)		{
			revert(NO_DELINQUENT);
		}
		(vars.liquidateAmount, vars.delinquentDepositAsset, vars.delinquentBorrowAsset) = marketManager.partialLiquidationUser(delinquentBorrower, liquidateAmount, liquidator, targetHandler, receiveHandler);
		vars.liquidatePrice = marketManager.getTokenHandlerPrice(targetHandler);
		vars.liquidateAsset = unifiedMul(vars.liquidateAmount, vars.liquidatePrice);
		vars.rewardAsset = unifiedDiv(unifiedMul(vars.liquidateAsset, vars.delinquentDepositAsset), vars.delinquentBorrowAsset);
		vars.receivePrice = marketManager.getTokenHandlerPrice(receiveHandler);
		vars.rewardAmount = unifiedDiv(vars.rewardAsset, vars.receivePrice);
		marketManager.partialLiquidationUserReward(delinquentBorrower, vars.rewardAmount, liquidator, receiveHandler);
    emit Liquidate(liquidator, delinquentBorrower, targetHandler, vars.liquidateAmount, receiveHandler, vars.rewardAmount);
    return vars.rewardAmount;
	}


function _getSIRandBIR ( uint256 depositTotalAmount , uint256 borrowTotalAmount ) internal view returns (uint256, uint256) {
		uint256 utilRate = _getUtilizationRate(depositTotalAmount, borrowTotalAmount);
		uint256 BIR;
		uint256 _jmpPoint = jumpPoint;
		if(utilRate < _jmpPoint) {
			BIR = utilRate.unifiedMul(basicSensitivity).add(minRate);
		}
 else {
      BIR = minRate      .add( _jmpPoint.unifiedMul(basicSensitivity) )      .add( utilRate.sub(_jmpPoint).unifiedMul(jumpSensitivity) );
		}
		uint256 SIR = utilRate.unifiedMul(BIR).unifiedMul(spreadRate);
		return (SIR, BIR);
	}

function _getSIRandBIRonBlock ( uint256 depositTotalAmount , uint256 borrowTotalAmount ) internal view returns (uint256, uint256) {
		uint256 SIR;
		uint256 BIR;
		(SIR, BIR) = _getSIRandBIR(depositTotalAmount, borrowTotalAmount);
		return ( SIR.div(blocksPerYear), BIR.div(blocksPerYear) );
	}

``` 

Input Specific Target Code Expression:
```Target Expression
if(utilization < optimal_utilization_level){
            annual_borrowing_interest_rate = minimum_interest_rate + utilization *  sensitivity_one;
        }
else if (utilization > optimal_utilization_level) {
            annual_borrowing_interest_rate = minimum_interest_rate + optimal_utilization_level *  sensitivity_one + (utilization - optimal_utilization_level) * sensitivity_two;
        }
        
```
Please identify the similar expressions, their corresponding function name and their corresponding line numbers in the code file. You also need to replace the function calls "add", "sub", "div", "mul", "divCeil" in the found similar expressions with "+", "-", "/" and "*".  Put your results in JSON format at the beginning.
"""

In [None]:
ans = chat_with_model(question)

In [None]:
ans = get_answer(question)

In [None]:
print(ans)

In [None]:
print(ans.split('<|assistant|>')[-1])

In [None]:
print(ans.split("# [/INST]\n# ")[-1])

In [None]:
from transformers import pipeline


def get_answer(question):
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
   # 
  prompt_template = "<|system|>\n<|end|>\n<|user|>\n{query}<|end|>\n<|assistant|>"
  prompt = prompt_template.format(query=question)

  outputs = pipe(prompt, max_new_tokens=256, no_repeat_ngram_size=2)

  # 
  return outputs[0]["generated_text"]

In [None]:
def get_answer(question):
  inputs = tokenizer.encode(question, return_tensors="pt").to(model.device)
  outputs = model.generate(inputs, max_length=4608)
  response = tokenizer.decode(outputs[0])

  return response

In [None]:

csv_filename = 'expression_status_codellama.csv'

In [None]:
import csv
import os

print(os.getcwd())

os.chdir("all_datasets/expression_match")

In [None]:
def create_csv_with_files(base_dir, csv_filename):
    with open(csv_filename, 'w', newline='', encoding='utf-8') as csvfile:
        filewriter = csv.writer(csvfile)
        filewriter.writerow(['filename', 'status'])  # 

        for file in sorted(os.listdir(base_dir)):
            filewriter.writerow([file, 0])  # 0

if os.path.exists(csv_filename) is False:
    create_csv_with_files(base_path, csv_filename)


In [None]:
import csv

def read_status_from_csv():
    status_dict = {}
    with open(csv_filename, 'r', newline='', encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        next(reader)  # 
        for row in reader:
            status_dict[row[0]] = int(row[1])  # ，
    return status_dict


In [None]:
def update_csv_status(filename, new_status):
    rows = []
    with open(csv_filename, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader(file)
        for row in reader:
            if row[0] == filename:
                row[1] = new_status
            rows.append(row)

    with open(csv_filename, 'w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerows(rows)


In [None]:
import glob

In [None]:
torch.cuda.empty_cache()

In [None]:
file_status = read_status_from_csv()
files = sorted(os.listdir(base_path))
for f in files:
    question_dir = os.path.join(base_path, f)
    print(question_dir)

# 'question'、'.md'
    files = glob.glob(f'{question_dir}/question*.md')
    question_file = files[0]
    answer_file = os.path.join(question_dir, "codellama.md")

    # 
    if file_status.get(f, 0) == 1:
        print(f": {question_file}")
        continue

    with open(question_file, "r") as file:
        question = file.read()
        
    # 
    try:
        answer = chat_with_model(question)  # 
        res = answer.split('# [/INST]\n# ')[-1]
        

        with open(answer_file, "w") as file:
            file.write(res)
    
        # 1
        update_csv_status(f, 1)
    except Exception as e:
        print(e)

In [None]:
def extract_answer(text):
    start_marker = "```json"
    end_marker = "```<|end|>"
    start_index = text.find(start_marker)
    if start_index != -1:
        start_index += len(start_marker) + 1
        end_index = text.find(end_marker, start_index)
        if end_index != -1:
            return text[start_index:end_index].strip()
    return None