# Подготовка датасета

In [1]:
import datasets
import os
import pandas as pd
import subprocess
import re
from tqdm.notebook import tqdm
from pathlib import Path
import asyncio
import aiopath
import glob

REPOS_DIR = os.getcwd()+'/data/repos'

gvm_root = os.environ['GVM_ROOT']
os.environ['PATH'] = f"{gvm_root}/bin:{gvm_root}/pkgsets/go1.24.2/global/bin:{gvm_root}/gos/go1.24.2/bin:{gvm_root}/pkgsets/go1.24.2/global/overlay/bin:{os.environ['PATH']}"

In [4]:
# Go 7.84M
# Go_Checksums 62.4K
# Go_Module 351K

if not os.path.isdir('data/source_dedup_ds'):
    go_ds = datasets.load_dataset("bigcode/the-stack-v2-dedup", "Go", split="train")
    go_checksums_ds = datasets.load_dataset("bigcode/the-stack-v2-dedup", "Go_Checksums", split="train")
    go_module_ds = datasets.load_dataset("bigcode/the-stack-v2-dedup", "Go_Module", split="train")
    source_ds = datasets.concatenate_datasets([
        go_ds,
        go_checksums_ds,
        go_module_ds,
    ])
    del go_ds, go_checksums_ds, go_module_ds

    source_ds = source_ds.remove_columns([
        'blob_id', 'directory_id', 'content_id', 'detected_licenses', 'license_type',
        'snapshot_id', 'branch_name', 'visit_date', 'revision_date', 'committer_date',
        'github_id', 'gha_license_id', 'gha_event_created_at', 'gha_created_at',
        'gha_language', 'src_encoding', 'language',
    ])
    source_ds = source_ds.sort(["repo_name", "path"])
    source_ds.save_to_disk('data/source_dedup_ds')

source_ds = datasets.load_from_disk('data/source_dedup_ds')

print(len(source_ds))
source_ds.take(5).to_pandas()

8248983


Unnamed: 0,path,repo_name,revision_id,star_events_count,fork_events_count,is_vendor,is_generated,length_bytes,extension,filename
0,/backend/cmd/main.go,0-10V/real-time-metrics,766dc882d779f07821bde740ce49802f67ae42b3,0,0,False,False,118,go,main.go
1,/backend/controllers/controllers.go,0-10V/real-time-metrics,766dc882d779f07821bde740ce49802f67ae42b3,0,0,False,False,1164,go,controllers.go
2,/backend/go.mod,0-10V/real-time-metrics,766dc882d779f07821bde740ce49802f67ae42b3,0,0,False,False,126,mod,go.mod
3,/backend/models/models.go,0-10V/real-time-metrics,766dc882d779f07821bde740ce49802f67ae42b3,0,0,False,False,767,go,models.go
4,/backend/routes/router.go,0-10V/real-time-metrics,766dc882d779f07821bde740ce49802f67ae42b3,0,0,False,False,233,go,router.go


In [21]:
import hashlib

if not os.path.isdir('data/files_ds'):

    def check_file(row):
        download_path = f"{REPOS_DIR}/{row['revision_id']}{row['path']}"    

        try:
            with open(download_path, "rb") as f:
                digest = hashlib.file_digest(f, "sha256")
                hash = digest.hexdigest()

                return {'file_exists': True, "file_hash": hash}
        except OSError:
            return {'file_exists': False, "file_hash": ""}

    files_ds = source_ds.sort(['star_events_count', 'repo_name', 'path'], reverse=[True, False, False]).map(check_file, num_proc=32)

    files_ds.save_to_disk('data/files_ds')

files_ds = datasets.load_from_disk('data/files_ds')

print(files_ds)
files_ds[0]

Dataset({
    features: ['path', 'repo_name', 'revision_id', 'star_events_count', 'fork_events_count', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'file_exists', 'file_hash'],
    num_rows: 8248983
})


{'path': '/tensorflow/go/attrs.go',
 'repo_name': 'tensorflow/tensorflow',
 'revision_id': 'a7f3934a67900720af3d3b15389551483bee50b8',
 'star_events_count': 208740,
 'fork_events_count': 109943,
 'is_vendor': False,
 'is_generated': False,
 'length_bytes': 7032,
 'extension': 'go',
 'filename': 'attrs.go',
 'file_exists': False,
 'file_hash': ''}

### Очистим тестовые файлы в проектах

In [6]:
revisions = files_ds.unique('revision_id')
print(len(revisions))
print(revisions[0])

for revision_id in revisions:
    project_path = REPOS_DIR+'/'+revision_id
    for p in Path(project_path).glob("**/*_test.go"):
        p.unlink()

868880
a7f3934a67900720af3d3b15389551483bee50b8


KeyboardInterrupt: 

In [22]:
filter_stats = {
    "not_exists_count": 0,
    "test_files_count": 0,
    "wrong_ext_count": 0,
}

def filter_files(row) -> bool:
    if not row['file_exists']:
        filter_stats['not_exists_count'] += 1
        return False

    lower_path = row['path'].lower()
    if lower_path.endswith("_test.go"):
        filter_stats['test_files_count'] += 1
        return False
    if not lower_path.endswith(".go") and not lower_path.endswith("/go.mod") and not lower_path.endswith("/go.sum"):
        filter_stats['wrong_ext_count'] += 1
        return False

    return True

ds = files_ds.filter(filter_files, num_proc=1)

print(ds)
filter_stats

Dataset({
    features: ['path', 'repo_name', 'revision_id', 'star_events_count', 'fork_events_count', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'file_exists', 'file_hash'],
    num_rows: 5625886
})


{'not_exists_count': 0, 'test_files_count': 0, 'wrong_ext_count': 0}

In [23]:
projects_by_repo: dict[str, set[str]] = {}

def get_go_package(relative_project_path: str) -> str:
    filename_slash_pos = relative_project_path.rfind('/')
    return relative_project_path[:filename_slash_pos+1]

def collect_projects(row):
    lower_path = row['path'].lower()
    if lower_path.endswith("/go.mod"):
        project = row['path'][:len(row['path'])-6]

        repo_projects = projects_by_repo.get(row['repo_name'], set())
        repo_projects.add(project)
        projects_by_repo[row['repo_name']] = repo_projects

ds.map(collect_projects)

def add_project_column(row):
    lower_path = row['path'].lower()
    if lower_path.endswith("/go.mod"):
        project = row['path'][:len(row['path'])-6]
        return {'project': project, 'relative_project_path': row['path'][len(row['path'])-7:], 'relative_go_package': ''}

    repo_projects = list(projects_by_repo.get(row["repo_name"], set()))
    repo_projects.sort(key=len, reverse=True)

    for repo_project in repo_projects:
        if ((lower_path.endswith('.go') and row['path'].startswith(repo_project)) or
            (lower_path == repo_project.lower()+"go.sum")):

            relative_project_path = row['path'][len(repo_project):]
            return {'project': repo_project, 'relative_project_path': relative_project_path, 'relative_go_package': get_go_package(relative_project_path)}

    return {'project': '', 'relative_project_path': '', 'relative_go_package': ''}

ds = ds.map(add_project_column, num_proc=32)
# .filter(lambda row: row['repo_name'] == 'go-gitea/gitea')

print(ds)
ds[0]
#ds.to_pandas().query('path == "/modules/git/tree.go"')

Map:   0%|          | 0/5625886 [00:00<?, ? examples/s]

Dataset({
    features: ['path', 'repo_name', 'revision_id', 'star_events_count', 'fork_events_count', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'file_exists', 'file_hash', 'project', 'relative_project_path', 'relative_go_package'],
    num_rows: 5625886
})


{'path': '/go.mod',
 'repo_name': 'avelino/awesome-go',
 'revision_id': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85',
 'star_events_count': 112752,
 'fork_events_count': 13739,
 'is_vendor': False,
 'is_generated': False,
 'length_bytes': 603,
 'extension': 'mod',
 'filename': 'go.mod',
 'file_exists': True,
 'file_hash': '0adeb56bf0ed5a68b4e0ac51da9f3d38e6623754e262f9e4cd1a9f7c4c5ce817',
 'project': '/',
 'relative_project_path': '/go.mod',
 'relative_go_package': ''}

### Отфильтруем файлы go без go module

Проекты без go.mod - это legacy проекты, который с большой вероятностью не получится запустить без ручной работы

In [24]:
ds = ds.filter(lambda row: row['project'] != '', num_proc=32)

print(ds)
ds.skip(100)[0]

Dataset({
    features: ['path', 'repo_name', 'revision_id', 'star_events_count', 'fork_events_count', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'file_exists', 'file_hash', 'project', 'relative_project_path', 'relative_go_package'],
    num_rows: 2671142
})


{'path': '/pkg/msg/msg.go',
 'repo_name': 'fatedier/frp',
 'revision_id': 'f1454e91f56508603e4c2e3c7bf37ccb534458c2',
 'star_events_count': 75141,
 'fork_events_count': 14118,
 'is_vendor': False,
 'is_generated': False,
 'length_bytes': 8303,
 'extension': 'go',
 'filename': 'msg.go',
 'file_exists': True,
 'file_hash': '8477b4eb0d5c5aa1197c1eb077a4b791d48952af2aed2acf85f49223c96be333',
 'project': '/',
 'relative_project_path': 'pkg/msg/msg.go',
 'relative_go_package': 'pkg/msg/'}

### Отфильтруем проекты с /generated/ в пути

Это с большой вероятностью означает, что они сгенерированы

In [25]:
ds = ds.filter(lambda row: '/generated/' not in '/'+row['relative_go_package'], num_proc=32)

print(ds)

Dataset({
    features: ['path', 'repo_name', 'revision_id', 'star_events_count', 'fork_events_count', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'file_exists', 'file_hash', 'project', 'relative_project_path', 'relative_go_package'],
    num_rows: 2664663
})


In [2]:
if not os.path.isdir('data/before_files_funcs_ds'):
    ds.save_to_disk('data/before_files_funcs_ds')

ds = datasets.load_from_disk('data/before_files_funcs_ds')

print(ds)
ds[0]

Dataset({
    features: ['path', 'repo_name', 'revision_id', 'star_events_count', 'fork_events_count', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'file_exists', 'file_hash', 'project', 'relative_project_path', 'relative_go_package'],
    num_rows: 2664663
})


{'path': '/go.mod',
 'repo_name': 'avelino/awesome-go',
 'revision_id': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85',
 'star_events_count': 112752,
 'fork_events_count': 13739,
 'is_vendor': False,
 'is_generated': False,
 'length_bytes': 603,
 'extension': 'mod',
 'filename': 'go.mod',
 'file_exists': True,
 'file_hash': '0adeb56bf0ed5a68b4e0ac51da9f3d38e6623754e262f9e4cd1a9f7c4c5ce817',
 'project': '/',
 'relative_project_path': '/go.mod',
 'relative_go_package': ''}

### Отфильтруем файлы без функций

Это могут быть константы, структуры без методов, которые не нужно тестить

In [3]:
if not os.path.isdir('data/files_funcs_ds'):

    def extract_funcs(row):
        download_path = f"{REPOS_DIR}/{row['revision_id']}{row['path']}"

        if row['extension'] != 'go':
            return {'funcs': '', 'methods': ''}

        try:
            with open(download_path, "rb") as f:
                content = f.read().decode('utf-8')
                test_fns: list[str] = re.findall(r'\nfunc\W*(\w+)\W*\(', content)
                test_methods: list[str] = re.findall(r'\nfunc\W*\(\w+ +(\*?\w+)\)\W*(\w+)\W*\(', content)
                test_methods_joined = [t+'.'+m for (t, m) in test_methods]

                return {'funcs': ','.join(test_fns), 'methods': ','.join(test_methods_joined)}
        except Exception as e:
            return {'funcs': '', 'methods': ''}

    ds = ds.map(extract_funcs, num_proc=16)

    ds.save_to_disk('data/files_funcs_ds')

ds = datasets.load_from_disk('data/files_funcs_ds')

print(ds)
ds[0]

Dataset({
    features: ['path', 'repo_name', 'revision_id', 'star_events_count', 'fork_events_count', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'file_exists', 'file_hash', 'project', 'relative_project_path', 'relative_go_package', 'funcs', 'methods'],
    num_rows: 2664663
})


{'path': '/go.mod',
 'repo_name': 'avelino/awesome-go',
 'revision_id': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85',
 'star_events_count': 112752,
 'fork_events_count': 13739,
 'is_vendor': False,
 'is_generated': False,
 'length_bytes': 603,
 'extension': 'mod',
 'filename': 'go.mod',
 'file_exists': True,
 'file_hash': '0adeb56bf0ed5a68b4e0ac51da9f3d38e6623754e262f9e4cd1a9f7c4c5ce817',
 'project': '/',
 'relative_project_path': '/go.mod',
 'relative_go_package': '',
 'funcs': '',
 'methods': ''}

In [6]:
ds = ds.filter(lambda row: row['extension'] != 'go' or row['funcs'] != ''
    # or row['methods'] != '' # Отфильтруем файлы с методами типов. Маленькая вероятность генерации успешных тестов для таких методов, так как нужно обеспечить работу моков
)

print(ds)
ds[0]

Filter:   0%|          | 0/2372969 [00:00<?, ? examples/s]

Dataset({
    features: ['path', 'repo_name', 'revision_id', 'star_events_count', 'fork_events_count', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'file_exists', 'file_hash', 'project', 'relative_project_path', 'relative_go_package', 'funcs', 'methods', 'project_path'],
    num_rows: 2048497
})


{'path': '/go.mod',
 'repo_name': 'avelino/awesome-go',
 'revision_id': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85',
 'star_events_count': 112752,
 'fork_events_count': 13739,
 'is_vendor': False,
 'is_generated': False,
 'length_bytes': 603,
 'extension': 'mod',
 'filename': 'go.mod',
 'file_exists': True,
 'file_hash': '0adeb56bf0ed5a68b4e0ac51da9f3d38e6623754e262f9e4cd1a9f7c4c5ce817',
 'project': '/',
 'relative_project_path': '/go.mod',
 'relative_go_package': '',
 'funcs': '',
 'methods': '',
 'project_path': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85/'}

### Скачивание зависимостей проектов

In [7]:
ds = ds.map(lambda row: {'project_path': row['revision_id']+row['project']})

Map:   0%|          | 0/2048497 [00:00<?, ? examples/s]

Зависимости слишком тяжелый для скачивания во всех репозиториях, возьмем 100k файлов из самых популярных репозиториев

In [8]:
ds = ds.take(100_000)

In [9]:
project_paths = ds.unique('project_path')

print(len(project_paths))
project_paths[0]

2293


'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85/'

In [None]:
project_download_deps_errors: dict[str, str] = {}

concurrency = 32
next_index = 0
last_error = ''
updated_count = 0

async def worker(pbar, i):
    global next_index
    global last_error
    global updated_count
    while next_index < len(project_paths):
        project_path = project_paths[next_index]
        next_index += 1

        go_mod_content = ''
        try:
            go_mod_content = await aiopath.AsyncPath(f"{REPOS_DIR}/{project_path}go.mod").read_text()
        except Exception:
            project_download_deps_errors[project_path] = 'go.mod not found'
            last_error = 'go.mod not found'
            pbar.update(1)
            pbar.set_postfix(errors_count=len(project_download_deps_errors), updated_count=updated_count, last_error=last_error)
            continue

        async for p in aiopath.AsyncPath(f"{REPOS_DIR}/{project_path}").glob("**/*_test.go"):
            await p.unlink()

        update_libs = ['github.com/!azure/azure-sdk-for-go', 'github.com/aws/aws-sdk-go-v2', 'github.com/aws/aws-sdk-go',
                       'google.golang.org/genproto', 'google.golang.org/api', 'google.golang.org/grpc', 'google.golang.org/protobuf']
        presented_libs = [lib for lib in update_libs if lib+' ' in go_mod_content]
        if len(presented_libs) > 0:
            libs_str = ' '.join([lib+'@latest' if lib != 'github.com/!azure/azure-sdk-for-go' else 'github.com/\!azure/azure-sdk-for-go' for lib in presented_libs])
            proc = await asyncio.create_subprocess_shell(
                "go get -u "+libs_str,
                cwd=f"{REPOS_DIR}/{project_path}",
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE)

            stdout, stderr = await proc.communicate()
            
            if proc.returncode != 0:
                project_download_deps_errors[project_path] = 'go get -u: '+stderr.decode()
                last_error = 'go get -u: '+stderr.decode()[:100]

                pbar.update(1)
                pbar.set_postfix(errors_count=len(project_download_deps_errors), updated_count=updated_count, last_error=last_error)

            updated_count += 1
        

        proc = await asyncio.create_subprocess_shell(
            "go mod tidy",
            cwd=f"{REPOS_DIR}/{project_path}",
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE)

        stdout, stderr = await proc.communicate()
        
        if proc.returncode != 0:
            project_download_deps_errors[project_path] = stderr
            last_error = stderr.decode()[:100]

        pbar.update(1)
        pbar.set_postfix(errors_count=len(project_download_deps_errors), updated_count=updated_count, last_error=last_error)



with tqdm(total=len(project_paths)) as pbar:
    features = [worker(pbar, i) for i in range(concurrency)]

    await asyncio.gather(*features)

len(project_download_deps_errors)

  0%|          | 0/2293 [00:00<?, ?it/s]

168

In [16]:
{k: project_download_deps_errors[k] for k in list(project_download_deps_errors)[:10]}

{'bf1c2a0126ed28afecc9a2ff874d29a9941910bb/POSIX/golang/test/': b'go: tencent.com/mmkv@v0.0.0-00010101000000-000000000000 (replaced by ../tencent.com/mmkv): reading ../tencent.com/mmkv/go.mod: open /media/hdd_1/vkr/data/repos/bf1c2a0126ed28afecc9a2ff874d29a9941910bb/POSIX/golang/tencent.com/mmkv/go.mod: no such file or directory\n',
 'e18340618f1bf512aa7327fedd64196ea763719e/tests/apps/gogrpc/': b'go: found google.golang.org/grpc/examples/helloworld/helloworld in google.golang.org/grpc/examples/helloworld/helloworld v0.0.0-00010101000000-000000000000\ngo: github.com/dokku/dokko/tests/apps/gorpc/greeter_client imports\n\tgoogle.golang.org/grpc/examples/helloworld/helloworld: module ./helloworld: reading helloworld/go.mod: open /media/hdd_1/vkr/data/repos/e18340618f1bf512aa7327fedd64196ea763719e/tests/apps/gogrpc/helloworld/go.mod: no such file or directory\n',
 '32c3d0a3ff4f6f5dd2dcfba17d66ef96601202b3/': b'go: finding module for package go.opentelemetry.io/otel/metric/instrument\ngo: f

### Проверка валидности исходных файлов по go пакетам

In [17]:
go_project_packages: set[(str, str)] = set()
for row in ds:
    if row['project_path'] in project_download_deps_errors:
        continue
    go_project_packages.add((row['project_path'], row['relative_go_package']))

go_project_packages = list(go_project_packages)

len(go_project_packages)

28794

In [18]:
compile_project_package_errors: dict[tuple[str, str], str] = {}

concurrency = 32
next_index = 0
last_error = ''

async def worker(pbar, i):
    global next_index
    global last_error
    while next_index < len(project_paths):
        (project_path, relative_go_package) = go_project_packages[next_index]
        next_index += 1

        proc = await asyncio.create_subprocess_shell(
            "go build -o /dev/null ./"+relative_go_package,
            cwd=f"{REPOS_DIR}/{project_path}",
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE)

        stdout, stderr = await proc.communicate()
        
        if proc.returncode != 0:
            compile_project_package_errors[(project_path, relative_go_package)] = stderr
            last_error = stderr.decode()[:100]

        pbar.update(1)
        pbar.set_postfix(errors_count=len(compile_project_package_errors), last_error=last_error)



with tqdm(total=len(project_paths)) as pbar:
    features = [worker(pbar, i) for i in range(concurrency)]

    await asyncio.gather(*features)

len(compile_project_package_errors)
    

  0%|          | 0/2293 [00:00<?, ?it/s]

606

In [19]:
ds = ds.filter(lambda row: row['project_path'] not in project_download_deps_errors and (row['project_path'], row['relative_go_package']) not in compile_project_package_errors)

print(len(ds))
ds[0]

Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]

75169


{'path': '/go.mod',
 'repo_name': 'avelino/awesome-go',
 'revision_id': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85',
 'star_events_count': 112752,
 'fork_events_count': 13739,
 'is_vendor': False,
 'is_generated': False,
 'length_bytes': 603,
 'extension': 'mod',
 'filename': 'go.mod',
 'file_exists': True,
 'file_hash': '0adeb56bf0ed5a68b4e0ac51da9f3d38e6623754e262f9e4cd1a9f7c4c5ce817',
 'project': '/',
 'relative_project_path': '/go.mod',
 'relative_go_package': '',
 'funcs': '',
 'methods': '',
 'project_path': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85/'}

In [20]:
list(compile_project_package_errors.items())[:10]

[(('a9031b7d6aeebb561e2ca9cd4aa2e3fce3b9ae88/',
   'tests/tools/vendor/github.com/go-toolsmith/astequal/'),
  b'main module (github.com/containers/buildah) does not contain package github.com/containers/buildah/tests/tools/vendor/github.com/go-toolsmith/astequal\n'),
 (('f69247ff2478057ddbbc6692fc96ff6520b6a8c1/', 'dcrutil/txsort/'),
  b'main module (github.com/decred/dcrd) does not contain package github.com/decred/dcrd/dcrutil/txsort\n'),
 (('ddc1d8de8b0d8f7fcbe251c2c18ac9ce82750611/tools/', ''),
  b'package github.com/douyu/jupiter-layout/tools: build constraints exclude all Go files in /media/hdd_1/vkr/data/repos/ddc1d8de8b0d8f7fcbe251c2c18ac9ce82750611/tools\n'),
 (('959dce294c0a43b675f80419f7189393221613d4/pkg/init/',
   'vendor/google.golang.org/grpc/internal/transport/'),
  b'go: inconsistent vendoring in /media/hdd_1/vkr/data/repos/959dce294c0a43b675f80419f7189393221613d4/pkg/init:\n\tgithub.com/containerd/containerd@v1.5.0-beta.1: is explicitly required in go.mod, but not mar

In [21]:
test_candidates_ds = ds.filter(lambda row: not row['is_generated']
                               and not row['is_vendor']
                               and row['path'].lower().endswith('.go')
                               and (row['filename'] != 'main.go' or row['funcs'] != 'main' or row['methods'] != ''))

test_candidates_ds.save_to_disk('data/test_candidates_ds')

print(len(test_candidates_ds))
test_candidates_ds[0]

Filter:   0%|          | 0/75169 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/64263 [00:00<?, ? examples/s]

64263


{'path': '/main.go',
 'repo_name': 'avelino/awesome-go',
 'revision_id': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85',
 'star_events_count': 112752,
 'fork_events_count': 13739,
 'is_vendor': False,
 'is_generated': False,
 'length_bytes': 8755,
 'extension': 'go',
 'filename': 'main.go',
 'file_exists': True,
 'file_hash': '105132c75e5c98fa0379971e18dd1bb6544cabc7c5b18a81ae44585a4db25705',
 'project': '/',
 'relative_project_path': 'main.go',
 'relative_go_package': '',
 'funcs': 'main,buildStaticSite,dropCreateDir,mkdirAll,renderCategories,renderSitemap,extractCategories,extractCategory,rewriteLinksInIndex,renderIndex',
 'methods': '',
 'project_path': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85/'}

In [35]:
test_candidates_ds = datasets.load_from_disk('data/test_candidates_ds')

def end_of_block(s: str, openBracketPos: int) -> int:
    if s[openBracketPos] != '{' and s[openBracketPos] != '(' and s[openBracketPos] != '[':
        raise Exception('want openBracketPos in second argument')
    
    quoteOpened = ''
    multilineComment = False
    single_line_comment_end = 0
    brackets = []
    brackets.append(s[openBracketPos])
    
    p = re.compile('[`\'"{}()]|\/\/|\/\*|\*\/')
    for m in p.finditer(s):
        x = m.group()
        pos = m.start()
        if pos <= openBracketPos or pos <= single_line_comment_end:
            continue
        # print(x, pos)
        if multilineComment:
            if x == "*/":
                multilineComment = False
            continue
        if quoteOpened != '':
            if x != quoteOpened:
                continue # inside string
            backslash_count = 0
            while pos-1-backslash_count >= 0:
                if s[pos-1-backslash_count] != '\\':
                    break
                backslash_count += 1
            if backslash_count % 2 == 1:
                continue # quote is escaped
            quoteOpened = ''
            continue
        if x == '//':
            single_line_comment_end = s.find('\n', pos)
            continue
        if x == '/*':
            multilineComment = True
            continue
        if x in '\'"`':
            quoteOpened = x
            continue
        if x in '{[(':
            brackets.append(x)
            continue
        wantBracket = '}' if brackets[-1] == '{' else ']' if brackets[-1] == '[' else ')'
        if x == wantBracket:
            brackets.pop()
            if len(brackets) == 0:
                return pos+1
            continue
        raise Exception(f"invalid syntax, wrong bracket: expected {wantBracket} got {x} at pos={pos}")
    raise Exception('invalid syntax, no end of block')
        

def get_prompt(row) -> str:
    package_path = './data/repos/'+row['project_path']+row['relative_go_package']

    package_files = glob.glob(package_path+'*.go')

    file_path = './data/repos/'+row['project_path']+row['relative_project_path']
    f = open(file_path, 'r')
    file_content = f.read()
    f.close()

    #print('### file', file_content, sep="\n")

    func_body: dict[str, str] = {}
    func_deps: dict[str, list[str]] = {}

    # for fn in row['funcs'].split(','):
    #     if fn == 'main':
    #         continue
            
    #     #p = re.compile('\nfunc\W*'+fn+'\W*\([^)]*\)[^{]*\{')
    #     m = re.search(r'\nfunc\W*'+fn+'\W*\([^)]*\)[^{]*\{', file_content)

    #     if m is None:
    #         continue

    #     fn_start_pos = m.start()
    #     block_start_pos = m.end()-1
    #     print("\n", fn)
    #     block_end_pos = end_of_block(file_content, block_start_pos)
        
    #     func_body[fn] = file_content[fn_start_pos:block_end_pos]

    #     print(fn, file_content[fn_start_pos:block_end_pos])

    for file_path in package_files:
        f = open(file_path, 'r')
        file_content = f.read()
        f.close()

        p = re.compile(r'\nfunc\W*(\w+)\W*\([^)]*\)[^{]*\{')
        for m in p.finditer(file_content):
            fn = m.group(1)
            fn_start_pos = m.start()
            block_start_pos = m.end()-1
            #print("\n", fn)
            block_end_pos = end_of_block(file_content, block_start_pos)
            
            func_body[fn] = file_content[fn_start_pos:block_end_pos]

            #print(fn, file_content[fn_start_pos:block_end_pos])

    for fn, body in func_body.items():
        other_fns = '|'.join([f for f in func_body.keys() if f != fn])
        func_deps[fn] = re.findall(r'\W+('+other_fns+')\(', body)

    print(func_body, func_deps, sep="\n")

    return ''

def finalize_row(row) -> dict:
    return {
        'project_path': row['project_path'],
        'relative_package_path': row['relative_go_package'],
        'relative_file_path': row['relative_project_path'],
        'prompt': get_prompt(row)
    }

ds2 = test_candidates_ds.take(100).map(finalize_row, num_proc=1).select_columns(['project_path', 'relative_package_path', 'relative_file_path', 'prompt'])

#final_ds.save_to_disk('./data/final_ds')

# final_ds
# final_ds[0]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

{'main': '\nfunc main() {\n\tif err := buildStaticSite(); err != nil {\n\t\tpanic(err)\n\t}\n}', 'buildStaticSite': '\nfunc buildStaticSite() error {\n\tif err := dropCreateDir(outDir); err != nil {\n\t\treturn fmt.Errorf("drop-create out dir: %w", err)\n\t}\n\n\tif err := renderIndex(readmePath, outIndexFile); err != nil {\n\t\treturn fmt.Errorf("convert markdown to html: %w", err)\n\t}\n\n\tinput, err := os.ReadFile(outIndexFile)\n\tif err != nil {\n\t\treturn fmt.Errorf("read converted html: %w", err)\n\t}\n\n\tdoc, err := goquery.NewDocumentFromReader(bytes.NewReader(input))\n\tif err != nil {\n\t\treturn fmt.Errorf("create goquery instance: %w", err)\n\t}\n\n\tcategories, err := extractCategories(doc)\n\tif err != nil {\n\t\treturn fmt.Errorf("extract categories: %w", err)\n\t}\n\n\tif err := renderCategories(categories); err != nil {\n\t\treturn fmt.Errorf("render categories: %w", err)\n\t}\n\n\tif err := rewriteLinksInIndex(doc, categories); err != nil {\n\t\treturn fmt.Errorf("

In [22]:
test_candidates_ds = datasets.load_from_disk('data/test_candidates_ds')

system_message = """
You are an expert programmer. 
You should only return output test file containing working code.
The user is going to give you code and would like to have unit tests for the first file.
All the other files are just dependencies to give you context of all the possible test cases to produce.
Cover all possible inputs and their respective outputs using tests.
Each subtest must be wrapped into t.Run
"""

def get_prompt(row) -> str:
    file_path = './data/repos/'+row['project_path']+row['relative_project_path']
    f = open(file_path, 'r')
    file_content = f.read()
    f.close()

    prompt = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": f"{file_content}"}
    ]

    return prompt

def finalize_row(row) -> dict:
    return {
        'project_path': row['project_path'],
        'relative_package_path': row['relative_go_package'],
        'relative_file_path': row['relative_project_path'],
        'prompt': get_prompt(row)
    }

final_ds = test_candidates_ds.map(finalize_row, num_proc=32).select_columns(['project_path', 'relative_package_path', 'relative_file_path', 'prompt'])

final_ds.save_to_disk('./data/final_ds')

final_ds
final_ds[0]

Map (num_proc=32):   0%|          | 0/64263 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/64263 [00:00<?, ? examples/s]

{'project_path': 'c3643eb9da5c673101f8fe15a6deb40bfc4a1c85/',
 'relative_package_path': '',
 'relative_file_path': 'main.go',
 'prompt': [{'content': '\nYou are an expert programmer. \nYou should only return output test file containing working code.\nThe user is going to give you code and would like to have unit tests for the first file.\nAll the other files are just dependencies to give you context of all the possible test cases to produce.\nCover all possible inputs and their respective outputs using tests.\nEach subtest must be wrapped into t.Run\n',
   'role': 'system'},
  {'content': '// Package main contains code for generate static site.\npackage main\n\nimport (\n\t"bytes"\n\t"errors"\n\t"fmt"\n\t"github.com/avelino/awesome-go/pkg/markdown"\n\tcp "github.com/otiai10/copy"\n\ttemplate2 "html/template"\n\t"net/url"\n\t"os"\n\t"path/filepath"\n\t"text/template"\n\n\t"github.com/PuerkitoBio/goquery"\n\t"github.com/avelino/awesome-go/pkg/slug"\n)\n\n// Link contains info about aweso

In [23]:
splitted_ds = final_ds.train_test_split(test_size=0.2)

splitted_ds.save_to_disk('./data/splitted_ds')

splitted_ds

Saving the dataset (0/1 shards):   0%|          | 0/51410 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/12853 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['project_path', 'relative_package_path', 'relative_file_path', 'prompt'],
        num_rows: 51410
    })
    test: Dataset({
        features: ['project_path', 'relative_package_path', 'relative_file_path', 'prompt'],
        num_rows: 12853
    })
})