In [1]:
import copy
from dataclasses import dataclass, field
import json
import pathlib
from typing import Dict, Optional, Sequence

import numpy as np
import torch
from torch.utils.data import Dataset
import transformers
from transformers.trainer_pt_utils import LabelSmoother #código para evitar overconfidence no modelo
from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import json

from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = 'helloollel/vicuna-7b'

model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name
    )
model.config.use_cache = False
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.85s/it]


In [3]:
tokenizer.model_max_length

1000000000000000019884624838656

In [21]:
local_rank = None

def rank0_print(*args):
    if local_rank==0:
        print(*args)

In [22]:
source = json.load(open("dummy.json"))

In [23]:
new_source = []
for i in source:
    new_conv = []
    for j in i['conversations']:
        if j['from'] == 'human':
            new_conv.append({"from":"client", "value":j['value']})
        else:
            new_conv.append({"from":"agent", "value":j['value']})
    new_source.append(new_conv)
    

In [30]:
tokenizer.pad_token = tokenizer.eos_token
preprocess(new_source,tokenizer)

SeparatorStyle.ADD_COLON_TWO


{'input_ids': tensor([[    1,   319, 13563,  ...,     2,     2,     2],
         [    1,   319, 13563,  ...,     2,     2,     2],
         [    1,   319, 13563,  ...,     2,     2,     2],
         ...,
         [    1,   319, 13563,  ...,     2,     2,     2],
         [    1,   319, 13563,  ...,     2,     2,     2],
         [    1,   319, 13563,  ...,     2,     2,     2]]),
 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         ...,
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100]]),
 'attention_mask': tensor([[ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         ...,
         [ True,  True,  True,  ..., False, False, False],
     

In [29]:
def preprocess(sources,
               tokenizer: transformers.PreTrainedTokenizer,
               ) -> Dict:
    """Preprocesses the data into a format suitable for training."""
    conv = get_conversation_template("vicuna")
    roles = {"client":conv.roles[0],"agent":conv.roles[1]}

    #Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
             # Skip the first one if it is not from client
             source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j%2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    input_ids = tokenizer(
        conversations,
        return_tensors="pt",
        padding="max_length",
        max_length=256,
        truncation=True,
    ).input_ids
    targets = input_ids.clone()
    
    assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
    
    #Mask targets
    sep = conv.sep + conv.roles[1] + ": "
    for conversation, target in zip(conversations,targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        rounds = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_TOKEN_ID

        for i, rou in enumerate(rounds):
            if rou =="":
                break
            
            parts = rou.split(sep)
            if len(parts) != 2:
                break
            
            parts[0] +=sep
            round_len = len(tokenizer(rou).input_ids)
            instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
            
            cur_len += round_len
        target[cur_len:] = IGNORE_TOKEN_ID

        if False:
            z = target.clone()
            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id,z)
            rank0_print(tokenizer.decode(z))
        
        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_TOKEN_ID
                rank0_print(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}"
                    f"(ignored)"
                )
    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask = input_ids.ne(tokenizer.pad_token_id),
    )