# Code downstream tasks

In [1]:
import os
import gzip
import json
import jsonlines
import tqdm.notebook as tqdm 
import random
import numpy as np
from src.useful_utils import chunks, read_in_chunks

%load_ext autoreload
%autoreload 2

### MRR on code search net valid and test

In [4]:
def load_CSN_data(datapath):
    """
    This function loads all the CodeSearchNet data in memory for train, valid, and test
    datapath: String: the path leading to the train, valid, test folders. Ex: "python/final/jsonl" -> /test /train /valid
    """
    def jsonl_dir_to_data(path):
        data_samples = []
        files_list = os.listdir(path)
        files_list.sort()
        for file_name in  tqdm.tqdm(files_list):
            file_path = os.path.join(path, file_name)
            with gzip.GzipFile(file_path, 'r') as fin:
                data = jsonlines.Reader(fin)
                for line in data.iter():
                    data_samples.append(line)
        return data_samples
             
    train_data = jsonl_dir_to_data(os.path.join(datapath, "train"))
    valid_data = jsonl_dir_to_data(os.path.join(datapath, "valid"))
    test_data = jsonl_dir_to_data(os.path.join(datapath, "test"))
    
    return train_data, valid_data, test_data
    
train_data, valid_data, test_data = load_CSN_data("/nfs/code_search_net_archive/python/final/jsonl/")

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [5]:
full_sample_lookup = {}
for split in [train_data, valid_data, test_data]:
    for sample in split:
        full_sample_lookup[sample["url"].replace(" ","%20")] = sample

In [None]:
with open('/nfs/code_search_net_archive/python/final/jsonl/full_CSN_dictionary.json', 'r') as f:
    full_sample_lookup = json.load(f)

In [8]:
with open('/nfs/code_search_net_archive/python/final/jsonl/full_CSN_dictionary.json', 'w') as f:
    json.dump(full_sample_lookup, f)

**TrecRun format**

qid Q0 docno rank score tag

where:  
- **qid**	is the query number
- **Q0**	is the literal Q0
- **docno**	is the id of a document returned for qid
- **rank**	(1-999) is the rank of this response for this qid
- **score**	is a system-dependent indication of the quality of the response
- **tag**	is the identifier for the system

Example:  
1 Q0 nhslo3844_12_012186 1 1.73315273652 mySystem  
1 Q0 nhslo1393_12_003292 2 1.72581054377 mySystem  
1 Q0 nhslo3844_12_002212 3 1.72522727817 mySystem  
1 Q0 nhslo3844_12_012182 4 1.72522727817 mySystem  
1 Q0 nhslo1393_12_003296 5 1.71374426875 mySystem  

**TrecQrel format**

qid 0 docno relevance  

where:  
- **qid**	is the query number
- **0**	is the literal 0
- **docno**	is the id of a document in your collection
- **relevance**	is how relevant is docno for qid

Example:  
1	0	aldf.1864_12_000027	1  
1	0	aller1867_12_000032	2  
1	0	aller1868_12_000012	0  
1	0	aller1871_12_000640	1  
1	0	arthr0949_12_000945	0  
1	0	arthr0949_12_000974	1  

**TrecRes format**

label qid value

where:  
- **label**	is any string, usually representing a metric
- **qid**	is the query number or 'all' to represent a aggregate value
- **value**	is numeral result of a metric

In [51]:
def create_qrel_file(data, target_file_path):
    '''
    data: [{k:v,...}], contains 
    '''
    with open(target_file_path, "w") as qrel_f:
        for sample in data:
            qrel_f.write(f"{sample['url'].replace(' ','%20')} 0 {sample['url'].replace(' ','%20')} 1\n")

create_qrel_file(valid_data, "/nfs/code_search_net_archive/python/final/jsonl/valid.qrel")

In [52]:
def create_run_file(data, target_file_path, hit_length=1000):
    '''
    data: [{k:v,...}], contains 
    '''
    assert len(data) >= hit_length
    
    run_array = []
    print("Creating runs")
    for i in tqdm.tqdm(range(len(data))):
        arr = list(range(len(data)))
        arr.pop(i)
        distractor_doc_indexes = random.sample(arr, hit_length-1)
        run_array.append([i]+distractor_doc_indexes)
    
    with open(target_file_path, "w") as run_f:
        for i in tqdm.tqdm(range(len(run_array))):
            url_list = [data[j]['url'].replace(" ","%20") for j in run_array[i]]
            full_str=""
            for k, url in enumerate(url_list):
                full_str += f"{data[i]['url'].replace(' ','%20')} Q0 {url} {k} 0.99 first_sample\n"
            run_f.write(full_str)
                
create_run_file(valid_data, "/nfs/code_search_net_archive/python/final/jsonl/valid.run")

Creating runs


HBox(children=(FloatProgress(value=0.0, max=23107.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23107.0), HTML(value='')))




In [53]:
def dummy_scorer(q_doc_pairs):
    '''
    q_doc_pairs: [(query, doc)]
    '''
#     print(q_doc_pairs[0])
    return list(np.random.uniform(low=0.0, high=1.0, size=(len(q_doc_pairs),)))

In [5]:
def identity_scorer(q_doc_pairs):
    return list(range(len(q_doc_pairs)))

In [3]:
def calculate_MRR(qrel_file, run_file, lookup, ranking_fn, chunk_size=1000):
    '''
    data: [(query, true_doc)]
    ranking_fn: a function that can take in a list of [(query, doc)] and return a [score] in that same order.
    
    The index of the data in the array is going to be the id used. The first element in each subsequent 
    array corresponding to a query will be the corresponding ground truth with the rest till the hit_length be 
    randomly sampled ids from the data.
    It is assumed the data is ordered and continuous following TREC format.
    
    >>> qrel_file = "/nfs/code_search_net_archive/python/final/jsonl/valid.qrel"
    >>> run_file = "/nfs/code_search_net_archive/python/final/jsonl/valid.run"
    >>> calculate_MRR(qrel_file, run_file, full_sample_lookup, dummy_scorer)
    '''    
    # {(qid: docid)}
    qrel_lookup = {}
    num_lines = sum(1 for line in open(qrel_file))
    with open(qrel_file, "r") as q_rel_f:
        print("getting qrels")
        pbar = tqdm.tqdm(q_rel_f, total=num_lines)
        for line in pbar:
            split_line = line.strip().split()
            qrel_lookup[split_line[0]] = split_line[2]
    
    MRR_scores = []
    num_lines = sum(1 for line in open(run_file))
    with open(run_file, "r") as run_f:
        print("getting runs")
        query_chunk = []
        pbar = tqdm.tqdm(run_f, total=num_lines)
        for line in pbar:
            split_line = line.strip().split()
            query_chunk.append((split_line[0],split_line[2]))
            if len(query_chunk) >= chunk_size:
                #pocess
                idx_relevant_doc = query_chunk.index((query_chunk[0][0], qrel_lookup[query_chunk[0][0]]))
#                 query_doc_pairs = [(lookup[q_id]['docstring_tokens'], 
#                                     lookup[d_id]['code'].replace(lookup[query_chunk[0][1]]['docstring'],"")) for q_id, d_id in query_chunk]
                query_doc_pairs = query_chunk
                scores = ranking_fn(query_doc_pairs)
                relevant_doc_score = scores.pop(idx_relevant_doc)
                rank = sum([1 for s in scores if s>relevant_doc_score])
                MRR_scores.append(1.0/(rank+1))
                query_chunk = []
        
    return np.average(MRR_scores), MRR_scores

In [67]:
qrel_file = "/nfs/code_search_net_archive/python/final/jsonl/valid.qrel"
run_file = "/nfs/code_search_net_archive/python/final/jsonl/valid.run"
score, _ = calculate_MRR(qrel_file, run_file, full_sample_lookup, dummy_scorer)
score

getting qrels


HBox(children=(FloatProgress(value=0.0, max=23107.0), HTML(value='')))


getting runs


HBox(children=(FloatProgress(value=0.0, max=23107000.0), HTML(value='')))




0.007508186952385621

In [6]:
qrel_file = "/nfs/phd_by_carlos/notebooks/datasets/test_100.qrels"
run_file = "/nfs/phd_by_carlos/notebooks/datasets/test_100.run"
score, _ = calculate_MRR(qrel_file, run_file, {}, identity_scorer)
score

getting qrels


HBox(children=(FloatProgress(value=0.0, max=6192.0), HTML(value='')))


getting runs


HBox(children=(FloatProgress(value=0.0, max=225156.0), HTML(value='')))

ValueError: ('enwiki:Antibiotics', 'a1039843710fd87981e2a2f4175422b792ce7d82') is not in list

In [None]:
%debug

> [0;32m<ipython-input-3-80d23062df23>[0m(36)[0;36mcalculate_MRR[0;34m()[0m
[0;32m     34 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mquery_chunk[0m[0;34m)[0m [0;34m>=[0m [0mchunk_size[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     35 [0;31m                [0;31m#pocess[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 36 [0;31m                [0midx_relevant_doc[0m [0;34m=[0m [0mquery_chunk[0m[0;34m.[0m[0mindex[0m[0;34m([0m[0;34m([0m[0mquery_chunk[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m,[0m [0mqrel_lookup[0m[0;34m[[0m[0mquery_chunk[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m][0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m[0;31m#                 query_doc_pairs = [(lookup[q_id]['docstring_tokens'],[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     38 [0;31m[0;31m#                              

ipdb>  query_chunk


[('enwiki:Antibiotics', '232f55315d6a5c26ffc91083240e1a449d808a46'), ('enwiki:Antibiotics', 'edc624f6c22ed13953b4ad63e14ed16df5b63eff'), ('enwiki:Antibiotics', '3d1072c8efcbaeffcd9a4cb10e0a3a76a25d889a'), ('enwiki:Antibiotics', '8f4e75922aec3ff3b30b29f1c4e1ca8ca5a667c8'), ('enwiki:Antibiotics', '2cd19a548932bdce56570d8f9ebf3db9a4580d58'), ('enwiki:Antibiotics', '73e7e72ae422f96f34d7bb5d8ec38cf15cccaeb3'), ('enwiki:Antibiotics', 'baa2bab77bdeb466ce9a1966fc638fdbe25824e3'), ('enwiki:Antibiotics', '9887fef5148ac26de93db1930f744dcff0246ffa'), ('enwiki:Antibiotics', '7ec65b5bd31eb12ec646ae85d650bf5e868f5429'), ('enwiki:Antibiotics', '375b32354b02293152c322216cdb5f658bfc1d10'), ('enwiki:Antibiotics', '6477218b58c23ae322fdb6190ebf5c33eccc96f8'), ('enwiki:Antibiotics', '9ef99ba093759948a49b44a186ddd60ca98931ad'), ('enwiki:Antibiotics', '45166f5fb8ebf8bf02de2a1193615a4d93af83ba'), ('enwiki:Antibiotics', '76dbb24b39fac12a7be9486164f45f9a12fc8993'), ('enwiki:Antibiotics', '2cfcbe2c507ca31605cb19f