# End-to-end

This notebook should form the core skeleton of the 'run' function

### Imports

In [35]:
# -- public imports

from transformers import AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
from torch.utils.data import DataLoader
import torch
from pandas.testing import assert_frame_equal
import time

# -- private imports
from colabtools.utils import move_to_device

# -- dev imports
%load_ext autoreload
%autoreload 2

from argminer.data import ArgumentMiningDataset, TUDarmstadtProcessor
from argminer.evaluation import inference

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [36]:
# constants (these will be abstracted away by inputs that you give to run)
model_name = 'google/bigbird-roberta-base'
max_length = 1024
epochs = 1
batch_size = 2
strategy = 'standard_bio'
strat_name = strategy.split('_')[1]
DEVICE = 'cpu'


### Tokenizer, Model and Optimizer

In [22]:
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=15) 
# TODO force_download
# TODO add option for optimizer
optimizer = torch.optim.Adam(params=model.parameters())

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Some weights of the model checkpoint at google/bigbird-roberta-base were not used when initializing BigBirdForTokenClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BigBirdForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BigBirdForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a Bert

### Dataset 
Note this will change as the Processor develops. On the cluster you will need to use different options

In [23]:
processor = TUDarmstadtProcessor('../../data/UCL/dataset2/ArgumentAnnotatedEssays-2.0/brat-project-final')
processor = processor.preprocess().process(strat_name).postprocess()
df_total = processor.dataframe
df_train = df_total[['text', 'labels']].head(201) 
df_test = df_total[['text', 'labels']].tail(201)

assert_frame_equal(df_total[['text', 'labels']], pd.concat([df_train, df_test]))

# todo this changes NOTE FIXED BT STRATEGY!!
df_label_map = pd.DataFrame({
    'label_id':[0,1,2,3,4,5,6],
    'label':['O', 'B-MajorClaim', 'I-MajorClaim', 'B-Claim', 'I-Claim', 'B-Premise', 'I-Premise']
})

assert set(df_train.labels.values[0]) == set(df_label_map.label)

Found non-matching segments:--------------------------------------------------

murdering criminals is therefore immoral and hard to accept

"murdering" criminals is therefore immoral and hard to accept

Found non-matching segments:--------------------------------------------------

Click is a very interesting comedy, with a serious approach about the importance of having a balanced life between family and work businesses

"Click" is a very interesting comedy, with a serious approach about the importance of having a balanced life between family and work businesses

Found non-matching segments:--------------------------------------------------

Blood diamond, an adaptation of a real story in South Africa, focuses on the link between diamonds and conflict

"Blood diamond", an adaptation of a real story in South Africa, focuses on the link between diamonds and conflict

Found non-matching segments:--------------------------------------------------

rush hours are usually not the direct co

In [33]:
train_set = ArgumentMiningDataset(df_label_map, df_train, tokenizer, max_length, strategy)
test_set = ArgumentMiningDataset(df_label_map, df_test, tokenizer, max_length, strategy, is_train=False)

train_loader = DataLoader(train_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

In [None]:
model.to(DEVICE)
print(f'Model pushed to device: {DEVICE}')
for epoch in range(epochs):
    model.train()
    start_epoch_message = f'EPOCH {epoch + 1} STARTED'
    print(start_epoch_message)
    print(f'{"-" * len(start_epoch_message)}')
    start_epoch = time.time()

    start_load = time.time()
    training_loss = 0
    for i, (inputs, targets) in enumerate(train_loader):
        start_train = time.time()
        inputs = move_to_device(inputs, DEVICE)
        targets = move_to_device(targets, DEVICE)
        if DEVICE != 'cpu':
            print(f'GPU Utilisation at batch {i+1} after data loading: {get_gpu_utilization()}')

        optimizer.zero_grad()

        loss, outputs = model(
            labels=targets,
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            return_dict=False
        )
        if DEVICE != 'cpu':
            print(f'GPU Utilisation at batch {i+1} after training: {get_gpu_utilization()}')


        training_loss += loss.item()

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        del targets, inputs, loss, outputs
        gc.collect()
        torch.cuda.empty_cache()

        end_train = time.time()

        if verbose > 1:
            print(
                f'Batch {i + 1} complete. Time taken: load({start_train - start_load:.3g}), '
                f'train({end_train - start_train:.3g}), total({end_train - start_load:.3g}). '
            )
        start_load = time.time()

    print_message = f'Epoch {epoch + 1}/{epochs} complete. ' \
                    f'Time taken: {start_load - start_epoch:.3g}. ' \
                    f'Loss: {training_loss/(i+1): .3g}'

    if verbose:
        print(f'{"-" * len(print_message)}')
        print(print_message)
        print(f'{"-" * len(print_message)}')

    if epoch % save_freq == 0:
        encoded_model_name = encode_model_name(model_name, epoch+1)
        save_path = f'models/{encoded_model_name}.pt'
        torch.save(model.state_dict(), save_path)
        print(f'Model saved at epoch {epoch+1} at: {save_path}')

encoded_model_name = encode_model_name(model_name, 'final')
save_path = f'models/{encoded_model_name}.pt'
torch.save(model.state_dict(), save_path)
print(f'Model saved at epoch {epoch + 1} at: {save_path}')