In [77]:
import datasets 
import numpy as np
import json
import gzip
from pathlib import Path

def pass_k(n: int, c: int, k: int) -> float:
    """
    Calculates 1 - comb(n - c, k) / comb(n, k).
    """
    if n - c < k:
        return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

def add_success_ratio(ds):
    statuses = ds["status"]
    sr = statuses.count("OK") / len(statuses)
    ds["success_ratio"] = sr
    return ds
def gunzip_json(path): 
    """
    Reads a .json.gz file, but produces None if any error occurs.
    """
    try:
        with gzip.open(path, "rt") as f:
            return json.load(f)
    except Exception as e:
        return None


def parse_problem_name(name):
    splitname = name.split("_")
    number = int(splitname[1])
    pname = "_".join(splitname[2:])
    return number, pname

def trim_problem_name(name):
    number, pname = parse_problem_name(name)
    return f"{number}_{pname}"



In [78]:
base_full_ds = datasets.load_dataset("json", data_files="starcoderbase-15b-results.jsonl", split="train")
base_rkt_ds = base_full_ds.filter(lambda x: x["language"] == "rkt").remove_columns(["language"])
base_rkt_ds = base_rkt_ds.map(add_success_ratio)
base_bad_rkt = base_rkt_ds.filter(lambda x: x["success_ratio"] <= 0.1 and x["success_ratio"] > 0.0)


Found cached dataset json (/home/jgouwar/.cache/huggingface/datasets/json/default-8ba98df331d2e277/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Loading cached processed dataset at /home/jgouwar/.cache/huggingface/datasets/json/default-8ba98df331d2e277/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-87e1329ebd82bf3b.arrow
Loading cached processed dataset at /home/jgouwar/.cache/huggingface/datasets/json/default-8ba98df331d2e277/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-f8c57ba37ad0b39c.arrow
Loading cached processed dataset at /home/jgouwar/.cache/huggingface/datasets/json/default-8ba98df331d2e277/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-4ffd0b533ea6c83a.arrow


In [79]:
bad_rkt_good_compls = []
for (prob, statues, compls, prompt) in zip(base_bad_rkt["problem"], base_bad_rkt["status"], base_bad_rkt["completion"], base_bad_rkt["prompt"]):
    for (stat, compl) in zip(statues, compls):
        if stat == "OK":
            bad_rkt_good_compls.append((prob, prompt+compl))

bad_rkt_probs = {}
for (prob, compl) in bad_rkt_good_compls:
    try: 
        bad_rkt_probs[prob].append(compl)
    except KeyError:
        bad_rkt_probs[prob] = [compl]



In [90]:
def multiple_results_to_ds(path):
    ds_dict = {
        "problem": [],
        "status": [],
        "program": [],
    }
    for file in path.glob("*.results.json.gz"):
        data = gunzip_json(file)
        if data is None:
            continue
        ds_dict["problem"].append(data["name"])
        programs = []
        statuses = []
        for res in data["results"]:
            statuses.append(res["status"])
            programs.append(res["program"])
        ds_dict["status"].append(statuses)
        ds_dict["program"].append(programs)
    return datasets.Dataset.from_dict(ds_dict).map(add_success_ratio)

def get_probs(ds, type, thresh):
    if type == "good":
        good_ds = ds.filter(lambda x: x["success_ratio"] >= thresh)
    elif type == "bad":
        good_ds = ds.filter(lambda x: x["success_ratio"] <= thresh and x["success_ratio"] > 0.0)
    else:
        raise ValueError(f"Unknown type {type}")
    good_compls = []
    for (prob, statues, progs) in zip(good_ds["problem"], good_ds["status"], good_ds["program"]):
        for (stat, prog) in zip(statues, progs):
            if stat == "OK":
                good_compls.append((prob, prog))
    good_probs = {}
    for (prob, prog) in good_compls:
        prob_name = trim_problem_name(prob)
        try: 
            good_probs[prob_name].append(prog)
        except KeyError:
            good_probs[prob_name] = [prog]
    return good_probs

def collect_candidates(bad_probs, good_probs):
    candidates = {}
    for prob in good_probs.keys():
        if bad_probs.get(prob) is not None:
            candidates[prob] = {"bad" : bad_probs[prob], "good" : good_probs[prob]}
    return candidates
   
tuned_rkt_path = Path("tuned-rkt-results/eval_checkpoint-584/")
tuned_rkt_ds = multiple_results_to_ds(tuned_rkt_path)
good_rkt_probs = get_probs(tuned_rkt_ds, "good", thresh=0.7)
rkt_candidates = collect_candidates(bad_rkt_probs, good_rkt_probs)
base_ocaml_path = Path("base-ml-results")
base_ocaml_ds = multiple_results_to_ds(base_ocaml_path)
bad_ocaml_probs = get_probs(base_ocaml_ds, "bad", thresh=0.1)
tuned_ocaml_path = Path("tuned-ml-results/eval_checkpoint-376/")
tuned_ocaml_ds = multiple_results_to_ds(tuned_ocaml_path)
good_ocaml_probs = get_probs(tuned_ocaml_ds, "good", thresh=0.9)
ocaml_candidates = collect_candidates(bad_ocaml_probs, good_ocaml_probs)



                                                                  

In [110]:
candidate = rkt_candidates["78_hex_key"]
bad_candidates = candidate["bad"]
good_candidates = candidate["good"]
print("Good len:" + str(len(good_candidates)))
print("Bad len:" + str(len(bad_candidates)))

Good len:19
Bad len:1


In [92]:
#print(good_candidates[0])

In [93]:
ocaml_candidates.keys()

dict_keys(['27_flip_case', '82_prime_length', '56_correct_bracketing', '45_triangle_area', '61_correct_bracketing'])

In [103]:
oc_candiate = ocaml_candidates['82_prime_length']
oc_good_candidates = oc_candiate["good"]
oc_bad_candidates = oc_candiate["bad"]
print(f"Good len: {len(oc_good_candidates)}")
print(f"Bad len: {len(oc_bad_candidates)}")

Good len: 19
Bad len: 2


In [107]:
print(oc_bad_candidates[0])

(**Write a function that takes a string and returns True if the string
length is a prime number or False otherwise
Examples
prime_length('Hello') == True
prime_length('abcdcba') == True
prime_length('kittens') == True
prime_length('orange') == False
*)
let prime_length (string : string) : bool =
    let len = String.length string in
    let rec is_prime n =
        if n <= 1 then false
        else if n = 2 then true
        else if n mod 2 = 0 then false
        else
            let rec check i =
                if i = n then true
                else if n mod i = 0 then false
                else check (i + 1)
            in check 3
    in is_prime len

let assertions =
 let candidate = prime_length in
  (assert ((candidate "Hello") = true));
  (assert ((candidate "abcdcba") = true));
  (assert ((candidate "kittens") = true));
  (assert ((candidate "orange") = false));
  (assert ((candidate "wow") = true));
  (assert ((candidate "world") = true));
  (assert ((candidate "MadaM") = tru

In [109]:
print(oc_good_candidates[15])

(**Write a function that takes a string and returns True if the string
length is a prime number or False otherwise
Examples
prime_length('Hello') == True
prime_length('abcdcba') == True
prime_length('kittens') == True
prime_length('orange') == False
*)
let prime_length (string : string) : bool =
  let rec is_prime (n : int) : bool =
    if n <= 1 then false
    else if n <= 3 then true
    else if n mod 2 = 0 || n mod 3 = 0 then false
    else
      let rec loop i =
        if i * i > n then true
        else if n mod i = 0 || n mod (i + 2) = 0 then false
        else loop (i + 6)
      in loop 5
  in is_prime (String.length string)

let assertions =
 let candidate = prime_length in
  (assert ((candidate "Hello") = true));
  (assert ((candidate "abcdcba") = true));
  (assert ((candidate "kittens") = true));
  (assert ((candidate "orange") = false));
  (assert ((candidate "wow") = true));
  (assert ((candidate "world") = true));
  (assert ((candidate "MadaM") = true));
  (assert ((candi