In [1]:
import os
import pickle
import logging
import gc
from torch.utils.data import Dataset
import torch

logger = logging.getLogger("cur_loger")

class EvalDataset(Dataset):
    def __init__(self, tokenizer, output_dir, data_dir,  logger, file_type='test', block_size=512):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        cached_file = os.path.join(output_dir, file_type+"_blocksize_%d"%(block_size))
        if os.path.exists(cached_file):
            with open(cached_file, 'rb') as handle:
                self.inputs = pickle.load(handle)

        else:
            self.inputs = []

            datafile = os.path.join(data_dir, f"{file_type}.txt")
            with open(datafile) as f:
                data = f.readlines()

            length = len(data)
            logger.info("Data size: %d"%(length))
            input_ids = []
            for idx,x in enumerate(data):
                x = x.strip()
                if x.startswith("<s>") and x.endswith("</s>"):
                    pass
                else:
                    x = "<s> " + x + " </s>"
                try:
                    input_ids.extend(tokenizer.encode(x))
                except Exception:
                    pass
                if idx % (length//10) == 0:
                    percent = idx / (length//10) * 10
                    logger.warning("load %d"%(percent))
            del data
            gc.collect()

            logger.info(f"tokens: {len(input_ids)}")
            self.split(input_ids, tokenizer, logger, block_size=block_size)
            del input_ids
            gc.collect()

            with open(cached_file, 'wb') as handle:
                pickle.dump(self.inputs, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    def split(self, input_ids, tokenizer, logger, block_size=1024):
        sample = []
        i = 0
        while i < len(input_ids):
            sample = input_ids[i: i+block_size]
            if len(sample) == block_size:
                for j in range(block_size):
                    if tokenizer.convert_ids_to_tokens(sample[block_size-1-j])[0] == '\u0120' or tokenizer.convert_ids_to_tokens(sample[block_size-1-j]).startswith("<NUM_LIT"):
                        break
                    if sample[block_size-1-j] in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id]:
                        if sample[block_size-1-j] != tokenizer.bos_token_id:
                            j -= 1
                        break
                if j == block_size-1:
                    print(tokenizer.decode(sample))
                    exit()
                sample = sample[: block_size-1-j]
            # print(len(sample))
            i += len(sample)
            pad_len = block_size-len(sample)
            sample += [tokenizer.pad_token_id]*pad_len
            self.inputs.append(sample)

            if len(self.inputs) % 10000 == 0:
                logger.info(f"{len(self.inputs)} samples")


    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, item):
        return torch.tensor(self.inputs[item])

In [2]:
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm.notebook import tqdm

distributed = False
n_gpu = 1
local_rank = -1
gpu_per_node = -1
per_gpu_eval_batch_size = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging_steps = 10
logger = logging.getLogger("cur_logger")


def post_process(preds, gts, true_gts, saved_file):
    wf = open(saved_file, "w")

    cnt = 0
    new_gt = []
    new_pred = []
    for i, (pred,gt) in enumerate(zip(preds,gts)):
        if gt in ["", "<pad>"]:
            continue
        new_gt.append(gt)
        new_pred.append(pred.replace(" ", ""))
        if gt == "</s>":
            gt_str = " ".join(new_gt)
            pred_str = " ".join(new_pred)
            assert gt_str == true_gts[cnt].strip(), f"{cnt} sample gt_str != true_gt"
            wf.write(pred_str+"\n")
            cnt += 1
            new_gt = []
            new_pred = []
    
    return cnt

def eval_acc(model, tokenizer, output_dir, data_dir, file_type='test'):
    eval_dataset = EvalDataset(
                            tokenizer=tokenizer,
                            output_dir=output_dir,
                            data_dir=data_dir,
                            logger=logger,
                          )


    eval_batch_size = per_gpu_eval_batch_size * max(1, n_gpu)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size)
    model.to(device)

    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    def DecodeIds(idxs):
        codes = ""
        for idx in idxs:
            to_add = tokenizer.convert_ids_to_tokens(idx)
            if tokenizer.convert_ids_to_tokens(idx)[0] == '\u0120':
                if not codes.endswith(" "):
                    codes += " " + to_add[1:]
                else:
                    codes += to_add[1:]
            elif (
                idx in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id] or
                tokenizer.convert_ids_to_tokens(idx).startswith("<NUM_LIT")
            ):
                codes += " " + to_add + " "
            else:
                codes += to_add
        return codes.strip(" ")
    

    model.eval()

    correct = 0.0
    total = 0

    total_pred = []
    total_gt = []

    for step, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
        inputs = batch.to(device)

        with torch.no_grad():
            outputs = model(inputs)
            pred_scores = outputs[0]
            pred_ids = pred_scores.argmax(-1)

        all_pred = []
        all_gt = []
        prev_pred = None
        for pred, gt in zip(pred_ids, inputs):
            pred = pred.cpu().tolist()
            gt = gt.cpu().tolist()

            for i, y in enumerate(gt):
                if i == 0:
                    if y in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]:
                        now_gt = [y]
                        now_pred = [0] if prev_pred is None else [prev_pred]
                        all_pred.append(DecodeIds(now_pred).strip().split()[0])
                        all_gt.append(DecodeIds(now_gt).strip())
                        now_gt = []
                        now_pred = []
                    else:
                        now_gt = [y]
                        now_pred = [0] if prev_pred is None else [prev_pred]
                else:
                    if tokenizer.convert_ids_to_tokens(y)[0] == '\u0120':
                        if len(now_gt) > 0:
                            try:
                                all_pred.append(DecodeIds(now_pred).strip().split()[0])
                            except IndexError:
                                all_pred.append("<SPACE>")
                            all_gt.append(DecodeIds(now_gt).strip())
                            now_gt = []
                            now_pred = []
                    if y in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id] or tokenizer.convert_ids_to_tokens(y).startswith("<NUM_LIT"):
                        if len(now_gt) > 0:
                            try:
                                all_pred.append(DecodeIds(now_pred).strip().split()[0])
                            except IndexError:
                                all_pred.append("<SPACE>")
                            all_gt.append(DecodeIds(now_gt).strip())
                        now_gt = [y]
                        now_pred = [pred[i-1]]
                        try:
                            all_pred.append(DecodeIds(now_pred).strip().split()[0])
                        except IndexError:
                            all_pred.append("<SPACE>")
                        all_gt.append(DecodeIds(now_gt).strip())
                        now_gt = []
                        now_pred = []
                        continue
                    now_gt.append(y)
                    now_pred.append(pred[i-1])
        assert len(all_pred) == len(all_gt)

        total_pred.extend(all_pred)
        total_gt.extend(all_gt)


        for x, y in zip(all_pred, all_gt):
            if y not in ["<s>", "</s>", "<EOL>", "<pad>"]:
                total += 1
                if x == y:
                    correct += 1
        
        if step % logging_steps == 0:
            logger.info(f"{step} are done!")
            logger.info(f"{total}, {correct/total}")

    # pickle.dump(total_pred, open(os.path.join(args.output_dir, "preds.pkl"), "wb"))
    # pickle.dump(total_gt, open(os.path.join(args.output_dir, "gts.pkl"), "wb"))

    saved_file = os.path.join(output_dir, "predictions.txt")
    total_samples = post_process(total_pred, total_gt, open(os.path.join(data_dir, f"{file_type}.txt")).readlines(), saved_file)
    logger.info(f"Eval on {total_samples}, saved at {saved_file}")
    
    return total, correct

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, PhiForCausalLM
import json

torch.set_default_device("cuda")

# python tokenizer

def get_special_tokens(path):
    lits = json.load(open(path))
    tokens = ["<STR_LIT>", "<NUM_LIT>", "<CHAR_LIT>"]
    for lit in lits["str"]:
        tokens.append(f"<STR_LIT:{lit}>")
    for lit in lits["num"]:
        tokens.append(f"<NUM_LIT:{lit}>")
    for lit in lits["char"]:
        tokens.append(f"<CHAR_LIT:{lit}>")
    return tokens

special_tokens = get_special_tokens("CodeXGLUE_python/CodeCompletion-token/dataset/py150/literals.json")

python_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5",
                                            sep_token='<EOL>',
                                            bos_token='<s>',
                                            eos_token='</s>',
                                            pad_token='<pad>',
                                            unk_token='<|UNKNOWN|>',
                                            additional_special_tokens=special_tokens
                                            )

# model
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype="auto")


model.resize_token_embeddings(len(python_tokenizer))

Embedding(50533, 2048)

In [5]:
eval_acc(
    model=model,
    tokenizer=python_tokenizer,
    output_dir="output_dir",
    data_dir="datasets/python",
    )

  0%|          | 0/8209 [00:00<?, ?it/s]

KeyboardInterrupt: 