In [None]:
# !pip install flask huggingface_hub

In [None]:
# Install llama-cpp-python cpu only -- will likely be slow:
# !CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp-python # cpu only

# Install llama-cpp-python with cuda support:
# !CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python # cuda gpu

# Install with macOS (Metal) support
# !CMAKE_ARGS="-DLLAMA_METAL=on" pip install -U llama-cpp-python --no-cache-dir 
# !pip install 'llama-cpp-python[server]'

In [None]:
import os
import threading
from datetime import datetime
from flask import Flask, jsonify, request
import time
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import multiprocessing
import re

num_cpu = multiprocessing.cpu_count()
print(num_cpu)

In [None]:
def log(string):

    # datetime object containing current date and time
    now = datetime.now()

    print("now =", now)
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")

    f = open("gendocstring.log", "a")
    f.write(dt_string + '\n')
    f.close()

    with open("gendocstring.log","a") as f:
        f.writelines(string)

In [None]:
def download_weights():

    weights_dir = Path("weights")
    weights_dir.mkdir(exist_ok=True)

    repo_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
    filename = "mistral-7b-instruct-v0.1.Q6_K.gguf"

    if (weights_dir / Path(filename)).exists():
        print(
            f"{filename} exists. Delete manually if you wish to download again."
            )
        return
    else:
        return hf_hub_download(
            repo_id = repo_id, filename=filename, local_dir='./weights/',
            local_dir_use_symlinks = False
            )

download_weights()

In [None]:
model = "weights/mistral-7b-instruct-v0.1.Q6_K.gguf"  # instruction model
llm = Llama(
    model_path=model, n_ctx=8192, n_batch=128, n_threads=num_cpu,
    n_gpu_layers=-1, verbose=True, seed=42
    )

In [None]:
def get_docstring(code):
    instruction = "Produce a docstring for the following python function."
    instruction += "Return the full function definition with the docstring."

    message = f"<s>[INST] {instruction} [/INST]</s> \n {code}"
    output = llm(message, echo=True, stream=False, max_tokens=4096)
    text = output['choices'][0]['text']
    print(text)
    docstring_pattern = re.compile(r'\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\"', re.DOTALL)
    match = docstring_pattern.search(text)

    return match.groups()[1] if match else None

In [None]:
def get_docstring_from_template(code, template):

    instruction = "According to the template, produce a docstring for the "
    instruction += "following python function."
    instruction += "Return the full function definition with the docstring."

    message = f"Docstring template: \n {template} \n <s>[INST] {instruction} [/INST]</s> \n {code}"
    output = llm(message, echo=True, stream=False, max_tokens=4096)
    text = output['choices'][0]['text']
    print(text)
    docstring_pattern = re.compile(r'\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\"', re.DOTALL)
    match = docstring_pattern.search(text)

    return match.groups()[1] if match else None

    # return text

In [None]:
template = """
    [Summary of the function fibonacci_of]

    Args:
        n ([type]): [description]

    Returns:
        [type]: [description]
    """

code = """
def fibonacci_of(n):
    if n in cache:
        return cache[n]
    cache[n] = fibonacci_of(n - 1) + fibonacci_of(n - 2)
    return cache[n]
"""

# docstr = get_docstring(code)
docstr_template = get_docstring_from_template(code, template)

In [None]:
print(docstr_template)

In [None]:
app = Flask(__name__)
port = "5000"

# Define Flask routes
@app.route("/")
def index():
    return "Hello from gendocstring server."

@app.route("/summary", methods=["POST"])
def summary():
    if request.method == "POST":
        payload = request.get_json()
        t0 = time.time()

        # Generate docstring here

        code = payload["code"]
        snippet = payload["snippet"]
        template = snippet.replace('\"\"\"\n', '')

        log("code:")
        log(code)
        log("snippet:")
        log(snippet)

        docstring = get_docstring_from_template(code, template)
        if docstring[0:1] == '\n':
            docstring = docstring[1:]
        docstring = docstring.rstrip()+'\n'

        log(docstring)

        t1 = time.time()
        result = {
            'message' : [docstring],
            'time' : (t1 - t0),
            'device' : "computer",
            'length' : len(docstring)
        }

        return jsonify(**result)

# Start the Flask server in a new thread
threading.Thread(target=app.run, kwargs={"use_reloader": False}).start()