In [1]:
from collections import Counter, defaultdict
from functools import partial
import math

from datasets import Dataset, load_from_disk
from tqdm import tqdm

In [2]:
dataset_dir = 'data/python/train'
dataset = load_from_disk(dataset_dir, keep_in_memory=False)

In [3]:
def gen_project_data(dataset):
  """
  Generator that yields a dictionary for each project in the dataset.
  Each dictionary contains the project name, a list of file names, a
  list of file contents, and a unique ID.
  """

  # First, make a map of projects to their indices in the dataset
  print('Building project to indices map...')
  
  project_names = list(dataset['max_stars_repo_name'])
  project_sizes = Counter(project_names)

  # Now, for each project, get the indices of the examples
  # that are in that project

  batch_size = 10000

  print('Building project data...')

  partial_project_dict = {}
  project_idx = 0

  # Break the dataset into batches
  n_batches = int(math.ceil(len(dataset) / batch_size))
  for i in tqdm(range(n_batches), total=n_batches):
    batch = dataset[i * batch_size : (i + 1) * batch_size]

    # Loop through each sample in the batch
    for sample_idx in range(len(batch['max_stars_repo_name'])):

      # Create a new entry for the sample's project if it doesn't exist
      project_name = batch['max_stars_repo_name'][sample_idx]
      if project_name not in partial_project_dict:
        partial_project_dict[project_name] = {
          'project_name': project_name,
          'file_names': [],
          'file_contents': [],
          'id': project_idx
        }
        project_idx += 1

      # Update the project with info from the new sample
      partial_project_dict[project_name]['file_names'].append(
        batch['max_stars_repo_path'][sample_idx])
      partial_project_dict[project_name]['file_contents'].append(
        batch['content'][sample_idx])
      
      # If all the project files have been added, yield the project
      if len(partial_project_dict[project_name]['file_names']) == project_sizes[project_name]:
        yield partial_project_dict[project_name]
        del partial_project_dict[project_name]

  # Sanity check that all projects have been yielded
  assert len(partial_project_dict) == 0, 'Not all projects have been yielded!'

In [72]:
len(dataset)

12866649

In [57]:
%timeit dataset.select(range(10000))[:]

47.8 ms ± 562 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
%timeit dataset[:10000]

45.4 ms ± 813 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [53]:
%timeit dataset.select(range(0, 1000000, 100))['content'][:]

94.6 ms ± 4.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
project_dataset = Dataset.from_generator(
  partial(gen_project_data, dataset),
)

Downloading and preparing dataset generator/default to /home/edan/.cache/huggingface/datasets/generator/default-2e606bd0fdd0be85/0.0.0...


Generating train split: 0 examples [00:00, ? examples/s]

Building project to indices map...
Building project data...


100%|██████████| 1287/1287 [1:22:55<00:00,  3.87s/it]


Dataset generator downloaded and prepared to /home/edan/.cache/huggingface/datasets/generator/default-2e606bd0fdd0be85/0.0.0. Subsequent calls will reuse this data.


In [5]:
# Save to disk
project_dataset.save_to_disk('data/python/projects')

Saving the dataset (0/123 shards):   0%|          | 0/1678572 [00:00<?, ? examples/s]

In [4]:
counts = Counter(dataset['max_stars_repo_name'])

# Print in order
for i, (name, count) in enumerate(counts.most_common()):
    print(f'{name}: {count}')
    if i > 20:
        break

ckamtsikis/cmssw: 7440
tefra/xsdata-w3c-tests: 4986
jnthn/intellij-community: 3997
tdiprima/code: 3716
kkcookies99/UAST: 3662
osoco/better-ways-of-thinking-about-software: 3107
antopen/alipay-sdk-python-all: 3072
usegalaxy-no/usegalaxy: 2800
webdevhub42/Lambda: 2683
kagemeka/atcoder-submissions: 2615
jjhenkel/dockerizeme: 2509
Amourspirit/ooo_uno_tmpl: 2482
harshp8l/deep-learning-lang-detection: 2377
ch1huizong/learning: 2207
lukaszlaszuk/insightconnect-plugins: 2160
enthought/etsproxy: 2150
scottwedge/OpenStack-Stein: 2034
google-cloud-sdk-unofficial/google-cloud-sdk: 2001
prorevizor/noc: 1966
the-zebulan/CodeWars: 1947
balmasea/genieparser: 1932
MrDelik/core: 1893


In [12]:
import numpy as np

np.median(list(counts.values()))

3.0

In [None]:
class ProjectDataset(Dataset):
  """A Dataset organized by projects."""

  def __init__(
    self,
    arrow_table: Table,
    info: Optional[DatasetInfo] = None,
    split: Optional[NamedSplit] = None,
    indices_table: Optional[Table] = None,
    fingerprint: Optional[str] = None,
  )