In [100]:
from datasets import load_dataset

In [132]:
def empty_rows(ds_path):# 返回所有生成失败的样本
    ds = load_dataset("json", data_files=ds_path, split="train")
    return ds.filter(lambda x: x['optimization_0']=='')

In [102]:
def count_message_tokens(content, tokenizer):
    tokens = tokenizer(content)['input_ids']
    num_tokens = len(tokens)

    return num_tokens

In [103]:
vicuna_prompt_template = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.

USER: As an expert software developer with years of experience, please meticulously inspect the following unoptimized inefficient code and give an optimized version of the code, making it solve the same exact problem while achieving faster execution time.
To pass the testcases, the generated optimized code should strictly follow the same input/output format as the original unoptimized code.
The detailed information are as follows:
1. Description of the problem: {task_description}
2. Programming language: {lang}
3. Unoptimized code: 
```
{baseline_code}
```
4. Example testcase input: {example_input}
5. Example testcase output: {example_output}

Respond only the optimized code in the following JSON format:
{{"optimized_code": code string}}
ASSISTANT:
"""

In [122]:
import transformers
from transformers import AutoTokenizer

def model_factory(model_name):
    user_message = """As an expert software developer with years of experience, please meticulously inspect the following the following unoptimized inefficient code and give an optimized version of the code, making it solve the same exact problem while achieving smaller memory usage.
To pass the testcases, the generated optimized code should strictly follow the same input/output format as the original unoptimized code.
The detailed information are as follows:
1. Description of the problem: {task_description}
2. Programming language: {lang}
3. Unoptimized code: 
```
{baseline_code}
```
4. Example testcase input: {example_input}
5. Example testcase output: {example_output}

Respond only the optimized code in the following JSON format:
{{"optimized_code": code string}}"""

    if model_name == 'vicuna':
        checkpoint_path = '/home/wyk/hf_cache/lmsys/vicuna-13b-v1.5-16k'
        prompt_template = f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.

USER: {user_message.strip()}
ASSISTANT:
"""
        tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
        model = None
    elif model_name == 'wizardcoder':
        checkpoint_path = '/home/wyk/hf_cache/WizardCoder'
        prompt_template = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{user_message.strip()}

### Response:"""
        tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    elif model_name == 'codellama':
        checkpoint_path = '/home/wyk/hf_cache/codellama'
        prompt_template = f'<s>[INST] {user_message.strip()} [/INST]'
        tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    elif model_name == 'starcoder':
        pass
    return prompt_template, tokenizer


In [123]:
def get_data_info(ds, ds_type, model_name):
    assert ds_type in ['mem','time']
    assert model_name in ['vicuna','wizardcoder','codellama','starcoder']
    prompt_template, tokenizer = model_factory(model_name)
    ds = ds.map(lambda d: {
        'code_tokens': count_message_tokens(d[f'{ds_type}_baseline_code'], tokenizer), 
        'prompt_tokens': count_message_tokens(
            prompt_template.format(
                task_description=d['task_description'],
                lang=d['lang'],
                baseline_code=d[f'{ds_type}_baseline_code'],
                example_input=d['testcases'][0]['input'],
                example_output=d['testcases'][0]['output'][0]
            ), 
            tokenizer
        )
    })
    ds = ds.select_columns(['src_uid','lang',f'{ds_type}_baseline_code_uid','code_tokens','prompt_tokens'])
    return ds.to_pandas()

## 原数据集（有很多空数据）

In [138]:
wizardcoder_mem_path = "./mem_code_opt_inference_wizardcoder_replenish4.jsonl"
wizardcoder_time_path = "./time_code_opt_inference_wizardcoder_replenish4.jsonl"
vicuna_mem_path = "./mem_code_opt_inference_vicuna.jsonl"
vicuna_time_path = "./time_code_opt_inference_vicuna.jsonl"
codellama_mem_path = "./mem_code_opt_data_codellama_replenish.jsonl"
codellama_time_path = "./time_code_opt_data_codellama_replenish.jsonl"


vicuna

In [141]:
vicuna_failed_mem = get_data_info(empty_rows(vicuna_mem_path), 'mem', 'vicuna')
vicuna_failed_time = get_data_info(empty_rows(vicuna_time_path), 'time', 'vicuna')
display(vicuna_failed_mem)
display(vicuna_failed_time)

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Unnamed: 0,src_uid,lang,mem_baseline_code_uid,code_tokens,prompt_tokens
0,369f37d3487ba8c158e24f5ca759287b,GNU C++,8e30009ff12625f543ba7fb2a045d89a,2923,3345
1,1d73b315694f2ebbf796654193372730,GNU C++,71b4e1e108665548d02d6be01c6798cd,1766,2314
2,566adc43d2d6df257c26c5f5495a5745,GNU C++,ed358d58411a7ab7daa56c1035e6d7ac,2241,2589


Unnamed: 0,src_uid,lang,time_baseline_code_uid,code_tokens,prompt_tokens
0,1ae2942b72ebb7c55359c41e141900d7,Python 3,7753e47e11c7622bbe870861f92a4c1e,4476,5122
1,00480885be97002dca98fe98a4238aee,GNU C++,06d994ff01f10c12154ea5db118111a3,1958,2639
2,369f37d3487ba8c158e24f5ca759287b,GNU C++,001a05c8f74315f65ab4be3817e71a57,4417,4839
3,dfd0814d912a7f2dfe31744ad1c778ae,GNU C++,c78457e32745642efcc495ba47a6d3b4,1725,2333
4,970cd8ce0cf7214b7f2be337990557c9,GNU C++,969236f30c57aae82414bd5d0d72f2e9,3921,4280
5,1d73b315694f2ebbf796654193372730,GNU C++,c37ee62c72c01d23c416789e586b84a4,1673,2221
6,ffa25047060e4741d8eddf2b91b1ca23,GNU C++,d2ab0503ff0b707cedadf40b27428b5e,2175,2597
7,a6cba17c5ddb93f6741e00280fb6c54c,Mono C#,11781cdf1336a2ed99a2c9a339ccd30d,1248,1965
8,1ae2942b72ebb7c55359c41e141900d7,Mono C#,497a6eb76cb3a56e13acb48bc19b427b,5831,6478
9,566adc43d2d6df257c26c5f5495a5745,Mono C#,c9a4ce2cbee18773f505a1108261a965,1699,2047


wizardcoder

In [140]:
wizardcoder_failed_mem = get_data_info(empty_rows(wizardcoder_mem_path), 'mem', 'wizardcoder')
wizardcoder_failed_time = get_data_info(empty_rows(wizardcoder_time_path), 'time', 'wizardcoder')
display(wizardcoder_failed_mem)
display(wizardcoder_failed_time)

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (2663 > 2048). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/8 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (3883 > 2048). Running this sequence through the model will result in indexing errors


Unnamed: 0,src_uid,lang,mem_baseline_code_uid,code_tokens,prompt_tokens
0,369f37d3487ba8c158e24f5ca759287b,GNU C++,8e30009ff12625f543ba7fb2a045d89a,2663,3072
1,1d73b315694f2ebbf796654193372730,GNU C++,71b4e1e108665548d02d6be01c6798cd,1640,2169
2,566adc43d2d6df257c26c5f5495a5745,GNU C++,ed358d58411a7ab7daa56c1035e6d7ac,2075,2408


Unnamed: 0,src_uid,lang,time_baseline_code_uid,code_tokens,prompt_tokens
0,1ae2942b72ebb7c55359c41e141900d7,Python 3,7753e47e11c7622bbe870861f92a4c1e,3883,4513
1,00480885be97002dca98fe98a4238aee,GNU C++,06d994ff01f10c12154ea5db118111a3,1799,2467
2,369f37d3487ba8c158e24f5ca759287b,GNU C++,001a05c8f74315f65ab4be3817e71a57,4027,4436
3,dfd0814d912a7f2dfe31744ad1c778ae,GNU C++,c78457e32745642efcc495ba47a6d3b4,1481,2072
4,970cd8ce0cf7214b7f2be337990557c9,GNU C++,969236f30c57aae82414bd5d0d72f2e9,2961,3298
5,1d73b315694f2ebbf796654193372730,GNU C++,c37ee62c72c01d23c416789e586b84a4,1513,2042
6,ffa25047060e4741d8eddf2b91b1ca23,GNU C++,d2ab0503ff0b707cedadf40b27428b5e,2025,2449
7,1ae2942b72ebb7c55359c41e141900d7,Mono C#,497a6eb76cb3a56e13acb48bc19b427b,4876,5506


codellama

In [137]:
codellama_failed_mem = get_data_info(empty_rows(codellama_mem_path), 'mem', 'codellama')
codellama_failed_time = get_data_info(empty_rows(codellama_time_path), 'time', 'codellama')
display(codellama_failed_mem)
display(codellama_failed_time)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/45 [00:00<?, ? examples/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/33 [00:00<?, ? examples/s]

Unnamed: 0,src_uid,lang,mem_baseline_code_uid,code_tokens,prompt_tokens
0,a37df9b239a40473516d1525d56a0da7,Python 3,974cba32d5751fefaf2213b5276c63dd,813,1246
1,4ecbfc792da55f458342c6eff2d5da5a,GNU C,418aa85079b2c6cb5335d15f06e8d2df,809,1250
2,a6cba17c5ddb93f6741e00280fb6c54c,GNU C,25f988034e63a2e037da470628f5bf33,723,1407
3,a17bac596b1f060209534cbffdf0f40e,GNU C,2a4472a4a10c95db0f8620d5c2b692ea,1624,2213
4,702ec7a08b8472fa843acb6f2107e583,GNU C++,bca62d60e7bfeb083167619714ea21ab,1015,1465
5,b0301a2d79a1ec126511ed769ec0b743,GNU C++,087bfe5d33cbc494ad7c9993e2ed14ba,1453,1879
6,c23d3ec2b9fb4b4d169bc8053bfd000e,GNU C++,c5916412903a6c823b9557f6796417c6,1406,2135
7,a9bad412597726f8cdc0cfa2da891bc4,GNU C++,fb7bfc8a1118ab4550e410e151a775c1,1430,1897
8,0996e41d0630e56472399bc81544756b,GNU C++,67ceb467d869143ef7e9f52cbedbb19f,608,1207
9,c16c49baf7b2d179764871204475036e,GNU C++,547eb29e1441f4a6414523b956d6f5ae,1215,1620


Unnamed: 0,src_uid,lang,time_baseline_code_uid,code_tokens,prompt_tokens
0,1ae2942b72ebb7c55359c41e141900d7,Python 3,7753e47e11c7622bbe870861f92a4c1e,4476,5090
1,5e055bad1da5bdc84599d6f2f89fbd12,GNU C,8d2e30036a9e1ba2f21c8d1e93023e00,890,1207
2,a6cba17c5ddb93f6741e00280fb6c54c,GNU C,674d4476f23920b540b31614cba86481,1198,1882
3,a17bac596b1f060209534cbffdf0f40e,GNU C,a19e8c72a9d869d9e9a02e2bed4a0443,1624,2213
4,702ec7a08b8472fa843acb6f2107e583,GNU C++,10887dde531ac9e6755e2b970e8169e5,1088,1538
5,c23d3ec2b9fb4b4d169bc8053bfd000e,GNU C++,64436b602ab2429463534e65451a19c7,1023,1752
6,0996e41d0630e56472399bc81544756b,GNU C++,b781fc676fbec394b21a5082ac0a0109,800,1399
7,00480885be97002dca98fe98a4238aee,GNU C++,06d994ff01f10c12154ea5db118111a3,1958,2607
8,369f37d3487ba8c158e24f5ca759287b,GNU C++,001a05c8f74315f65ab4be3817e71a57,4417,4807
9,dfd0814d912a7f2dfe31744ad1c778ae,GNU C++,c78457e32745642efcc495ba47a6d3b4,1725,2301


In [131]:
df = get_data_info(time_ds.filter(lambda x: x['optimization_0']!=''), 'mem', 'vicuna')
df[df['prompt_tokens']>2048]

Filter:   0%|          | 0/121 [00:00<?, ? examples/s]

Map:   0%|          | 0/113 [00:00<?, ? examples/s]

Unnamed: 0,src_uid,lang,mem_baseline_code_uid,code_tokens,prompt_tokens
42,a17bac596b1f060209534cbffdf0f40e,GNU C,2a4472a4a10c95db0f8620d5c2b692ea,1624,2245
63,c23d3ec2b9fb4b4d169bc8053bfd000e,GNU C++,c5916412903a6c823b9557f6796417c6,1406,2167
80,566adc43d2d6df257c26c5f5495a5745,GNU C++,ed358d58411a7ab7daa56c1035e6d7ac,2241,2589
87,f8315dc903b0542c453cab4577bcb20d,Mono C#,8dbecbe304244baa540fc15df607068e,1620,2073
94,a6cba17c5ddb93f6741e00280fb6c54c,Mono C#,4689cdd91cff9d9ff47ff863673c4a77,1348,2066
105,c175d010d75c391d0b25391fecff007c,Mono C#,310d028f0f6555fed72adcbbda69bb28,1577,2188


In [125]:
get_data_info(failed_mem, 'mem', 'vicuna')

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Unnamed: 0,src_uid,lang,mem_baseline_code_uid,code_tokens,prompt_tokens
0,369f37d3487ba8c158e24f5ca759287b,GNU C++,8e30009ff12625f543ba7fb2a045d89a,2923,3345
1,1d73b315694f2ebbf796654193372730,GNU C++,71b4e1e108665548d02d6be01c6798cd,1766,2314
2,566adc43d2d6df257c26c5f5495a5745,GNU C++,ed358d58411a7ab7daa56c1035e6d7ac,2241,2589


In [126]:
get_data_info(failed_time, 'time', 'vicuna')

Map:   0%|          | 0/8 [00:00<?, ? examples/s]

Unnamed: 0,src_uid,lang,time_baseline_code_uid,code_tokens,prompt_tokens
0,1ae2942b72ebb7c55359c41e141900d7,Python 3,7753e47e11c7622bbe870861f92a4c1e,4476,5122
1,00480885be97002dca98fe98a4238aee,GNU C++,06d994ff01f10c12154ea5db118111a3,1958,2639
2,369f37d3487ba8c158e24f5ca759287b,GNU C++,001a05c8f74315f65ab4be3817e71a57,4417,4839
3,dfd0814d912a7f2dfe31744ad1c778ae,GNU C++,c78457e32745642efcc495ba47a6d3b4,1725,2333
4,970cd8ce0cf7214b7f2be337990557c9,GNU C++,969236f30c57aae82414bd5d0d72f2e9,3921,4280
5,1d73b315694f2ebbf796654193372730,GNU C++,c37ee62c72c01d23c416789e586b84a4,1673,2221
6,ffa25047060e4741d8eddf2b91b1ca23,GNU C++,d2ab0503ff0b707cedadf40b27428b5e,2175,2597
7,1ae2942b72ebb7c55359c41e141900d7,Mono C#,497a6eb76cb3a56e13acb48bc19b427b,5831,6478


## 补充生成的数据集（也可能有空数据）

In [20]:
CS_mem_path = "./mem_code_opt_inference_wizardcoder_CS.jsonl"
CS_time_path = "./time_code_opt_inference_wizardcoder_CS.jsonl"
CS_mem_ds = load_dataset("json", data_files=CS_mem_path, split="train")
CS_time_ds = load_dataset("json", data_files=CS_time_path, split="train")
CS_failed_mem = CS_mem_ds.filter(lambda x: x['optimization_0']=='')
CS_failed_time = CS_time_ds.filter(lambda x: x['optimization_0']=='')
CS_failed_mem,CS_failed_time

(Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 0
 }),
 Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 1
 }))

In [21]:
CPP_mem_path = "./mem_code_opt_inference_wizardcoder_CPP.jsonl"
CPP_time_path = "./time_code_opt_inference_wizardcoder_CPP.jsonl"
CPP_mem_ds = load_dataset("json", data_files=CPP_mem_path, split="train")
CPP_time_ds = load_dataset("json", data_files=CPP_time_path, split="train")
CPP_failed_mem = CPP_mem_ds.filter(lambda x: x['optimization_0']=='')
CPP_failed_time = CPP_time_ds.filter(lambda x: x['optimization_0']=='')
CPP_failed_mem,CPP_failed_time

(Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 3
 }),
 Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 10
 }))

## 查看数据集和补充数据集的空项是否重合

In [30]:
failed_time.filter(lambda x: x['lang']=='Python 3'), failed_time.filter(lambda x: x['lang']=='GNU C++'), failed_time.filter(lambda x: x['lang']=='Mono C#')

(Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 1
 }),
 Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 24
 }),
 Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 30
 }))

In [34]:
all(src_uid in failed_time.filter(lambda x: x['lang']=='Mono C#')['src_uid'] for src_uid in CS_failed_time)

Filter:   0%|          | 0/55 [00:00<?, ? examples/s]

False

In [33]:
all(src_uid in failed_time.filter(lambda x: x['lang']=='GNU C++')['src_uid'] for src_uid in CPP_failed_time)

False

重合项

In [37]:
set(failed_time.filter(lambda x: x['lang']=='Mono C#')['src_uid']) & set(CS_failed_time['src_uid'])

{'1ae2942b72ebb7c55359c41e141900d7'}

In [38]:
set(failed_time.filter(lambda x: x['lang']=='GNU C++')['src_uid']) & set(CS_failed_time['src_uid'])

set()

## 将补充数据替换进数据集

In [43]:
def replace_empty_mem(data):
    if data['optimization_0']=='':# 生成失败的
        if data['lang']=='Mono C#':
            data = CS_mem_ds.filter(lambda x: x['lang']=='Mono C#' and x['src_uid']==data['src_uid'])[0]
        elif data['lang']=='GNU C++':
            data = CPP_mem_ds.filter(lambda x: x['lang']=='GNU C++' and x['src_uid']==data['src_uid'])[0]
        else:
            print(f"cannot find relacement for empty row: lang={data['lang']} src_uid={data['src_uid']}")
    return data
def replace_empty_time(data):
    if data['optimization_0']=='':# 生成失败的
        if data['lang']=='Mono C#':
            data = CS_time_ds.filter(lambda x: x['lang']=='Mono C#' and x['src_uid']==data['src_uid'])[0]
        elif data['lang']=='GNU C++':
            data = CPP_time_ds.filter(lambda x: x['lang']=='GNU C++' and x['src_uid']==data['src_uid'])[0]
        else:
            print(f"cannot find relacement for empty row: lang={data['lang']} src_uid={data['src_uid']}")
    return data


In [None]:
replaced_mem_ds = mem_ds.map(replace_empty_mem)
replaced_time_ds = time_ds.map(replace_empty_time)

In [None]:
replaced_mem_ds.to_json("./mem_code_opt_inference_wizardcoder_replaced1.jsonl")
replaced_time_ds.to_json("./time_code_opt_inference_wizardcoder_replaced1.jsonl")

## 查看补充之后的结果

wizardcoder

In [48]:
empty_rows(replaced_mem_ds)

Filter:   0%|          | 0/121 [00:00<?, ? examples/s]

Dataset({
    features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
    num_rows: 3
})

In [49]:
empty_rows(replaced_time_ds)

Filter:   0%|          | 0/121 [00:00<?, ? examples/s]

Dataset({
    features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
    num_rows: 12
})

Vicuna

In [54]:
mem_path = "./mem_code_opt_inference_vicuna.jsonl"
time_path = "./time_code_opt_inference_vicuna.jsonl"
mem_ds = load_dataset("json", data_files=mem_path, split="train")
time_ds = load_dataset("json", data_files=time_path, split="train")
empty_rows(mem_ds), empty_rows(time_ds)

Filter:   0%|          | 0/121 [00:00<?, ? examples/s]

Filter:   0%|          | 0/121 [00:00<?, ? examples/s]

(Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 3
 }),
 Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 10
 }))

codellama

In [57]:
mem_path = "./mem_code_opt_data_codellama.jsonl"
time_path = "./time_code_opt_data_codellama.jsonl"
mem_ds = load_dataset("json", data_files=mem_path, split="train")
time_ds = load_dataset("json", data_files=time_path, split="train")
empty_rows(mem_ds), empty_rows(time_ds)

(Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 51
 }),
 Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 86
 }))

gpt3.5

In [64]:
mem_path = "./mem_code_opt_data_gpt3.jsonl"
time_path = "./time_code_opt_data_gpt3.jsonl"
mem_ds = load_dataset("json", data_files=mem_path, split="train")
time_ds = load_dataset("json", data_files=time_path, split="train")
empty_rows(mem_ds), empty_rows(time_ds)

(Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 0
 }),
 Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 1
 }))

palm

In [61]:
mem_path = "./mem_code_opt_data_palm.jsonl"
time_path = "./time_code_opt_data_palm.jsonl"
mem_ds = load_dataset("json", data_files=mem_path, split="train")
time_ds = load_dataset("json", data_files=time_path, split="train")
empty_rows(mem_ds), empty_rows(time_ds)

(Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 0
 }),
 Dataset({
     features: ['src_uid', 'lang', 'mem_baseline_code_uid', 'mem_baseline_code', 'mem_baseline_perf', 'time_baseline_code_uid', 'time_baseline_code', 'time_baseline_perf', 'task_description', 'testcases', 'optimization_0', 'optimization_1', 'optimization_2', 'optimization_3', 'optimization_4'],
     num_rows: 1
 }))

In [68]:
empty_rows(time_ds)[0]

{'src_uid': '1ae2942b72ebb7c55359c41e141900d7',
 'lang': 'Mono C#',
 'mem_baseline_code_uid': 'd2f59a1950dc9af6f81f113cf4048abd',
 'mem_baseline_code': '\ufeffusing System;\nusing System.Collections.Generic;\nusing System.Text;\n\n\npublic class taskA\n{\n    static void Main(string[] args)\n    {\n        string inp = Console.ReadLine();\n        int n = int.Parse(inp);\n\n        string ps = Console.ReadLine();\n        long[] p = new long[n];\n        string[] spl = ps.Split(\' \');\n        for (int i = 0; i < n; i++)\n            p[i] = long.Parse(spl[i]);\n\n        string costs = Console.ReadLine();\n        spl = costs.Split(\' \');\n        long[] c = new long[5];\n        for(int i=0; i<5; i++)\n            c[i] = long.Parse(spl[i]);\n\n        long balance = 0;\n        long[] counts = new long[5];\n\n        for (int i = 0; i < n; i++)\n        {\n            balance += p[i];\n\n            for (int j = 4; j >= 0; j--)\n            {\n                if (balance >= c[j])\n 

llama2

In [69]:
mem_path = "./mem_code_opt_data_llama2.jsonl"
time_path = "./time_code_opt_data_llama2.jsonl"
mem_ds = load_dataset("json", data_files=mem_path, split="train")
time_ds = load_dataset("json", data_files=time_path, split="train")
empty_mem_ds = empty_rows(mem_ds)
empty_time_ds = empty_rows(time_ds)

In [71]:
for d in empty_mem_ds:
    print(d['src_uid'], d['lang'])

a17bac596b1f060209534cbffdf0f40e GNU C
c23d3ec2b9fb4b4d169bc8053bfd000e GNU C++
00480885be97002dca98fe98a4238aee GNU C++
369f37d3487ba8c158e24f5ca759287b GNU C++
1d73b315694f2ebbf796654193372730 GNU C++
566adc43d2d6df257c26c5f5495a5745 GNU C++
cb47d710361979de0f975cc34fc22c7a GNU C++
f8315dc903b0542c453cab4577bcb20d Mono C#
a6cba17c5ddb93f6741e00280fb6c54c Mono C#
c175d010d75c391d0b25391fecff007c Mono C#


In [72]:
for d in empty_time_ds:
    print(d['src_uid'], d['lang'])

1ae2942b72ebb7c55359c41e141900d7 Python 3
a17bac596b1f060209534cbffdf0f40e GNU C
00480885be97002dca98fe98a4238aee GNU C++
369f37d3487ba8c158e24f5ca759287b GNU C++
dfd0814d912a7f2dfe31744ad1c778ae GNU C++
970cd8ce0cf7214b7f2be337990557c9 GNU C++
1d73b315694f2ebbf796654193372730 GNU C++
a6cba17c5ddb93f6741e00280fb6c54c GNU C++
ffa25047060e4741d8eddf2b91b1ca23 GNU C++
a6cba17c5ddb93f6741e00280fb6c54c Mono C#
1ae2942b72ebb7c55359c41e141900d7 Mono C#


In [75]:
len(empty_time_ds)

11