# Train the Paper Title Generator

In [2]:
from tqdm.notebook import tqdm
import wandb
from datasets import load_from_disk
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, DataCollatorForSeq2Seq, EncoderDecoderConfig, EncoderDecoderModel

In [3]:
wandb.init(project="abstract-to-title", entity="nerdimite")

[34m[1mwandb[0m: Currently logged in as: [33mnerdimite[0m (use `wandb login --relogin` to force relogin)


## Initialize BERT

In [4]:
# Initialize bert model for finetuning
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relatio

In [5]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

## Preprocess Dataset

In [6]:
dataset = load_from_disk('arxiv_AI_dataset')

In [7]:
MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 128

In [8]:
def preprocess_data(example):
    
    model_inputs = tokenizer(example['abstract'], max_length=MAX_SOURCE_LEN, padding=False, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(example['title'], max_length=MAX_SOURCE_LEN, padding=False, truncation=True)

    # Replace all pad token ids in the labels by -100 to ignore padding in the loss
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    ]

    model_inputs['labels'] = labels["input_ids"]
    # model_inputs['decoder_input_ids'] = [label[1:] for label in labels["input_ids"]]

    return model_inputs

In [9]:
# Apply preprocess_data() to the whole dataset
processed_dataset = dataset.map(
    preprocess_data,
    batched=True,
    remove_columns=['abstract', 'title'],
    desc="Running tokenizer on dataset",
)

# Set return type to torch tensors
processed_dataset.set_format(type='torch')

processed_dataset

Loading cached processed dataset at arxiv_AI_dataset/train\cache-4fc5b5eb43cd14e2.arrow
Loading cached processed dataset at arxiv_AI_dataset/test\cache-d6fa3acc4f0ed90d.arrow
Loading cached processed dataset at arxiv_AI_dataset/val\cache-bd67eee0b8029284.arrow


DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids'],
        num_rows: 36074
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids'],
        num_rows: 2005
    })
    val: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids'],
        num_rows: 2004
    })
})

In [10]:
train_data, val_data, test_data = processed_dataset['train'], processed_dataset['val'], processed_dataset['test']

In [12]:
train_data[0]

{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 'input_ids': tensor([  101,  7674, 16700, 13384, 26633,  2015,  1996,  2529,  3754,  2000,
          2191,  3653, 17421, 16790,  2015,  2055,  2256,  3558,  2088,  1010,
          1998,  2009,  2003,  2019, 27427,  2483, 11837, 19150, 23354,  1999,
          2311,  2236,  9932,  3001,  1012,  2057, 16599,  1037,  2047,  7674,
         16700, 13384,  2951, 13462,  2241,  2006,  2529,  1005,  1055,  9123,
          4349,  2208,  2652,  2015,  2004,  2529,  2867, 10580, 20228,  4765,
         18424,  1998,  7578,  7674, 16700, 13384,  1012,  1996,  2047,  2951,
         13462,

In [13]:
# Dynamic padding using a collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100
)

## Train

In [19]:
def train(
    model,
    train_loader,
    val_loader,
    epochs,
    optimizer,
    scheduler,
    device=None,
):  
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    n_train_steps = len(train_loader)

    model.to(device)

    best_val_loss = 1e10
    
    train_steps = 0
    val_steps = 0

    for e in range(epochs):

        # Training
        model.train()
        pbar = tqdm(total=n_train_steps, desc=f"Epoch {e+1}")

        for i, batch in enumerate(train_loader):
            
            for key, value in batch.items():
                batch[key] = batch[key].to(device)
            del(batch['decoder_input_ids'])
            
            optimizer.zero_grad()
            
            outputs = model(**batch)
            loss, logits = outputs.loss, outputs.logits

            loss.backward()
            optimizer.step()
            
            pbar.set_postfix({'Train Loss': loss.item()})
            pbar.update(1)
            
            train_steps += 1
            
            wandb.log({'train': {"loss": loss.item(), 'epoch': e+1, 'batch': i+1}})
            
            scheduler.step()
        
        # Validation
        model.eval()
        
        val_losses = []

        with torch.no_grad():
            for i, batch in enumerate(val_loader):

                for key, value in batch.items():
                    batch[key] = batch[key].to(device)
                del(batch['decoder_input_ids'])

                outputs = model(**batch)
                loss, logits = outputs.loss, outputs.logits
                
                pbar.set_postfix({'Val Loss': loss.item()})
                pbar.update(1)

                val_steps += 1
                
                wandb.log({'val': {"loss": loss.item(), 'epoch': e+1, 'batch': i+1}})
                val_losses.append(loss.item())
        
        if np.array(val_losses).mean() < best_val_loss:
            torch.save(model.state_dict(), f'epoch-{e}-{int(np.array(val_losses).mean())}.pt')
            best_val_loss = np.array(val_losses).mean()

    model.cpu()

In [15]:
wandb.config = {
  "learning_rate": 0.001,
  "epochs": 5,
  "batch_size": 1
}

In [16]:
# Create dataloaders
train_loader = DataLoader(train_data, shuffle=True, collate_fn=data_collator, batch_size=wandb.config['batch_size'])
val_loader = DataLoader(val_data, collate_fn=data_collator, batch_size=wandb.config['batch_size'])
test_loader = DataLoader(test_data, collate_fn=data_collator, batch_size=wandb.config['batch_size'])

In [17]:
optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config['learning_rate'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

In [18]:
train(model, train_loader, val_loader, wandb.config['epochs'], optimizer, scheduler, device='cuda')

Epoch 1:   0%|          | 0/36074 [00:00<?, ?it/s]



RuntimeError: CUDA out of memory. Tried to allocate 90.00 MiB (GPU 0; 6.00 GiB total capacity; 4.34 GiB already allocated; 0 bytes free; 4.50 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF