In [85]:
from datasets import load_from_disk
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, DataCollatorForSeq2Seq, EncoderDecoderConfig, EncoderDecoderModel

In [None]:
# Initialize pretrained model for finetuning
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

## Preprocess Dataset

In [97]:
dataset = load_from_disk('arxiv_AI_dataset')

In [98]:
MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 128

In [99]:
def preprocess_data(example):
    
    model_inputs = tokenizer(example['abstract'], max_length=MAX_SOURCE_LEN, padding=False, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(example['title'], max_length=MAX_SOURCE_LEN, padding=False, truncation=True)

    # Replace all pad token ids in the labels by -100 to ignore padding in the loss
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    ]

    model_inputs['labels'] = labels["input_ids"]

    return model_inputs

In [100]:
# Apply preprocess_data() to the whole dataset
processed_dataset = dataset.map(
    preprocess_data,
    batched=True,
    remove_columns=['abstract', 'title'],
    desc="Running tokenizer on dataset",
)
processed_dataset

Running tokenizer on dataset: 100%|██████████| 37/37 [02:34<00:00,  4.18s/ba]
Running tokenizer on dataset: 100%|██████████| 3/3 [00:09<00:00,  3.17s/ba]
Running tokenizer on dataset: 100%|██████████| 3/3 [00:08<00:00,  2.89s/ba]


DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids'],
        num_rows: 36074
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids'],
        num_rows: 2005
    })
    val: Dataset({
        features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids'],
        num_rows: 2004
    })
})

In [101]:
# Set return type to torch tensors
processed_dataset.set_format(type='torch')

In [102]:
train_data, val_data, test_data = processed_dataset['train'], processed_dataset['val'], processed_dataset['test']

In [103]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100
)

In [104]:
train_loader = DataLoader(train_data, shuffle=True, collate_fn=data_collator, batch_size=2)
val_loader = DataLoader(val_data, collate_fn=data_collator, batch_size=2)
test_loader = DataLoader(test_data, collate_fn=data_collator, batch_size=2)

## Train

In [None]:
ARTICLE_TO_SUMMARIZE = "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data."
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=512, return_tensors='pt')

# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=100, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

In [105]:
batch = next(train_loader.__iter__())

In [106]:
output = model(**batch)



In [107]:
output.loss

tensor(10.1948, grad_fn=<NllLossBackward0>)

In [108]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [89]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [None]:
input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt").input_ids
labels = tokenizer("Salut, mon chien est mignon", return_tensors="pt").input_ids
outputs = model(input_ids=batch['input_ids'], labels=batch['labels'])
loss, logits = outputs.loss, outputs.logits
print(loss)

In [109]:
for e in range(50):
    optimizer.zero_grad()
    output = model(**batch)
    logits, loss = output.logits, output.loss
    print(loss.item())
    loss.backward()
    optimizer.step()

10.194809913635254
9.398115158081055
7.756308078765869
6.957815647125244
5.367212772369385
4.701416969299316
3.9998509883880615
3.793119192123413
3.600857973098755
3.40639328956604


In [115]:
pretrained = EncoderDecoderModel.from_pretrained('Callidior/bert2bert-base-arxiv-titlegen')

In [120]:
pretrained.config.decoder == model.config.decoder

False

In [163]:
for attr1, attr2 in zip(dir(pretrained.config.decoder), dir(model.config.decoder)):
    if not attr1.startswith('_') and not attr2.startswith('_') and eval(f"type(pretrained.config.decoder.{attr1}) in [int, str, float]"):
        if eval(f"pretrained.config.decoder.{attr1} == model.config.decoder.{attr2}"):
            pass
        else:
            print(False)
            print(attr1, attr2)
            print(eval(f"pretrained.config.decoder.{attr1}, model.config.decoder.{attr2}"))


False
transformers_version transformers_version
('4.3.2', '4.12.3')


In [165]:
model2 = EncoderDecoderModel.from_pretrained('Callidior/bert2bert-base-arxiv-titlegen', config=model.config)

In [None]:
output = model2.generate(batch['input_ids'], num_beams=4, max_length=100, early_stopping=True)

for i, l in enumerate(batch['labels']):
    print('\nPrediction:', tokenizer.decode(output[i], skip_special_tokens=True), '\n')
    print('Labels:', tokenizer.decode(l.where(l != -100, torch.ones(l.shape).long()), skip_special_tokens=True), '\n')

In [169]:
for i, l in enumerate(batch['labels']):
    print('\nPrediction:', tokenizer.decode(output[i], skip_special_tokens=True), '\n')
    print('Labels:', tokenizer.decode(l.where(l != -100, torch.ones(l.shape).long()), skip_special_tokens=True), '\n')


Prediction: rule discovery for obesity risk prediction ehr data mining ehr rule discovery method 

Labels: identifying the leading factors of significant weight gains using a new rule discovery method 


Prediction: state of ai ethics report ( june 2020 )ss on the state of ai ethics report on the state of ai ethics report ( soccer ethics report ) on the state of ai ethics report ( june 2020 on the state of ai ethics report ( soccer ethics report on the state of 

Labels: the state of ai ethics report ( june 2020 ) [unused0] [unused0] [unused0] [unused0] 



In [None]:
def train(
    model,
    train_data,
    val_data,
    batch_size,
    lr,
    epochs,
    device=None
):  
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
    # criterion = nn.CrossEntropyLoss(ignore_index=0) # ignore pad idx

    n_train_steps = len(train_data) // batch_size

    model.to(device)

    train_losses = []
    val_losses = []

    for e in range(epochs):

        epoch_train_loss = []
        epoch_val_loss = []

        train_loader = DataGenerator(train_data, batch_size)
        val_loader = DataGenerator(val_data, batch_size)

        # Training
        model.train()
        pbar = tqdm(total=n_train_steps, desc=f"Epoch {e+1}")

        for i, (image_inputs, caption_inputs) in enumerate(train_loader):

            inputs = {
                'pixel_values': image_inputs['pixel_values'].to(device),
                'labels': caption_inputs['input_ids'].to(device)
            }

            optimizer.zero_grad()
            output = model(**inputs)
            
            logits = output.logits
#             logits = logits.reshape(-1, logits.shape[2])
#             targets = caption_inputs['input_ids'].to(device).reshape(-1)
            
            loss = output.loss
#             loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()

            pbar.set_postfix({'Loss': loss.item()})
            pbar.update(1)
            epoch_train_loss.append(loss.item())

        scheduler.step()

        # Validation
        model.eval()

        with torch.no_grad():
            for i, (image_inputs, caption_inputs) in enumerate(val_loader):

                inputs = {
                    'pixel_values': image_inputs['pixel_values'].to(device),
#                     'decoder_input_ids': caption_inputs['input_ids'].to(device),
#                     'decoder_attention_mask': caption_inputs['attention_mask'].to(device),
                    'labels': caption_inputs['input_ids'].to(device)
                }

                output = model(**inputs)
                logits = output.logits
#                 logits = logits.reshape(-1, logits.shape[2])
#                 targets = caption_inputs['input_ids'].to(device).reshape(-1)

#                 loss = criterion(logits, targets)
                loss = output.loss
                epoch_val_loss.append(loss.item())

        mean_epoch_train_loss = np.array(epoch_train_loss).mean()
        mean_epoch_val_loss = np.array(epoch_val_loss).mean()

        train_losses.append(mean_epoch_train_loss)
        val_losses.append(mean_epoch_val_loss)

        pbar.set_postfix({'Train Loss': mean_epoch_train_loss, 'Val Loss': mean_epoch_val_loss})

    model.to('cpu')

    return {'model': model, 'train_losses': train_losses, 'val_losses': val_losses}
