In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import yaml
import numpy as np
import pandas as pd

from render import RenderParams, get_line_diff_range
import ray_util
import util
import render
import process_pr_events
import cfg

## Ray cluster management 

In [None]:
import ray_server
server = ray_server.get_ray_server()

In [None]:
server.scale_cluster(60)

## Load secrets

In [None]:
with open('secrets.yaml') as f:
    secrets = yaml.safe_load(f)

## Get opt outs

In [None]:
repos_opt_out, users_for_repo_opt_out, users_for_commits_opt_out, users_for_issues_opt_out = util.get_opt_outs(
    src=cfg.opt_outs_dataset_name,
    token=secrets['hf_api_key']
)

## Filter opt outs non permissive licenses and add stats

In [None]:
files = list(cfg.prs_grouped_path.glob('*.parquet'))
dst = cfg.prs_grouped_filtered_path
dst.mkdir(parents=True, exist_ok=True)

In [None]:
rp = RenderParams()
res = util.ray_map(
    process_pr_events.process_pr_bucket,
    files,
    dst=dst,
    repos_opt_out=repos_opt_out,
    users_for_repo_opt_out=users_for_repo_opt_out,
    users_for_issues_opt_out=users_for_issues_opt_out,
    min_desc_length=rp.min_text_size,
    min_title_length=rp.min_title_size
)

In [None]:
ray_util.ray_tasks_progress(res)
res = ray.get(res)

## Add pr count per repo

In [None]:
@ray.remote
def get_df_repo_pr_bucket(file):
    df =  pd.read_parquet(file, columns=['repo_name', 'pull_request.guid'])
    df['bucket'] = file.stem
    return df

@ray.remote
def merge_pr_count_per_repo(data):
    file = data[0]
    df_pr_per_repo = data[1][['pull_request.guid', 'pr_count_per_repo']]
    df = pd.read_parquet(file)
    if 'pr_count_per_repo' in df.columns:
        return 1
    df = df.merge(df_pr_per_repo, on = 'pull_request.guid', how='left')
    util.df_to_parquet_safe(df, file)
    return 0

In [None]:
files = list(cfg.prs_grouped_filtered_path.glob('*.parquet'))
res = util.ray_map(
    get_df_repo_pr_bucket,
    files
)
res = ray.get(res)

In [None]:
res = pd.concat(res)
res['pr_count_per_repo'] = res.groupby('repo_name')['pull_request.guid'].transform('count')
path = cfg.prs_grouped_filtered_path
src = [(path / f'{key}.parquet', group) for key, group in res.groupby('bucket')]

In [None]:
res = util.ray_map(
    merge_pr_count_per_repo,
    src
)
res = ray.get(res)

## Render

In [None]:
pr_files = list(cfg.prs_grouped_filtered_path.glob('*.parquet'))
commits_path  = cfg.pr_commid_pairs_files_filtered_cleaned_grouped_path
dst = cfg.prs_renders_path
dst.mkdir(parents=True, exist_ok=True)

df_bw_lang_list = pd.read_csv('language_labels.csv')
blacklisted_languaged = df_bw_lang_list[df_bw_lang_list['include_final'] == False]

render_params = render.RenderParams()
render_params.subsample_pr_per_repo = True

In [None]:
res = []
for f in rest_files:
    res.append(render.get_renders_for_bucket.remote(
        f, commits_path,
        render_params,
        return_render=False,
        return_lang_distr=False,
        return_data=False,
        base_seed=42,
        seed=int(f.stem, 16),
        dst_file_name=dst/f.name,
        language_blacklist=blacklisted_languaged,
    ))

In [None]:
ray_util.ray_tasks_progress(res)
res = ray.get(res)

## Ray cluster management 

In [None]:
ray.shutdown()
server.scale_cluster(0)