In [15]:
# %load_ext autoreload
# %autoreload 2
import os
import sys

import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm

# Modify path so that we can import local modules into notebook
module_path = os.path.abspath(os.path.join('./utils')) 
sys.path.insert(0, module_path)

from utils.exp_util import load_config, init_model, get_tokenizer, setup_exp_folders
from utils.data_util import get_datasets

In [2]:
if torch.cuda.is_available(): 
 dev = "cuda:0" 
else: 
 dev = "cpu" 
device = torch.device(dev) 
print(f"Device: {device}")

Device: cuda:0


## Config

In [16]:
EXPERIMENT_NAME = "small_test"
config = load_config(EXPERIMENT_NAME)

print(config)

{'GPT_SIZE': 'small', 'TARGET_TYPE': 'desc', 'TRAIN_SPLIT': 0.75, 'VAL_SPLIT': 0.15, 'TEST_SPLIT': 0.1, 'EPOCHS': 1, 'BATCH_SIZE': 8, 'LR': 5e-05, 'WARMUP_STEP': 100, 'GRADIENT_ACCUMULATION_STEPS': 32, 'MAX_GRAD_NORM': 1, 'RANDOM_SEED': 42}


In [4]:
RANDOM_SEED = config["RANDOM_SEED"]
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.random.manual_seed(RANDOM_SEED)
import random
random.seed(RANDOM_SEED)

## Load Dataset

In [13]:
tokenizer = get_tokenizer(config["GPT_SIZE"])
train_dataset, val_dataset, test_dataset = get_datasets(config, tokenizer)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
train_dataset.df.head()

10000 articles loaded.
6461 samples after cleaning
Train: 4845, Val: 970, Test: 646


Unnamed: 0,title,description,title_tokens,description_tokens
1703,freenet (FRA:FNTN) PT Set at EUR27.00 by Deuts...,freenet (FRA:FNTN - Get Rating) has been given...,"[f, reen, et, Ġ(, F, RA, :, F, NT, N, ), ĠPT, ...","[f, reen, et, Ġ(, F, RA, :, F, NT, N, Ġ-, ĠGet..."
3238,"Ally Financial (NYSE:ALLY) Upgraded to ""Hold"" ...",Ally Financial (NYSE:ALLY - Get Rating) was up...,"[All, y, ĠFinancial, Ġ(, NYSE, :, ALLY, ), ĠUp...","[All, y, ĠFinancial, Ġ(, NYSE, :, ALLY, Ġ-, ĠG..."
5840,Baytex/Ranger Oil Combination First Of Its Kin...,Baytex Energy Corp. announced on Tuesday that ...,"[Bay, tex, /, R, anger, ĠOil, ĠComb, ination, ...","[Bay, tex, ĠEnergy, ĠCorp, ., Ġannounced, Ġon,..."
5820,"Aaron Rodgers, QBs become top attractions at N...",Several teams including the Packers will have ...,"[Aaron, ĠRodgers, ,, ĠQB, s, Ġbecome, Ġtop, Ġa...","[Several, Ġteams, Ġincluding, Ġthe, ĠPackers, ..."
782,PACAF Airmen assist in ENCAP during Cobra Gold...,U.S. Air Force Airmen from Joint Base Elmendor...,"[PAC, AF, ĠA, irm, en, Ġassist, Ġin, ĠE, NC, A...","[U, ., S, ., ĠAir, ĠForce, ĠA, irm, en, Ġfrom,..."


In [7]:
sample_tokens = train_dataset[0]['token_ids']
sep_idx =  train_dataset[0]['sep_pos']
print("Example context: ", tokenizer.decode(sample_tokens[:sep_idx]))
print("Example target: ", tokenizer.decode(sample_tokens[sep_idx+1:]))

Example context:  Robert Platt honored as first in school history to bring home state title in boys wrestling
Example target:  Brawley wrestling has been around since 1968 - and in 55 years of existence, the school has never had a boy state champ, until now.The post Robert Platt honored as first in school history to bring home state title in boys wrestling appeared first on KYMA.<|endoftext|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pa

## Train

In [8]:
from torch.utils.tensorboard import SummaryWriter

def train(model, train_dataset, valid_dataset, ignore_index, config, checkpoint_every=0, checkpoint_path=None):
    log_dir, model_dir = setup_exp_folders(EXPERIMENT_NAME)
    writer = SummaryWriter(log_dir)

    train_dataloader = DataLoader(train_dataset,batch_size=config["BATCH_SIZE"])#,num_workers=args.num_workers)
    EPOCHS = config["EPOCHS"]
    GRADIENT_ACCUMULATION_STEPS = config["GRADIENT_ACCUMULATION_STEPS"]
    
    loss_fct = CrossEntropyLoss(ignore_index=ignore_index) #ignores padding token for loss calculation
    optimizer = AdamW(model.parameters(),lr=config["LR"])
    total_steps = len(train_dataloader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=config["WARMUP_STEP"], num_training_steps=total_steps
    )

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    start_epoch = 1

    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path)
        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint['model_state_dict'])
        global_step = checkpoint["global_step"]
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loss = checkpoint["loss"]
        tr_loss = checkpoint["tr_loss"]

    model.train()
    model.zero_grad()
    
    for epoch in range(start_epoch, EPOCHS+1):
        print(f"Epoch {epoch}")
        for step, batch in enumerate(tqdm(train_dataloader)):
            inputs, labels = batch['token_ids'], batch['token_ids']

            inputs = inputs.to(device)
            labels = labels.to(device)
            logits = model(inputs)[0]
            idx = batch['sep_pos']

            losses = []
            for i, sep_idx in enumerate(idx):
                shift_logits = logits[i, sep_idx:-1, :].contiguous()
                shift_labels = labels[i, sep_idx+1:].contiguous()
                l = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                losses.append(l)

            # Combine the losses
            loss = torch.stack(losses).mean()
            loss = loss/GRADIENT_ACCUMULATION_STEPS
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), config["MAX_GRAD_NORM"])
            tr_loss += loss.item()

            if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                writer.add_scalar('loss', (tr_loss - logging_loss)/GRADIENT_ACCUMULATION_STEPS, global_step)
                logging_loss = tr_loss

            if (step + 1) % (10*GRADIENT_ACCUMULATION_STEPS) == 0:
                results = evaluate(model, valid_dataset, config["BATCH_SIZE"], ignore_index)
                for key, value in results.items():
                    writer.add_scalar('eval_{}'.format(key), value, global_step)
                model.train()

        if epoch % checkpoint_every == 0:
            # Save checkpoint
            model_path = os.path.join(model_dir, f"checkpoint_epoch{epoch}.pt")
            torch.save({
                    'epoch': epoch,
                    'global_step': global_step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    'tr_loss': tr_loss
                    }, model_path)
            
            print(f"Saved checkpoint to {model_path}\n")

    if epoch % checkpoint_every != 0:
        # Save final checkpoint (if it wasn't already saved)
        model_path = os.path.join(model_dir, f"checkpoint_epoch{epoch}.pt")
        torch.save({
                'epoch': EPOCHS,
                'global_step': global_step,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                'tr_loss': tr_loss
                }, model_path)
        print(f"Saved final checkpoint to {model_path}\n")
    print(f"Training complete!")


In [9]:
def evaluate(model, eval_dataset, batch_size, ignore_index):
    """ Returns perplexity score on validation dataset.
        global_step: no. of times gradients have backpropagated
        ignore_index: token not considered in loss calculation
    """
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
    loss_fct = CrossEntropyLoss(ignore_index=ignore_index) #ignores padding token for loss calculation

    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()

    for batch in eval_dataloader:
        inputs, labels = batch['token_ids'].to(device), batch['token_ids'].to(device)
        
        with torch.no_grad():
            logits = model(inputs)[0]
            idx = batch['sep_pos']

            losses = []
            for i, sep_idx in enumerate(idx):
                shift_logits = logits[i, sep_idx:-1, :].contiguous()
                shift_labels = labels[i, sep_idx+1:].contiguous()
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                losses.append(loss)

            # Combine the losses
            eval_loss += torch.stack(losses).mean()

        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    return {"perplexity": perplexity, "loss": eval_loss}           

In [10]:
model = init_model(tokenizer, config["GPT_SIZE"])
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50259, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50259, bias=False)
)

In [None]:
train(model, train_dataset, val_dataset, tokenizer.pad_token_id, config, checkpoint_every=5)