In [1]:
import torch
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel

from torch.utils.data import DataLoader, Dataset, SequentialSampler

import os
import logging

import time


In [2]:
def tag_scale(batch, tokenizer, scale_method, scale_factor):

    if scale_method == 'both_vs_other' or scale_method == 'both_only':
        antec_beg_tok_id = tokenizer.encode("<anteced>")[0]
        antec_end_tok_id = tokenizer.encode("</anteced>")[0]
        anaph_beg_tok_id = tokenizer.encode("<anaphor>")[0]
        anaph_end_tok_id = tokenizer.encode("</anaphor>")[0]

        # partially vectorized version ~ 30 times faster!
        antec_beg_idx = torch.tensor([s.tolist().index(antec_beg_tok_id) for s in batch]).view(-1, 1)
        antec_end_idx = torch.tensor([s.tolist().index(antec_end_tok_id) for s in batch]).view(-1, 1)
        anaph_beg_idx = torch.tensor([s.tolist().index(anaph_beg_tok_id) for s in batch]).view(-1, 1)
        anaph_end_idx = torch.tensor([s.tolist().index(anaph_end_tok_id) for s in batch]).view(-1, 1)

        ranges = torch.arange(batch.size(1)).view(1, -1).repeat_interleave(batch.size(0), dim=0)
        booltensor = ((ranges >= antec_beg_idx) & (ranges <= antec_end_idx)) | ((ranges >= anaph_beg_idx) & (ranges <= anaph_end_idx))
        if scale_method == 'both_vs_other':
            res = torch.where(booltensor, torch.tensor(scale_factor), torch.tensor(1.))
        else:
            res = torch.where(booltensor, torch.tensor(1.), torch.tensor(0.))
    
    if scale_method == 'anaphor_vs_other' or scale_method == 'anaphor_only':
        anaph_beg_tok_id = tokenizer.encode("<anaphor>")[0]
        anaph_end_tok_id = tokenizer.encode("</anaphor>")[0]
        anaph_beg_idx = torch.tensor([s.tolist().index(anaph_beg_tok_id) for s in batch]).view(-1, 1)
        anaph_end_idx = torch.tensor([s.tolist().index(anaph_end_tok_id) for s in batch]).view(-1, 1)
        ranges = torch.arange(batch.size(1)).view(1, -1).repeat_interleave(batch.size(0), dim=0)
        booltensor = (ranges >= anaph_beg_idx) & (ranges <= anaph_end_idx)
        if scale_method == 'anaphor_vs_other':
            res = torch.where(booltensor, torch.tensor(scale_factor), torch.tensor(1.))
        else:
            res = torch.where(booltensor, torch.tensor(1.), torch.tensor(0.))
        
    return res
    
def my_scaled_collate(batch, tokenizer, scale_method, scale_factor):
    pad_token_id = None
    if isinstance(tokenizer, GPT2Tokenizer):
        pad_token_id = tokenizer.eos_token_id
    else:
        raise NotImplementedError
    sorted_batch = sorted(batch, key=lambda b: b.shape[0], reverse=True)
    padded = torch.nn.utils.rnn.pad_sequence(sorted_batch, batch_first=True,
                                             padding_value=pad_token_id)
    scaling = tag_scale(padded, tokenizer, scale_method, scale_factor)
    lengths = torch.LongTensor([len(x) for x in sorted_batch])
    padding_mask = (torch.arange(padded.shape[1])[None, :] < lengths[:, None]) \
                   .type(torch.FloatTensor)

    return padded, padding_mask, lengths, scaling

    # for loop basic version
#     start = time.time()
#     scaling = []
#     for s in batch:
#         s_scaling = []
#         in_tag = False
#         for x in s:
#             if x == antec_beg_tok_id or x == anaph_beg_tok_id:
#                 in_tag = True
#             s_scaling.append(scale_factor if in_tag else 1. )
#             if x == antec_end_tok_id or x == anaph_end_tok_id:
#                 in_tag = False
#         scaling.append(s_scaling)
#     res = torch.tensor(scaling)
#     print("for loop time", time.time() - start)
#     return torch.tensor(scaling)

In [3]:
class TextDataset(Dataset):
    def __init__(self, tokenizer, file_path):
        assert os.path.isfile(file_path)

        self.examples = []
        
        line_count = 0
        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                line_count += 1
                text = line.strip()
                tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
                self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return torch.tensor(self.examples[item])

In [10]:
from torch.nn import CrossEntropyLoss

class AnagenGPT2LMHeadModel(GPT2LMHeadModel):
    def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
                labels=None, lengths=None, scaling=None):
        transformer_outputs = self.transformer(input_ids,
                                               past=past,
                                               attention_mask=attention_mask,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
                                               head_mask=head_mask,
                                               inputs_embeds=inputs_embeds)
        hidden_states = transformer_outputs[0]
        # print("hidden_states.shape", hidden_states.shape) # torch.Size([3, 73, 768])

        lm_logits = self.lm_head(hidden_states)
        print("lm_logits.shape", lm_logits.shape) # torch.size([3, 73, 50681])

        outputs = (lm_logits,) + transformer_outputs[1:]
        if labels is not None:
            logits_mask = attention_mask.bool()
            logits_mask[torch.arange(lm_logits.size(0)), lengths-1] = False
            
            labels_mask = attention_mask.bool()
            labels_mask[:, 0] = False
            
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=-1, reduction="none")
            flat_logits = lm_logits.view(-1, lm_logits.size(-1))
            flat_logits_mask = logits_mask.view(-1)
            flat_labels = labels.view(-1)
            flat_labels_mask = labels_mask.view(-1)
            flat_scaling = scaling.view(-1)
            filtered_logits = flat_logits[flat_logits_mask]
            filtered_labels = flat_labels[flat_labels_mask]
            filtered_scaling = flat_scaling[flat_labels_mask]
            losses = loss_fct(filtered_logits, filtered_labels)
            loss = torch.mean(losses * filtered_scaling)
            
            outputs = (loss,) + outputs

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)

In [11]:
model_dir = "/home/hansonlu/links/data/anagen_models/anagen_b28_model"
tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
model = AnagenGPT2LMHeadModel.from_pretrained(model_dir)
old_model = GPT2LMHeadModel.from_pretrained(model_dir)

special_tokens_dict = {'additional_special_tokens': ['<anteced>', '</anteced>', '<anaphor>', '</anaphor>']}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
old_model.resize_token_embeddings(len(tokenizer))

Embedding(50261, 768)

In [12]:
train_dataset = TextDataset(tokenizer, file_path="data/dummy.txt")
train_batch_size = 3
train_sampler = SequentialSampler(train_dataset)
collate_fn = lambda b: my_scaled_collate(b, tokenizer, 'anaphor_only', 1)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size, collate_fn=collate_fn)

model.zero_grad()
old_model.zero_grad()

In [13]:
for step, batch in enumerate(train_dataloader):
    batch, attention_mask, lengths, scaling = batch
    print(batch.shape, attention_mask.shape, lengths.shape, scaling.shape)
    inputs, labels = batch, batch
    model.train()
    outputs = model(inputs, attention_mask=attention_mask, labels=labels, lengths=lengths, scaling=scaling)
       
    #     old_start = time.time()
    #     old_outputs = old_model(inputs, attention_mask=attention_mask, labels=labels)
    #     print("old time", time.time() - old_start)  
        # TODO: need to try optimizer

torch.Size([3, 73]) torch.Size([3, 73]) torch.Size([3]) torch.Size([3, 73])
lm_logits.shape torch.Size([3, 73, 50261])


In [None]:
# ctx_strs = ["With <anteced> their </anteced> unique charm , <anaphor>",
#             "The world 's fifth <anteced> Disney </anteced> park will soon open to the public here . The most important thing about <anaphor>"]

# examples = []
# for l in ctx_strs:
#     ctx_toks = tokenizer.encode(l, add_special_tokens=False)
#     examples.append(l)
#     print(ctx_toks)

# ctx = torch.tensor(ctx_toks, dtype=torch.long)
# print(ctx)
# ctx = ctx.unsqueeze(0).repeat(3, 1)
# print(ctx)

# inputs = {"input_ids": ctx}

# outputs = model(**inputs)
# print(outputs[0].shape)
# next_token_logits = outputs[0][:, -1, :]
# print(next_token_logits.shape)


In [None]:
0.0002567768096923828
0.005959749221801758