<a href="https://colab.research.google.com/github/markNZed/GPT-NeoX-Colab/blob/main/notebooks/codecompletion_benchmark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
try:
  import google.colab
  DOCKER = False
except:
  DOCKER = True
print(DOCKER)

True


In [4]:
# We could modify these paths to "stub" behavior for test/dev
workspaceDir = "/content"
GPTNeoXColabDirName = "GPT-NeoX-Colab"
if DOCKER:
    GPTNeoXColabDir = f"/workspace"
else:
    GPTNeoXColabDir = f"{workspaceDir}/{GPTNeoXColabDirName}"

# Clone CodeXGLUE Repo

In [5]:
# Not using this at but for a final sanity check we should use the data and evaluate.py from here
#%cd {workspaceDir}
#!git clone --depth 1 https://github.com/microsoft/CodeXGLUE.git

In [6]:
%%time
#@title Clone GPT-NeoX-Colab
if DOCKER:
    %cd {GPTNeoXColabDir}
else:
    %cd {workspaceDir}
    # Don't use --depth 1 because that does not play nice with git-annex
    !git clone https://github.com/markNZed/GPT-NeoX-Colab.git
    %cd {GPTNeoXColabDir}
    %pip install -q -r requirements_colab.txt
    %pip install -q .

/workspace
CPU times: user 2.04 ms, sys: 1 ms, total: 3.05 ms
Wall time: 2.1 ms


In [7]:
%cd {GPTNeoXColabDir}
from dotenv import load_dotenv
import os
load_dotenv(f"{GPTNeoXColabDir}/.env")
import GPTNeoXColab
GPTNeoXColab.utils.colab.fetch_data("data/codecompletion/token_completion.tar.gz")
%cd {GPTNeoXColabDir}/data/codecompletion
if not os.path.exists(f"data/codecompletion/token_completion"):
    !tar -xzf token_completion.tar.gz
%cd {GPTNeoXColabDir}
GPTNeoXColab.utils.colab.fetch_data("models/codecompletion/global_step7000_HF.tar.gz")
%cd {GPTNeoXColabDir}/models/codecompletion
if not os.path.exists(f"latest"):
    !tar -xzf global_step7000_HF.tar.gz
    !mv global_step7000_HF latest

/workspace
Data retrieval successful.
/workspace/data/codecompletion
/workspace
Data retrieval successful.
/workspace/models/codecompletion


# Using Byte-Pair Encoding Tokenizer

In [8]:
%cd {GPTNeoXColabDir}/models/codecompletion/latest
if not os.path.exists("vocab.json"):
    !wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json
    !mv gpt2-vocab.json vocab.json
if not os.path.exists("merges.txt"):
    !wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt
    !mv gpt2-merges.txt merges.txt

/workspace/models/codecompletion/latest


In [53]:
from torch.utils.data import Dataset
import gc

class EvalDataset(Dataset):
    def __init__(self, tokenizer, args, logger, file_type='train', seq_length=1024):
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        cached_file = os.path.join(args.output_dir, file_type+"_blocksize_%d"%(seq_length))
        if os.path.exists(cached_file) and not args.overwrite_cache:
            with open(cached_file, 'rb') as handle:
                self.inputs = pickle.load(handle)

        else:
            self.inputs = []

            datafile = os.path.join(args.data_dir, f"{file_type}.txt")
            with open(datafile) as f:
                data = f.readlines()

            length = len(data)
            logger.info("Data size: %d"%(length))
            input_ids = []
            for idx,x in enumerate(data):
                x = x.strip()
                if x.startswith("<s>") and x.endswith("</s>"):
                    pass
                else:
                    x = "<s> " + x + " </s>"
                try:
                    input_ids.extend(tokenizer.encode(x))
                except Exception:
                    pass
                if idx % (length//10) == 0:
                    percent = idx / (length//10) * 10
                    logger.info("load %d"%(percent))
                if args.max_eval_length is not None and (idx + 1) == args.max_eval_length:
                    logger.info(f"max eval length reached at {idx}")
                    break
            del data
            gc.collect()

            logger.info(f"tokens: {len(input_ids)}")
            self.split(input_ids, tokenizer, logger, seq_length=seq_length)
            del input_ids
            gc.collect()

            with open(cached_file, 'wb') as handle:
                pickle.dump(self.inputs, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def split(self, input_ids, tokenizer, logger, seq_length=1024):
        sample = []
        i = 0
        while i < len(input_ids):
            sample = input_ids[i: i+seq_length]
            if len(sample) == seq_length:
                for j in range(seq_length):
                    if tokenizer.convert_ids_to_tokens(sample[seq_length-1-j])[0] == '\u0120' or tokenizer.convert_ids_to_tokens(sample[seq_length-1-j]).startswith("<NUM_LIT"):
                        break
                    if sample[seq_length-1-j] in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id]:
                        if sample[seq_length-1-j] != tokenizer.bos_token_id:
                            j -= 1
                        break
                if j == seq_length-1:
                    print(tokenizer.decode(sample))
                    exit()
                sample = sample[: seq_length-1-j]
            # print(len(sample))
            i += len(sample)
            pad_len = seq_length-len(sample)
            sample += [tokenizer.pad_token_id]*pad_len
            self.inputs.append(sample)

            if len(self.inputs) % 10000 == 0:
                logger.info(f"{len(self.inputs)} samples")


    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, item):
        return torch.tensor(self.inputs[item])


In [63]:
import logging
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler
from transformers import GPTNeoXForCausalLM, GPT2Tokenizer
from types import SimpleNamespace
import os
import pickle

def decode_token_ids(token_ids, tokenizer):
    """
    Convert token IDs to a string of code, handling special tokens, spacing, and literals.
    """
    decoded_code = ""
    for token_id in token_ids:
        token = tokenizer.convert_ids_to_tokens(token_id)

        # Handle tokens with a space prefix (e.g., '\u0120')
        if token.startswith('\u0120'):
            if not decoded_code.endswith(" "):  # Avoid double spaces
                decoded_code += " "
            decoded_code += token[1:]  # Remove the space marker
        # Handle special tokens (bos, eos, sep, pad)
        elif token_id in [
            tokenizer.bos_token_id,
            tokenizer.eos_token_id,
            tokenizer.sep_token_id,
            tokenizer.pad_token_id
        ]:
            decoded_code += " " + token + " "  # Add spaces around special tokens
        # Handle literals (e.g., <NUM_LIT>, <STR_LIT>)
        elif token.startswith("<NUM_LIT") or token.startswith("<STR_LIT"):
            decoded_code += " " + token + " "
        # Handle regular tokens
        else:
            decoded_code += token

    # Strip any leading/trailing spaces from the final decoded string
    return decoded_code.strip()

def eval_acc(args, model, tokenizer, file_type='test'):
    """
    Evaluate the model’s token-level code completion accuracy.
    """
    # Load evaluation dataset
    eval_dataset = EvalDataset(tokenizer, args, logger, file_type=file_type, seq_length=args.seq_length)
    eval_dataloader = DataLoader(eval_dataset, sampler=SequentialSampler(eval_dataset), batch_size=args.eval_batch_size)
    model.to(args.device)
    model.eval()

    # Initialize counters for accuracy
    total_correct, total_predictions = 0, 0
    total_pred_tokens, total_gt_tokens = [], []

    # Iterate through batches in the evaluation dataset
    for step, batch in enumerate(eval_dataloader):
        inputs = batch.to(args.device)

        # no_grad because only inference
        with torch.no_grad():
            outputs = model(inputs)
            predicted_token_ids = outputs.logits.argmax(-1)

        # Move from the GPU to CPU (if GPU is being used)
        pred_ids = predicted_token_ids.cpu()
        gt_ids = inputs.cpu()

        # Process predictions and ground truths
        all_pred = []
        all_gt = []
        for pred_seq, gt_seq in zip(pred_ids, gt_ids):
            pred_seq = pred_seq.tolist()
            gt_seq = gt_seq.tolist()

            # Arrays that can store multiple "sub-tokens"
            # The model may tokenize into smaller tokens than the benchmark uses
            now_pred = []
            now_gt = []
            for i, (pred_id, gt_id) in enumerate(zip(pred_seq, gt_seq)):
                gt_token = tokenizer.convert_ids_to_tokens(gt_id)
                pred_token = tokenizer.convert_ids_to_tokens(pred_id)

                if i == 0:
                    if gt_id in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]:
                        now_gt = [gt_id]
                        # These tokens are excluded from accuracy metrics so insert a placeholder in now_pred
                        now_pred = [0]
                        all_pred.append(decode_token_ids(now_pred, tokenizer).strip().split()[0])
                        all_gt.append(decode_token_ids(now_gt, tokenizer).strip())
                        now_gt = []
                        now_pred = []
                    else:
                        # The prediction is the next token after the ground_truth so we do not use it
                        now_gt = [gt_id]
                        now_pred = [0]
                else:
                    # \u0120 special char indicates the start of a new token
                    if gt_token.startswith('\u0120'):
                        # Check not empty because it can be reset to empty
                        if len(now_gt) > 0:
                            try:
                                # only the first word of the decoded string is appended to all_pred
                                all_pred.append(decode_token_ids(now_pred, tokenizer).strip().split()[0])
                            except IndexError:
                                all_pred.append("<SPACE>")
                            all_gt.append(decode_token_ids(now_gt, tokenizer).strip())
                            now_gt = []
                            now_pred = []
                    # We are at a gt_token boundary
                    if gt_id in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id] \
                    or gt_token.startswith("<NUM_LIT") or gt_token.startswith("<STR_LIT"):
                        if len(now_gt) > 0:
                            try:
                                # only the first word of the decoded string is appended to all_pred
                                all_pred.append(decode_token_ids(now_pred, tokenizer).strip().split()[0])
                            except IndexError:
                                all_pred.append("<SPACE>")
                            all_gt.append(decode_token_ids(now_gt, tokenizer).strip())
                        now_gt = [gt_id]
                        now_pred = [pred_seq[i-1]] # Because prediction is one token ahead of gt
                        try:
                            all_pred.append(decode_token_ids(now_pred, tokenizer).strip().split()[0])
                        except IndexError:
                            all_pred.append("<SPACE>")
                        all_gt.append(decode_token_ids(now_gt, tokenizer).strip())
                        now_gt = []
                        now_pred = []
                        continue
                    
                    now_gt.append(gt_id)
                    now_pred.append(pred_seq[i-1]) # Because prediction is one token ahead of gt

        assert len(all_pred) == len(all_gt)

        total_pred_tokens.extend(all_pred)
        total_gt_tokens.extend(all_gt)

        # Calculate batch accuracy
        for pred_token, gt_token in zip(all_pred, all_gt):
            if gt_token not in ["<s>", "</s>", "<EOL>", "<pad>", ""]:
                total_predictions += 1
                if pred_token == gt_token:
                    total_correct += 1
                    #logger.info(f"Match {total_correct}/{total_predictions} {pred_token} == '{gt_token}'")
                else:
                    #logger.info(f"Mismatch {total_correct}/{total_predictions} {pred_token} != '{gt_token}'")
                    pass

        # Logging progress
        if step % args.logging_steps == 0:
            accuracy = total_correct / total_predictions if total_predictions > 0 else 0
            logger.info(f"Step {step} processed with cumulative accuracy: {accuracy:.2%}")

    # Final accuracy calculation
    accuracy = total_correct / total_predictions if total_predictions > 0 else 0
    logger.info(f"Final Test Accuracy: {accuracy:.2%}")

    # Call post_process to generate predictions.txt and answers.txt
    pred_file = os.path.join(args.output_dir, "predictions.txt")
    gt_file = os.path.join(args.output_dir, "answers.txt")
    true_texts = open(os.path.join(args.data_dir, f"{file_type}.txt")).readlines()
    total_samples = post_process(total_pred_tokens, total_gt_tokens, true_texts, pred_file, gt_file)
    logger.info(f"Evaluated on {total_samples} samples, saved predictions at {pred_file} and ground truths at {gt_file}")


    return total_predictions, total_correct

def post_process(preds, gts, true_gts, pred_file_path, gt_file_path):
    """
    Save the post-processed predictions and ground truths, and verify with the expected true ground truths.

    Args:
        preds: List of predicted tokens from the model.
        gts: List of ground truth tokens for each prediction.
        true_gts: List of full ground truth sequences for each input, used for verification.
        pred_file_path: Path to the file where the processed predictions will be saved.
        gt_file_path: Path to the file where the processed ground truths will be saved.

    Returns:
        int: The count of sequences processed and saved.
    """
    with open(pred_file_path, "w") as pred_file, open(gt_file_path, "w") as gt_file:
        count = 0
        new_gt = []
        new_pred = []

        for pred, gt in zip(preds, gts):
            if gt in ["", "<pad>"]:
                continue
            new_gt.append(gt)
            # Spaces are used to separate the tokens in prediction.txt
            # So we remove extrs spaces
            new_pred.append(pred.replace(" ", ""))

            if gt == "</s>":
                gt_str = " ".join(new_gt)
                pred_str = " ".join(new_pred)
                if gt_str != true_gts[count].strip():
                    print(f"gt_str   {gt_str}")
                    print(f"true_gts {true_gts[count].strip()}")
                    raise Exception(f"Sample {count} mismatch between ground truth and expected text")
                assert gt_str == true_gts[count].strip(), f"Sample {count} mismatch between ground truth and expected text"
                pred_file.write(pred_str + "\n")
                gt_file.write(gt_str + "\n")
                count += 1
                new_gt = []
                new_pred = []

    return count


In [57]:
import logging

# Reset the root logger to avoid inherited handlers
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Configure your specific logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Check if logger already has handlers; clear them if necessary
if logger.hasHandlers():
    logger.handlers.clear()

# Create a StreamHandler to ensure logging messages appear in the notebook
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)

# Set a formatter for consistent log message formatting
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s -   %(message)s', datefmt='%m/%d/%Y %H:%M:%S')
stream_handler.setFormatter(formatter)

# Add the StreamHandler to the logger
logger.addHandler(stream_handler)


In [64]:
import humanize

pretrained_model_path = f"{GPTNeoXColabDir}/models/codecompletion/latest"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set up evaluation arguments
args = {
    "logging_steps": 1,
    "output_dir": f"{GPTNeoXColabDir}/out",
    "data_dir": f"{GPTNeoXColabDir}/data/codecompletion/token_completion",
    "device": device,
    "seq_length": 2048,
    "max_eval_length": 10,
    "overwrite_cache": True,
    "eval_batch_size": 1,
}

# Wrap args dictionary in a namespace to allow dot notation
args = SimpleNamespace(**args)

# Set random seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# Load model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_path, sep_token='<EOL>', bos_token='<s>', eos_token='</s>', pad_token='<pad>')
model = GPTNeoXForCausalLM.from_pretrained(pretrained_model_path)
model.resize_token_embeddings(len(tokenizer))

total_params = sum(p.numel() for p in model.parameters())
readable_params = humanize.intword(total_params) 
logger.info(f"Model has {readable_params} trainable parameters")

# Evaluate model
total_predictions, total_correct = eval_acc(args, model, tokenizer, 'test')
accuracy = total_correct / total_predictions if total_predictions > 0 else 0
logger.info(f"Test accuracy: {accuracy:.2%}")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPTNeoXTokenizerFast'. 
The class this function is called from is 'GPT2Tokenizer'.
11/16/2024 00:04:14 - INFO - __main__ -   Model has 44.6 million trainable parameters
11/16/2024 00:04:14 - INFO - __main__ -   Data size: 50000
11/16/2024 00:04:14 - INFO - __main__ -   load 0
11/16/2024 00:04:15 - INFO - __main__ -   max eval length reached at 9
11/16/2024 00:04:15 - INFO - __main__ -   tokens: 20751
11/16/2024 00:04:18 - INFO - __main__ -   Step 0 processed with cumulative accuracy: 34.67%
11/16/2024 00:04:21 - INFO - __main__ -   Step 1 processed with cumulative accuracy: 37.67%
11/16/2024 00:04:24 - INFO - __main__ -   Step 2 processed with cumulative accuracy: 37.95%
11/16/2024 00:04:26 - INFO - __main__ -   Step 3 processed with cumulative accuracy: 35.33%
11/16/2024 00:04:

In [65]:
#@title Run evaluator.py on the generated files
evaluator_script = f"{GPTNeoXColabDir}/scripts/evaluator.py"
answers_file = f"{GPTNeoXColabDir}/out/answers.txt"
predictions_file = f"{GPTNeoXColabDir}/out/predictions.txt"
!python {evaluator_script} --answers {answers_file} --predictions {predictions_file}

INFO:__main__:Total 9412 tokens, accuracy: 34.57
