# Syntax Repair with Language Models

## Dataset

TODO: Detail data, deepfix cite, explain 
types of errors

#

In [84]:
from repair import utils
from importlib import reload
utils = reload(utils)
train = utils.get_train_data()

In [None]:
# number of buggy programs in training split
len(train)

In [None]:
from repair import utils
import tqdm
# distribution of length in terms of tokens
num_tokens = [len(utils.tokenize(b.source)) for b in tqdm.tqdm(train)]


In [None]:
import seaborn as sns
ax = sns.histplot(num_tokens)
ax.set_xlabel("Number of tokens")
ax.set_ylabel("Count of cases")

In [None]:
# distribution of number of compile errors
errorcounts = [b.errorcount for b in train]
ax = sns.histplot(errorcounts)
ax.set_xlabel("Number of errors")
ax.set_ylabel("Count of cases")

In [None]:
# Relationship between length and errors
ax = sns.scatterplot(x=num_tokens, y=errorcounts, alpha=0.5)
ax.set_xlabel("Number of tokens")
ax.set_ylabel("Number of compilation errors")

In [None]:
# random sample of k compile errors in training split
import random
k = 10
sampled = train[:k]
for i, case in enumerate(sampled):
    print("Sample", i)
    print(case.source)
    print("Error:", case.error)
    print("--------------")

In [None]:
# random sample of k compile errors in test split
test = utils.get_test_data()
k = 10
sampled = test[:k]
for i, case in enumerate(sampled):
    print("Sample", i)
    print(case.source)
    print("Error:", case.error)
    print("--------------")

# Metrics

In [113]:
bad_simple = "int main( {return 0;}"
ok_simple = "int main() {return 0;}"

In [None]:
utils.gcc_compile(bad_simple)

In [None]:
utils.gcc_compile(ok_simple)

In [None]:
# typically: satisfy oracle (i.e. compiler) but also need to limit edit distance (e.g. deleting everything yields compilable unit)
# token edit distance: standard levenshtein distance but on lexer outputs (better than simple string distance)
utils.token_edit_distance(
    bad_simple,
    ok_simple
)

In [None]:
from importlib import reload
utils = reload(utils)
# can also use tree edit distance with parse tree that maintains syntax errors
utils.tree_edit_distance(
    bad_simple,
    ok_simple
)

# Running our benchmark suite

TODO

# Fixing Syntax Errors with Codex

In [None]:
# Reference and explain https://arxiv.org/abs/2208.11640

In [None]:
from repair.engines import codex

In [None]:
import inspect
print(inspect.getsource(codex.CodexEngine))

In [154]:
import inspect
print(inspect.getsource(codex.CodexBaseRepair))

class CodexBaseRepair(CodexRepair):
    """
    Based on https://beta.openai.com/examples/default-fix-python-bugs
    """

    def __init__(self, *args, **kwargs):
        self.codex = CodexEngine(*args, **kwargs)

    def get_prompt(self, code: str, **kwargs) -> str:
        # Based on https://beta.openai.com/examples/default-fix-python-bugs
        prompt = """//// Fix bugs in the below code\n"""
        prompt += f"/// Buggy C\n{code}\n\n"
        prompt += "/// Fixed C"
        if 'fixed' in kwargs and kwargs['fixed'] is not None:
            prompt += f"\n{kwargs['fixed']}\n\n"

        return prompt

    def get_repair_from_completion_(
            self, completion_dict: Dict[str, Any]) -> Dict[str, Any]:
        comp = completion_dict["completion"]
        # remove first \n and last \n, result from our prompt style
        if comp[0] == "\n":
            comp = comp[1:]
        if comp[-1] == "\n":
            comp = comp[:-1]
        return {"repair": comp, "score": completion_

In [160]:
from repair import tutorial_utils
tutorial_utils.show_code(codex.CodexBaseRepair)

In [100]:
from importlib import reload
from repair.engines import codex
from repair import utils
codex = reload(codex)
utils = reload(utils)
codex_base = codex.CodexBaseRepair(openai.api_key)

In [101]:
test = utils.get_test_data()

In [102]:
len(test)

5

In [97]:
pred = codex_base.repair(test[0].source, n=1)

In [103]:
res = utils.run_benchmark(codex_base, n=1)

100%|██████████| 5/5 [00:44<00:00,  8.85s/it]


### Add error message to improve error localization

In [110]:
import inspect
print(inspect.getsource(codex.CodexWithErrorInfo))

class CodexWithErrorInfo(CodexBaseRepair):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.error_cache = {}

    def _get_error_info(self, code: str) -> str:
        if code not in self.error_cache:
            self.error_cache[code] = gcc_compile(code).error
        return self.error_cache[code]

    def get_prompt(self, code: str, **kwargs):
        prompt = """//// Fix bugs in the below code\n"""
        prompt += f"/// Buggy C\n{code}\n\n"
        error_msg = self._get_error_info(code)
        prompt += f"/// Error Message\n{error_msg}\n\n"
        prompt += "/// Fixed C"
        if 'fixed' in kwargs and kwargs['fixed'] is not None:
            prompt += f"\n{kwargs['fixed']}\n\n"

        return prompt



In [111]:
codex_msg = codex.CodexWithErrorInfo(openai.api_key)

In [114]:
print(codex_msg.get_prompt(bad_simple))

//// Fix bugs in the below code
/// Buggy C
int main( {return 0;}

/// Error Message
/tmp/tmp2av57t37.c:1:11: error: expected declaration specifiers or ‘...’ before ‘{’ token
    1 | int main( {return 0;}
      |           ^


/// Fixed C


In [115]:
codex_msg.repair(bad_simple, maxtokens=100, n=1)

[{'repair': 'int main() {return 0;}\n', 'score': -0.33848194814375}]

In [116]:
res2 =  utils.run_benchmark(codex_msg, n=1)

100%|██████████| 5/5 [00:38<00:00,  7.61s/it]


In [118]:
res2[0]

cutoff,stat,top-1,top-3,top-5
0,compile,0.2,0.2,0.2
1,compile+distance,0.2,0.2,0.2


### Add few-shots to help show what kind of edits may be needed

In [121]:
codex = reload(codex)
print(inspect.getsource(codex.CodexWithFewShots))

class CodexWithFewShots(CodexBaseRepair):

    def __init__(self, shot_selector: FewShotSelector, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.shot_selector = shot_selector
        self.prompt_helper = CodexWithErrorInfo(*args, **kwargs)

    def get_prompt(self, code: str, **kwargs):
        few_shots = self.shot_selector.select_shots(code)
        prompt = ""
        for (buggy_shot, fixed_shot) in few_shots:
            prompt += self.prompt_helper.get_prompt(buggy_shot,
                                                    fixed=fixed_shot)

        prompt += self.prompt_helper.get_prompt(code)
        return prompt



In [138]:
codex = reload(codex)
utils = reload(utils)

In [139]:
fixed_shot_selector = codex.FixedFewShots(train[:5])
codex_shots = codex.CodexWithFewShots(fixed_shot_selector, openai.api_key)

In [135]:
codex_shots.repair(bad_simple, n=1, k=2)

[{'repair': 'int main() {return 0;}\n', 'score': -0.09847468290307693}]

In [142]:
utils = reload(utils)
res3 =  utils.run_benchmark(codex_shots, n=1, k=2)

100%|██████████| 5/5 [00:34<00:00,  6.94s/it]


In [143]:
res3[0]

cutoff,stat,top-1,top-3,top-5
0,compile,0.2,0.2,0.2
1,compile+distance,0.2,0.2,0.2


In [146]:
random_shot_selector = codex.RandomFewShots(train[:5])
codex_shots = codex.CodexWithFewShots(random_shot_selector, openai.api_key)

In [147]:
utils = reload(utils)
res3 =  utils.run_benchmark(codex_shots, n=1, k=2, t=0.7)
res3[0]

100%|██████████| 5/5 [00:37<00:00,  7.52s/it]


cutoff,stat,top-1,top-3,top-5
0,compile,0.2,0.2,0.2
1,compile+distance,0.2,0.2,0.2


In [153]:
utils = reload(utils)
utils.basic_results_table(res3[1])


Unnamed: 0,stat,top-1,top-3,top-5
0,compile,0.2,0.2,0.2
1,compile+distance,0.2,0.2,0.2


In [None]:
print(codex_shots.repair(bad_simple, fewshots=[('void main(a{}', 'void main(a){}')], n=10))

### How to create an initial example bank?

## Synthetic noise in good data

In [None]:
### Generate data using Codex (or other model) to fix real broken programs and then add pairs to data (cite bifi)

In [None]:
# bootstap it yourself with 
codex = reload(codex)
example_bank = codex.generate_basic_example_bank(benchmark, size=100)

### Picking Few-shots

* Fixed (cover some basic examples by hand)
* Random (sample from other programs you have)
* Similar (what does it mean to be similar?)

(Some) Parameters that affect repair