In [None]:
from transformers import GPT2LMHeadModel, GPT2Config, GPT2TokenizerFast, DataCollatorForLanguageModeling
import torch
import numpy as np

In [None]:
def load_pretrained_tokenizer(pretrained_model_name_or_path):
    tokenizer = GPT2TokenizerFast.from_pretrained(
        pretrained_model_name_or_path, 
        add_prefix_space=True, # ?
    )
    tokenizer.pad_token = tokenizer.eos_token # ?
    return tokenizer

In [None]:
class BidirectionalLM:
    """
    Adapted from pqian11/fragment-completion (Qian and Levy, 2022)
    """
    def __init__(self, device='cuda'):
        configuration = GPT2Config()
        self.model = GPT2LMHeadModel(configuration).to(device)
        self.tokenizer = load_pretrained_tokenizer('gpt2')
        self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)

        self.BLANK = '[BLANK]'
        self.FILLER = '[FILLER]'
        self.SEP = '[SEP]'
        self.num_added_tokens = self.tokenizer.add_tokens([self.BLANK, self.FILLER, self.SEP])
        self.model.resize_token_embeddings(len(self.tokenizer))
        self.BLANK_id = self.tokenizer.convert_tokens_to_ids(self.BLANK)
        self.FILLER_id = self.tokenizer.convert_tokens_to_ids(self.FILLER)
        self.SEP_id = self.tokenizer.convert_tokens_to_ids(self.SEP)

    def expand_inputs(self, inputs):

        input_ids = inputs['input_ids']
        # attention_mask = features['attention_mask']

        n_tokens = len(input_ids)

        bidi_input_ids = [input_ids[:i] + [self.BLANK_id] + input_ids[i+1:] + [self.SEP_id, self.FILLER_id] 
                        for i in range(n_tokens)]

        bidi_attention_mask = [[1 for _ in range(n_tokens + 2)] for _ in range(n_tokens)]

        bidi_labels = [[-100 for _ in range(n_tokens + 1)] + [answer_token] 
                    for answer_token in input_ids]

        mini_batch = {
            'input_ids': bidi_input_ids,
            'attention_mask': bidi_attention_mask,
            'labels': bidi_labels
        }

        # mini_batch = {
        #     'input_ids': torch.tensor(bidi_input_ids, dtype=torch.long),
        #     'attention_mask': torch.tensor(bidi_attention_mask, dtype=torch.long),
        #     'labels': torch.tensor(bidi_labels, dtype=torch.long)
        # }
        # print(batch)
        return mini_batch
    
    def make_batch(self, mini_batches, device='cuda'):
        """
        Given mini_batches (List[Dict]), create batch (Dict)
        containing input_ids, attention_mask, and labels tensors

        Reduce comprehensions for efficiency
        """
        infilling_ids_batch = [batch["input_ids"] for batch in mini_batches]
        infilling_labels_batch = [batch['labels'] for batch in mini_batches]
        batch_max_len = max(infilling_ids_batch, key=len)

        infilling_ids_padded_batch = [
            infilling_ids + [self.tokenizer.pad_token_id for _ in range(batch_max_len - len(infilling_ids))] 
            for infilling_ids in infilling_ids_batch
        ]
        
        attention_mask_padded_batch = [
            [1 for _ in range(len(infilling_ids))] + [0 for _ in range(batch_max_len - len(infilling_ids))] 
            for infilling_ids in infilling_ids_batch
        ]

        labels_padded_batch = [
            infilling_labels + [-100 for _ in range(batch_max_len - len(infilling_labels))] 
            for infilling_labels in infilling_labels_batch
        ]


        batch_input_ids = torch.tensor(infilling_ids_padded_batch, dtype=torch.long, device=torch.device('cuda'))
        batch_attention_mask = torch.tensor(attention_mask_padded_batch, dtype=torch.long, device=torch.device('cuda'))
        batch_labels = torch.tensor(labels_padded_batch, dtype=torch.long, device=torch.device('cuda'))

        batch = {
            'input_ids': batch_input_ids,
            'attention_mask': batch_attention_mask,
            'labels': batch_labels
        }
        return batch
    
    def get_batch_loss(self, batch):
        loss = self.model(batch["input_ids"], labels=batch["label_ids"], attention_mask=batch["attention_mask"])[0]
        # batch_token_count = np.sum([len(answer_tokens) for answer_tokens in answer_tokens_batch])
        return loss#, batch_token_count

    # def get_batch_loss(self, mini_batches, device='cuda'):
    #     """
    #     Given mini_batches (List[Dict]), create batched 
    #     input_ids, attention_mask, and labels tensors and 
    #     return batch loss
    #     """

    #     infilling_ids_batch = [context_ids + [self.SEP_id] + answer_ids for context_ids, answer_ids in zip(context_ids_batch, answer_ids_batch)]
    #     batch_max_len = np.max([len(infilling_ids) for infilling_ids in infilling_ids_batch])
    #     infilling_ids_padded_batch = [infilling_ids + [self.tokenizer.bos_token_id for _ in range(batch_max_len - len(infilling_ids))] for infilling_ids in infilling_ids_batch]
        
    #     attention_mask = [[1 for _ in range(len(infilling_ids))] + [0 for _ in range(batch_max_len - len(infilling_ids))] for infilling_ids in infilling_ids_batch]
    #     attention_mask = torch.tensor(attention_mask).to(device)

    #     input_ids = torch.tensor(infilling_ids_padded_batch).to(device)
    #     label_ids = [[-100 for _ in range(len(context_ids)+1)] + answer_ids + [-100 for _ in range(batch_max_len - 1 - len(context_ids) - len(answer_ids)) ] for context_ids, answer_ids in zip(context_ids_batch, answer_ids_batch)]
    #     label_ids = torch.tensor(label_ids).to(device)

    #     loss = self.model(input_ids, labels=label_ids, attention_mask=attention_mask)[0]
    #     # batch_token_count = np.sum([len(answer_tokens) for answer_tokens in answer_tokens_batch])
    #     return loss#, batch_token_count


    # def get_batch_loss(self, data_batch, device='cuda'):
    #     context_tokens_batch = [self.tokenizer.tokenize(' '.join(context)) for context, _ in data_batch]
    #     answer_tokens_batch = [self.tokenizer.tokenize(' '.join(answer)) for _, answer in data_batch]

    #     context_ids_batch = [self.tokenizer.convert_tokens_to_ids(context_tokens) for context_tokens in context_tokens_batch]
    #     answer_ids_batch = [self.tokenizer.convert_tokens_to_ids(answer_tokens) for answer_tokens in answer_tokens_batch]

    #     infilling_ids_batch = [context_ids + [self.SEP_id] + answer_ids for context_ids, answer_ids in zip(context_ids_batch, answer_ids_batch)]
    #     batch_max_len = np.max([len(infilling_ids) for infilling_ids in infilling_ids_batch])
    #     infilling_ids_padded_batch = [infilling_ids + [self.tokenizer.bos_token_id for _ in range(batch_max_len - len(infilling_ids))] for infilling_ids in infilling_ids_batch]
        
    #     attention_mask = [[1 for _ in range(len(infilling_ids))] + [0 for _ in range(batch_max_len - len(infilling_ids))] for infilling_ids in infilling_ids_batch]
    #     attention_mask = torch.tensor(attention_mask).to(device)

    #     input_ids = torch.tensor(infilling_ids_padded_batch).to(device)
    #     label_ids = [[-100 for _ in range(len(context_ids)+1)] + answer_ids + [-100 for _ in range(batch_max_len - 1 - len(context_ids) - len(answer_ids)) ] for context_ids, answer_ids in zip(context_ids_batch, answer_ids_batch)]
    #     label_ids = torch.tensor(label_ids).to(device)

    #     loss = self.model(input_ids, labels=label_ids, attention_mask=attention_mask)[0]
    #     batch_token_count = np.sum([len(answer_tokens) for answer_tokens in answer_tokens_batch])
    #     return loss, batch_token_count
    
    # def get_loss(self, data, batch_size, device='cuda'):
    #     total_loss = 0
    #     total_token_count = 0

    #     for data_batch in get_batches(data, batch_size):
    #         loss, batch_token_count = self.get_batch_loss(data_batch, device=device)
    #         total_loss += loss.item()*batch_token_count
    #         total_token_count += batch_token_count

    #     return total_loss/total_token_count

    def load(self, model_path):
        self.model.load_state_dict(torch.load(model_path))

    def save(self, model_path):
        torch.save(self.model.state_dict(), model_path)


In [None]:
BidiLM = BidirectionalLM()

In [None]:
import accelerate
import transformers

transformers.__version__, accelerate.__version__

In [None]:
import torch
torch.cuda.is_available()
torch.cuda.get_device_name(0)

In [None]:
from transformers import Trainer, TrainingArguments

In [None]:
training_output_dir = '../models/test2/'
args = TrainingArguments(
    training_output_dir,
    per_device_train_batch_size=128, # change to fit GPU specs
    per_device_eval_batch_size=128,
    group_by_length=True, # bucketing
)
print(args.device)

In [None]:
tokenized_datasets = None

In [None]:
trainer = Trainer(
    model=BidiLM.model,
    tokenizer=BidiLM.tokenizer,
    args=args,
    data_collator=BidiLM.data_collator,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['val'],
)

In [None]:
trainer.train(resume_from_checkpoint=True)