In [4]:
%load_ext autoreload
%autoreload 2

In [6]:
from base_data_collector import get_files
from custom_data_collector import SimpleSplitDataSampler

In [12]:
py_files = get_files(
    dir='../data/SDF-JEPA-main', 
    extension='.py', 
    filter_regex=r'^(?!import\b).*\S.*$',
    min_lines=15,
)

In [15]:
sampler = SimpleSplitDataSampler(
    py_files, 
    300,
    300,
    300,
    splitters=['.', ',', '(', ')', ' '],
)

In [18]:
dataset = sampler.sample(20, strategy='finish_line', strategy_kwargs={'max_tries': 5})

In [23]:
dataset.iloc[3]

filename        ../data/SDF-JEPA-main/app/main_distributed.py
prefix      # Copyright (c) Meta Platforms, Inc. and affil...
middle      (\n    help='yaml file containing config file ...
suffix          default='configs.yaml')\nparser.add_argume...
meta                                              finish_line
Name: 3, dtype: object

In [2]:
# Create dataset

import numpy as np
import pandas as pd
import re


num_examples = 350
splitters = ['.', ',', '(', ')', ' ']
pattern = '|'.join(map(re.escape, splitters))

dataset = {
    'filename': np.random.choice(py_files, num_examples, replace=False),
    'prefix': [],
    'middle': [],
    'suffix': [],
}

for filename in dataset['filename']:
    with open(filename, 'r') as file:
        lines = file.readlines()

    matches = []
    while len(matches) == 0:
        cursor_line = np.random.randint(len(lines) - 2)
        matches = [match.start() for match in re.finditer(pattern, lines[cursor_line])]

    cursor_pos = np.random.choice(matches, 1)[0]

    prefix = ''.join(lines[:cursor_line + 1]) + lines[cursor_line][:cursor_pos]
    middle = lines[cursor_line][cursor_pos:] + lines[cursor_line + 1]
    suffix = ''.join(lines[cursor_line + 2:])

    dataset['prefix'].append(prefix)
    dataset['middle'].append(middle)
    dataset['suffix'].append(suffix)

dataset = pd.DataFrame(dataset)

In [3]:
# pip install -q transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = 'bigcode/tiny_starcoder_py'
device = 'cuda' # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)



In [4]:
from tqdm.notebook import tqdm
import torch 

In [5]:
dataset = dataset.sample(50)

In [6]:
max_prefix_len = 300
max_suffix_len = 300
max_middle_len = 100

dataset['query'] = (
    '<fim_prefix>' + dataset['prefix'].str[-max_prefix_len:] + 
    '<fim_suffix>' + dataset['suffix'].str[:max_suffix_len] + 
    '<fim_middle>'
)

In [7]:
tokenizer.all_special_tokens

['<|endoftext|>',
 '<fim_prefix>',
 '<fim_middle>',
 '<fim_suffix>',
 '<fim_pad>',
 '<filename>',
 '<gh_stars>',
 '<issue_start>',
 '<issue_comment>',
 '<issue_closed>',
 '<jupyter_start>',
 '<jupyter_text>',
 '<jupyter_code>',
 '<jupyter_output>',
 '<empty_output>',
 '<commit_before>',
 '<commit_msg>',
 '<commit_after>',
 '<reponame>']

In [7]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
inputs, masks = tokenizer(dataset['query'].to_list(), padding=True, return_tensors='pt').values()
inputs = inputs.to(device)
masks = masks.to(device)

outputs = model.generate(inputs, attention_mask=masks, max_length=max_prefix_len + max_suffix_len + max_middle_len)

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


In [9]:
outputs

tensor([[    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,    81,  1172,    81],
        [    0,     0,     0,  ...,   347,    35,  6935],
        ...,
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,   645, 12643,    32],
        [    0,     0,     0,  ...,   280,   313,  2958]], device='cuda:0')

In [24]:
outputs_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [25]:
outputs_text[0].find('<fim_middle>')

-1

In [26]:
print(dataset['query'].iloc[0])

<fim_prefix>from itertools import product
from string import ascii_lowercase

import numpy as np
import pytest

from pandas import (
from pandas import <fim_suffix>    Index,
    MultiIndex,
    Period,
    Series,
    Timedelta,
    Timestamp,
    date_range,
)
import pandas._testing as tm


class TestCounting:
    def test_cumcount(self):
        df = DataFrame([["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"])
        g = df.groupby("A")
        sg = g.A

<fim_middle>


In [27]:
print(outputs_text[0])

from itertools import product
from string import ascii_lowercase

import numpy as np
import pytest

from pandas import (
from pandas import     Index,
    MultiIndex,
    Period,
    Series,
    Timedelta,
    Timestamp,
    date_range,
)
import pandas._testing as tm


class TestCounting:
    def test_cumcount(self):
        df = DataFrame([["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"])
        g = df.groupby("A")
        sg = g.A

DataFrame,

