In [1]:
import orjson
import transformers
import torch
import wandb
import time
import random

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from tqdm.notebook import tqdm

fp16 = True
if fp16:
    from apex import amp

model_name = "xlnet-base-cased"

In [2]:
def load_dataset(input_file, _max_seq_len):
    data = orjson.loads(open(input_file, "rb").read())
    # Root data is a list of lists of features. The first-order list organizes the sequences into sets of like-length 
    # that can be batched together.
    datasets = []
    count = 0
    for features in data:
        # Each feature is a dictionary of a 'text' sequence and a 'title' sequence. The goal of this model is to
        # predict the 'title' given the 'text'. Process them out together, the model trainer will do the rest of the
        # work.
        input_ids = torch.tensor([f['text']['input_ids'] for f in features], dtype=torch.long)
        attention_mask = torch.tensor([f['text']['attention_mask'] for f in features], dtype=torch.float)
        token_type_ids = torch.tensor([f['text']['token_type_ids'] for f in features], dtype=torch.long)
        title_input_ids = torch.tensor([f['title']['input_ids'] for f in features], dtype=torch.long)
        datasets.append(TensorDataset(input_ids, attention_mask, token_type_ids, title_input_ids))
        # This trainer "chunks" the dataset lower and trains per-chunk.
        _num_chunks = input_ids.shape[-1] / _max_seq_len
        count += len(features) * _num_chunks
    return datasets, count

# Process dataset
input_folder = "C:\\Users\\jbetk\\Documents\\data\\ml\\title_prediction\\outputs\\"
train_datasets, total_train_data_sz = load_dataset(input_folder + "processed.json", 128)
val_datasets, total_val_data_sz = load_dataset(input_folder + "validation.json", 128)

In [3]:
EPOCHS = 1
BATCH_SIZE = 4
CHUNK_SEQ_LEN = 128
TITLE_PRED_MAX_LEN = 64

# Load model
tokenizer = transformers.XLNetTokenizer.from_pretrained(model_name)
config = transformers.XLNetConfig.from_pretrained(model_name)
config.mem_len = 1024
model = transformers.XLNetLMHeadModel.from_pretrained(model_name, config=config)
device = torch.device("cuda")
cpu = torch.device("cpu")

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0,
    },
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                                                         num_warmup_steps=0, num_training_steps=EPOCHS * int(total_train_data_sz))

# Shift model to cuda & enable fp16 if applicable.
model.to(device)
if fp16:
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    
# Initialize w&b logger
do_wandb = True
if do_wandb:
    wandb.init(project="nonint-transformers-torch",\
               name="xlnet_title_prediction",\
               config={"dataset": "title_pred"})
    # There's something bugged about this, but it doesnt really seem to do much anyways. Apparently it enables some 
    # sort of gradient exploration map.
    #wandb.watch(model)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


wandb: Wandb version 0.8.29 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [None]:
preprocess_times = []
xfer_times = []
forward_times = []
backward_times = []
opt_times = []
sched_times = []

def clear_timers():
    xfer_times.clear()
    forward_times.clear()
    backward_times.clear()
    opt_times.clear()
    sched_times.clear()

def prepare_chunked_inputs(_batched_inputs, _max_seq_len, _max_title_len):
    # We need to do a lot more data preparation before feeding into the model.
    # First, chunk the batch into a list of tensors each of size max_seq_len.
    with torch.no_grad():
        _batch_sz = _batched_inputs[0].shape[0]
        _nr_chunks = int(_batched_inputs[0].shape[-1] / _max_seq_len)
        _batch_inputs = _batched_inputs[0:3]
        _chunked_batch_inputs_by_input_nr = [torch.chunk(_bi, _nr_chunks, -1) for _bi in _batch_inputs]
        _chunked_batch_inputs = []
        
        # These tensors will be used to append on to the input tensors where the prediction will occur.
        _input_mask_tensor = torch.full((_batch_sz, _max_title_len), tokenizer.mask_token_id, dtype=torch.long)
        _ones_float_tensor_for_title = torch.ones((_batch_sz, _max_title_len), dtype=torch.float)
        _ones_long_tensor_for_title = torch.ones((_batch_sz, _max_title_len), dtype=torch.long)
        for i in range(_nr_chunks):
            # For the input_ids (index 0), append on _max_title_len masks.
            _chunked_input_ids = torch.cat([_chunked_batch_inputs_by_input_nr[0][i], _input_mask_tensor], dim=-1)
            # For the attention mask, just add all 1s because this is not padding.
            _chunked_attention_mask = torch.cat([_chunked_batch_inputs_by_input_nr[1][i], _ones_float_tensor_for_title], dim=-1)
            # For token type IDs, also all 1s since this is the "second sentence".
            _chunked_token_type_ids = torch.cat([_chunked_batch_inputs_by_input_nr[2][i], _ones_long_tensor_for_title], dim=-1)
            _chunked_batch_inputs.append([_chunked_input_ids, _chunked_attention_mask, _chunked_token_type_ids])
        
        # Create a target mapping that will be used for all inputs, since they all follow a similar format.
        _target_mapping = torch.zeros((_batch_sz, _max_title_len, _max_seq_len + _max_title_len), dtype=torch.float)
        for i in range(_max_title_len):
            for b in range(_batch_sz):
                _target_mapping[b][i][_max_seq_len + i] = 1
        
        # Next, gather the expected output IDs and generate the 'labels' format that transformers is expecting.
        _labels = _batched_inputs[3]
    return _chunked_batch_inputs, _target_mapping, _labels

def chunk_to_inputs(_chunk, _target_mapping, _labels, _mems, _device):
    _inputs = {"input_ids": _chunk[0], 
              "attention_mask": _chunk[1], 
              "token_type_ids": _chunk[2],
              "target_mapping": _target_mapping}

    if _labels is not None:
        _inputs["labels"] = _labels

    # Don't forget to send all these tensors to the device.
    __s = time.time()
    for i, (k,v) in enumerate(_inputs.items()):
        _inputs[k] = v.to(_device)
    xfer_times.append(time.time() - __s)
    
    # Mems will just stay on-device, so add them last.
    if _mems is not None:
        _inputs["mems"] = _mems
    return _inputs

def train_epoch(_model, _optimizer, _scheduler, _device, _dataloader, _max_seq_len, _max_title_len, _fp16):
    clear_timers()
    
    _epoch_iterator = tqdm(_dataloader, desc="Iteration")
    _steps = 0
    _tr_loss, _logging_loss = 0, 0
    _accuracy_accum, _accuracy_last = 0, 0
    _model.train()
    
    for _step, _batch in enumerate(_epoch_iterator):
        __s = time.time()
        _chunked_batch_inputs, _target_mapping, _labels = prepare_chunked_inputs(_batch, _max_seq_len, _max_title_len)
        _num_chunks = len(_chunked_batch_inputs)
        preprocess_times.append(time.time() - __s)
        
        _mems = None
        _loss = None
        _chunk_loss_schedule = []
        for _chunk in _chunked_batch_inputs:
            _inputs = chunk_to_inputs(_chunk, _target_mapping, _labels, _mems, _device)
            
            # Forward
            __s = time.time()
            _loss, _logits, _mems = _model.forward(**_inputs)
            forward_times.append(time.time() - __s)            
            
            # Backwards
            __s = time.time()
            if fp16:
                with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
                    _scaled_loss.backward()
                    backward_time = time.time() - __s
            else:
                _loss.backward()
                backward_time = time.time() - __s
            backward_times.append(backward_time)
            
            # Update weights
            if _fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), 1)
            else:
                torch.nn.utils.clip_grad_norm_(_model.parameters(), 1)
            __s = time.time()
            _optimizer.step()
            opt_times.append(time.time() - __s)
            __s = time.time()
            _scheduler.step()
            sched_times.append(time.time() - __s)
            _model.zero_grad()
            
            _chunk_loss_schedule.append(_loss.item())
        
        # Always accumulate loss across the last chunk, where it should be lowest. That's the goal of this model.
        _steps += 1
        _tr_loss += _loss.item()
        
        # Always log.
        _loss_scalar = (_tr_loss - _logging_loss)
        _logging_loss = _tr_loss
        _logs = {}
        _logs["loss_" + str(_num_chunks)] = _loss_scalar
        _logs["learning_rate"] = _scheduler.get_lr()[0]
        if do_wandb:
            wandb.log(_logs)


def validate_epoch(_model, _device, _dataloader, _max_seq_len, _max_title_len):
    _epoch_iterator = tqdm(_dataloader, desc="Iteration")
    _steps = 0
    _tr_loss, _logging_loss = 0, 0
    _accuracy_accum, _accuracy_last = 0, 0
    _model.train()
    for _step, _batch in enumerate(_epoch_iterator):
        _chunked_batch_inputs, _target_mapping, _labels = prepare_chunked_inputs(_batch, _max_seq_len, _max_title_len)
        _num_chunks = len(_chunked_batch_inputs)
        
        _mems = None
        _loss = None
        _chunk_loss_schedule = []
        for _chunk in _chunked_batch_inputs:
            _inputs = chunk_to_inputs(_chunk, _target_mapping, _labels, _mems, _device)
            
            with torch.no_grad():
                _loss, _logits, _mems = _model.forward(**_inputs)
            
            _chunk_loss_schedule.append(_loss.item())
        
        # Always accumulate loss across the last chunk, where it should be lowest. That's the goal of this model.
        _steps += 1
        _tr_loss += _loss.item()
        break
    return _tr_loss, _steps

print("***** Running training *****")

def full_validate():
    combined_val_steps, combined_val_loss = 0, 0
    for i, val_dataset in enumerate(val_datasets):
        print("Running validation %i.." % (i))
        val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
        l, s = validate_epoch(model, device, val_dataloader, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN)
        combined_val_steps += s
        combined_val_loss += l
        
    _logs = {}
    _logs["val_loss"] = combined_val_loss / combined_val_steps
    if do_wandb:
        wandb.log(_logs)
    print("Validation loss averaged over %i steps: %f" % (int(combined_val_steps), combined_val_loss / combined_val_steps))

model.zero_grad()
for _ in range(EPOCHS):
    random.shuffle(train_datasets)
    for train_dataset in train_datasets:
        full_validate()
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)
        
        train_epoch(model, optimizer, scheduler, device, train_dataloader, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN, fp16)
    

***** Running training *****
Running validation 0..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 1..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 2..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…



Running validation 3..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 4..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…


Running validation 5..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 6..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…


Running validation 7..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 8..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 9..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 10..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 11..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 12..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 13..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 14..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 15..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 16..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 17..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 18..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 19..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 20..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 21..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 22..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 23..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 24..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 25..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 26..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 27..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 28..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

Running validation 29..


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

wandb: Wandb version 0.8.29 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Validation loss averaged over 30 steps: 14.817024


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=864.0, style=ProgressStyle(description_wi…

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0




Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1024.0




Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0


In [5]:
import os

# Save the model 
output_dir = os.path.join("c:/Users/jbetk/Documents/data/ml/saved_models", "xlnet_title_generation")
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
model_to_save = (
    model.module if hasattr(model, "module") else model
)  # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

torch.save(model.state_dict(), os.path.join(output_dir, "model.pt"))
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

print("Save completed.")


Save completed.


In [None]:
# Test the model.
actual_article_title = "Italy announces lockdown as global coronavirus cases surpass 105,000"
article_text = """
Italian Prime Minister Giuseppe Conte signed a decree early Sunday that will put millions of people across northern Italy under lockdown due to the novel coronavirus.
The sweeping move puts the entire Lombardy region, as well as 14 other provinces, under travel restrictions, and is one of the toughest responses implemented outside of mainland China to get the Covid-19 epidemic under control.
CNN is verifying exactly when the lockdown will go into effect.
The announcement came after Italy saw a dramatic spike of 1,247 confirmed novel coronavirus cases on Saturday, the Civil Protection Department said in a statement.
The country has now recorded 5,883 cases and 233 deaths, the most fatalities outside mainland China and the biggest outbreak in Europe.
Announcing the new measures, Conte said: "There will be an obligation to avoid any movement of people who are either entering or leaving" the affected areas. "Even within the areas moving around will occur only for essential work or health reasons," he said, according to Reuters.
While the lockdown only applies to northern Italy, other measures will be applied to the entire country. These include the suspension of schools, university classes, theaters and cinemas, as well as bars, nightclubs, and sports events. Religious ceremonies, including funerals, will also be suspended.
Other countries in Europe are also struggling to contain outbreaks as cases continue to rise.
On Saturday, France's general director of health, Jerome Salomon, confirmed 16 dead and 949 infected nationwide, and Germany now has 795 cases. The United Kingdom confirmed a second death from the novel coronavirus on Saturday, while 206 people have tested positive, British health officials said in a statement.
The World Health Organization (WHO) has called on "all countries to continue efforts that have been effective in limiting the number of cases and slowing the spread of the virus."
In a statement, the WHO said: "Allowing uncontrolled spread should not be a choice of any government, as it will harm not only the citizens of that country but affect other countries as well."
Meanwhile in China, search and rescue efforts continued on Sunday for survivors from the collapse of a hotel that was being used as a coronavirus quarantine center.
The hotel, in the southeastern city of Quanzhou, in Fujian province, came down Saturday night with 80 people inside. Four people died, one person remains in critical condition and four others are seriously injured, according to China's Ministry of Emergency Management.
"We are using life detection instruments to monitor signs of life and professional breaking-in tools to make forcible entries. We are trying our utmost to save trapped people," said Guo Yutuan, squadron leader of the Quanzhou armed police detachment's mobile unit.
The building's owner is in police custody, according to state news agency Xinhua and an investigation is underway.
"""

def test_model(_text_input, _test_model, _seq_len, _title_len, _test_device):
    tokenized_text_plus = tokenizer.encode_plus(_text_input, add_special_tokens=True, max_length=None, pad_to_max_length=False,
                                   return_token_type_ids=True, return_attention_mask=True)
    # The chunker expects a labels element, but we dont actually want to supply one for test; just supply an empty tensor.
    tokenized_text_plus.append(torch.empty((0,), dtype=torch.long))
    _test_batch = [tokenized_text_plus]
    
    _test_chunked_batch_inputs, _test_target_mapping, _test_labels = prepare_chunked_inputs(_test_batch, _seq_len, _title_len)
    _test_num_chunks = len(_test_chunked_batch_inputs)
    
    _test_mems = None
    _test_loss = None
    _test_chunk_loss_schedule = []
    _test_logits = None
    with torch.no_grad():
        for _test_chunk in _test_chunked_batch_inputs:
            _test_inputs = chunk_to_inputs(_test_chunk, _test_target_mapping, None, _test_mems, _test_device)
            # I'm assuming no loss is returned since I'm not giving it any labels - may need to adjust this.
            _test_logits, _test_mems = _model.forward(**_test_inputs)
            _test_chunk_loss_schedule.append(_test_loss.item())
        _test_logits_argmax = torch.argmax(_test_logits, dim=-1)
        return tokenizer.decode(_test_logits_argmax[0].numpy())
    
print(test_model(article_text, model, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN, device))