In [None]:
!pip install transformers[sentencepiece]
!pip install wandb

In [None]:
import gc
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, EarlyStoppingCallback
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
import json
from urllib.request import urlopen
import seaborn as sns
import matplotlib.pyplot as plt
import wandb

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
# Setup device

device_string = 'cuda' if torch.cuda.is_available() else 'cpu'
device_hf = 0 if torch.cuda.is_available() else -1
device = torch.device(device_string)
print("Device:", device)
NUM_WORKERS = 0

In [None]:
# Setup wandb

wandb.login()
%env WANDB_PROJECT=annotype_text_classification

In [None]:
# Config

IGNORED_CLASSES = [] # e.g. set it to ['sentiment'] to remove the objects with the type of sentiment, from the databset
MODEL_NAME = 'distilbert-base-cased'
INPUT_TYPE = 'TEXT_HEAD' # Possible values: 'TEXT_HEAD', 'TEXT_ONLY', 'HEAD_ONLY'
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 64
LOGGING_STEPS = 100
EVAL_STRATEGY = 'steps'
SAVE_STRATEGY = 'steps'
WEIGHT_DECAY = 0.1
LOAD_BEST_MODEL_AT_END = True
NUM_TRAIN_EPOCHS = 10
CALLBACKS = [EarlyStoppingCallback(4)]
SEED = 0
DATA = 'MPQA3.0_v211021'

In [None]:
# Getting data & augmented data urls

data_name_to_google_drive_url = {
    'MPQA3.0_v211021': 'https://drive.google.com/file/d/1e-pDfZ2cyBzgD9MEerP9YCcDnPvIQuGo/view?usp=sharing',
}

# Get direct download link
def get_download_url_from_google_drive_url(google_drive_url):
    return f'https://drive.google.com/uc?id={google_drive_url.split("/")[5]}&export=download'

# Data URL
google_drive_url = data_name_to_google_drive_url[DATA]
data_url = get_download_url_from_google_drive_url(google_drive_url)

In [None]:
np.random.seed(SEED)
torch.manual_seed(SEED)

# Preparing the dataset

In [None]:
# Fetch the dataset

FETCH_FROM_WEB = True ### Set it to true, to download the datasets from github and google drive ###

if FETCH_FROM_WEB:
    response = urlopen(data_url)
    csds_collection = json.loads(response.read())
else:
    file_address = '..\\json2csds\\data.json'
    with open(file_address) as file:
        csds_collection = json.load(file)

In [None]:
# Preparing inputs and targets

inputs_text = []
inputs_head = []
inputs_tuple_text_head = []
targets_annotype = []
n_samples = 0

for csds_object in csds_collection['csds_objects']:
    if csds_object['annotation_type'] not in IGNORED_CLASSES:
        inputs_text += [csds_object['text']]
        inputs_head += [csds_object['head']]
        inputs_tuple_text_head += [(csds_object['text'], csds_object['head'])]
        targets_annotype += [csds_object['annotation_type']]
        n_samples += 1

i = 128 # A sample
print(f'inputs and targets for {i+1}-th csds object (out of {n_samples}):')
print('inputs_text:\t\t', inputs_text[i])
print('inputs_head:\t\t', inputs_head[i])
print('inputs_tuple_text_head:\t', inputs_tuple_text_head[i])
print('targets_annotype:\t', targets_annotype[i])

In [None]:
# Count the number of each annotation type and extract the labels

num_annotype = {}
for annotype in targets_annotype:
    num_annotype[annotype] = num_annotype.get(annotype, 0) + 1
print(sorted(num_annotype.items()))
classes = sorted(list(num_annotype.keys()))

In [None]:
# Create a map for class ids and class names

classname2classid = {classes[i]:i for i in range(len(classes))}
classid2classname = {i:classes[i] for i in range(len(classes))}

In [None]:
# Apply classname2classid mapping

y = [classname2classid[i] for i in targets_annotype]

In [None]:
# Shuffle and split the dataset into training and validation sets

X_train, X_val, y_train, y_val = train_test_split(
    np.array(inputs_tuple_text_head), y, test_size=0.2,
    random_state=SEED, shuffle=True, stratify=y
)

X_train_text, X_train_head = X_train[:, 0].tolist(), X_train[:, 1].tolist()
X_val_text,   X_val_head   = X_val[:, 0].tolist(),   X_val[:, 1].tolist()

# Preparing the model and torch dataset

In [None]:
# Load the model, tokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=len(classes), resume_download=True,
)

In [None]:
# Tokenize the inputs

if INPUT_TYPE == 'TEXT_HEAD':
    X_train_tokenized = tokenizer(X_train_text, X_train_head, truncation=True)
    X_val_tokenized   = tokenizer(X_val_text,   X_val_head,   truncation=True)

if INPUT_TYPE == 'TEXT_ONLY':
    X_train_tokenized = tokenizer(X_train_text, truncation=True)
    X_val_tokenized   = tokenizer(X_val_text,   truncation=True)

if INPUT_TYPE == 'HEAD_ONLY':
    X_train_tokenized = tokenizer(X_train_head, truncation=True)
    X_val_tokenized   = tokenizer(X_val_head,   truncation=True)

In [None]:
# Find the largest input size

t = 0
for i in X_train_tokenized['input_ids']:
    t = max(t, len(i))
for i in X_val_tokenized['input_ids']:
    t = max(t, len(i))
print("Maximum input length:", t)

In [None]:
# Create torch dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

train_dataset = Dataset(X_train_tokenized, y_train)
val_dataset = Dataset(X_val_tokenized, y_val)

In [None]:
# Data collator

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
# Metrics

def compute_metrics(pred):
    targets = pred.label_ids
    preds = pred.predictions.argmax(-1)
    labels = [i for i in range(len(classes))] # [0, 1, 2, ..., len(classes)-1]
    precision, recall, f1, _ = precision_recall_fscore_support(
        targets, preds, labels=labels, zero_division=0, average='macro'
    )
    precision_list, recall_list, f1_list, _ = precision_recall_fscore_support(
        targets, preds, labels=labels, zero_division=0
    )
    acc = accuracy_score(targets, preds)
    decimals = 4
    return {
        'accuracy': acc,
        'f1': np.around(f1, decimals),
        'precision': np.around(precision, decimals),
        'recall': np.round(recall, decimals),
        'f1-list': np.around(f1_list, decimals).tolist(),
        'precision-list': np.around(precision_list, decimals).tolist(),
        'recall-list': np.round(recall_list, decimals).tolist(),
    }

In [None]:
# Training Arguments

training_args = TrainingArguments(
    output_dir = 'models/pretrain_'+MODEL_NAME+'_'+INPUT_TYPE,
    overwrite_output_dir = True,
    per_device_train_batch_size = TRAIN_BATCH_SIZE,
    per_device_eval_batch_size = EVAL_BATCH_SIZE,
    evaluation_strategy = EVAL_STRATEGY,
    logging_steps = LOGGING_STEPS,
    save_strategy = SAVE_STRATEGY,
    save_steps = LOGGING_STEPS,
    save_total_limit = 2,
    weight_decay = WEIGHT_DECAY,
    num_train_epochs = NUM_TRAIN_EPOCHS,
    load_best_model_at_end = LOAD_BEST_MODEL_AT_END,
    dataloader_num_workers = NUM_WORKERS,
    seed = SEED,
    report_to = 'wandb'
)

# Train

In [None]:
# Free some space

if 'trainer' in globals():
    del trainer
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Setup trainer

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    data_collator = data_collator,
    compute_metrics = compute_metrics,
    callbacks = CALLBACKS
)

In [None]:
trainer.train()

In [None]:
pred = trainer.predict(val_dataset)

In [None]:
# Show confusion matrix

targets = pred.label_ids
preds = pred.predictions.argmax(-1)

def show_confusion_matrix(confusion_matrix):
    hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap='Blues')
    hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
    hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
    plt.ylabel('True annotation type')
    plt.xlabel('Predicted annotation type');

cm = confusion_matrix(targets, preds)
df_cm = pd.DataFrame(cm, index=classes, columns=classes)
show_confusion_matrix(df_cm)

In [None]:
wandb.finish()