In [1]:
import torch
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel

from torch.utils.data import DataLoader, Dataset, SequentialSampler

import os
import logging



In [2]:
def my_collate(batch, tokenizer):
    pad_token_id = None
    if isinstance(tokenizer, GPT2Tokenizer):
        pad_token_id = tokenizer.eos_token_id
    else:
        raise NotImplementedError
    sorted_batch = sorted(batch, key=lambda b: b.shape[0], reverse=True)
    padded = torch.nn.utils.rnn.pad_sequence(sorted_batch, batch_first=True,
                                             padding_value=pad_token_id)
    lengths = torch.LongTensor([len(x) for x in sorted_batch])
    padding_mask = (torch.arange(padded.shape[1])[None, :] < lengths[:, None]) \
                   .type(torch.FloatTensor)

    # logging.info("@@@@@@ padded.shape: {}".format(padded.shape))
    return padded, padding_mask

In [3]:
class TextDataset(Dataset):
    def __init__(self, tokenizer, file_path):
        assert os.path.isfile(file_path)

        self.examples = []
        
        line_count = 0
        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                line_count += 1
                text = line.strip()
                tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
                self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return torch.tensor(self.examples[item])

In [27]:
from torch.nn import CrossEntropyLoss

class AnagenGPT2LMHeadModel(GPT2LMHeadModel):
    def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
                labels=None):
        transformer_outputs = self.transformer(input_ids,
                                               past=past,
                                               attention_mask=attention_mask,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
                                               head_mask=head_mask,
                                               inputs_embeds=inputs_embeds)
        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)

        outputs = (lm_logits,) + transformer_outputs[1:]
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            print("shift_logits.shape", shift_logits.shape)
            print("shift_labels.shape", shift_labels.shape)
            
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=-1)
            loss1 = shift_logits.view(-1, shift_logits.size(-1))
            print("loss1.shape", loss1.shape)
            loss2 = shift_labels.view(-1)
            print("loss2.shape", loss2.shape)
            loss = loss_fct(loss1, loss2)
            
            print("loss2[:10]", loss2[:10])
            outputs = (loss,) + outputs

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)

In [28]:
model_dir = "/home/hansonlu/links/data/anagen_models/anagen_b28_model"
tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
model = AnagenGPT2LMHeadModel.from_pretrained(model_dir)

special_tokens_dict = {'additional_special_tokens': ['<anteced>', '</anteced>', '<anaphor>', '</anaphor>']}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

Embedding(50261, 768)

In [29]:
train_dataset = TextDataset(tokenizer, file_path="data/dummy.txt")
train_batch_size = 3
train_sampler = SequentialSampler(train_dataset)
collate_fn = lambda b: my_collate(b, tokenizer)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size, collate_fn=collate_fn)

model.zero_grad()

In [31]:
for step, batch in enumerate(train_dataloader):
    batch, attention_mask = batch
    print(batch.shape)
    inputs, labels = batch, batch
    model.train()
    outputs = model(inputs, attention_mask=attention_mask, labels=labels)
    break

torch.Size([3, 73])
tensor([17947,  2873,  6379, 14306,   319,   262,  3878,  3668,   286,  2807])
shift_logits.shape torch.Size([3, 72, 50261])
shift_labels.shape torch.Size([3, 72])
loss1.shape torch.Size([216, 50261])
loss2.shape torch.Size([216])
loss2[:10] tensor([ 2873,  6379, 14306,   319,   262,  3878,  3668,   286,  2807,  1058])


In [None]:
ctx_strs = ["With <anteced> their </anteced> unique charm , <anaphor>",
            "The world 's fifth <anteced> Disney </anteced> park will soon open to the public here . The most important thing about <anaphor>"]

examples = []
for l in ctx_strs:
    ctx_toks = tokenizer.encode(l, add_special_tokens=False)
    examples.append(l)
    print(ctx_toks)

ctx = torch.tensor(ctx_toks, dtype=torch.long)
print(ctx)
ctx = ctx.unsqueeze(0).repeat(3, 1)
print(ctx)

inputs = {"input_ids": ctx}

outputs = model(**inputs)
print(outputs[0].shape)
next_token_logits = outputs[0][:, -1, :]
print(next_token_logits.shape)
