In [None]:
from functools import cache
from adapter_args_helper import (
    DataArguments,
    ModelArguments,
    TrainingArguments
)
from datasets import load_from_disk, load_metric, set_caching_enabled, DatasetDict
from data_utils import load_dataset
from itertools import chain
from models import ClipCaptionModel
from torch.nn.functional import cross_entropy
from tqdm import tqdm
from transformers import (
    default_data_collator,
    get_linear_schedule_with_warmup,
    set_seed,
    AdamW,
    DataCollatorForLanguageModeling,
    GPT2Config,
    GPT2Tokenizer,
    GPT2Model,
    EarlyStoppingCallback,
    HfArgumentParser,
    Trainer
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process

import logging
import math
import numpy as np
import os
import sys
import torch
import transformers
from torch.utils.data import Dataset

from ppcm_models.pytorch_pretrained_bert.modeling_adapter import GPT2LMHeadModel, GPT2Config
from utils.helper import load_model_recursive

set_caching_enabled(True)
logger = logging.getLogger(__name__)

In [None]:
class DataArguments():
    def __init__(self):
        self.dataset_path = '/home/bryan/datasets/bookcorpusopen/bookcorpusopen_chunked.arrow'
        self.bookcorpusopen_story_column_name = 'chunk'
        self.preprocessing_num_workers = 8
        self.genre='Romance'
        self.adapter_id=1
        self.match_up_to_n_genres=3
        self.sample_row=None
        
class ModelArguments():
    def __init__(self):
        self.model_size = 'medium'
        self.load_checkpoint_adapter = ""
        self.max_seq_len=512
        # self.lr = 2e-4 #, help="Learning rate")

class TrainingArguments(TrainingArguments):
    def __init__(self):
        self.output_dir = "./save"
        self.eval_accumulation_steps = None
        
model_args = ModelArguments()
data_args = DataArguments()
training_args = TrainingArguments()

In [None]:
model_args.model_path = f'ppcm_models/dialoGPT/{model_args.model_size}/'

config = GPT2Config.from_json_file(os.path.join(model_args.model_path, 'config.json'))
tokenizer = GPT2Tokenizer.from_pretrained(model_args.model_path)

## Load either Adapters' checkpoint, or just finetuned DialoGPT
if(model_args.load_checkpoint_adapter != ""):
    print("Loading ADAPTERS")
    model = load_model_recursive(GPT2LMHeadModel(config), model_args.load_checkpoint_adapter, model_args, verbose=True)
else:
    model = load_model_recursive(GPT2LMHeadModel(config), model_args.model_path+f"{model_args.model_size}_ft.pkl", model_args, verbose=True)

## Load GPT2 instead of DialoGPT

pt_gpt2_model = GPT2Model.from_pretrained('gpt2-medium')

model.transformer.wte.weight = pt_gpt2_model.wte.weight
model.transformer.wpe.weight = pt_gpt2_model.wpe.weight

layers = np.arange(0,len(pt_gpt2_model.h),1)
for layer in layers:
    model.transformer.h[layer].ln_1.weight = pt_gpt2_model.h[layer].ln_1.weight
    model.transformer.h[layer].attn.c_attn.weight = pt_gpt2_model.h[layer].attn.c_attn.weight
    model.transformer.h[layer].attn.c_proj.weight = pt_gpt2_model.h[layer].attn.c_proj.weight
    model.transformer.h[layer].ln_2.weight = pt_gpt2_model.h[layer].ln_2.weight
    model.transformer.h[layer].mlp.c_fc.weight = pt_gpt2_model.h[layer].mlp.c_fc.weight
    model.transformer.h[layer].mlp.c_proj.weight = pt_gpt2_model.h[layer].mlp.c_proj.weight
# model.to(model_args.device)
print('GPT2 loaded instead DialoGPT')

for n, p in model.named_parameters():
    if "adapter" not in str(n):
        p.requires_grad = False
parameters_to_update = [p for n, p in model.named_parameters() if "adapter" in str(n)]
# optimizer = AdamW(parameters_to_update, lr=model_args.lr, correct_bias=True)
print('GPT2 param frozen, Adapter is trainable and initialized with AdamW')

In [None]:
class BookcorpusopenGenreAdapterDataset(Dataset):
    def __init__(self, data_args, split, tokenizer, genre=None, adapter_id=-1,
                         sample_row=100, match_up_to_n_genres=None, truncate=True, 
                         max_seq_len=512, add_special_tokens=True,
                         *args, **kwargs):
        super(BookcorpusopenGenreAdapterDataset, self).__init__(*args, **kwargs)
        """
        Args:
            adapter_id: int, adapter_id for the genre we want the adapter to be trained with
        """
        
        self.data_args = data_args
        self.tokenizer = tokenizer
        self.add_special_tokens = add_special_tokens
        self.truncate = truncate
        self.max_seq_len = max_seq_len
        self.adapter_id = adapter_id
        self.preprocessing_num_workers = data_args.preprocessing_num_workers
        self.dataset = self.load_bookcorpusopen(split, genre, 
                                                match_up_to_n_genres,
                                                sample_row)

    def load_bookcorpusopen(self, split, genre='Fiction', 
                            match_up_to_n_genres=None, sample_row=None):
        """
        Load bookcorpusopen from pyarrow file.
        
        Further improvement:
        Group, concat, and truncate entries based on the adapter_id after tokenization
            
        Args:
            split: string, {train, valid, test}
            genre: string, genre that we want the adapter to be trained with, e.g. 'Fiction'
            match_up_to_n_genres: int, how many of the firsts bookcorpusopen genres entries 
                                    is considered as a genre to match with the genre input.
                                    None defaults to use all bookcorpusopen genres to match.
            sample_row: int, set the int number to sample the dataset, 
                        None means using all the datasets samples available
            match_up_to_n_genres
            
        Returns:
            dataset: tokenized huggingface dataset format from one of the bookcorpusopen split, 
                        with the adapter_id attached, and without any adapter_id = -1
        """

        def genre_match(entry_genres_string_list, genre, match_up_to_n_genres):
            """
            True to the genre that match to match_up_to_n_genres genres from the entry_genres
            else false
            """
            story_genre_list = [genre[1:-1] for genre in entry_genres_string_list[1:-1].split(', ')]
            story_genre_stringlist = ", ".join(story_genre_list[:match_up_to_n_genres])
            
            return genre.lower() in story_genre_stringlist.lower()
        
        def map_tokenization(batch):
            self.tokenizer.pad_token = self.tokenizer.eos_token
            tokenized = self.tokenizer(batch[self.data_args.bookcorpusopen_story_column_name], 
                                          truncation=self.truncate,
                                          max_length=self.max_seq_len,
                                          add_special_tokens=self.add_special_tokens)
            return tokenized
        
        # Main data processing function that will concatenate all texts 
        # from our dataset and generate chunks of max_seq_len.
        def group_texts(examples):
            # Concatenate all texts.
            concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
            total_length = len(concatenated_examples[list(examples.keys())[0]])
            # We drop the small remainder, we could add padding if the model supported 
            # it instead of this drop, you can customize this part to your needs.
            if total_length >= self.max_seq_len:
                total_length = (total_length // self.max_seq_len) * self.max_seq_len
            # Split by chunks of max_len.
            result = {
                k: [t[i : i + self.max_seq_len] \
                    for i in range(0, total_length, self.max_seq_len)]
                for k, t in concatenated_examples.items()
            }
            return result
        
        # load bookcorpusopen from arrow file
        datasets = DatasetDict()
        print('Loading train, validation, test dataset...')
        datasets = load_from_disk(self.data_args.dataset_path)
        print('Loaded')
        
        # Select rows sampled and filter for the matching genres
        sample_row = len(datasets[split]) if sample_row == None else sample_row
        
        dataset = datasets[split].select(np.arange(0,sample_row,1))\
                                .filter(lambda x: genre_match(x['genre'], genre, match_up_to_n_genres)\
                                        , num_proc=self.preprocessing_num_workers)

        
        # Tokenize with huggingface datasets mapping function
        tokenized_dataset = dataset.map(
            map_tokenization,
            remove_columns=self.data_args.bookcorpusopen_story_column_name,
            num_proc=self.preprocessing_num_workers,
            load_from_cache_file=True
        )
        print(split, 'split tokenized')
        
        group_concatted_dataset = tokenized_dataset.map(
                                        group_texts,
                                        batched=True,
                                        num_proc=self.preprocessing_num_workers,
                                        load_from_cache_file=True,
                                        desc=f"Grouping texts in chunks of {self.max_seq_len}",
                                    )
                                
        return group_concatted_dataset

    def __getitem__(self, index):
            
        forward_inputs = {}
        forward_inputs['task_id'] = self.adapter_id
        forward_inputs['input_ids'] = [self.dataset[index]['input_ids']]
        forward_inputs["labels"] = forward_inputs["input_ids"].copy()
        
        return forward_inputs

    def __len__(self):
        return self.dataset.num_rows

## Load and Checks

In [None]:
# Load the preprocessed dataset splits
dataset_dict = {}
for split in ['train', 'valid']:
    dataset_dict[split] = BookcorpusopenGenreAdapterDataset(
                                    data_args, split, tokenizer, genre=data_args.genre,
                                    adapter_id=data_args.adapter_id, sample_row=data_args.sample_row,
                                    match_up_to_n_genres=data_args.match_up_to_n_genres,
                                    max_seq_len=model_args.max_seq_len)

    for i in range(len(dataset_dict[split])):
        input_ids_len = len(dataset_dict[split][i]['input_ids'][0])
        if input_ids_len < model_args.max_seq_len:
            print(split, i, input_ids_len)

In [None]:
def remove_remainder(input_ids):
    # print(len(input_ids))
    return len(input_ids[0]) == self.max_seq_len

In [None]:
dataset_dict['train'].dataset.filter(lambda x: remove_remainder(x['input_ids'])\
                                                    , num_proc=8)

In [None]:
dataset_dict['train'].dataset[0]['input_ids']