In [1]:
from tree_sitter_parser import LANGUAGE, make_parser, node_to_string
import datasets
import os
import signal
from multiprocessing import Pool
import boto3
import smart_open
from datasets import load_dataset,Dataset
from botocore import UNSIGNED
from botocore.config import Config

s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
def download_contents(blob_id, src_encoding):
    s3_url = f"s3://softwareheritage/content/{blob_id}"
    with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
        content = fin.read().decode(src_encoding)
    return content

In [2]:
# TOPLEVEL_DOCSTRING_QUERY = LANGUAGE.query("""
# (
#     (function_definition
#       name: (identifier)
#       body: (block .
#         (expression_statement
#             (string
#                 (string_start) @docstring.start
#                 (string_content)
#                 (string_end) @docstring.end)))) @function.def
#     (#eq? @docstring.start "\\\"\\\"\\\"")
#     (#eq? @docstring.end "\\\"\\\"\\\"")
# )
# """)

TOPLEVEL_DOCSTRING_QUERY_RUST = LANGUAGE.query("""
(function_item
    name: (identifier)
    body: (block)) @function.def
""")



def get_fns_with_docstrings(src, tree):
    captures = TOPLEVEL_DOCSTRING_QUERY_RUST.captures(tree.root_node)
    res = []
    for capture in captures:
        node, ty = capture
        if ty != "function.def":
            continue
        # if the starting col is not 0, then it's not a top-level fn
        _, col = node.start_point
        if col != 0:
            continue
        res.append(node_to_string(src, node))
    return res


def parse_ex(parser, ex):
    # ex = ex["content"]
    ex = download_contents(ex["blob_id"], ex["src_encoding"])
    try:
        buf = bytes(ex, "utf8")
        tree = parser.parse(buf)
        return get_fns_with_docstrings(buf, tree)
    except:
        return []


# if one parser segfaults, we can just make a new one and other parsers will still be fine
# WE LOVE TREE SITTER!
PARSERS = None

def process_chunk(idx_and_chunk):
    assert PARSERS is not None
    idx, chunk = idx_and_chunk
    parser = PARSERS[idx]
    chunk_new_funs = set()
    
    for ex in chunk:
        chunk_new_funs.update(parse_ex(parser, ex))
        break
    return chunk_new_funs

In [3]:
ds = load_dataset("bigcode/the-stack-v2-dedup", "Rust", cache_dir = f"./data/rust", streaming=True, split="train")

Resolving data files:   0%|          | 0/757 [00:00<?, ?it/s]

In [4]:
import torch

print(torch.cuda.device_count())

NUMWORKERS = os.cpu_count()
print(NUMWORKERS)

funs = set()
PARSERS = [make_parser() for _ in range(NUMWORKERS)]
CHUNK_SIZE = 1000 * NUMWORKERS
print(f"Chunk size: {CHUNK_SIZE}")

chunk = []
p = Pool(NUMWORKERS)

2
128
Chunk size: 128000


In [5]:
for i, ex in enumerate(iter(ds)):
    try:
        chunk.append(ex)
        if len(chunk) == 200:
            print(f"Processing chunk {i // CHUNK_SIZE}")
            # divide the chunk into NUM_WORKERS chunks
            subchunk_size = len(chunk) // NUMWORKERS
            subchunks = [chunk[i:i + subchunk_size]
                         for i in range(0, len(chunk), subchunk_size)]
            new_funs_iter = p.imap(process_chunk, [(j, subchunk) for j, subchunk in enumerate(subchunks)])
            print("Getting new functions")
            len_before = len(funs)
            while True:
                try:
                    def timeout_handler(_, __):
                        raise KeyboardInterrupt  # it's fineeeeeee
                    signal.signal(signal.SIGALRM, timeout_handler)
                    signal.alarm(60)
                    funs.update(next(new_funs_iter))
                    signal.alarm(0)
                except KeyboardInterrupt:
                    signal.alarm(0)
                    print("Keyboard interrupt. Terminating pool")
                    p.terminate()
                    p = Pool(NUMWORKERS)
                    break
                except StopIteration:
                    break
                except Exception as e:
                    print(e)

            signal.alarm(0)

            PARSERS = [make_parser() for _ in range(NUMWORKERS)]

            print(
                f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions")

            chunk = []
    except Exception as e:
        print(e)
        chunk = []

    if i >= 200:
        break


p.close()
new_ds_dict = {
    "seed": list(funs),
    "id": list(range(len(funs)))
}

new_ds = datasets.Dataset.from_dict(new_ds_dict)

Processing chunk 0
Getting new functions
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range
list index out of range

In [6]:
new_ds

Dataset({
    features: ['seed', 'id'],
    num_rows: 284
})

In [7]:
for ex in new_ds:
    print(ex['seed'])
    break

async fn main() {
    env::set_var("RUST_LOG", "warp_server");
    env_logger::init();

    let log = warp::log("warp_server");

    let homepage = warp::path::end().map(|| {
        Response::builder()
            .header("content-type", "text/html")
            .body(
                "<html><h1>juniper_warp</h1><div>visit <a href=\"/playground\">/playground</a></html>"
                    .to_string(),
            )
    });

    log::info!("Listening on 127.0.0.1:8080");

    let state = warp::any().map(move || Context {});
    let graphql_filter = juniper_warp::make_graphql_filter(schema(), state.boxed());

    warp::serve(
        warp::get()
            .and(warp::path("playground"))
            .and(juniper_warp::playground_filter("/graphql", None))
            .or(homepage)
            .or(warp::path("graphql").and(graphql_filter))
            .with(log),
    )
    .run(([127, 0, 0, 1], 8080))
    .await
}


In [8]:
ds = new_ds
save_dir = "./datasets_rust/seed2"
ds.save_to_disk(save_dir)

Saving the dataset (0/1 shards):   0%|          | 0/284 [00:00<?, ? examples/s]

In [9]:
save_dir = "./datasets_rust/seed2/output_step1.json"
ds.to_json(save_dir)

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

244517

In [1]:
from datasets import Dataset

ds = Dataset.from_file("./datasets_rust/seed2/data-00000-of-00001.arrow")

In [2]:
import subprocess
import tempfile
import signal
import hashlib
import os
import argparse
from typing import List, Dict
from tqdm import tqdm
from tree_sitter_parser import LANGUAGE, global_parser

RETURN_QUERY = LANGUAGE.query("""
(return_expression) @return
""")

def does_have_return(src):
    tree = global_parser.parse(bytes(src, "utf8"))
    root = tree.root_node
    captures = RETURN_QUERY.captures(root)
    for node, _ in captures:
        # if it doesn't have an argument, it's not a return with a value
        if len(node.children) <= 1:  # includes "return" itself
            continue
        else:
            return True
    return False

# runs sorbet in the given directory for typecheck, returns stdout
# then, it logs the number of errors for each file
# def run_typeprof(d):
#     try:
#         # outs = subprocess.run(
#         #     ['bundle', 'exec', 'rake', './typeprof/test'],  # Command to run
#         #     #cwd=d,               # Current working directory
#         #     capture_output=True, # Capture stdout and stderr
#         #     timeout=120,         # Timeout after 120 seconds
#         #     text=True            # Capture output as text
#         # ).stdout
#     except Exception as e:
#         print(e)
#         return None
#     filemap = {}
#     lines = outs.split("\n")
#     for line in lines:
#         if line.strip():
#             parts = line.split(":")
#             if len(parts) >= 2:
#                 file = parts[0].split("/")[-1]
#                 if file not in filemap:
#                     filemap[file] = 0
#                 if "error:" in line:
#                     filemap[file] += 1

    return filemap

def typecheck_batch(files: List[str]) -> Dict[str, str]:
    # Create a temporary directory using the tempfile module
    filemap: Dict[str, str] = {}
    with tempfile.TemporaryDirectory() as tempdir:
        for contents in files:
            hash_object = hashlib.sha1(bytes(contents, "utf8"))
            hex_dig = hash_object.hexdigest()
            filemap[hex_dig] = contents
            name = os.path.join(tempdir, hex_dig + ".rs")
            with open(name, "w") as f:
                f.write(contents)

        # Run typeprof in the temporary directory
        # typecheck_map = run_typeprof(tempdir)
        # print(typecheck_map)

        # if typecheck_map is None:
        #     return {}

        # for contents, errors in typecheck_map.items():
        #     no_py = contents.replace(".rs", "")
        #     if errors == 0:
        #         continue
        #     if no_py in filemap:
        #         del filemap[no_py]

        print(f"Pass rate: {len(filemap)}/{len(files)}")
        return filemap

def infer_imports(code: str) -> str:
    import autoimport
    try:
        def handler(signum, frame):
            raise Exception("Timeout")
        signal.signal(signal.SIGALRM, handler)
        signal.alarm(10)
        inferred = autoimport.fix_code(code)
        signal.alarm(0)
        return inferred
    except Exception as e:
        signal.alarm(0)
        print(f"Error while inferring imports: {e}")
        return code

In [None]:
# print("Filtering to only functions with return statements")
# ds = ds.filter(lambda ex: does_have_return(
#     ex["seed"]), num_proc=os.cpu_count())




Filtering to only functions with return statements


Filter (num_proc=128):   0%|          | 0/237 [00:00<?, ? examples/s]

In [None]:
# if args.infer_imports:
#     print("Inferring imports for functions")
#     ds = ds.map(lambda ex: {"content": infer_imports(
#         ex["content"])}, num_proc=os.cpu_count())

batch = []
max_i = len(ds) - 1

new_ds = {
    "seed": [],
    "sha1": [],
    "id": [],
}

e_id = 0

for i, ex in enumerate(tqdm(ds, total=len(ds))):
    try:
        code = ex["seed"]

        batch.append(code)

        if len(batch) == 250 or i == max_i:
            filemap = typecheck_batch(batch)
            for sha1, contents in filemap.items():
                new_ds["seed"].append(contents)
                new_ds["sha1"].append(sha1)
                new_ds["id"].append(e_id)
                e_id += 1
            batch = []
            
    except Exception as e:
        print(f"There was an error: {e}")
        continue

new_ds_hf = datasets.Dataset.from_dict(new_ds)

In [11]:
ds

Dataset({
    features: ['seed', 'id'],
    num_rows: 284
})

In [3]:
import datasets
import os
from tree_sitter_parser import global_parser, LANGUAGE, does_have_return, make_parser
import benchmark_data
from tqdm import tqdm
import torch
import argparse
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
from vllm import LLM, SamplingParams
# from transformers import AutoModelForCausalLM, AutoTokenizer
import random

In [4]:
def unindent(s):
    lines = s.splitlines()
    non_blank_lines = [line for line in lines if line.strip()]
    min_indent = min(len(line) - len(line.lstrip()) for line in non_blank_lines) if non_blank_lines else 0
    unindented_lines = [line[min_indent:] if len(line) >= min_indent else line for line in lines]
    return '\n'.join(unindented_lines)


def rust_extract_docstring(code): 
    # In Rust, comments usually start with `//` for single-line comments or `///` for doc comments
    lines = code.splitlines() 
    doc_lines = [] 
    code_lines = [] 
    in_doc = False 
    for line in lines: 
        if line.strip().startswith("///"): 
            in_doc = True 
            doc_lines.append(line.strip("///").strip())  # Extract doc comment
        else: 
            in_doc = False 
            code_lines.append(line) 
    doc = "\n".join(doc_lines) 
    code = "\n".join(code_lines) 
    return doc, code

In [5]:
FN_BLOCK_QUERY = LANGUAGE.query("""
(
    (block) @definition.block
)
"""
)

def template_few_shot_rust(code, answer, rationale):
    doc, code = rust_extract_docstring(code)
    assert answer == "No" or answer == "Yes"
    prompt = f"""<issue_start>username_0: I have a function in Rust and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
rust
{code}


Here is my description of this program:
{doc}


Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is: {answer}

{rationale}

Upvotes: 200"""
    return prompt


FEW_SHOTS_RUST = [
    (
        '''fn greet_user(name: &str) -> String {
    // Greet the user with a friendly message
    format!("Hello, {}!", name)
}''',
        "Yes",
        "The docstring accurately describes the function. The greet_user method simply returns a greeting string using the provided name."
    ),
    (
        '''fn calculate_area(radius: f64) -> f64 {
    // Calculate the circumference of a circle given its radius
    std::f64::consts::PI * radius.powi(2)
}''',
        "No",
        "The description states that it calculates the circumference, but the method actually calculates the area of the circle based on the radius."
    ),
    (
        '''fn reverse_string(s: &str) -> String {
    // Reverse the characters in a string
    s.chars().rev().collect()
}''',
        "Yes",
        "The docstring accurately describes the function's behavior. It reverses the characters in the provided string using Rust's chars() and rev() methods."
    ),
    (
        '''fn calculate_total(price: f64, tax_rate: f64) -> f64 {
    // Adds tax to the price to get the total amount
    price + (price * tax_rate)
}''',
        "Yes",
        "The docstring is clear and accurately describes the method's purpose of calculating the total price by adding tax."
    ),
    (
        '''fn print_numbers(n: u32) {
    // Print numbers from 1 to n, inclusive
    for num in 1..=n {
        println!("{}", num);
    }
}''',
        "Yes",
        "The docstring provides an accurate description of the method's function, which is to print numbers from 1 to n."
    ),
    (
        '''fn process_data(mut data: Vec<i32>) -> Vec<i32> {
    // Process the data by removing duplicates and sorting it
    data.sort();
    data.dedup();
    data
}''',
        "Yes",
        "The docstring accurately describes what the process_data method does. It removes duplicates and sorts the data array."
    ),
    (
        '''fn find_max(arr: Vec<i32>) -> Option<i32> {
    // Find the maximum number in an array of integers
    arr.into_iter().max()
}''',
        "Yes",
        "The docstring provides an accurate description of the method, which finds the maximum number in the provided array using Rust's max method."
    ),
    (
        '''fn send_email(address: &str, subject: &str, body: &str) {
    // Set up and send an email
    // Warning: This function does not implement sending
}''',
        "No",
        "The description implies that the method sends an email, but there is no implementation for actually sending an email in this code. It only sets up the parameters."
    ),
]


def prompt_fmt_rust(code): 
    doc, code = rust_extract_docstring(code) 
    random.shuffle(FEW_SHOTS_RUST) 
    buf = "" 
    for few in FEW_SHOTS_RUST: 
        buf += template_few_shot_rust(*few) 
    buf += f"""<issue_start>username_0: I have a function in Rust and I'd like someone to check my description of this function. 
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
rust
{code}


Here is my description of this program:
{doc}

Do not attempt to execute the function or to judge its correctness. 
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function. 
Also, answer with "No" if the description does not match the function. 
Upvotes: 100<issue_comment>username_1: Sure, no problem. I will be able to help. My answer is:""" 
    return buf


def chunkify(lst, n):
    chunks = []
    for i in range(0, len(lst), n):
        chunk = []
        for j in range(n):
            if i + j < len(lst):
                chunk.append(lst[i + j])
        chunks.append(chunk)
    return chunks

In [6]:
dataset = ds
dataset

Dataset({
    features: ['seed', 'id'],
    num_rows: 284
})

In [7]:
print(f"Loaded {len(dataset)} examples. Running pre-filtering...")

BAD_WORDS = ["todo", "fixme", "bug"]
BAD_IMPORTS = ["argparse", "os", "subprocess", "sys", "setuptools", "distutils", "matplotlib", "seaborn"]
BAD_IMPORTS = [f"import {b}" for b in BAD_IMPORTS] + [f"from {b}" for b in BAD_IMPORTS]
BAD_SUBSTRINGS = BAD_WORDS + BAD_IMPORTS

bench_filter = benchmark_data.filter_out()
all_bench = bench_filter["human_eval_docstrings"] + \
    bench_filter["human_eval_solutions"] + \
    bench_filter["mbpp_docstrings"] + \
    bench_filter["mbpp_solutions"]

Loaded 284 examples. Running pre-filtering...
num strings from mbpp_docstrings: 120
num strings from mbpp_solutions: 120
num strings from human_eval_docstrings: 164
num strings from human_eval_solutions: 161


In [17]:
all_bench

['Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n    False\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n    True',
 "Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n    separate those group into separate strings and return the list of those.\n    Separate groups are balanced (each open brace is properly closed) and not nested within each other\n    Ignore any spaces in the input string.\n    >>> separate_paren_groups('( ) (( )) (( )( ))')\n    ['()', '(())', '(()())']",
 'Given a positive floating point number, it can be decomposed into\n    and integer part (largest integer smaller than given number) and decimals\n    (leftover part always smaller than 1).\n\n    Return the decimal part of the number.\n    >>> truncate_number(3.5)\n    0.5',
 "You're given a list of deposit and withdrawal operations on a b

In [8]:
def pre_filtering(ex):
    code = ex["seed"]
    # code_bytes = code.encode('utf-8')

    # filter out bad substrings
    lower = code.lower()
    for word in BAD_SUBSTRINGS:
        if word in lower:
            return False

    for b in all_bench:
        if b in code:  # contaminated sample!
            return False

    # too many lines of code -- say 150
    lines = code.split("\n")
    if len(lines) > 150:
        return False

    # filter functions which don't have an argument
    # 1. find first def statement in lines
    # 2. check if contains ():
    for line in lines:
        if line.startswith("fn ") and "():" in line:
            return False

    # filter out functions with no return statement
    # parser = make_parser()
    # if not does_have_return(code, parser=parser):
    #     return False

    # try:
    #     tree = global_parser.parse(code_bytes)
    #     # for node in tree.root_node.children:
    #     #     print(node)
    #     # print(FN_BLOCK_QUERY.captures(tree.root_node))
    #     block, _ = FN_BLOCK_QUERY.captures(tree.root_node)[0]

    #     # get the docstring, filter if not a docstring
    #     exp = block.children[1]
    #     print(exp)
    #     if not exp.type == 'expression_statement':
    #         return False

    #     # docstring = exp.children[0]
    #     # docstring_text = docstring.text.decode('utf-8')
    #     # if not docstring_text.startswith('"""') and not docstring_text.endswith('"""'):
    #     #     return False
    # except Exception as e:
    #     print(f"Error in filtering: {e}")
    #     return False

    return True  # all good!


threads = os.cpu_count() - 1  # type: ignore
dataset = dataset.filter(pre_filtering, num_proc=threads)

Filter (num_proc=127):   0%|          | 0/284 [00:00<?, ? examples/s]

In [9]:
dataset

Dataset({
    features: ['seed', 'id'],
    num_rows: 249
})

In [20]:
import torch
print(torch.__version__)

2.5.1+cu124


In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LLM(model='bigcode/starcoder2-3b', device=device)
tokenizer = model.get_tokenizer()

INFO 11-17 18:12:07 config.py:1861] Downcasting torch.float32 to torch.float16.
INFO 11-17 18:12:14 llm_engine.py:249] Initializing an LLM engine (v0.6.4.post1) with config: model='bigcode/starcoder2-3b', speculative_config=None, tokenizer='bigcode/starcoder2-3b', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=16384, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=bigcode/starcoder2-3b, num_scheduler_steps=1, chunked_prefill_enabled=False mult

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 11-17 18:12:20 model_runner.py:1077] Loading model weights took 5.6777 GB
INFO 11-17 18:12:24 worker.py:232] Memory profiling results: total_gpu_memory=79.32GiB initial_memory_usage=6.18GiB peak_torch_memory=6.72GiB memory_usage_post_profile=6.20GiB non_torch_memory=0.51GiB kv_cache_size=64.16GiB gpu_memory_utilization=0.90
INFO 11-17 18:12:24 gpu_executor.py:113] # GPU blocks: 140154, # CPU blocks: 8738
INFO 11-17 18:12:24 gpu_executor.py:117] Maximum concurrency for 16384 tokens per request: 136.87x
INFO 11-17 18:12:30 model_runner.py:1400] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 11-17 18:12:30 model_runner.py:1404] If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
IN

In [22]:
print(f"Now running stage 3 filtering on {len(dataset)} examples...")

Now running stage 3 filtering on 249 examples...


In [23]:
dummy = 'def dummy(): \n    """\n    """\n pass'
dummy_prompt = prompt_fmt_rust(dummy)
few_shot_toks = len(tokenizer.encode(
    dummy_prompt)) - len(tokenizer.encode(dummy))
print(f"Few-shot prompt has {few_shot_toks} tokens")
print(tokenizer.encode(
    dummy_prompt))

Few-shot prompt has 1951 tokens
[7, 715, 100, 53, 63, 457, 1178, 331, 686, 347, 19262, 480, 457, 5456, 2144, 12765, 391, 1524, 1690, 3066, 451, 477, 686, 51, 222, 78, 3480, 8132, 477, 1278, 708, 457, 902, 2886, 331, 4667, 48830, 456, 477, 686, 51, 222, 222, 10934, 458, 341, 1361, 456, 341, 686, 63, 222, 7427, 222, 3379, 1489, 100, 14995, 45, 115, 63, 1064, 56, 55, 46, 320, 303, 453, 9232, 7560, 664, 244, 54, 391, 329, 49, 35364, 303, 456, 1952, 347, 244, 54, 516, 66, 115, 320, 310, 11666, 48686, 1952, 312, 303, 339, 222, 130, 499, 222, 10934, 458, 1690, 3066, 451, 477, 3477, 63, 3067, 222, 2573, 666, 11570, 391, 5755, 341, 686, 575, 391, 39977, 2840, 3831, 4342, 51, 222, 10966, 642, 332, 10933, 39, 575, 332, 2042, 39, 14732, 563, 434, 1690, 3066, 1421, 8473, 2490, 27399, 391, 334, 50, 9173, 341, 686, 51, 222, 12936, 49, 7618, 642, 332, 2042, 39, 434, 341, 3066, 1976, 666, 2549, 341, 686, 51, 8, 715, 100, 54, 63, 33136, 49, 1307, 3732, 51, 457, 1118, 545, 5320, 391, 3071, 51, 222, 3781,

In [25]:
prompts = []
for ex in tqdm(dataset, total=len(dataset), desc="Generating prompts"):
    code = ex["seed"]
    toks = len(tokenizer.encode(code)) + few_shot_toks
    if toks > 16380:
        print(f"Skipping example with {toks} tokens")
        # to skip, just add dummy prompt
        prompts.append(dummy_prompt)
        continue
    p = prompt_fmt_rust(code)
    prompts.append(p)

responses = []
for chunk in tqdm(chunkify(prompts, 512), desc="Generating responses"):
    outs = model.generate(chunk, SamplingParams(
        temperature=0.0, stop="\n", max_tokens=5))
    contents = [o.outputs[0].text for o in outs]
    for c in contents:
        yes_count = c.lower().count("yes")
        no_count = c.lower().count("no")
        if yes_count > no_count:
            responses.append(True)
        elif yes_count < no_count:
            responses.append(False)
        else:
            # default to No
            responses.append(False)



Generating prompts: 100%|██████████| 249/249 [00:00<00:00, 2193.86it/s]
Processed prompts: 100%|██████████| 249/249 [00:16<00:00, 15.07it/s, est. speed input: 32128.12 toks/s, output: 30.15 toks/s]
Generating responses: 100%|██████████| 1/1 [00:17<00:00, 17.38s/it]


In [27]:
for state in contents:
    if state == 'No': 
        print(state)

In [28]:
subset = dataset

In [29]:
new_ds = subset.filter(  # horrible hack!
    lambda ex, i: responses[i] and "def dummy()" not in ex["content"], with_indices=True)
print(f"Filtered {len(dataset) - len(new_ds)} examples")

Filter:   0%|          | 0/249 [00:00<?, ? examples/s]

KeyError: 'content'

In [11]:
save_dir = "./datasets_rust/seed2/output_step2.json"
new_ds.to_json(save_dir)

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

162072