In [None]:
import glob
import os
import wandb
import copy
from tqdm import tqdm
from datasets import load_dataset
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model

In [None]:
os.environ["WANDB_ENTITY"] = "reviewco"
os.environ["WANDB_PROJECT"] = "Autocompletion with evaluation"
os.environ["WANDB_USERNAME"] = "keisuke-kamata"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_WATCH"] = "gradients"

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

In [None]:
data = load_dataset("databricks/databricks-dolly-15k",split="train")
data_train_test = data.train_test_split(test_size=0.2)
data_train_valid = data_train_test["train"]
data_train_valid = data_train_valid.train_test_split(test_size=0.2)
data_train = data_train_valid["train"]
data_valid = data_train_valid["test"]
data_test = data_train_test["test"]

In [None]:
PROMPT_NO_INPUT_FORMAT = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response"""

PROMPT_WITH_INPUT_FORMAT = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
Input:
{context}
### Response"""

In [None]:
class InstructDataset(Dataset):
    def __init__(self, json_list, tokenizer, ignore_index=-100):
        self.tokenizer = tokenizer
        self.ignore_index = ignore_index
        self.features = []
        
        for j in tqdm(json_list):
            # In cases like open_qa where context information is not necessary, there is no input column.
            # Therefore, we differentiate the template sentences based on whether the input column is present or not.
            if 'context' in j:
                source_text = PROMPT_WITH_INPUT_FORMAT.format_map(j)
            else:
                source_text = PROMPT_NO_INPUT_FORMAT.format_map(j)
            
            # Combine the instruction sentence and the response sentence, and insert an EOS token at the end
            example_text = source_text + j['response'] + self.tokenizer.eos_token
            
            # okenize only the instruction sentence (up to 'The following is a task to ~### Response:')
            # What we want is the length of the instruction sentence.
            source_tokenized = self.tokenizer(
                source_text,
                padding='longest',
                truncation=True,
                max_length=512,
                return_length=True,
                return_tensors='pt'
            )
            
            # Tokenize both the instruction sentence and the response sentence
            example_tokenized = self.tokenizer(
                example_text, 
                padding='longest', 
                truncation=True, 
                max_length=512, 
                return_tensors='pt'
            )
            
            input_ids = example_tokenized['input_ids'][0]
            
            # Copy the input sentence as is to be the correct answer that the LLM generates.
            labels = copy.deepcopy(input_ids)
            
            # Length up to the instruction sentence
            source_len = source_tokenized['length'][0]
            
            # Since the desired correct sentence for the LLM to generate also includes the instruction sentence,
            # we fill the section of the instruction sentence with -100 as IGNORE_INDEX to avoid calculating the CrossEntropyLoss.
            labels[:source_len] = self.ignore_index
            
            self.features.append({
                'input_ids': input_ids,
                'labels': labels
            })
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx]
        
class InstructCollator():
    def __init__(self, tokenizer, ignore_index=-100):
        self.tokenizer = tokenizer
        self.ignore_index = -100

    def __call__(self, examples):
        input_batch = []
        label_batch = []
        for example in examples:
            input_batch.append(example['input_ids'])
            label_batch.append(example['labels'])
        
        input_ids = pad_sequence(
            input_batch, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        # labelsのpaddingトークンは先程と同様にignore_indexである-100で埋める
        labels = pad_sequence(
            label_batch, batch_first=True, padding_value=self.ignore_index
        )

        # attention_maskはbool値でもいいらしい
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
            
        return {
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': attention_mask
        }