In [1]:
from dataclasses import dataclass, field
import json
import math
import pathlib
from typing import Dict, Optional, Sequence

import numpy as np
import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother

from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template

from tokenizer import Tokenizer

# Training
import os
import time

from model import Transformer, ModelArgs
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from train import TrainingArgs, process_checkpoints
from pathlib import Path
from contextlib import nullcontext
from datetime import datetime
from tqdm import tqdm
from functools import partial

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


@dataclass
class ModelArguments:
    checkpoint: Optional[str] = field(default="little-checkpoints/ckpt-best-98500.pt")

@dataclass
class DataArgs:
    data_path: str = field(
        default="data/clean_alpaca_data.json", metadata={"help": "Path to the training data."}
    )
    eval_data_path: str = field(
        default=None, metadata={"help": "Path to the evaluation data."}
    )
    lazy_preprocess: bool = False


local_rank = None

def rank0_print(*args):
    if local_rank == 0:
        print(*args)


def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    conv = get_conversation_template("vicuna")
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    input_ids = tokenizer(
        conversations,
        return_tensors="pt",
        padding="max_length",
        max_length=2048,
        truncation=True,
    ).input_ids
    targets = input_ids.clone()

    assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO

    # Mask targets. Only compute loss on the assistant outputs.
    sep = conv.sep + conv.roles[1] + ": "
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        turns = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_TOKEN_ID
        for i, turn in enumerate(turns):
            if turn == "":
                break
            turn_len = len(tokenizer(turn).input_ids)

            parts = turn.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep
            # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
            instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            # Ignore the user instructions
            target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
            cur_len += turn_len

        target[cur_len:] = IGNORE_TOKEN_ID

        if False:  # Inspect and check the correctness of masking
            z = target.clone()
            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
            rank0_print(tokenizer.decode(z))

        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_TOKEN_ID
                rank0_print(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                    f" (ignored)"
                )

    return dict(
        input_ids=input_ids,
        labels=targets,
    )

class PretokDataset(torch.utils.data.IterableDataset):
    """Loads pretokenized examples from disk and yields them as PyTorch tensors."""

    def __init__(self, data_path, split, tokenizer: transformers.PreTrainedTokenizer):
        super().__init__()

        rank0_print("Formatting inputs...")
        raw_data = json.load(open(data_path, "r"))
        sources = [example["conversations"] for example in raw_data]
        data_dict = preprocess(sources, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

        self.split = split


    def __iter__(self):
        # train/test split
        # np.random.seed(0)
        # perm = np.random.permutation(len(self.input_ids))
        perm = np.arange(len(self.input_ids))
        split = int(len(perm) * 0.98)
        train_indices = perm[:split]
        eval_indices = perm[split:]

        indices = train_indices if self.split == "train" else eval_indices

        while True:
            for i in indices:
                x = self.input_ids[i]
                y = self.labels[i]

                yield x, y
            
class Task:
    @staticmethod
    def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
        rank0_print("Loading data...")
        ds = PretokDataset(**dataset_kwargs)
        dl = torch.utils.data.DataLoader(
            ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
        )
        for x, y in dl:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            yield x, y

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
training_args = TrainingArgs(out_dir="finetune_ckpt", wandb_log=False)
data_args = DataArgs()
tokenizer = transformers.AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="meta-llama/Llama-2-7b-chat-hf",
    model_max_length=training_args.max_seq_len,
    padding_side="right",
    use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
# task-specific setup
iter_batches = partial(
    Task.iter_batches,
    batch_size=training_args.max_batch_size,
    device=training_args.device,
    num_workers=0,
    data_path=data_args.data_path,
    tokenizer=tokenizer,
)

In [5]:
batch_iter = iter_batches(split="train")
X, Y = next(batch_iter)

print(X)

tensor([[    1,   319, 13563,  ...,     0,     0,     0],
        [    1,   319, 13563,  ...,     0,     0,     0]], device='cuda:0')


In [6]:
X.shape

torch.Size([2, 2048])

In [8]:
tokenizer.decode(X[0,:])

"<s> A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Give three tips for staying healthy. ASSISTANT: 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.</s><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><un

In [10]:
X, Y = next(batch_iter)
tokenizer.decode(X[0,:])

IndexError: piece id is out of range.

In [13]:
print(Y)

tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100]], device='cuda:0')


In [17]:
print(torch.sum(Y[0,:] != -100).item())

160


In [19]:
indices = torch.nonzero(Y[0] != -100)

print(indices)

tensor([[ 52],
        [ 53],
        [ 54],
        [ 55],
        [ 56],
        [ 57],
        [ 58],
        [ 59],
        [ 60],
        [ 61],
        [ 62],
        [ 63],
        [ 64],
        [ 65],
        [ 66],
        [ 67],
        [ 68],
        [ 69],
        [ 70],
        [ 71],
        [ 72],
        [ 73],
        [ 74],
        [ 75],
        [ 76],
        [ 77],
        [ 78],
        [ 79],
        [ 80],
        [ 81],
        [ 82],
        [ 83],
        [ 84],
        [ 85],
        [ 86],
        [ 87],
        [ 88],
        [ 89],
        [ 90],
        [ 91],
        [ 92],
        [ 93],
        [ 94],
        [ 95],
        [ 96],
        [ 97],
        [ 98],
        [ 99],
        [100],
        [101],
        [102],
        [103],
        [104],
        [105],
        [106],
        [107],
        [108],
        [109],
        [110],
        [111],
        [112],
        [113],
        [114],
        [115],
        [116],
        [117],
        [1

In [23]:
print(tokenizer.decode(Y[0,52:212]))

I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client’s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team’s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client’s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities.</s>


In [21]:
print(tokenizer.decode(X[0]))

<s> A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Describe a time when you had to make a difficult decision. ASSISTANT: I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client’s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team’s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client’s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities.</s><unk><unk><

In [24]:
print(IGNORE_TOKEN_ID)

-100
