In [None]:
import transformers
import torch
import wandb
import time
import os

from chunked_text_dataloader import ChunkedTextDataset
from tqdm.notebook import tqdm

fp16 = True
if fp16:
    from apex import amp
    

In [2]:
EPOCHS = 1
BATCH_SIZE = 4
CHUNK_SEQ_LEN = 256
TITLE_PRED_MAX_LEN = 32
model_name = "xlnet-base-cased"

tokenizer = transformers.XLNetTokenizer.from_pretrained(model_name)

# Get the datasets
input_folder = "C:\\Users\\jbetk\\Documents\\data\\ml\\title_prediction\\outputs\\"
train_set = ChunkedTextDataset(input_folder + "train.pt", tokenizer, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN, mask_target_percentage=.5,
                               pad_left=True, force_max_len_gen=False)
val_set = ChunkedTextDataset(input_folder + "val.pt", tokenizer, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN, mask_target_percentage=.5,
                               pad_left=True, force_max_len_gen=False)
test_set = ChunkedTextDataset(input_folder + "test.pt", tokenizer, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN, mask_target_percentage=.5,
                               pad_left=True, force_max_len_gen=False)
train_loader = train_set.get_dataloader(BATCH_SIZE, num_workers=0)
val_loader = val_set.get_dataloader(BATCH_SIZE, num_workers=0, random=False)
test_loader = test_set.get_dataloader(BATCH_SIZE, num_workers=0, random=False)

# Initialize w&b logger
do_wandb = True
if do_wandb:
    wandb.init(project="nonint-transformers-torch",\
               name="xlnet_title_prediction_front_unfixed_maskes_256_seq",\
               config={"dataset": "title_pred"})
    # There's something bugged about this, but it doesnt really seem to do much anyways. Apparently it enables some 
    # sort of gradient exploration map.
    #wandb.watch(model)

wandb: Wandb version 0.8.29 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [3]:
# Load model
config = transformers.XLNetConfig.from_pretrained(model_name)
config.mem_len = 1024
model = transformers.XLNetLMHeadModel.from_pretrained(model_name, config=config)
device = torch.device("cuda")
cpu = torch.device("cpu")

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0,
    },
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                                                         num_warmup_steps=0, num_training_steps=EPOCHS * len(train_set))

# Shift model to cuda & enable fp16 if applicable.
model.to(device)
if fp16:
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [None]:
preprocess_times = []
forward_times = []
backward_times = []
opt_times = []

def clear_timers():
    forward_times.clear()
    backward_times.clear()
    opt_times.clear()

def save_model(_model, _chkpt_name):
    # Save the model 
    _output_dir = os.path.join("c:/Users/jbetk/Documents/data/ml/saved_models", "xlnet_title_generation", _chkpt_name)
    if not os.path.exists(_output_dir):
        os.makedirs(_output_dir)
    _model_to_save = (
        _model.module if hasattr(_model, "module") else _model
    )  # Take care of distributed/parallel training
    _model_to_save.save_pretrained(_output_dir)
    print("Save completed. %s" % (_output_dir))

def train_epoch(_model, _optimizer, _scheduler, _device, _dataloader, _max_seq_len, _max_title_len, _fp16):
    _logging_steps = 5
    _steps_till_save = 2000
    _steps_till_validate = 2000
    
    clear_timers()
    
    _epoch_iterator = tqdm(_dataloader, desc="Iteration")
    _steps = 0
    _tr_loss, _logging_loss = 0, 0
    _chunks = 0
    _accuracy_accum, _accuracy_last = 0, 0
    _model.train()
    
    __s = time.time()
    for _step, _batch in enumerate(_epoch_iterator):
        preprocess_times.append(time.time() - __s)
        
        _mems = None
        _loss = None
        _chunk_loss_schedule = []
        _num_chunks = len(_batch["input_ids"])
        _chunks += _num_chunks
        for _masked_input_ids, _attention_masks, _labels in zip(_batch['input_ids_masked'],
                                                                _batch['attention_masks'],
                                                                _batch['labels']):            
            # Forward
            _inputs = {
                'input_ids': _masked_input_ids.to(_device),
                'attention_mask': _attention_masks.to(_device),
                'labels': _labels.to(_device)
            }
            if _mems is not None:
                _inputs['mems'] = _mems
            
            __s = time.time()
            _loss, _logits, _mems = _model.forward(**_inputs)
            forward_times.append(time.time() - __s)            
            
            # Backwards
            __s = time.time()
            if fp16:
                with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
                    _scaled_loss.backward()
                    backward_time = time.time() - __s
            else:
                _loss.backward()
                backward_time = time.time() - __s
            backward_times.append(backward_time)
            
            # Update weights
            if _fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), 1)
            else:
                torch.nn.utils.clip_grad_norm_(_model.parameters(), 1)
            __s = time.time()
            _optimizer.step()
            opt_times.append(time.time() - __s)
            _scheduler.step()
            _model.zero_grad()
            
            _chunk_loss_schedule.append(_loss.item())
        
        # Always accumulate loss across the last chunk, where it should be lowest. That's the goal of this model.
        _tr_loss += _loss.item()
        
        if _steps % _logging_steps == 0:
            _loss_scalar = (_tr_loss - _logging_loss) / _logging_steps
            _logging_loss = _tr_loss
            _logs = {}
            _logs["avg_chunks"] = _chunks / _logging_steps
            _chunks = 0
            _logs["loss"] = _loss_scalar
            _logs["learning_rate"] = _scheduler.get_lr()[0]
            if do_wandb:
                wandb.log(_logs)
        
        if _steps % _steps_till_save == 0:
            save_model(model, "chkpt_%i" % (_steps))
        if _steps % _steps_till_validate == 0:
            validate(_model, _device, _max_seq_len, _max_title_len)
            
        _steps += 1
        # Record time so we see how long it takes to fetch a batch.
        __s = time.time()


def validate(_model, _device, _max_seq_len, _max_title_len):
    _epoch_iterator = tqdm(val_loader, desc="Validation Iteration")
    _actual_steps = 0
    _total_loss = 0
    
    with torch.no_grad():
        for _step, _batch in enumerate(_epoch_iterator):
            # Skip 1 in 10 steps, because the validator just takes too long otherwise. It's not as easy as just cutting
            # down the dataset, either, since we run into chunk/batch size mismatches then.
            if _step % 10 != 0:
                continue
            _mems = None
            _loss = None
            for _masked_input_ids, _attention_masks, _labels in zip(_batch['input_ids_masked'],
                                                                    _batch['attention_masks'],
                                                                    _batch['labels']):            
                # Forward
                _inputs = {
                    'input_ids': _masked_input_ids.to(_device),
                    'attention_mask': _attention_masks.to(_device),
                    'labels': _labels.to(_device)
                }
                if _mems is not None:
                    _inputs['mems'] = _mems
                
                _loss, _logits, _mems = _model.forward(**_inputs)
            
            # Always accumulate loss across the last chunk, where it should be lowest. That's the goal of this model.
            _actual_steps += 1
            _total_loss += _loss.item()
    
        _logs = {}
        _val_loss = _total_loss / _actual_steps
        _logs["val_loss"] = _val_loss
        if do_wandb:
            wandb.log(_logs)
        print("Validation loss: " + str(_val_loss))

print("***** Running training *****")

model.zero_grad()
for _ in range(EPOCHS):    
    train_epoch(model, optimizer, scheduler, device, train_loader, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN, fp16)
    # Slowly increase the mask percentage per epoch to make the model have to work harder.
    train_set.mask_target_percentage += .1


***** Running training *****


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=24894.0, style=ProgressStyle(description_…

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0




Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0


wandb: Wandb version 0.8.29 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Save completed. c:/Users/jbetk/Documents/data/ml/saved_models\xlnet_title_generation\chkpt_0


HBox(children=(FloatProgress(value=0.0, description='Validation Iteration', max=512.0, style=ProgressStyle(des…


Validation loss: 7.229584409640386
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Save completed. c:/Users/jbetk/Documents/data/ml/saved_models\xlnet_title_generation\chkpt_2000


HBox(children=(FloatProgress(value=0.0, description='Validation Iteration', max=512.0, style=ProgressStyle(des…


Validation loss: 1.9668402855212872
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0


In [None]:
validate(model, device, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN)

In [None]:
import os
output_dir = os.path.join("c:/Users/jbetk/Documents/data/ml/saved_models", "xlnet_title_generation")

# Load model from saved state.
config = transformers.XLNetConfig.from_pretrained(output_dir)
config.mem_len = 1024
model = transformers.XLNetLMHeadModel.from_pretrained(output_dir, config=config)
device = torch.device("cuda")
cpu = torch.device("cpu")

#optimizer = torch.load(os.path.join(output_dir, "optimizer.pt"))
#scheduler = torch.load(os.path.join(output_dir, "scheduler.pt"))
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0,
    },
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                                                         num_warmup_steps=0, num_training_steps=EPOCHS * len(train_set))

# Shift model to cuda & enable fp16 if applicable.
model.to(device)
if fp16:
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

In [None]:
# Test the model.
actual_article_title = "Italy announces lockdown as global coronavirus cases surpass 105,000"
article_text = """
Italian Prime Minister Giuseppe Conte signed a decree early Sunday that will put millions of people across northern Italy under lockdown due to the novel coronavirus.
The sweeping move puts the entire Lombardy region, as well as 14 other provinces, under travel restrictions, and is one of the toughest responses implemented outside of mainland China to get the Covid-19 epidemic under control.
CNN is verifying exactly when the lockdown will go into effect.
The announcement came after Italy saw a dramatic spike of 1,247 confirmed novel coronavirus cases on Saturday, the Civil Protection Department said in a statement.
The country has now recorded 5,883 cases and 233 deaths, the most fatalities outside mainland China and the biggest outbreak in Europe.
Announcing the new measures, Conte said: "There will be an obligation to avoid any movement of people who are either entering or leaving" the affected areas. "Even within the areas moving around will occur only for essential work or health reasons," he said, according to Reuters.
While the lockdown only applies to northern Italy, other measures will be applied to the entire country. These include the suspension of schools, university classes, theaters and cinemas, as well as bars, nightclubs, and sports events. Religious ceremonies, including funerals, will also be suspended.
Other countries in Europe are also struggling to contain outbreaks as cases continue to rise.
On Saturday, France's general director of health, Jerome Salomon, confirmed 16 dead and 949 infected nationwide, and Germany now has 795 cases. The United Kingdom confirmed a second death from the novel coronavirus on Saturday, while 206 people have tested positive, British health officials said in a statement.
The World Health Organization (WHO) has called on "all countries to continue efforts that have been effective in limiting the number of cases and slowing the spread of the virus."
In a statement, the WHO said: "Allowing uncontrolled spread should not be a choice of any government, as it will harm not only the citizens of that country but affect other countries as well."
Meanwhile in China, search and rescue efforts continued on Sunday for survivors from the collapse of a hotel that was being used as a coronavirus quarantine center.
The hotel, in the southeastern city of Quanzhou, in Fujian province, came down Saturday night with 80 people inside. Four people died, one person remains in critical condition and four others are seriously injured, according to China's Ministry of Emergency Management.
"We are using life detection instruments to monitor signs of life and professional breaking-in tools to make forcible entries. We are trying our utmost to save trapped people," said Guo Yutuan, squadron leader of the Quanzhou armed police detachment's mobile unit.
The building's owner is in police custody, according to state news agency Xinhua and an investigation is underway.
"""
import math

# Create inputs to the model for a given chunk of input_ids and a partially generated output.
def create_inputs_for_chunk(_chunk: torch.Tensor, 
                            _mems: torch.Tensor, 
                            _tokenizer: transformers.PreTrainedTokenizer, 
                            _seq_len: int, 
                            _title_len: int, 
                            _test_device: torch.device):
    assert(_outputs_so_far.shape[0] == _title_len)
    _padding_needed = _seq_len - _chunk.shape[0] - _title_len
    _pad_tensor = torch.full((_padding_needed,), _tokenizer.pad_token_id, dtype=torch.long)
    
    _input_ids = torch.cat([_pad_tensor, _chunk, _outputs_so_far]).unsqueeze(dim=0)
    _attention_mask = torch.cat([torch.zeros((_padding_needed,), dtype=torch.float), torch.ones((_seq_len - _padding_needed), dtype=torch.float)]).unsqueeze(dim=0)
    _token_type_ids = torch.cat([torch.zeros((_seq_len - _title_len), dtype=torch.long), torch.ones((_title_len,), dtype=torch.long)]).unsqueeze(dim=0)
    _inputs = {
        "input_ids": _input_ids.to(_test_device),
        "attention_mask": _attention_mask.to(_test_device),
        "token_type_ids": _token_type_ids.to(_test_device)
    }
    if _mems is not None:
        _inputs["mems"] = _mems
    return _inputs

# Returns top-k words the model predicts given _text_tensor and computed _outputs_so_far.
def predict_words(_text_tensor: torch.Tensor, 
                  _outputs_so_far: torch.Tensor,
                  _tokenizer: transformers.PreTrainedTokenizer, 
                  _test_model: transformers.PreTrainedModel, 
                  _seq_len: int, 
                  _title_len: int, 
                  _test_device: torch.device, 
                  _k_count=3):
    _chunk_count = math.ceil(_text_tensor.shape[0] / (_seq_len - _title_len))
    _tok_text_chunked = torch.chunk(_text_tensor, _chunk_count, dim=0)
    _mems = None
    for _chunk in _tok_text_chunked:
        _inputs = create_inputs_for_chunk(_chunk, _mems, _tokenizer, _seq_len, _title_len, _test_device)
        _logits, _mems = _test_model.forward(**_inputs)
    # Remove the batch dimension.
    _logits = _logits[0]
    return torch.topk(_logits, _k_count)
    
def predict_forward(_text_tensor: torch.Tensor, 
                    _predict_tensor: torch.Tensor,
                    _predict_index: int,
                    _tokenizer: transformers.PreTrainedTokenizer, 
                    _test_model: transformers.PreTrainedModel, 
                    _seq_len: int, 
                    _title_len: int, 
                    _test_device: torch.device):
    if _predict_index == _title_len:
        return _predict_tensor
    _words = predict_words(_text_tensor, _predict_tensor, _tokenizer, _test_model, _seq_len, _title_len, _test_device, 3)
    for _prob, _word in _words:
        print("Predict %s at %f probability" % (_tokenizer.decode([_word]), _prob))
    _predict_tensor[_predict_index] = _words[0][1]
    if _predict_tensor[_predict_index] == _tokenizer.eos_token_id:
        return _predict_tensor
    return predict_forward(_text_tensor, _predict_tensor, _predict_index + 1, _tokenizer, _test_model, _seq_len, _title_len, _test_device)

def test_model(_text_input: string, 
               _tokenizer: transformers.PreTrainedTokenizer, 
               _test_model: transformers.PreTrainedModel, 
               _seq_len: int, 
               _title_len: int, 
               _test_device: torch.device):
    with torch.no_grad():
        _tok_text = torch.tensor(_tokenizer.encode(_tokenizer.bos_token + _text_input + _tokenizer.sep_token, add_special_tokens=False), dtype=torch.long)
        _tok_title = torch.full((title_len,), _tokenizer.mask_token_id, dtype=torch.long)
        _predicted_tensor = predict_forward(_text_input, _tok_title, 0, _tokenizer, _test_model, _seq_len, _title_len, _test_device)
        return _tokenizer.decode(_predicted_tensor)
    
print(test_model(article_text, model, CHUNK_SEQ_LEN, TITLE_PRED_MAX_LEN, device))