# Notebook Summary
- Get BERTweet from huggingface
- Use BERTweet to
    - extract embeddings from Tweets
    - extract embeddings from concatenations of Tweet + OCR text

# 0. Imports and Constants
- Do not forget to select dataset version in the #CONSTANTS# part

In [79]:
############## AUTORELOAD MAGIC ###################
%load_ext autoreload
%autoreload 2
###################################################

############## FUNDAMENTAL MODULES ################
import json
import os
import sys
import copy
import numpy as np
import pickle
import re
##################################################

############## TASK-SPECIFIC MODULES #############
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))
from data import TweetNormalizer, utils, feature_extraction
###################################################

############## DATA SCIENCE & ML MODULES ##########
import torch
import pandas as pd
from transformers import AutoModel, AutoTokenizer
from scipy import stats
###################################################

############# CONSTANT DICT KEYS ###################
# Constant dict keys
TRAIN = "train"
DEV = "dev"
TEST = "test"
GOLD = "gold"
TXT = "txt"
IMG = "img"
OCR = "ocr"
TXT_OCR = "txt_ocr"
SPLITS = [TRAIN, DEV, TEST, GOLD]
####################################################

####################### SELECT ###########################
users = ["patriziopalmisano", "onurdenizguler", "jockl"]
user = users[2] # SELECT USER
version = "v2" # SELECT DATASET VERSION
dataset_version = version
##########################################################

if user in users[:2]:
    cw_dir = f"/Users/{user}/Library/CloudStorage/GoogleDrive-check.worthiness@gmail.com/My Drive"
    data_dir = f"{cw_dir}/data/CT23_1A_checkworthy_multimodal_english"
    data_dir_with_version = f"{data_dir}_{dataset_version}"
    gold_dir = f"{cw_dir}/data/CT23_1A_checkworthy_multimodal_english_test_gold"

else:
    cw_dir = f"/home/jockl/Insync/check.worthiness@gmail.com/Google Drive"
    data_dir = f"{cw_dir}/data/CT23_1A_checkworthy_multimodal_english"
    data_dir_with_version = f"{data_dir}_{dataset_version}"
    gold_dir = f"{cw_dir}/data/CT23_1A_checkworthy_multimodal_english_test_gold"


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# 1. Load all Datasets
First, we extract all the raw texts from the JSON files. Note that the concatenation of tweet and OCR text is realized with a new line in-between.

In [80]:
# Load the datasets
raw_dataset, tweet_texts, imgs, tweet_ids, ocr_texts, tweet_concat_ocr = utils.load_data_splits_with_gold_dataset(data_dir, version)

Sizes of txt, img, ocr, txt+ocr arrays in train, test, dev, gold:
2356 2356 2356 2356
271 271 2356 2356
548 548 2356 2356
736 736 2356 2356


In [81]:
# Inspect Tweet and OCR concatenation
print(f"Tweet:\n{tweet_texts[TRAIN][6]}")
print(f"\nOCR:\n{ocr_texts[TRAIN][6]}")
print(f"\nConcat:\n{tweet_concat_ocr[TRAIN][6]}")

Tweet:
Despite calls for calm, some local people are panicking over the deadly coronavirus outbreak.

"I really am. It's really something that I'm really frightened about right now."

"It might be actually worse than what they're telling us."

https://t.co/2aeFCPdg4T https://t.co/Stjz0XlcwA

OCR:
CORONAVIRUS
MGN


Concat:
Despite calls for calm, some local people are panicking over the deadly coronavirus outbreak.

"I really am. It's really something that I'm really frightened about right now."

"It might be actually worse than what they're telling us."

https://t.co/2aeFCPdg4T https://t.co/Stjz0XlcwA
CORONAVIRUS
MGN



# 2. Normalize Texts

In [82]:
# Normalize all tweets using TweetNormalizer()
normalized_tweets = {split: [TweetNormalizer.normalizeTweet(tweet) for tweet in tweet_texts[split]] for split in SPLITS}
normalized_tweet_concat_ocr = {split: [TweetNormalizer.normalizeTweet(concat) for concat in tweet_concat_ocr[split]] for split in SPLITS}
print(len(normalized_tweets[TRAIN]))
print(len(normalized_tweet_concat_ocr[TRAIN]))

2356
2356


In [83]:
# Inspect normalization of tweets
print(f"Tweet:\n{tweet_texts[TRAIN][6]}")
print(f"\nNormalized Tweet:\n{normalized_tweets[TRAIN][6]}")

Tweet:
Despite calls for calm, some local people are panicking over the deadly coronavirus outbreak.

"I really am. It's really something that I'm really frightened about right now."

"It might be actually worse than what they're telling us."

https://t.co/2aeFCPdg4T https://t.co/Stjz0XlcwA

Normalized Tweet:
Despite calls for calm , some local people are panicking over the deadly coronavirus outbreak . " I really am . It 's really something that I 'm really frightened about right now . " " It might be actually worse than what they 're telling us . " HTTPURL HTTPURL


In [84]:
# Inspect normalization of tweet_ocr_concat
print(f"Concat:\n{tweet_texts[TRAIN][6]}")
print(f"Normalized Concat:\n{normalized_tweet_concat_ocr[TRAIN][6]}")

Concat:
Despite calls for calm, some local people are panicking over the deadly coronavirus outbreak.

"I really am. It's really something that I'm really frightened about right now."

"It might be actually worse than what they're telling us."

https://t.co/2aeFCPdg4T https://t.co/Stjz0XlcwA
Normalized Concat:
Despite calls for calm , some local people are panicking over the deadly coronavirus outbreak . " I really am . It 's really something that I 'm really frightened about right now . " " It might be actually worse than what they 're telling us . " HTTPURL HTTPURL CORONAVIRUS MGN


# 3. Set Up BERTweet and Embed Minimal Example

In [85]:
# Set up devicde
device = "cuda" if torch.cuda.is_available() else \
         ("mps" if torch.backends.mps.is_available() else "cpu")

In [86]:
# Get the model
bertweet = AutoModel.from_pretrained("vinai/bertweet-base").to(device)

Some weights of the model checkpoint at vinai/bertweet-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [87]:
# Set up the tokenizer
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", use_fast=False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [88]:
# Encode one example text
input_text = normalized_tweets[TRAIN][1]
input_ids = torch.tensor([tokenizer.encode(input_text)])
print(input_text)
print(input_ids.shape)

Chicago Cardinal : Global Warming , Migrants Are ‘ Bigger Agenda ' than Sex Abuse HTTPURL via @USER HTTPURL
torch.Size([1, 21])


In [89]:
# Embed the example
with torch.no_grad():
    features = bertweet(input_ids)
print(features.pooler_output.shape)

torch.Size([1, 768])


# 4. Resolve the Sequence Length Issue
Maximum number of tokens for BERTweet input: 128

In [90]:
# Set token limit and padding ID
token_limit = 128
padding_idx = 1

As we can see, the blunt attempt of tokenizing yields the warning that BERTweet only accepts input sequences of length < 128.

In [91]:
# First, tokenize everything
normalized_tweets_tokenized = {split: tokenizer(normalized_tweets[split], padding=True, return_tensors="pt")["input_ids"] for split in SPLITS}
normalized_concat_tokenized = {split: tokenizer(normalized_tweet_concat_ocr[split], padding=True, return_tensors="pt")["input_ids"] for split in SPLITS}

Token indices sequence length is longer than the specified maximum sequence length for this model (138 > 128). Running this sequence through the model will result in indexing errors


While the plain tweets never exceed the token limit, at least one sample from the tweet+OCR concatenations conatains 3586 tokens. This results in padding any sequence to this number.  

In [92]:
# Inspect tokenized samples
print(normalized_tweets_tokenized[TRAIN].shape)
print(normalized_concat_tokenized[TRAIN].shape)

torch.Size([2356, 116])
torch.Size([2356, 3586])


Which samples exactly feature excess tokens?

In [93]:
# Find all the examples with excess tokens
split_to_tweets_with_excess_tokens = utils.get_samples_with_excess_tokens(token_limit, normalized_tweets_tokenized, padding_idx)
split_to_concat_with_excess_tokens = utils.get_samples_with_excess_tokens(token_limit, normalized_concat_tokenized, padding_idx)
print(f"(ID, Length, Excess Tokens)")
print(f"Tweets: {len(split_to_tweets_with_excess_tokens[TRAIN])}\t{split_to_tweets_with_excess_tokens[TRAIN]}")
print(f"Concat: {len(split_to_concat_with_excess_tokens[TRAIN])}\t{split_to_concat_with_excess_tokens[TRAIN]}")

(ID, Length, Excess Tokens)
Tweets: 0	[]
Concat: 296	[(651, 129, 1), (1144, 129, 1), (1985, 129, 1), (526, 130, 2), (1410, 130, 2), (1454, 130, 2), (1738, 130, 2), (1338, 131, 3), (1978, 131, 3), (769, 132, 4), (1615, 132, 4), (67, 133, 5), (471, 133, 5), (1070, 133, 5), (1136, 133, 5), (1519, 133, 5), (1984, 133, 5), (1991, 133, 5), (330, 134, 6), (367, 134, 6), (1715, 134, 6), (2007, 134, 6), (120, 135, 7), (2045, 135, 7), (2205, 135, 7), (318, 136, 8), (672, 136, 8), (137, 137, 9), (1770, 137, 9), (279, 138, 10), (1053, 138, 10), (1429, 138, 10), (59, 139, 11), (762, 139, 11), (1902, 139, 11), (830, 140, 12), (2130, 140, 12), (1097, 141, 13), (1821, 141, 13), (785, 142, 14), (1700, 142, 14), (2290, 142, 14), (215, 143, 15), (1408, 143, 15), (1710, 144, 16), (2006, 144, 16), (631, 148, 20), (1425, 149, 21), (1956, 149, 21), (2138, 149, 21), (109, 150, 22), (1026, 150, 22), (1780, 150, 22), (371, 151, 23), (862, 151, 23), (2016, 151, 23), (2222, 151, 23), (657, 152, 24), (51, 153, 25)

Worst case: We truncate 296 examples (for the train set) from the tweet+OCR concatenations. Not a lot.

Let's now further normalize the samples that feature excess tokens. If this further normalizations still yields a sequence that is too long, the following function automatically truncates it.

In [94]:
# Further normalize/truncate those examples with excess tokens
final_normalized_tweets = utils.further_normalize_samples_with_excess_tokens(token_limit, normalized_tweets, split_to_tweets_with_excess_tokens, normalized_tweets_tokenized, tokenizer, padding_idx)
final_normalized_concat = utils.further_normalize_samples_with_excess_tokens(token_limit, normalized_tweet_concat_ocr, split_to_concat_with_excess_tokens, normalized_concat_tokenized, tokenizer, padding_idx)

In [95]:
# Tokenize and encode final texts
tweets_encoded = {split: tokenizer(final_normalized_tweets[split], padding=True, return_tensors="pt")["input_ids"] for split in SPLITS}
concat_encoded = {split: tokenizer(final_normalized_concat[split], padding=True, return_tensors="pt")["input_ids"] for split in SPLITS}

In [96]:
# Inspect tokenized split before/after normalization and truncation
print(f"Tweet: {normalized_tweets_tokenized[TRAIN].shape}\t{tweets_encoded[TRAIN].shape}")
print(f"Concat: {normalized_concat_tokenized[TRAIN].shape}\t{concat_encoded[TRAIN].shape}")

Tweet: torch.Size([2356, 116])	torch.Size([2356, 116])
Concat: torch.Size([2356, 3586])	torch.Size([2356, 128])


Perfect! Now, the concat sequences are padded only up to BERTweet's limit of 128.

# 5. Embed the Tokenized Samples

First, we set the batch size with which BERTweet should embed our encoded data splits.

In [97]:
# Set batch size here
batch_size = 8

## 5.1 Embed the Tweets Only

In [98]:
# Embed every split
for split in SPLITS:
    utils.embed_and_pickle_split_with_bertweet(bertweet, data_dir_with_version, split, tweets_encoded[split], with_ocr=False, batch_size=batch_size, device=device)

Split: train
Num samples: 2356
Num batches: 294
train batch 0/294
train batch 1/294
train batch 2/294
train batch 3/294
train batch 4/294
train batch 5/294
train batch 6/294
train batch 7/294
train batch 8/294
train batch 9/294
train batch 10/294
train batch 11/294
train batch 12/294
train batch 13/294
train batch 14/294
train batch 15/294
train batch 16/294
train batch 17/294
train batch 18/294
train batch 19/294
train batch 20/294
train batch 21/294
train batch 22/294
train batch 23/294
train batch 24/294
train batch 25/294
train batch 26/294
train batch 27/294
train batch 28/294
train batch 29/294
train batch 30/294
train batch 31/294
train batch 32/294
train batch 33/294
train batch 34/294
train batch 35/294
train batch 36/294
train batch 37/294
train batch 38/294
train batch 39/294
train batch 40/294
train batch 41/294
train batch 42/294
train batch 43/294
train batch 44/294
train batch 45/294
train batch 46/294
train batch 47/294
train batch 48/294
train batch 49/294
train batch 

In [99]:
# Embed every split
for split in SPLITS:
    utils.embed_and_pickle_split_with_bertweet(bertweet, data_dir_with_version, split, concat_encoded[split], with_ocr=True, batch_size=batch_size, device=device)

Split: train
Num samples: 2356
Num batches: 294
train batch 0/294
train batch 1/294
train batch 2/294
train batch 3/294
train batch 4/294
train batch 5/294
train batch 6/294
train batch 7/294
train batch 8/294
train batch 9/294
train batch 10/294
train batch 11/294
train batch 12/294
train batch 13/294
train batch 14/294
train batch 15/294
train batch 16/294
train batch 17/294
train batch 18/294
train batch 19/294
train batch 20/294
train batch 21/294
train batch 22/294
train batch 23/294
train batch 24/294
train batch 25/294
train batch 26/294
train batch 27/294
train batch 28/294
train batch 29/294
train batch 30/294
train batch 31/294
train batch 32/294
train batch 33/294
train batch 34/294
train batch 35/294
train batch 36/294
train batch 37/294
train batch 38/294
train batch 39/294
train batch 40/294
train batch 41/294
train batch 42/294
train batch 43/294
train batch 44/294
train batch 45/294
train batch 46/294
train batch 47/294
train batch 48/294
train batch 49/294
train batch 

# 6. Load the Pickled Embeddings
Example Usage:

In [100]:
# Load embedding tensor of one data split from pickle file
pickle_file = f"{data_dir_with_version}/BERTweet_embeddings_with_ocr_{GOLD}.pickle"
with open(pickle_file, 'rb') as handle:
    embeddings_tensor = pickle.load(handle)
    print(embeddings_tensor.shape)

torch.Size([736, 768])
