In [53]:
import os , re , ast
from together import Together

## AST Utils

In [54]:
def get_functions_and_classes(filepath):
    """
    It takes a file path and returns the original source code of functions and code-snippets in the file
    Code snippets are the source code of functions and classes in the file
    """
    with open(filepath, "r") as file:
        file_content = file.read()
    
    # Parse the file content into an Abstract Syntax Tree (AST)
    tree = ast.parse(file_content)
    
    # List to store the source code of functions and classes
    code_snippets = {
        "classes": [],
        "functions": []
    }

    # Walk through the AST and find all functions and classes
    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef):
            start_line = node.lineno - 1
            end_line = node.end_lineno
            code_snippets["classes"].append("".join(file_content.splitlines(keepends=True)[start_line:end_line]))
        if isinstance(node, ast.FunctionDef):
            start_line = node.lineno - 1
            end_line = node.end_lineno
            code_snippets["functions"].append("".join(file_content.splitlines(keepends=True)[start_line:end_line]))
    return file_content , code_snippets



#  A function to read body inside if __name__ == '__main__': block

def read_main_block(filepath):
    with open(filepath, "r") as file:
        file_content = file.read()
    tree = ast.parse(file_content)
    main_block = ""
    for node in ast.walk(tree):
        if isinstance(node, ast.If):
            if isinstance(node.test, ast.Compare):
                if isinstance(node.test.left, ast.Name) and isinstance(node.test.comparators[0], ast.Str):
                    if node.test.left.id == "__name__" and node.test.comparators[0].s == "__main__":
                        start_line = node.lineno - 1
                        end_line = node.end_lineno
                        main_block = "".join(file_content.splitlines(keepends=True)[start_line:end_line])
    return main_block


## CodeLlama Utils

In [1]:
def code_llama_prompt_formatter(query: str, system_prompt: str=None):
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

    if system_prompt is None:
        SYSTEM_PROMPT = """You are helpful coding assistant. User is asking you to write a function or class for a specific task. 
        Write the function or class in the programming language specified in the query."""
    else:
        SYSTEM_PROMPT = system_prompt

    USER_INSTRUCTION = f"User: {query}"
    
    SYSTEM_PROMPT = B_SYS + SYSTEM_PROMPT + E_SYS
    PROMPT = B_INST + SYSTEM_PROMPT + USER_INSTRUCTION + E_INST
    return PROMPT

In [78]:
# Create Code LLama Instruct Engine
class CodeLLamaInstructEngine:
    def __init__(self, systemPromptText: str=None, paramcount: int=7):
        self.AVL_PARAMS = [7,13,34]
        if paramcount not in self.AVL_PARAMS:
            raise ValueError(f"Invalid paramcount. Choose from {self.AVL_PARAMS}")
        self.paramcount = paramcount
        self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
        self.model = f"codellama/CodeLlama-{self.paramcount}b-Instruct-hf"
        if systemPromptText is None:
            self.systemPromptText = """
            You are an AI assistant. You are helping a user with a task. The user is asking you to write a function or class for a specific task.
            """
        else:
            self.systemPromptText = systemPromptText
        
    def run(self, query: str, extract_and_clean_code=None):
        PROMPT = code_llama_prompt_formatter(
            system_prompt=self.systemPromptText,
            query=query
        )
        response = self.client.completions.create(model=self.model, prompt=PROMPT)
        response_text = response.choices[0].text
        
        if extract_and_clean_code is not None:
            response_text = extract_and_clean_code(response_text)
        
        return response_text

## Testing

In [88]:
_main_block = read_main_block("/Users/debasmitroy/Desktop/medtech/med-tech-codes-git/backend0/flask_server/log_emb.py")
_content,_ = get_functions_and_classes("/Users/debasmitroy/Desktop/medtech/med-tech-codes-git/backend0/deprecated/old_target_projs/prog1.py")

In [84]:
class MainFunctionCurator:
    def __init__(self):
        self.code_llama_engine = CodeLLamaInstructEngine(
            systemPromptText="""
            You are a helpful coding assistant. 
            You are given a not so well written code snippet. The user inputs are not well managed.
            You have to rewrite handling the user inputs using python's argparse library inside the __main__ block.
            Carefully read the code snippet and understand each possible user input and their types.
            JUST WRITE THE CODE INSIDE THE __main__ BLOCK. DO NOT WRITE ANY THING ELSE. ASSUME THAT ALL THE FUNCTIONS USED IN THE CODE ARE DEFINED SOMEWHERE ELSE.
            MAKE SURE YOU COMPLETE THE CODE INSIDE THE __main__ BLOCK.
            """,paramcount=13)
        
    
    def generate_query_prompt(self, query):
        return f"""
        The original code snippet is given below:
        ```
        {query}
        ```

        You have to rewrite the code snippet handling the user inputs using python's argparse library inside the __main__ block. Follow the instructions given in the system prompt.
        JUST WRITE THE CODE INSIDE THE __main__ BLOCK. DO NOT WRITE ANY THING ELSE. ASSUME THAT ALL THE FUNCTIONS USED IN THE CODE ARE DEFINED SOMEWHERE ELSE.
        MAKE SURE YOU COMPLETE THE CODE INSIDE THE __main__ BLOCK.
        """
    
    @staticmethod
    def extract_and_clean_code(s):
        # Extract text between ``` and ```
        pattern = r'```(.*?)```'
        matches = re.findall(pattern, s, re.DOTALL)
        s = matches[0]
        
        s = '\n'.join([line.rstrip() for line in s.splitlines()])
        s = '\n'.join([line for line in s.splitlines() if '__main__' not in line])
        
        return s
        
    def curate_main_block(self, filepath):
        content,_ = get_functions_and_classes(filepath)
        main_block = read_main_block(filepath)
        main_block_body = '\n'.join(main_block.split('\n')[1:])
        response_text = self.code_llama_engine.run(query=self.generate_query_prompt(main_block), extract_and_clean_code=MainFunctionCurator.extract_and_clean_code)
        # Replace the original main block with the curated main block
        content = content.replace(main_block_body, response_text)
        return content

In [85]:
main_function_curator = MainFunctionCurator()

In [86]:
_content = main_function_curator.curate_main_block("/Users/debasmitroy/Desktop/medtech/med-tech-codes-git/backend0/deprecated/old_target_projs/prog3.py")

In [87]:
print(_content)

import time
import os
import argparse

# CPU Heavy Task
def cpu_bound(p):
    print("Starting CPU bound task")
    x = 0
    for i in range(10**p):
        x += i
        if i % 10**p == 0:
            print(f"CPU bound task: {i}")
            time.sleep(1)
    print("Exiting CPU bound task")
    return x

# Memory Heavy Task
def memory_bound(p):
    print("Starting memory bound task")
    x = []
    for i in range(10**p):
        x.append(i)
        if i % 10**(p-1) == 0:
            print(f"Memory bound task: {i}")
            time.sleep(1)
    print("Exiting memory bound task")
    return x

# I/O Heavy Task
def io_bound(p):
    # Generate a large file
    print("Starting I/O bound task")
    with open("large_file.txt", "w") as f:
        for i in range(10**p):
            f.write("Hello world!\n")
            if i % 10**(p-1) == 0:
                print(f"I/O bound task: {i}")
                time.sleep(1)

    # Now read the file
    with open("large_file.txt", "r") as f:
        