In [None]:
import os
import subprocess
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import sys
import utils

In [None]:
torch.mps.empty_cache()

# Check if MPS (Metal Performance Shaders) is available and set the device accordingly
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
# device = torch.device("cpu")

# Load tokenizer and model from Hugging Face Hub and move model to the MPS device if available
model_reference = "LLM4Binary/llm4decompile-1.3b-v2"
# model_reference = "LLM4Binary/llm4decompile-6.7b-v1.5"
tokenizer = AutoTokenizer.from_pretrained(model_reference)
model = AutoModelForCausalLM.from_pretrained(model_reference)

# Convert model to FP16 to reduce memory usage
model = model.half().to(device)

In [None]:
def decompile_with_llm4decompile(disassembled_code):
    # Tokenize the input text and add attention mask
    inputs = tokenizer(disassembled_code, return_tensors="pt", padding=True).to(device)

    # Setting pad_token_id to eos_token_id, since this is the default behavior for many generative models
    model.config.pad_token_id = model.config.eos_token_id

    # Print the number of tokens in the input
    print(f"Number of tokens in input: {inputs.input_ids.size(1)}")
    
    # Generate decompiled code with attention mask
    outputs = model.generate(
        inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_new_tokens=200
    )
    decompiled_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Save the decompiled code for inspection
    decompiled_path = os.path.join(utils.WORKSPACE_DIR, "decompiled_code.c")
    with open(decompiled_path, "w") as f:
        f.write(decompiled_code)
    
    return decompiled_code

In [None]:
def process_file(input_path, function_name):
    # Copy the C code or binary to the workspace for reference
    try:
        disassembled_code = utils.disassemble(input_path, function_name)

        print(disassembled_code)
        
        prompt = f"# This is the assembly code:\n{disassembled_code}\n# What is the source code?\n"
        
        # Decompile with LLM4Decompile
        print("\nDecompiling...")
        decompiled_code = decompile_with_llm4decompile(prompt)
        
        # Remove input from the decompiled code
        decompiled_code = decompiled_code[len(prompt):].strip()
        
        print("\nDecompiled Code:")
        print(decompiled_code)
    
    except subprocess.CalledProcessError as e:
        print(f"Error during processing: {e}")

In [None]:
c_code_path = "c_code/fibonacci.c"
function_name = "func0"
process_file(c_code_path, function_name)