In [None]:
import collections
import json
import re
from itertools import chain
from multiprocessing import Pool
from pathlib import Path
from typing import *

import dotenv
import git
import pandas as pd
import pydriller
from pymongo import MongoClient

In [None]:
dotenv.load_dotenv()
ENV = dotenv.dotenv_values(".env")
DATA_DIR = Path(ENV["DATA_DIR"])
DATA_DIR, DATA_DIR.exists()

In [None]:
client = MongoClient("localhost", 42692)
db = client.s5_snyk_libio
patch_urls_mongo_collection = db.patchUrls
df = pd.DataFrame(list(patch_urls_mongo_collection.find()))
df

In [None]:
df = df[df['PatchUrls'].map(len) > 0]
df

In [None]:
df_dict = df.to_dict(orient='records')
df_dict

In [None]:
# patch_url -> vuln_url
patch_urls = collections.defaultdict(set)
for r in df_dict:
    for pu in r['PatchUrls']:
        patch_urls[pu].add(r['VulnUrl'])
len(patch_urls.keys())

In [None]:
def get_repo_and_commit(commit_url: str) -> Tuple[str, str]:
    for m in re.finditer(r"(?P<repo>(https://)?(www\.)?github\.com(?:/[^/]+)*)/commit/(?P<hash>[0-9a-f]+)(\#diff.*)?",
                         commit_url):
        return m.group('repo'), m.group('hash')

In [None]:
repo_patches = collections.defaultdict(set)
repo_commit_to_url = dict()
for p in patch_urls.keys():
    repo, commit = get_repo_and_commit(p)
    repo_patches[repo].add(commit)
    repo_commit_to_url[(repo, commit)] = p

repo_patches, len(repo_patches.keys()), len(list(chain(*repo_patches.values())))

In [None]:
repo_path = DATA_DIR / 'interim' / 'repositories'
repo_path.mkdir(exist_ok=True, parents=True)


def get_new_records(repo_url) -> Tuple[List, Dict]:
    error_data = {
        'clone': [],
        'git': [],
        'no_changed_method_gavs': [],
        'no_java_gavs': [],
        'no_source_code': [],
        'is_test': [],
        'traversal_problems': [],
        'merge_commits': [],
        'no_modifications': [],
    }
    new_records = list()
    commit_hashes = repo_patches[repo_url]

    def original_commit_hash(commit_hash: str) -> str:
        return next(x for x in commit_hashes if (x.startswith(commit_hash) or commit_hash.startswith(x)))

    try:
        repo = pydriller.Repository(repo_url, clone_repo_to=repo_path, include_remotes=True, include_deleted_files=True,
                                    include_refs=True)
    except Exception:
        print(f'error cloning {repo_url}')
        error_data['clone'].append(repo_url)
        return [], error_data

    try:
        repo_commits = [c for c in repo.traverse_commits()]
    except git.exc.CommandError as e:
        print(f'git cmd error ({repo_url}): [{type(e)}: {e}]')
        error_data['git'].append(repo_url)
        return [], error_data
    except Exception as e:
        print(f'error traversing commits ({repo_url}): [{type(e)}: {e}]')
        error_data['git'].append(repo_url)
        return [], error_data

    for ch in commit_hashes:
        to_be_traversed = list(c for c in repo_commits if c.hash.startswith(ch) or ch.startswith(c.hash))
        if len(to_be_traversed) != 1:
            error_data['traversal_problems'].append((repo_url, ch))
            print(f'there is a traversal problem for {(repo_url, ch, len(to_be_traversed))}')

        for commit in to_be_traversed:
            if len(commit.modified_files) == 0:
                if commit.merge:
                    print(f'merge commit {(repo_url, commit.hash)}')
                    error_data['merge_commits'].append((repo_url, commit.hash))
                else:
                    print(f'commit has no modified files: {(repo_url, commit.hash)}')
                    error_data['no_modifications'].append((repo_url, commit.hash))

                continue

            modified_java_files = list(mf for mf in commit.modified_files if mf.filename.endswith('.java'))
            if len(modified_java_files) == 0:
                error_data['no_java_gavs'].append((repo_url, commit.hash))
                continue

            modified_java_files = list(mf for mf in modified_java_files if mf.source_code_before is not None)
            if len(modified_java_files) == 0:
                error_data['no_source_code'].append((repo_url, commit.hash))
                continue

            modified_java_files = list(mf for mf in modified_java_files if len(mf.changed_methods) > 0)
            if len(modified_java_files) == 0:
                error_data['no_changed_method_gavs'].append((repo_url, commit.hash))
                continue

            modified_java_files = list(
                mf for mf in modified_java_files if len(re.compile(r'[Tt]est').findall(mf.old_path)) == 0)
            if len(modified_java_files) == 0:
                error_data['is_test'].append((repo_url, commit.hash))
                continue

            

            for mf in modified_java_files:
                try:
                    new_row = dict()
                    new_row['repo'] = repo_url
                    new_row['commitHash'] = ch
                    new_row['snykPatchUrl'] = repo_commit_to_url[(repo_url, original_commit_hash(commit.hash))]
                    new_row['commitHash'] = commit.hash
                    new_row['modifiedFilePathBefore'] = mf.old_path
                    new_row['modifiedFilePathAfter'] = mf.new_path
                    new_row['modifiedFileSrcBefore'] = mf.source_code_before
                    new_row['modifiedFileSrcAfter'] = mf.source_code
                    new_row['diffParsedJson'] = json.dumps(mf.diff_parsed)
                    new_row['nloc'] = mf.nloc
                    new_row['changedMethods'] = list(map(lambda m: m.name, mf.changed_methods))

                    new_records.append(new_row)

                except Exception:
                    print(f'error parsing modified files {repo_url}: {ch} ({mf.filename})')
                    continue
    return new_records, error_data

In [None]:
with Pool(32) as p:
    res = p.map(get_new_records, repo_patches.keys())

In [None]:
new_df = list(chain(*[x[0] for x in res]))
len(new_df), new_df[0].keys()

In [None]:
errors = collections.defaultdict(set)
for e in [r[1] for r in res]:
    for k, v in e.items():
        errors[k].update(v)
for k, v in errors.items():
    print(f'{k}: {len(v)}')

In [None]:
commit_urls = set()
commits = set()
repos = set()
for it in new_df:
    commit_urls.add((it['repo'], it['commitHash']))
    repos.add(it['repo'])
    commits.add(it['commitHash'])
len(repos), len(commit_urls)  # 87, 266 commits

In [None]:
res_df = pd.DataFrame(data=new_df)
res_df  # 1050 files

In [None]:
# db.patchCommitsLibio.insert_many(new_df)