# Train a GitHub copilot-style code completion model

This trains a minimum viable character-level code completion model, similar to GitHub's copilot. The code in this notebook is based on Andrej Karpathy's `play_char` notebook, trained to complete Shakespeare test.

A pretrained model is available [here](https://drive.google.com/file/d/1K_P0PYJBjanq8YTAzS8FF_kflcT9sOlI/view?usp=sharing), trained on my 8GB GTX 1070 for about 2 days. It spits out almost valid code, although isn't coming for anyone's job just yet. 

In [1]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [2]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [3]:
from pathlib import Path

from tqdm import tqdm
import numpy as np
import random
import json
from pprint import pprint
import torch
import torch.nn as nn
from torch.nn import functional as F

In [4]:
"""Uncomment to download and extract data. Only necessary if retraining."""
# !mkdir play_copilot_data
# !wget https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/python.zip -P play_copilot_data/
# !unzip play_copilot_data/python.zip -d play_copilot_data
# !gzip -d -r play_copilot_data/python/final

'Uncomment to download and extract data. Only necessary if retraining.'

In [5]:
python_files_train = list(Path('play_copilot_data/python/final/jsonl/train/').glob('*.jsonl'))
python_files_train += list(Path('play_copilot_data/python/final/jsonl/valid/').glob('*.jsonl'))
python_files_test = list(Path('play_copilot_data/python/final/jsonl/test/').glob('*.jsonl'))

def load_files(json_files):
    """Load raw data into list for training."""
    data = []
    for f in tqdm(json_files, desc='Loading code into memory'):
        with open(f, 'r') as fp:
            file_data = fp.readlines()
        data += [
            str(json.loads(line)['code'].encode('ascii', 'ignore'))  # Drop non-ascii characters
            for line in file_data if len(line) > 100
        ]
    return data

train_data = load_files(python_files_train)
test_data = load_files(python_files_test)

print(f"\n{len(train_data):,} training and {len(test_data):,} testing functions found.")

Loading code into memory: 100%|██████████| 15/15 [00:19<00:00,  1.27s/it]
Loading code into memory: 100%|██████████| 1/1 [00:00<00:00,  1.03it/s]
435,285 training and 22,176 testing functions found.



In [6]:
from torch.utils.data import Dataset


class CopilotDataset(Dataset):
    """Dataset for training a GitHub copilot-style code completion model.

    `data` should be a list of strings where each item is a continuous
    block of Python code, in all ASCII characters. In the data provided 
    in this notebook, each item is a complete function.

    For every sample, a single function is loaded up and a random slice
    is taken. Any padding applied is a null character, so that in the
    future, you could generate data until a null character is returned.

    Each sample is the converted to its ASCII character value, and the
    model predicts a single character at a time.
    """
    def __init__(self, data, block_size):
        self.vocab_size = 128  # Use all ascii characters
        self.block_size = block_size
        self.data = data
        random.shuffle(self.data)

        self.null = '\x00'

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        chunk = self.data[idx]
        chunklen = len(chunk)

        # Get random slice of data, allowing ends to be overlapped
        final_idx = np.random.randint(
            self.block_size // 16, 
            chunklen + self.block_size // 16,
        )

        if final_idx > chunklen:
            # Pad with null if selection is overrun
            dix = chunk + self.null * (final_idx - chunklen)
            first_idx = final_idx - self.block_size - 1
            dix = dix[first_idx:final_idx]

            # If chunk is still too short, add leading spaces
            if len(dix) < self.block_size:
                dix = self.null * (self.block_size - len(dix) + 1) + dix
        
        elif final_idx <= self.block_size:
            # Pad with leading spaces if selection is too short
            dix = chunk[:final_idx + 1]
            dix = self.null * (self.block_size - len(dix) + 1) + dix
  
        elif final_idx > self.block_size:
            first_idx = final_idx - self.block_size - 1
            dix = chunk[first_idx:final_idx]
        
        dix = [ord(s) for s in dix]

        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        
        return x, y

In [7]:
block_size = 512
train_dataset = CopilotDataset(train_data, block_size)
test_dataset = CopilotDataset(test_data, block_size)

# Load pretrained model
# Checkpoint available at https://drive.google.com/file/d/1K_P0PYJBjanq8YTAzS8FF_kflcT9sOlI/view?usp=sharing
checkpoint_weights = 'play_copilot_ckpt_trn_0.4793_tst_0.4076.pt'

In [8]:
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)

if checkpoint_weights is not None:
    model.load_state_dict(torch.load(checkpoint_weights))
    model.eval()

07/17/2021 14:30:06 - INFO - mingpt.model -   number of parameters: 2.561331e+07


In [9]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(
    max_epochs=10, 
    batch_size=12, 
    learning_rate=6e-4,
    lr_decay=True, 
    warmup_tokens=512*20, 
    final_tokens=2*len(train_dataset)*block_size,
    num_workers=4,
    ckpt_path='play_copilot_checkpoint.pt',
)
trainer = Trainer(model, train_dataset, test_dataset, tconf)

retrain = False
if retrain:
    trainer.train()

In [10]:
from mingpt.utils import sample

def sample_model(context, n_characters=500, temperature=1):
    x = torch.tensor([ord(s) for s in context], dtype=torch.long)[None, ...].to(trainer.device)
    y = sample(model, x, n_characters, temperature=1, sample=True, top_k=10)[0]
    completion = ''.join([chr(i) for i in y])
    completion = completion.replace('\\n', '\n')

    return completion

In [11]:
print(sample_model('def multiply(x, y):\n    """Multiply two numbers together."""\n    ', temperature=0))

def multiply(x, y):
    """Multiply two numbers together."""
                    """
        if x < 0:
            return
        if y > 0:
            temp = x                                             """
            temp = x - 1
            y = y - 1.0
            single = int(x ** 2 + 1)
            for ind in temp:
                if len(temp) % 10 == 0:
                    single = -temp[ind+1]
                    temp = y * (single - 1) % 10
                    y = y                                   """
                    single = y - 


In [12]:
print(sample_model('def add(a, b):', temperature=0))

def add(a, b):
    """Add an array to a sequence of arrays b to the collection. """

    sequence = []
    for i in range(num_percent_sequences):
        a_id = a << b
        if a_id is not None:
            sequence.append(sequence[i] * 2)
            a_id = a_id
            sequence.append(a_id)
        a_array = a_id
        a_array.append(num_array)
        if not a_array:
            a_array = a_array.replace(\'\\t\', \'\\t\')
        #if not a_array:
        #                            


In [13]:
print(sample_model('    x = np.linspace(-10, 10, 1000)\n    y = np.sin(x) / x\n', temperature=1))


    x = np.linspace(-10, 10, 1000)
    y = np.sin(x) / x
      # multiply this slice (or
            (x.shape[1] >= node[\'colormap\']).sum())
        else:
            # slice the original slice before the slice was not set.
            # It computes the single slice to set the sort of nodes,
            # but this can be set up to 1000 seconds.
            slice = (slice(x.shape[1]) - node[\'shape\']).sum()
            x.start()
            if slice is not None:
                self.__slice_slices[slice] = slice.sort_keys == slice'         
