In [None]:
# !pip install datasets
!pip install transformers

In [1]:
import copy
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence

import torch
import transformers
# import utils
from torch.utils.data import Dataset
from transformers import Trainer

############################
from datetime import datetime
import os
# import deepspeed
import json, socket
import argparse
from datasets import load_dataset
import datasets
from functools import partial
from typing import Dict, Optional, Sequence, Any
import time

In [10]:
fileslist = {'train': '1_AM_wiki_00'}
train_data = load_dataset('json', data_files=fileslist, split='train')
# train_data = load_dataset('json', data_files=fileslist, split='train', streaming=True)

In [11]:

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

def preprocess(
        sources: Sequence[str],
        targets: Sequence[str],
        tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
        super(SupervisedDataset, self).__init__()
        logging.warning("Loading data...")
        # list_data_dict = utils.jload(data_path)
        
#         flist = get_local_files_list(data_path)
        
#         print('-------DD-------DATALOAD1: ', time.time())
#         train_data = load_dataset('json', data_files=flist, split='train')
        list_data_dict = train_data.to_list()
        # print('-------DD-------DATALOAD2: ', time.time())
            
        logging.warning("Formatting inputs...")
        # prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
        # sources = [
        #     prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
        #     for example in list_data_dict
        # ]
        # targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
        
        txtinput = "\n{title}\n\n\t\t{text}\n\n"

        # example = list_data_dict

        sources = [
            f"{tokenizer.bos_token}"
        ]
        targets = [txtinput.format_map(example) if example.get("title", "") != "" else example.get("text", "") for example in list_data_dict]
        
        print('---DDD---',len(sources),len(targets))
        
        logging.warning("Tokenizing inputs... This may take some time...")
        data_dict = preprocess(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        # print('-------DD-------DATALOAD3: ', time.time())

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])



In [12]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    'TheBloke/Llama-2-7B-fp16',
    cache_dir='',
    model_max_length=2048,
    padding_side="right",
    use_fast=False,
    legacy=False ############ p5 test
)

tokenizer.add_special_tokens({'pad_token': '[PAD]'})


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        
        
        
        
        
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path='')
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)





---DDD--- 1 790


In [13]:
next(iter(train_dataset))


{'input_ids': tensor([    1,     1,    13, 30630,   231,   187,   192,   232,   165,   134,
         30967,   313, 31756,   235,   193,   148, 29897,    13,    13,    12,
            12, 30630,   231,   187,   192,   232,   165,   134, 30967,   313,
         31756,   235,   193,   148, 29897,    13,    13, 30866, 30630,   231,
           187,   192,   232,   165,   134, 30967, 30843, 30419, 30409, 30392,
         30630, 30356, 31441, 30732, 31173, 30880,   235,   140,   193,   235,
           145,   140, 30602,   231,   189,   157, 30064,   232,   138,   178,
         31824, 30210, 30622, 31304, 31328, 31283, 30941,   232,   177,   167,
         31756,   235,   193,   148, 30214, 30909, 29906, 29900, 29896, 29953,
         30470, 29896, 29896, 30534, 29946, 30325,   236,   131,   146, 31138,
         29934,  5454,   232,   151,   180, 31122, 30910, 30448, 30267, 31756,
           235,   193,   148, 31062,   232,   138,   137, 30845, 30210, 31688,
         31541, 31166, 31467, 30866, 30