### SEED GATHERING GET CONTENT

In [2]:
from tree_sitter_parser import LANGUAGE, make_parser, node_to_string
import datasets
import os
import signal
from multiprocessing import Pool
import boto3
import smart_open
from datasets import load_dataset,Dataset
from botocore import UNSIGNED
from botocore.config import Config

s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
def download_contents(blob_id, src_encoding):
    s3_url = f"s3://softwareheritage/content/{blob_id}"
    with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
        content = fin.read().decode(src_encoding)
    return content

In [3]:
# TOPLEVEL_DOCSTRING_QUERY = LANGUAGE.query("""
# (
#     (function_definition
#       name: (identifier)
#       body: (block .
#         (expression_statement
#             (string
#                 (string_start) @docstring.start
#                 (string_content)
#                 (string_end) @docstring.end)))) @function.def
#     (#eq? @docstring.start "\\\"\\\"\\\"")
#     (#eq? @docstring.end "\\\"\\\"\\\"")
# )
# """)

TOPLEVEL_DOCSTRING_QUERY_RUBY = LANGUAGE.query("""
(
  [
    (method
      name: (_) @name) @definition.method
  ]
  (#strip! name "^#\\s*")
  (#select-adjacent! name @definition.method)
)
""")


def get_fns_with_docstrings(src, tree):
    captures = TOPLEVEL_DOCSTRING_QUERY_RUBY.captures(tree.root_node)
    res = []
    for capture in captures:
        node, ty = capture
        if ty != "definition.method":
            continue
        # if the starting col is not 0, then it's not a top-level fn
        _, col = node.start_point
        if col != 0:
            continue
        res.append(node_to_string(src, node))
    return res


def parse_ex(parser, ex):
    #ex = ex["content"]
    ex = download_contents(ex["blob_id"], ex["src_encoding"])
    try:
        buf = bytes(ex, "utf8")
        tree = parser.parse(buf)
        return get_fns_with_docstrings(buf, tree)
    except:
        return []


# if one parser segfaults, we can just make a new one and other parsers will still be fine
# WE LOVE TREE SITTER!
PARSERS = None


def process_chunk(idx_and_chunk):
    assert PARSERS is not None
    idx, chunk = idx_and_chunk
    parser = PARSERS[idx]
    chunk_new_funs = set()
    
    for ex in chunk:
        chunk_new_funs.update(parse_ex(parser, ex))
        break
    return chunk_new_funs


def main(args):
    global PARSERS
    ds = datasets.load_dataset(
        args.dataset,
        data_dir=args.data_dir,
        split="train",
    )
    funs = set()
    PARSERS = [make_parser() for _ in range(args.num_workers)]
    total_len = len(ds)
    CHUNK_SIZE = 1000 * args.num_workers

    print(f"Total length: {total_len}")
    print(f"Chunk size: {CHUNK_SIZE}")

    chunk = []
    p = Pool(args.num_workers)
    for i, ex in enumerate(ds):
        if i % (total_len // 100) == 0:
            print(f"{i}/{total_len}")
        try:
            chunk.append(ex)
            if len(chunk) == CHUNK_SIZE or i == total_len - 1:
                print(f"Processing chunk {i // CHUNK_SIZE}")
                # divide the chunk into NUM_WORKERS chunks
                subchunk_size = len(chunk) // args.num_workers
                subchunks = [chunk[i:i + subchunk_size]
                             for i in range(0, len(chunk), subchunk_size)]
                new_funs_iter = p.imap(
                    process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)])
                print(new_funs_iter)
                print("Getting new functions")
                len_before = len(funs)
                while True:
                    try:
                        def timeout_handler(_, __):
                            raise KeyboardInterrupt  # it's fineeeeeee
                        signal.signal(signal.SIGALRM, timeout_handler)
                        signal.alarm(60)
                        funs.update(next(new_funs_iter))
                        signal.alarm(0)
                    except KeyboardInterrupt:
                        signal.alarm(0)
                        print("Keyboard interrupt. Terminating pool")
                        p.terminate()
                        p = Pool(args.num_workers)
                        break
                    except StopIteration:
                        break
                    except Exception as e:
                        print(e)

                signal.alarm(0)

                PARSERS = [make_parser() for _ in range(args.num_workers)]

                print(
                    f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions")

                chunk = []
        except Exception as e:
            print(e)
            chunk = []

        if i == total_len - 1:
            break

    p.close()

    new_ds_dict = {
        "content": list(funs),
        "id": list(range(len(funs)))
    }

    new_ds = datasets.Dataset.from_dict(new_ds_dict)
    #new_ds.push_to_hub(args.push, private=True)




In [5]:
NUMWORKERS = os.cpu_count()
print(NUMWORKERS)

128


In [7]:
ds = load_dataset("bigcode/the-stack-v2-dedup", "Ruby", cache_dir = f"../thai/stack", streaming=True, split="train")

Resolving data files:   0%|          | 0/757 [00:00<?, ?it/s]

In [8]:
funs = set()
PARSERS = [make_parser() for _ in range(NUMWORKERS)]
CHUNK_SIZE = 1000 * NUMWORKERS

print(f"Chunk size: {CHUNK_SIZE}")

chunk = []
p = Pool(NUMWORKERS)

Chunk size: 128000


In [9]:
for i, ex in enumerate(iter(ds)):
    try:
        chunk.append(ex)
        if len(chunk) == 1000:
            print(f"Processing chunk {i // CHUNK_SIZE}")
            # divide the chunk into NUM_WORKERS chunks
            subchunk_size = len(chunk) // NUMWORKERS
            subchunks = [chunk[i:i + subchunk_size]
                         for i in range(0, len(chunk), subchunk_size)]
            new_funs_iter = p.imap(process_chunk, [(j, subchunk) for j, subchunk in enumerate(subchunks)])
            print("Getting new functions")
            len_before = len(funs)
            while True:
                try:
                    def timeout_handler(_, __):
                        raise KeyboardInterrupt  # it's fineeeeeee
                    signal.signal(signal.SIGALRM, timeout_handler)
                    signal.alarm(60)
                    funs.update(next(new_funs_iter))
                    signal.alarm(0)
                except KeyboardInterrupt:
                    signal.alarm(0)
                    print("Keyboard interrupt. Terminating pool")
                    p.terminate()
                    p = Pool(NUMWORKERS)
                    break
                except StopIteration:
                    break
                except Exception as e:
                    print(e)

            signal.alarm(0)

            PARSERS = [make_parser() for _ in range(NUMWORKERS)]

            print(
                f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions")

            chunk = []
    except Exception as e:
        print(e)
        chunk = []

    if i >= 1000:
        break


p.close()
new_ds_dict = {
    "content": list(funs),
    "id": list(range(len(funs)))
}

new_ds = datasets.Dataset.from_dict(new_ds_dict)

Processing chunk 0
Getting new functions
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
Done processing chunk 0. Got 11 new functions


In [10]:
new_ds

Dataset({
    features: ['content', 'id'],
    num_rows: 11
})

In [11]:
for ex in new_ds:
    print(ex['content'])
    break

def update()
  sql = "UPDATE films SET (title, price) = ($1, $2) WHERE id = $3"
  values = [@title, @price]
  SqlRunner.run(sql, values)
end


In [12]:
ds = new_ds

### SEED GATHERING HIGH-QUALITY SUBSET

In [13]:
import subprocess
import tempfile
import signal
import hashlib
import os
import argparse
from typing import List, Dict
from tqdm import tqdm
from tree_sitter_parser import LANGUAGE, global_parser

# RETURN_QUERY = LANGUAGE.query("""
# (return) @return
# """)

# def does_have_return(src):
#     tree = global_parser.parse(bytes(src, "utf8"))
#     root = tree.root_node
#     captures = RETURN_QUERY.captures(root)
#     for node, _ in captures:
#         # if it doesn't have an argument, it's not a return with a value
#         if len(node.children) <= 1:  # includes "return" itself
#             continue
#         else:
#             return True
#     return False

# runs sorbet in the given directory for typecheck, returns stdout
# then, it logs the number of errors for each file
def run_typeprof(d):
    try:
        outs = subprocess.run(
            ['bundle', 'exec', 'rake', './typeprof/test'],  # Command to run
            #cwd=d,               # Current working directory
            capture_output=True, # Capture stdout and stderr
            timeout=120,         # Timeout after 120 seconds
            text=True            # Capture output as text
        ).stdout
    except Exception as e:
        print(e)
        return None
    filemap = {}
    lines = outs.split("\n")
    for line in lines:
        if line.strip():
            parts = line.split(":")
            if len(parts) >= 2:
                file = parts[0].split("/")[-1]
                if file not in filemap:
                    filemap[file] = 0
                if "error:" in line:
                    filemap[file] += 1

    return filemap

def typecheck_batch(files: List[str]) -> Dict[str, str]:
    # Create a temporary directory using the tempfile module
    filemap: Dict[str, str] = {}
    with tempfile.TemporaryDirectory() as tempdir:
        for contents in files:
            hash_object = hashlib.sha1(bytes(contents, "utf8"))
            hex_dig = hash_object.hexdigest()
            filemap[hex_dig] = contents
            name = os.path.join(tempdir, hex_dig + ".rb")
            with open(name, "w") as f:
                f.write(contents)

        # Run typeprof in the temporary directory
        typecheck_map = run_typeprof(tempdir)
        print(typecheck_map)

        if typecheck_map is None:
            return {}

        for contents, errors in typecheck_map.items():
            no_py = contents.replace(".rb", "")
            if errors == 0:
                continue
            if no_py in filemap:
                del filemap[no_py]

        print(f"Pass rate: {len(filemap)}/{len(files)}")
        return filemap

def infer_imports(code: str) -> str:
    import autoimport
    try:
        def handler(signum, frame):
            raise Exception("Timeout")
        signal.signal(signal.SIGALRM, handler)
        signal.alarm(10)
        inferred = autoimport.fix_code(code)
        signal.alarm(0)
        return inferred
    except Exception as e:
        signal.alarm(0)
        print(f"Error while inferring imports: {e}")
        return code

In [11]:
# print("Filtering to only functions with return statements")
# ds = ds.filter(lambda ex: does_have_return(
#     ex["content"]), num_proc=os.cpu_count())


In [14]:
ds

Dataset({
    features: ['content', 'id'],
    num_rows: 11
})

In [15]:
# if args.infer_imports:
#     print("Inferring imports for functions")
#     ds = ds.map(lambda ex: {"content": infer_imports(
#         ex["content"])}, num_proc=os.cpu_count())

batch = []
max_i = len(ds) - 1

new_ds = {
    "content": [],
    "sha1": [],
    "id": [],
}

e_id = 0

for i, ex in enumerate(tqdm(ds, total=len(ds))):
    try:
        code = ex["content"]

        batch.append(code)

        if len(batch) == 250 or i == max_i:
            filemap = typecheck_batch(batch)
            for sha1, contents in filemap.items():
                new_ds["content"].append(contents)
                new_ds["sha1"].append(sha1)
                new_ds["id"].append(e_id)
                e_id += 1
            batch = []
            
    except Exception as e:
        print(f"There was an error: {e}")
        continue

new_ds_hf = datasets.Dataset.from_dict(new_ds)

100%|██████████| 11/11 [00:00<00:00, 3512.28it/s]

[Errno 2] No such file or directory: 'bundle'
None





In [16]:
print(new_ds_hf['content'][0])

IndexError: list index out of range

In [18]:
save_dir = "./datasets/seed2"

In [19]:
ds.save_to_disk(save_dir)

Saving the dataset (0/1 shards):   0%|          | 0/11 [00:00<?, ? examples/s]

In [1]:
from datasets import Dataset

ds = Dataset.from_file("./datasets/seed2/data-00000-of-00001.arrow")

### SEED GATHERING FILTER DATASET

In [20]:
import torch
print(torch.__version__)

2.4.0+cu121


In [21]:
import datasets
import os
from tree_sitter_parser import global_parser, LANGUAGE, does_have_return, make_parser
import benchmark_data
from tqdm import tqdm
import torch
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
import random

ModuleNotFoundError: No module named 'benchmark_data'

In [4]:
FN_BLOCK_QUERY = LANGUAGE.query("""
(
[
    (begin_block) @definition.begin_block
    (end_block) @definition.end_block
]
)
""")

def template_few_shot_ruby(code, answer, rationale):
    doc, code = ruby_extract_docstring(code)
    assert answer == "No" or answer == "Yes"
    prompt = f"""<issue_start>username_0: I have a function in Ruby and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```rb
{code}
```

Here is my description of this program:
```
{doc}
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is: {answer}

{rationale}

Upvotes: 200"""
    return prompt


FEW_SHOTS_RUBY = [
    (
        '''def greet_user(name)
  # Greet the user with a friendly message
  "Hello, #{name}!"
end''',
        "Yes",
        "The docstring accurately describes the function. The `greet_user` method simply returns a greeting string using the provided `name`."
    ),
    (
        '''def calculate_area(radius)
  # Calculate the circumference of a circle given its radius
  Math::PI * radius**2
end''',
        "No",
        "The description states that it calculates the circumference, but the method actually calculates the area of the circle based on the radius."
    ),
    (
        '''def reverse_string(str)
  # Reverse the characters in a string
  str.reverse
end''',
        "Yes",
        "The docstring accurately describes the function's behavior. It reverses the characters in the provided string using Ruby's `reverse` method."
    ),
    (
        '''def calculate_total(price, tax_rate)
  # Adds tax to the price to get the total amount
  price + (price * tax_rate)
end''',
        "Yes",
        "The docstring is clear and accurately describes the method's purpose of calculating the total price by adding tax."
    ),
    (
        '''def print_numbers(n)
  # Print numbers from 1 to n, inclusive
  (1..n).each { |num| puts num }
end''',
        "Yes",
        "The docstring provides an accurate description of the method's function, which is to print numbers from 1 to `n`."
    ),
    (
        '''def process_data(data)
  # Process the data by removing duplicates and sorting it
  data.uniq.sort
end''',
        "Yes",
        "The docstring accurately describes what the `process_data` method does. It removes duplicates and sorts the `data` array."
    ),
    (
        '''def find_max(arr)
  # Find the maximum number in an array of integers
  arr.max
end''',
        "Yes",
        "The docstring provides an accurate description of the method, which finds the maximum number in the provided array using Ruby's `max` method."
    ),
    (
        '''def send_email(address, subject, body)
  # Set up and send an email
  # Warning: This function does not implement sending
end''',
        "No",
        "The description implies that the method sends an email, but there is no implementation for actually sending an email in this code. It only sets up the parameters."
    ),
]


def prompt_fmt_ruby(code): 
    doc, code = ruby_extract_docstring(code) 
    random.shuffle(FEW_SHOTS_RUBY) 
    buf = "" 
    for few in FEW_SHOTS_RUBY: 
        buf += template_few_shot_ruby(*few) 
    buf += f"""<issue_start>username_0: I have a function in Ruby and I'd like someone to check my description of this function. 
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```rb
{code}
```

Here is my description of this program:
```
{doc}
```
Do not attempt to execute the function or to judge its correctness. 
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function. 
Also, answer with "No" if the description does not match the function. 
Upvotes: 100<issue_comment>username_1: Sure, no problem. I will be able to help. My answer is:""" 
    return buf

def auto_dtype():
    if torch.cuda.is_bf16_supported():
        return "bfloat16"
    return "auto"


def chunkify(lst, n):
    chunks = []
    for i in range(0, len(lst), n):
        chunk = []
        for j in range(n):
            if i + j < len(lst):
                chunk.append(lst[i + j])
        chunks.append(chunk)
    return chunks


In [5]:
dataset = ds

In [6]:
print(f"Loaded {len(dataset)} examples. Running pre-filtering...")

BAD_WORDS = ["todo", "fixme", "bug"]
BAD_IMPORTS = ["argparse", "os", "subprocess", "sys", "setuptools",
               "distutils", "matplotlib", "seaborn"]
BAD_IMPORTS = [f"import {b}" for b in BAD_IMPORTS] + \
    [f"from {b}" for b in BAD_IMPORTS]
BAD_SUBSTRINGS = BAD_WORDS + BAD_IMPORTS

bench_filter = benchmark_data.filter_out()
all_bench = bench_filter["human_eval_docstrings"] + \
    bench_filter["human_eval_solutions"] + \
    bench_filter["mbpp_docstrings"] + \
    bench_filter["mbpp_solutions"]

Loaded 1 examples. Running pre-filtering...
num strings from mbpp_docstrings: 120
num strings from mbpp_solutions: 120
num strings from human_eval_docstrings: 164
num strings from human_eval_solutions: 161


In [7]:
def pre_filtering(ex):
    code = ex["content"]
    code_bytes = code.encode('utf-8')

    # filter out bad substrings
    lower = code.lower()
    for word in BAD_SUBSTRINGS:
        if word in lower:
            return False

    for b in all_bench:
        if b in code:  # contaminated sample!
            return False

    # too many lines of code -- say 150
    lines = code.split("\n")
    if len(lines) > 150:
        return False

    # filter functions which don't have an argument
    # 1. find first def statement in lines
    # 2. check if contains ():
    for line in lines:
        if line.startswith("def ") and "():" in line:
            return False

    # filter out functions with no return statement
    # parser = make_parser()
    # if not does_have_return(code, parser=parser):
    #     return False

    # try:
    #     tree = global_parser.parse(code_bytes)
    #     for node in tree.root_node.children:
    #         print(node)
    #     # print(FN_BLOCK_QUERY.captures(tree.root_node))
    #     # block, _ = FN_BLOCK_QUERY.captures(tree.root_node)

    #     # get the docstring, filter if not a docstring
    #     # exp = block.children[0]
    #     # print(exp)
    #     # if not exp.type == 'expression_statement' and not exp.children[0].type == 'string':
    #     #     return False

    #     # docstring = exp.children[0]
    #     # docstring_text = docstring.text.decode('utf-8')
    #     # if not docstring_text.startswith('"""') and not docstring_text.endswith('"""'):
    #     #     return False
    # except Exception as e:
    #     print(f"Error in filtering: {e}")
    #     return False

    return True  # all good!


threads = os.cpu_count() - 1  # type: ignore
dataset = dataset.filter(pre_filtering, num_proc=threads)

num_proc must be <= 1. Reducing num_proc to 1 for dataset of size 1.


In [8]:
dataset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 1
})

In [9]:
model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [None]:
# tokenizer = model.get_tokenizer()

In [None]:
print(f"Now running stage 3 filtering on {len(dataset)} examples...")

In [10]:
def unindent(s):
    lines = s.splitlines()
    non_blank_lines = [line for line in lines if line.strip()]
    min_indent = min(len(line) - len(line.lstrip())
                     for line in non_blank_lines) if non_blank_lines else 0
    unindented_lines = [line[min_indent:] if len(
        line) >= min_indent else line for line in lines]
    return '\n'.join(unindented_lines)


def ruby_extract_docstring(code): 
    # In Ruby, comments usually start with # for each line of the description 
    lines = code.splitlines() 
    doc_lines = [] 
    code_lines = [] 
    in_doc = False 
    for line in lines: 
        if line.strip().startswith("#"): 
            in_doc = True 
            doc_lines.append(line.strip("# ").strip()) 
        else: 
            in_doc = False 
            code_lines.append(line) 
            doc = "\n".join(doc_lines) 
            code = "\n".join(code_lines) 
    return doc, code



In [13]:
dummy = 'def dummy(): \n    """\n    """\n pass'
dummy_prompt = prompt_fmt_ruby(dummy)
few_shot_toks = len(tokenizer.encode(
    dummy_prompt)) - len(tokenizer.encode(dummy))
print(f"Few-shot prompt has {few_shot_toks} tokens")
print(tokenizer.encode(
    dummy_prompt))

Few-shot prompt has 1806 tokens
[27, 11159, 4906, 29, 5113, 62, 15, 25, 358, 614, 264, 729, 304, 23726, 323, 358, 4172, 1075, 4325, 311, 1779, 847, 4008, 315, 419, 729, 624, 40, 2776, 3730, 419, 773, 429, 358, 646, 3270, 264, 1661, 4629, 917, 369, 419, 729, 382, 8420, 374, 279, 2038, 369, 279, 729, 510, 73594, 10681, 198, 750, 11047, 15030, 61022, 340, 220, 4149, 486, 1893, 353, 10578, 334, 17, 198, 408, 198, 13874, 19324, 8420, 374, 847, 4008, 315, 419, 2025, 510, 13874, 3989, 47866, 279, 74926, 315, 264, 12671, 2661, 1181, 10578, 198, 13874, 19324, 5404, 537, 4774, 311, 9026, 279, 729, 476, 311, 11651, 1181, 57323, 624, 16141, 448, 330, 9454, 1, 476, 330, 2753, 1, 11649, 389, 421, 847, 4008, 702, 3322, 1995, 7484, 311, 312, 36925, 2764, 279, 729, 624, 13394, 11, 4226, 448, 330, 2753, 1, 421, 279, 4008, 1558, 537, 2432, 279, 729, 15757, 11159, 17638, 29, 5113, 62, 16, 25, 22555, 11, 902, 3491, 13, 358, 686, 387, 2952, 311, 1492, 624, 5050, 4226, 374, 25, 2308, 271, 785, 4008, 5302, 42

In [None]:
prompts = []
for ex in tqdm(dataset, total=len(dataset), desc="Generating prompts"):
    code = ex["content"]
    oks = len(tokenizer.encode(code)) + few_shot_toks
    if toks > 16380:
        print(f"Skipping example with {toks} tokens")
        # to skip, just add dummy prompt
        prompts.append(dummy_prompt)
        continue
    p = prompt_fmt_ruby(code)
    prompts.append(p)

responses = []
for chunk in tqdm(chunkify(prompts, 512), desc="Generating responses"):
    outs = model.generate(chunk, SamplingParams(
        temperature=0.0, stop="\n", max_tokens=5))
    contents = [o.outputs[0].text for o in outs]
    for c in contents:
        yes_count = c.lower().count("yes")
        no_count = c.lower().count("no")
        if yes_count > no_count:
            responses.append(True)
        elif yes_count < no_count:
            responses.append(False)
        else:
            # default to No
            responses.append(False)



In [None]:
dataset

In [None]:
subset = dataset

In [None]:
for ex in prompts:
    print(ex)
    break

In [None]:
new_ds = subset.filter(  # horrible hack!
    #lambda ex, i: responses[i] and "def dummy()" not in ex["content"], with_indices=True)
    lambda ex, i: prompts[i] and "def dummy()" not in ex["content"], with_indices=True)
print(f"Filtered {len(dataset) - len(new_ds)} examples")

In [None]:
for ex in new_ds:
    print(ex['content'])

In [None]:
new_ds.save_to_disk("./datasets/seed3")

In [None]:
new_ds