In [123]:
import datasets 
import numpy as np
import json
import gzip
from pathlib import Path

def pass_k(n: int, c: int, k: int) -> float:
    """
    Calculates 1 - comb(n - c, k) / comb(n, k).
    """
    if n - c < k:
        return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

def add_success_ratio(ds):
    statuses = ds["status"]
    sr = statuses.count("OK") / len(statuses)
    ds["success_ratio"] = sr
    return ds
def gunzip_json(path): 
    """
    Reads a .json.gz file, but produces None if any error occurs.
    """
    try:
        with gzip.open(path, "rt") as f:
            return json.load(f)
    except Exception as e:
        return None


def parse_problem_name(name):
    splitname = name.split("_")
    number = int(splitname[1])
    pname = "_".join(splitname[2:])
    return number, pname

def trim_problem_name(name):
    number, pname = parse_problem_name(name)
    return f"{number}_{pname}"

In [142]:
base_full_ds = datasets.load_dataset("json", data_files="starcoderbase-15b-results.jsonl", split="train")
base_rkt_ds = base_full_ds.filter(lambda x: x["language"] == "rkt").remove_columns(["language"])
base_rkt_ds = base_rkt_ds.map(add_success_ratio)
base_bad_rkt = base_rkt_ds.filter(lambda x: x["success_ratio"] <= 0.1 and x["success_ratio"] > 0.0)


Found cached dataset json (/home/jgouwar/.cache/huggingface/datasets/json/default-8ba98df331d2e277/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Loading cached processed dataset at /home/jgouwar/.cache/huggingface/datasets/json/default-8ba98df331d2e277/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-87e1329ebd82bf3b.arrow
Loading cached processed dataset at /home/jgouwar/.cache/huggingface/datasets/json/default-8ba98df331d2e277/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-f8c57ba37ad0b39c.arrow
Loading cached processed dataset at /home/jgouwar/.cache/huggingface/datasets/json/default-8ba98df331d2e277/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-4ffd0b533ea6c83a.arrow


200

In [125]:
bad_rkt_good_compls = []
for (prob, statues, compls, prompt) in zip(base_bad_rkt["problem"], base_bad_rkt["status"], base_bad_rkt["completion"], base_bad_rkt["prompt"]):
    for (stat, compl) in zip(statues, compls):
        if stat == "OK":
            bad_rkt_good_compls.append((prob, prompt+compl))

bad_rkt_probs = {}
for (prob, compl) in bad_rkt_good_compls:
    try: 
        bad_rkt_probs[prob].append(compl)
    except KeyError:
        bad_rkt_probs[prob] = [compl]



In [126]:
tuned_rkt_path = Path("tuned-rkt-results/eval_checkpoint-584/")
tuned_rkt_dict = {
    "problem": [],
    "status": [],
    "program": [],
}
for file in tuned_rkt_path.glob("*.results.json.gz"):
    data = gunzip_json(file)
    if data is None:
        continue
    tuned_rkt_dict["problem"].append(data["name"])
    programs = []
    statuses = []
    for res in data["results"]:
        statuses.append(res["status"])
        programs.append(res["program"])
    tuned_rkt_dict["status"].append(statuses)
    tuned_rkt_dict["program"].append(programs)

tuned_rkt_ds = datasets.Dataset.from_dict(tuned_rkt_dict).map(add_success_ratio)
tuned_good_rkt_ds = tuned_rkt_ds.filter(lambda x: x["success_ratio"] >= 0.7)
tuned_good_rkt_ds["problem"][0]
tuned_good_rkt_ds["status"][0]
tuned_rkt_good_compls = []
for (prob, statues, progs) in zip(tuned_good_rkt_ds["problem"], tuned_good_rkt_ds["status"], tuned_good_rkt_ds["program"]):
    for (stat, prog) in zip(statues, progs):
        if stat == "OK":
            tuned_rkt_good_compls.append((prob, prog))
good_rkt_probs = {}
for (prob, prog) in tuned_rkt_good_compls:
    prob_name = trim_problem_name(prob)
    try: 
        good_rkt_probs[prob_name].append(prog)
    except KeyError:
        good_rkt_probs[prob_name] = [prog]


                                                      

In [127]:

candidates = {}
for prob in good_rkt_probs.keys():
    if bad_rkt_probs.get(prob) is not None:
        candidates[prob] = {"bad" : bad_rkt_probs[prob], "good" : good_rkt_probs[prob]}
print(candidates.keys())


dict_keys(['78_hex_key', '122_add_elements', '48_is_palindrome', '103_rounded_avg', '79_decimal_to_binary', '52_below_threshold', '7_filter_by_substring'])


In [139]:
candidate = candidates["78_hex_key"]
bad_candidates = candidate["bad"]
good_candidates = candidate["good"]
print("Good len:" + str(len(good_candidates)))
print("Bad len:" + str(len(bad_candidates)))

Good len:19
Bad len:1


In [140]:
print(bad_candidates[0])

#lang racket

;; You have been tasked to write a function that receives 
;; a hexadecimal number as a string and counts the number of hexadecimal 
;; digits that are primes (prime number, or a prime, is a natural number 
;; greater than 1 that is not a product of two smaller natural numbers).
;; Hexadecimal digits are 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F.
;; Prime numbers are 2, 3, 5, 7, 11, 13, 17,...
;; So you have to determine a number of the following digits: 2, 3, 5, 7, 
;; B (=decimal 11), D (=decimal 13).
;; Note: you may assume the input is always correct or empty string, 
;; and symbols A,B,C,D,E,F are always uppercase.
;; Examples:
;; >>> (hex_key "AB")
;; 1
;; >>> (hex_key "1077E")
;; 2
;; >>> (hex_key "ABED1A33")
;; 4
;; >>> (hex_key "123456789ABCDEF0")
;; 6
;; >>> (hex_key "2020")
;; 2
(define (hex_key num)

  (define (hex_key_helper num count)
    (if (string=? num "")
        count
        (hex_key_helper (substring num 1 (string-length num))
                  

In [141]:
print(good_candidates[0])

#lang racket

;; You have been tasked to write a function that receives 
;; a hexadecimal number as a string and counts the number of hexadecimal 
;; digits that are primes (prime number, or a prime, is a natural number 
;; greater than 1 that is not a product of two smaller natural numbers).
;; Hexadecimal digits are 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F.
;; Prime numbers are 2, 3, 5, 7, 11, 13, 17,...
;; So you have to determine a number of the following digits: 2, 3, 5, 7, 
;; B (=decimal 11), D (=decimal 13).
;; Note: you may assume the input is always correct or empty string, 
;; and symbols A,B,C,D,E,F are always uppercase.
;; Examples:
;; >>> (hex_key "AB")
;; 1
;; >>> (hex_key "1077E")
;; 2
;; >>> (hex_key "ABED1A33")
;; 4
;; >>> (hex_key "123456789ABCDEF0")
;; 6
;; >>> (hex_key "2020")
;; 2
(define (hex_key num)

	(length 
		(filter 
			(lambda (x) 
				(member x '(2 3 5 7 11 13)))
			(map 
				(lambda (x) 
					(string->number (string x) 16))
				(string->list num)))