In [23]:
import os, math
import torch
import wandb
from torch import nn
from datasets import load_dataset
from transformers import ElectraTokenizer, TextDataset, ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, \
    set_seed, DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizer, Trainer, ElectraTokenizerFast, \
    TrainingArguments, EvaluationStrategy
from transformers.tokenization_utils_base import PaddingStrategy

In [24]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33megordm[0m (use `wandb login --relogin` to force relogin)


True

In [25]:

tokenizer_custom = {
    '@HTAG': '[HTAG]',
    '@USR': '[USR]',
    '@CURR': '[CURR]',
    '@EMOJI': '[EMOJI]',
    '@URL': '[URL]',
    '@TIME': '[TIME]',
    '@DATE': '[DATE]',
    '@NUM': '[NUM]'
}

DATASET_DIR = '../../data/bitcoin_twitter_corpus'
VOCAB_FILE = '../../data/vocab/bitcoin_twitter/bitcoin_twitter-vocab.txt'
TRAIN_DS = os.path.join(DATASET_DIR, 'train.tokens')
TEST_DS = os.path.join(DATASET_DIR, 'test.tokens')
VALIDATE_DS = os.path.join(DATASET_DIR, 'validate.tokens')

model_path = './bitcoin_twitter'
seq_length = 256
accum_multipler = 1
batch_size = 128
epochs = 1
warmup_ratio = 0.06
lr = 5e-4
vocab_size = 16537
block_size = 200
seed = 1337

set_seed(seed)

In [26]:
tokenizer = ElectraTokenizerFast(vocab_file=VOCAB_FILE)
tokenizer.add_special_tokens({
    'additional_special_tokens': list(tokenizer_custom.values())
})

assert tokenizer.vocab_size == vocab_size

In [27]:
class CombinedModel(nn.Module):
    def __init__(self, discriminator: PreTrainedModel, generator: PreTrainedModel, tokenizer: PreTrainedTokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.discriminator = discriminator
        self.generator = generator

        # Embeddings are shared
        self.discriminator.set_input_embeddings(self.generator.get_input_embeddings())

    @staticmethod
    def mask_inputs(
            input_ids: torch.Tensor,
            mask_token_id, mask_probability,
            tokens_to_ignore,
            max_predictions_per_seq,
            proposal_distribution=1.0
    ):
        inputs_which_can_be_masked = torch.ones_like(input_ids)
        for token in tokens_to_ignore:
            inputs_which_can_be_masked -= torch.eq(input_ids, token).long()

        total_number_of_tokens = input_ids.shape[-1]

        # Identify the number of tokens to be masked, which should be: 1 < num < max_predictions per seq.
        # It is set to be: n_tokens * mask_probability, but is truncated if it goes beyond bounds.
        number_of_tokens_to_be_masked = torch.max(
            torch.tensor(1),
            torch.min(
                torch.tensor(max_predictions_per_seq),
                torch.tensor(total_number_of_tokens * mask_probability, dtype=torch.long)
            )
        )

        # The probability of each token being masked
        sample_prob = proposal_distribution * inputs_which_can_be_masked
        sample_prob /= torch.sum(sample_prob)
        # Should be passed through a log function here

        # Weight of each position: 1 the position will be masked, 0 the position won't be masked
        masked_lm_weights = torch.tensor([0] * max_predictions_per_seq, dtype=torch.bool)
        masked_lm_weights[:number_of_tokens_to_be_masked] = True

        # Sample from the probabilities
        masked_lm_positions = sample_prob.multinomial(max_predictions_per_seq)

        # Apply the weights to the positions
        masked_lm_positions *= masked_lm_weights.long()

        # Gather the IDs from the positions
        masked_lm_ids = input_ids.gather(-1, masked_lm_positions)

        # Apply weights to the IDs
        masked_lm_ids *= masked_lm_weights.long()

        replace_with_mask_positions = masked_lm_positions * (torch.rand(masked_lm_positions.shape) < 0.85)

        # Replace the input IDs with masks on given positions
        masked_input_ids = input_ids.scatter(-1, replace_with_mask_positions, mask_token_id)

        # Updates to index 0 should be ignored
        masked_input_ids[..., 0] = input_ids[..., 0]

        return masked_input_ids, masked_lm_positions

    @staticmethod
    def gather_positions(
            sequence,
            positions
    ):
        batch_size, sequence_length, dimension = sequence.shape
        position_shift = (sequence_length * torch.arange(batch_size)).unsqueeze(-1)
        flat_positions = torch.reshape(positions + position_shift, [-1]).long()
        flat_sequence = torch.reshape(sequence, [batch_size * sequence_length, dimension])
        gathered = flat_sequence.index_select(0, flat_positions)
        return torch.reshape(gathered, [batch_size, -1, dimension])

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None
    ):
        masked_input_ids, masked_lm_positions = self.mask_inputs(
            input_ids,  self.tokenizer.mask_token_id, 0.2,
            [self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.mask_token_id],
            30
        )

        generator_loss, generator_output = self.generator(
            masked_input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            position_ids,
            labels=labels
        )[:2]

        fake_logits = self.gather_positions(generator_output, masked_lm_positions)
        fake_argmaxes = fake_logits.argmax(-1)
        fake_tokens = masked_input_ids.scatter(-1, masked_lm_positions, fake_argmaxes)
        fake_tokens[:, 0] = input_ids[:, 0]

        # discriminator_output
        discriminator_loss, discriminator_output = self.discriminator(
            fake_tokens,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            position_ids,
            labels=labels
        )[:2]

        discriminator_predictions = torch.round((torch.sign(discriminator_output) + 1) / 2).int().tolist()

        total_loss = discriminator_loss + generator_loss
        return (
            total_loss,
            (discriminator_predictions, generator_output),
            (fake_tokens, masked_input_ids)
        )

    def save_pretrained(self, directory):
        generator_path = os.path.join(directory, "generator")
        discriminator_path = os.path.join(directory, "discriminator")

        if not os.path.exists(generator_path):
            os.makedirs(generator_path)

        if not os.path.exists(discriminator_path):
            os.makedirs(discriminator_path)

        self.generator.save_pretrained(generator_path)
        self.discriminator.save_pretrained(discriminator_path)

In [28]:
generator_config = ElectraConfig(
    embedding_size=128,
    hidden_size = 256,
    intermediate_size = 1024,
    max_position_embeddings=seq_length,
    num_attention_heads=4,
    num_hidden_layers=12,
    vocab_size=vocab_size,
)

discriminator_config = ElectraConfig(
    embedding_size=128,
    hidden_size=256,
    intermediate_size=1024,
    max_position_embeddings=seq_length,
    num_attention_heads=4,
    num_hidden_layers=12,
    vocab_size=vocab_size,
)

In [29]:
generator = ElectraForMaskedLM(config=generator_config)
discriminator = ElectraForPreTraining(config=discriminator_config)
model = CombinedModel(discriminator, generator, tokenizer)
# wandb.watch(model)

In [32]:
dataset = load_dataset("text", data_files={
    'train': TEST_DS,
    # 'train': TRAIN_DS,
    # 'test': TEST_DS,
    'validate': VALIDATE_DS
}, cache_dir='./cache')

Using custom data configuration default-73d35a8d0d4c2a4e


Downloading and preparing dataset text/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to ./cache/text/default-73d35a8d0d4c2a4e/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691...
Dataset text downloaded and prepared to ./cache/text/default-73d35a8d0d4c2a4e/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

In [33]:
def tokenize_function(examples):
    return tokenizer(
        examples['text'], truncation=True,
        padding=PaddingStrategy.MAX_LENGTH, max_length=seq_length
    )

tokenized_datasets = dataset.map(tokenize_function, batched=True)

HBox(children=(FloatProgress(value=0.0, max=1198.0), HTML(value='')))





HBox(children=(FloatProgress(value=0.0, max=1198.0), HTML(value='')))

In [34]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False, mlm_probability=0
)

In [35]:
arguments = TrainingArguments(
    output_dir=model_path,
    do_train=True,
    evaluation_strategy=EvaluationStrategy.STEPS,
    eval_steps = 50000,
    prediction_loss_only=True,
    learning_rate=lr,
    report_to=["wandb"],
    load_best_model_at_end=True,
    num_train_epochs=20
)

# Initialize our Trainer
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validate'],
)

In [36]:
# Training
trainer.train(model_path=model_path)
trainer.save_model()

  torch.tensor(total_number_of_tokens * mask_probability, dtype=torch.long)


Step,Training Loss


KeyboardInterrupt: 

In [None]:
# Evaluation
results = {}
print("*** Evaluate ***")

eval_output = trainer.evaluate()

perplexity = math.exp(eval_output["loss"])
result = {"perplexity": perplexity}

output_eval_file = "eval_results_lm.txt"
with open(output_eval_file, "w") as writer:
    print("***** Eval results *****")
    for key in sorted(result.keys()):
        print("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))

results.update(result)