**Approach-<br>- Regular expression clean-up<br>- Spacy Lemmatization<br>- BERT Transformer**

In [1]:
!pip install pytorch-pretrained-bert pytorch-nlp

Collecting pytorch-pretrained-bert
  Downloading pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123 kB)
[K     |████████████████████████████████| 123 kB 1.3 MB/s 
[?25hCollecting pytorch-nlp
  Downloading pytorch_nlp-0.5.0-py3-none-any.whl (90 kB)
[K     |████████████████████████████████| 90 kB 3.0 MB/s 
Installing collected packages: pytorch-pretrained-bert, pytorch-nlp
Successfully installed pytorch-nlp-0.5.0 pytorch-pretrained-bert-0.6.2


In [2]:
import warnings
warnings.filterwarnings('ignore')

import regex as re
import string
import random
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tensorflow.keras.backend as K

import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler, RandomSampler
from pytorch_pretrained_bert import BertTokenizer, BertConfig
from pytorch_pretrained_bert import BertForSequenceClassification, BertAdam

import spacy

In [3]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
#set display option
pd.options.display.max_colwidth = 100

seed_val=42
tf.random.set_seed(seed_val)
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

**Data Reading and Understanding**

In [6]:
train_df = pd.read_csv('/kaggle/input/nlp-getting-started/train.csv')
train_df.head()

Unnamed: 0,id,keyword,location,text,target
0,1,,,Our Deeds are the Reason of this #earthquake May ALLAH Forgive us all,1
1,4,,,Forest fire near La Ronge Sask. Canada,1
2,5,,,All residents asked to 'shelter in place' are being notified by officers. No other evacuation or...,1
3,6,,,"13,000 people receive #wildfires evacuation orders in California",1
4,7,,,Just got sent this photo from Ruby #Alaska as smoke from #wildfires pours into a school,1


In [7]:
print (train_df.info())
print ("# of training records: ", len(train_df))

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7613 entries, 0 to 7612
Data columns (total 5 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   id        7613 non-null   int64 
 1   keyword   7552 non-null   object
 2   location  5080 non-null   object
 3   text      7613 non-null   object
 4   target    7613 non-null   int64 
dtypes: int64(2), object(3)
memory usage: 297.5+ KB
None
# of training records:  7613


In [8]:
#to see distribution of target labels
train_df['target'].value_counts(normalize=True)

0    0.57034
1    0.42966
Name: target, dtype: float64

In [9]:
#read test data
test_df = pd.read_csv('/kaggle/input/nlp-getting-started/test.csv')
test_df.head()

Unnamed: 0,id,keyword,location,text
0,0,,,Just happened a terrible car crash
1,2,,,"Heard about #earthquake is different cities, stay safe everyone."
2,3,,,"there is a forest fire at spot pond, geese are fleeing across the street, I cannot save them all"
3,9,,,Apocalypse lighting. #Spokane #wildfires
4,11,,,Typhoon Soudelor kills 28 in China and Taiwan


**Look at few of the tweets**

In [10]:
train_df.head()[['text','target']]

Unnamed: 0,text,target
0,Our Deeds are the Reason of this #earthquake May ALLAH Forgive us all,1
1,Forest fire near La Ronge Sask. Canada,1
2,All residents asked to 'shelter in place' are being notified by officers. No other evacuation or...,1
3,"13,000 people receive #wildfires evacuation orders in California",1
4,Just got sent this photo from Ruby #Alaska as smoke from #wildfires pours into a school,1


**Checking- if features keyword and location are useful and should be retained**

In [11]:
train_df[~train_df['keyword'].isnull()][['keyword', 'text']]

Unnamed: 0,keyword,text
31,ablaze,@bbcmtd Wholesale Markets ablaze http://t.co/lHYXEOHY6C
32,ablaze,We always try to bring the heavy. #metal #RT http://t.co/YAo1e0xngw
33,ablaze,#AFRICANBAZE: Breaking news:Nigeria flag set ablaze in Aba. http://t.co/2nndBGwyEi
34,ablaze,Crying out for more! Set me ablaze
35,ablaze,On plus side LOOK AT THE SKY LAST NIGHT IT WAS ABLAZE http://t.co/qqsmshaJ3N
...,...,...
7578,wrecked,@jt_ruff23 @cameronhacker and I wrecked you both
7579,wrecked,Three days off from work and they've pretty much all been wrecked hahaha shoutout to my family f...
7580,wrecked,#FX #forex #trading Cramer: Iger's 3 words that wrecked Disney's stock http://t.co/7enNulLKzM
7581,wrecked,@engineshed Great atmosphere at the British Lion gig tonight. Hearing is wrecked. http://t.co/oM...


In [12]:
len(train_df['location'].unique())

3342

**Dropping the features: keyword and location**

In [13]:
train_df.drop(columns=['keyword', 'location'], inplace=True)

**Data Cleaning and Preprocessing**

In [14]:
spacy_en = spacy.load('en_core_web_sm', disable=['parser','ner'])
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

100%|██████████| 231508/231508 [00:00<00:00, 848748.47B/s]


In [15]:
abbreviations = {
    "$" : " dollar ",
    "€" : " euro ",
    "4ao" : "for adults only",
    "a.m" : "before midday",
    "a3" : "anytime anywhere anyplace",
    "aamof" : "as a matter of fact",
    "acct" : "account",
    "adih" : "another day in hell",
    "afaic" : "as far as i am concerned",
    "afaict" : "as far as i can tell",
    "afaik" : "as far as i know",
    "afair" : "as far as i remember",
    "afk" : "away from keyboard",
    "app" : "application",
    "approx" : "approximately",
    "apps" : "applications",
    "asap" : "as soon as possible",
    "asl" : "age, sex, location",
    "atk" : "at the keyboard",
    "ave." : "avenue",
    "aymm" : "are you my mother",
    "ayor" : "at your own risk", 
    "b&b" : "bed and breakfast",
    "b+b" : "bed and breakfast",
    "b.c" : "before christ",
    "b2b" : "business to business",
    "b2c" : "business to customer",
    "b4" : "before",
    "b4n" : "bye for now",
    "b@u" : "back at you",
    "bae" : "before anyone else",
    "bak" : "back at keyboard",
    "bbbg" : "bye bye be good",
    "bbc" : "british broadcasting corporation",
    "bbias" : "be back in a second",
    "bbl" : "be back later",
    "bbs" : "be back soon",
    "be4" : "before",
    "bfn" : "bye for now",
    "blvd" : "boulevard",
    "bout" : "about",
    "brb" : "be right back",
    "bros" : "brothers",
    "brt" : "be right there",
    "bsaaw" : "big smile and a wink",
    "btw" : "by the way",
    "bwl" : "bursting with laughter",
    "c/o" : "care of",
    "cet" : "central european time",
    "cf" : "compare",
    "cia" : "central intelligence agency",
    "csl" : "can not stop laughing",
    "cu" : "see you",
    "cul8r" : "see you later",
    "cv" : "curriculum vitae",
    "cwot" : "complete waste of time",
    "cya" : "see you",
    "cyt" : "see you tomorrow",
    "dae" : "does anyone else",
    "dbmib" : "do not bother me i am busy",
    "diy" : "do it yourself",
    "dm" : "direct message",
    "dwh" : "during work hours",
    "e123" : "easy as one two three",
    "eet" : "eastern european time",
    "eg" : "example",
    "embm" : "early morning business meeting",
    "encl" : "enclosed",
    "encl." : "enclosed",
    "etc" : "and so on",
    "faq" : "frequently asked questions",
    "fawc" : "for anyone who cares",
    "fb" : "facebook",
    "fc" : "fingers crossed",
    "fig" : "figure",
    "fimh" : "forever in my heart", 
    "ft." : "feet",
    "ft" : "featuring",
    "ftl" : "for the loss",
    "ftw" : "for the win",
    "fwiw" : "for what it is worth",
    "fyi" : "for your information",
    "g9" : "genius",
    "gahoy" : "get a hold of yourself",
    "gal" : "get a life",
    "gcse" : "general certificate of secondary education",
    "gfn" : "gone for now",
    "gg" : "good game",
    "gl" : "good luck",
    "glhf" : "good luck have fun",
    "gmt" : "greenwich mean time",
    "gmta" : "great minds think alike",
    "gn" : "good night",
    "g.o.a.t" : "greatest of all time",
    "goat" : "greatest of all time",
    "goi" : "get over it",
    "gps" : "global positioning system",
    "gr8" : "great",
    "gratz" : "congratulations",
    "gyal" : "girl",
    "h&c" : "hot and cold",
    "hp" : "horsepower",
    "hr" : "hour",
    "hrh" : "his royal highness",
    "ht" : "height",
    "ibrb" : "i will be right back",
    "ic" : "i see",
    "icq" : "i seek you",
    "icymi" : "in case you missed it",
    "idc" : "i do not care",
    "idgadf" : "i do not give a damn fuck",
    "idgaf" : "i do not give a fuck",
    "idk" : "i do not know",
    "ie" : "that is",
    "i.e" : "that is",
    "ifyp" : "i feel your pain",
    "IG" : "instagram",
    "iirc" : "if i remember correctly",
    "ilu" : "i love you",
    "ily" : "i love you",
    "imho" : "in my humble opinion",
    "imo" : "in my opinion",
    "imu" : "i miss you",
    "iow" : "in other words",
    "irl" : "in real life",
    "j4f" : "just for fun",
    "jic" : "just in case",
    "jk" : "just kidding",
    "jsyk" : "just so you know",
    "l8r" : "later",
    "lb" : "pound",
    "lbs" : "pounds",
    "ldr" : "long distance relationship",
    "lmao" : "laugh my ass off",
    "lmfao" : "laugh my fucking ass off",
    "lol" : "laughing out loud",
    "ltd" : "limited",
    "ltns" : "long time no see",
    "m8" : "mate",
    "mf" : "motherfucker",
    "mfs" : "motherfuckers",
    "mfw" : "my face when",
    "mofo" : "motherfucker",
    "mph" : "miles per hour",
    "mr" : "mister",
    "mrw" : "my reaction when",
    "ms" : "miss",
    "mte" : "my thoughts exactly",
    "nagi" : "not a good idea",
    "nbc" : "national broadcasting company",
    "nbd" : "not big deal",
    "nfs" : "not for sale",
    "ngl" : "not going to lie",
    "nhs" : "national health service",
    "nrn" : "no reply necessary",
    "nsfl" : "not safe for life",
    "nsfw" : "not safe for work",
    "nth" : "nice to have",
    "nvr" : "never",
    "nyc" : "new york city",
    "oc" : "original content",
    "og" : "original",
    "ohp" : "overhead projector",
    "oic" : "oh i see",
    "omdb" : "over my dead body",
    "omg" : "oh my god",
    "omw" : "on my way",
    "p.a" : "per annum",
    "p.m" : "after midday",
    "pm" : "prime minister",
    "poc" : "people of color",
    "pov" : "point of view",
    "pp" : "pages",
    "ppl" : "people",
    "prw" : "parents are watching",
    "ps" : "postscript",
    "pt" : "point",
    "ptb" : "please text back",
    "pto" : "please turn over",
    "qpsa" : "what happens", #"que pasa",
    "ratchet" : "rude",
    "rbtl" : "read between the lines",
    "rlrt" : "real life retweet", 
    "rofl" : "rolling on the floor laughing",
    "roflol" : "rolling on the floor laughing out loud",
    "rotflmao" : "rolling on the floor laughing my ass off",
    "rt" : "retweet",
    "ruok" : "are you ok",
    "sfw" : "safe for work",
    "sk8" : "skate",
    "smh" : "shake my head",
    "sq" : "square",
    "srsly" : "seriously", 
    "ssdd" : "same stuff different day",
    "tbh" : "to be honest",
    "tbs" : "tablespooful",
    "tbsp" : "tablespooful",
    "tfw" : "that feeling when",
    "thks" : "thank you",
    "tho" : "though",
    "thx" : "thank you",
    "tia" : "thanks in advance",
    "til" : "today i learned",
    "tl;dr" : "too long i did not read",
    "tldr" : "too long i did not read",
    "tmb" : "tweet me back",
    "tntl" : "trying not to laugh",
    "ttyl" : "talk to you later",
    "u" : "you",
    "u2" : "you too",
    "u4e" : "yours for ever",
    "utc" : "coordinated universal time",
    "w/" : "with",
    "w/o" : "without",
    "w8" : "wait",
    "wassup" : "what is up",
    "wb" : "welcome back",
    "wtf" : "what the fuck",
    "wtg" : "way to go",
    "wtpa" : "where the party at",
    "wuf" : "where are you from",
    "wuzup" : "what is up",
    "wywh" : "wish you were here",
    "yd" : "yard",
    "ygtr" : "you got that right",
    "ynk" : "you never know",
    "zzz" : "sleeping bored and tired"
}

In [16]:
special_characters = {
  "SuruÌ¤":"Suruc",
  "JapÌ_n":"Japan"  ,
  "\x89ÛÏWhen":"When",
  "å£3million":"3 million",
  "fromåÊwounds":"from wounds",
  "mÌ¼sica":"music",
  "donå«t":"do not",
  "didn`t":"did not",
  "i\x89Ûªm":"I am",
  "I\x89Ûªm":"I am",
  "it\x89Ûªs":"it is",
  "It\x89Ûªs":"It is",
  "i\x89Ûªd":"I would",
  "I\x89Ûªd":"I would",
  "i\x89Ûªve":"I have",
  "I\x89Ûªve":"I have",
  "let\x89Ûªs":"let us",
  "don\x89Ûªt":"do not",
  "Don\x89Ûªt":"Do not",
  "can\x89Ûªt":"cannot",
  "Can\x89Ûªt":"Cannot",
  "that\x89Ûªs":"that is",
  "That\x89Ûªs":"That is",
  "here\x89Ûªs":"here is",
  "Here\x89Ûªs":"Here is",
  "you\x89Ûªre":"you are",
  "You\x89Ûªre":"You are",
  "you\x89Ûªve":"you have",
  "You\x89Ûªve":"You have",
  "you\x89Ûªll":"you will",
  "You\x89Ûªll":"You will",
  "China\x89Ûªs":"China's",
  "doesn\x89Ûªt":"does not",
  "wouldn\x89Ûªt":"would not",
  "\x89Û_":"",
  "\x89Û¢":"",
  "\x89ÛÒ":"",
  "\x89ÛÓ":"",
  "\x89ÛÏ":"",
  "\x89Û÷":"",
  "\x89Ûª":"",
  "\x89Û¢åÊ":"",
  "\x89Û\x9d":"",
  "å_":"",
  "å¨":"",
  "åÀ":"",
  "åÇ":"",
  "åÊ":"",
  "åÈ":""  ,
  "Ì©":"",
  "&lt;":"<",
  "&gt;":">",
  "&amp;":"&"    
}

In [17]:
expand_contractions = {
  "I'm":"I am",
  "I'M":"I am",
  "i'm":"I am",
  "i'M":"I am",
  "i'd":"I would",
  "I'd":"I would",
  "i'll":"I will",
  "I'll":"I will",
  "i've":"I have",
  "I've":"I have",
  "you're":"you are",
  "You're":"You are",
  "you'd":"you would",
  "You'd":"You would",
  "you've":"you have",
  "You've":"You have",
  "you'll":"you will",
  "You'll":"You will"  ,
  "y'know":"you know"  ,
  "Y'know":"You know"  ,
  "y'all":"you all",
  "Y'all":"You all",
  "we're":"we are",
  "We're":"We are",
  "we've":"we have",
  "We've":"We have" ,
  "we'd":"we would",
  "We'd":"We would",
  "WE'VE":"We have",
  "we'll":"we will",
  "We'll":"We will",
  "they're":"they are",
  "They're":"They are",
  "they'd":"they would",
  "They'd":"They would"  ,
  "they've":"they have",
  "They've":"They have",
  "they'll":"they will",
  "They'll":"They will",
  "he's":"he is",
  "He's":"He is",
  "he'll":"he will",
  "He'll":"He will",
  "she's":"she is",
  "She's":"She is",
  "she'll":"she will",
  "She'll":"She will",
  "it's":"it is",
  "It's":"It is",
  "it'll":"it will",
  "It'll":"It will",
  "isn't":"is not",
  "Isn't":"Is not",
  "who's":"who is",
  "Who's":"Who is",
  "what's":"what is",
  "What's":"What is",
  "that's":"that is",
  "That's":"That is",
  "here's":"here is",
  "Here's":"Here is",
  "there's":"there is",
  "There's":"There is",
  "where's":"where is",
  "Where's":"Where is"  ,
  "wHeRE's":"where is" ,
  "how's":"how is"  ,
  "How's":"How is"  ,
  "how're":"how are"  ,
  "How're":"How are" ,
  "let's":"let us",
  "Let's":"Let us",
  "won't":"will not",
  "wasn't":"was not",
  "aren't":"are not",
  "couldn't":"could not",
  "shouldn't":"should not",
  "haven't":"have not",
  "Haven't":"Have not",
  "hasn't":"has not",
  "wouldn't":"would not",
  "weren't":"were not",
  "Weren't":"Were not",
  "ain't":"am not",
  "Ain't":"am not",
  "don't":"do not",
  "Don't":"do not",
  "DON'T":"Do not",
  "didn't":"did not",
  "Didn't":"Did not",
  "DIDN'T":"Did not",
  "doesn't":"does not",
  "can't":"cannot",
  "Can't":"Cannot",
  "Could've":"Could have",
  "should've":"should have",
  "would've":"would have"
}

In [18]:
informal_abbreviations = {
  "b/c":"because",
  "w/e":"whatever",
  "w/out":"without",
  "w/o":"without",
  "w/":"with ",
  "<3":"love",
  "c/o":"care of",
  "p/u":"pick up",
  "\n":" "
}

In [19]:
smileys = {
  "\:33333" : "smile",    # :33333
  "\:\)\)\)\)" : "smile", # :))))
  "\:\)\)\)" : "smile", # :)))
  "\:\)\)" : "smile",   # :))
  "\:-\)" : "smile",   # :-)
  "\;-\)" : "smile",   # ;-)
  "3\-D" : "smile",  # 3-D
  "\:O" : "smile",   # :O
  "\:D" : "smile",   # :D
  "\:P" : "smile",   # :P
  "\:p" : "smile",   # :p
  "\;\)" : "smile",  # ;)
  "\:\)" : "smile",  # :)
  "\=\)" : "smile",  # =)
  "\^\^" : "smile",  # ^^
  "\:-\(" : "sad",   # :-(
  "\:\(" : "sad",    # :(
  "\=\(" : "sad",    # =(
  "\-\_\_\-" : "",   # -__-
  "\.\_\." : "",     # ._.
  "T\_T" : "",       # T_T    
}

In [20]:
def clean_text(text):  
  cleaned_text = text.lower()
  
  #substitute web address with string "URL"
  cleaned_text = re.sub(r'https?:\S+|www\.\S+', '', cleaned_text)

  #remove html tags
  cleaned_text = re.sub(r'<.*?>', '', cleaned_text)

  #remove non-ascii characters
  cleaned_text = ''.join(ch for ch in cleaned_text if ch in string.printable)

  #replace abbreviation
  cleaned_text = ' '.join(abbreviations[word] if word in abbreviations else word for word in cleaned_text.split())

  #substitue @mention with string "USER"
  cleaned_text = re.sub(r'@\S+', 'USER', cleaned_text)

  #substitute numerics with number "NUMBER"
  cleaned_text = re.sub(r'\d+?[.,\d]*', 'NUMBER', cleaned_text)

  #remove emoji's
  emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F" 
                           u"\U0001F300-\U0001F5FF" 
                           u"\U0001F680-\U0001F6FF" 
                           u"\U0001F1E0-\U0001F1FF" 
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
  cleaned_text = emoji_pattern.sub(r'', cleaned_text)

  #remove extra white spaces
  cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()

  cleaned_text = word_tokenize(cleaned_text)
  cleaned_text = [abbreviations[word] if word in abbreviations else word for word in cleaned_text]
  cleaned_text = [special_characters[word] if word in special_characters else word for word in cleaned_text]
  cleaned_text = [expand_contractions[word] if word in expand_contractions else word for word in cleaned_text]
  cleaned_text = [informal_abbreviations[word] if word in informal_abbreviations else word for word in cleaned_text]
  cleaned_text = [smileys[word] if word in smileys else word for word in cleaned_text]
  cleaned_text = ' '.join([word for word in cleaned_text])
    
  doc = spacy_en(cleaned_text)
  cleaned_text = ' '.join([token.lemma_ for token in doc if token.lemma_ not in ['-PRON-'] and token.pos_ not in ['DET','-PRON']])  
  return cleaned_text

In [21]:
#Bert needs special token at the front and end
def add_special_token(text):
  cleaned_text = "[CLS] " + text + " [SEP]"   
  return cleaned_text

In [22]:
train_df['cleaned_text'] = np.vectorize(clean_text)(train_df['text'])
train_df['cleaned_text'] = np.vectorize(add_special_token)(train_df['cleaned_text'])

In [23]:
#constants
MAX_LEN = 128
BATCH_SZ=32

In [24]:
def generate_input_attention_mask(tweets):   
  tokenized_tweets = [bert_tokenizer.tokenize(tweet) for tweet in tweets]
  input_ids = [bert_tokenizer.convert_tokens_to_ids(x) for x in tokenized_tweets]
  input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

  attention_masks = []  
  for seq in input_ids:
    seq_mask = [float(i>0) for i in seq]
    attention_masks.append(seq_mask)

  return input_ids, attention_masks

In [25]:
train_input_ids, train_attention_masks = generate_input_attention_mask(train_df['cleaned_text'])

**Modelling**

In [26]:
#Custom metric f1-score to be used for monitoring
def recall(y_true, y_pred):    
    true_positives = K.sum(K.round(y_true * y_pred))
    possible_positives = K.sum(y_true)
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision(y_true, y_pred):
    true_positives = K.sum(K.round(y_true * y_pred))
    predicted_positives = K.sum(K.round(y_pred))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def f1_score(y_true, y_pred):
    prec = precision(y_true, y_pred)
    rec = recall(y_true, y_pred)
    return 2*((prec*rec)/(prec+rec+K.epsilon()))

**Train - 80% Val - 20% of entire training data available**

In [27]:
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(train_input_ids, train_df['target'], train_size=0.8, random_state=100)
train_masks, validation_masks, _, _ = train_test_split(train_attention_masks, train_df['target'], train_size=0.8, random_state=100)

train_inputs = torch.tensor(train_inputs, dtype=torch.long)
validation_inputs = torch.tensor(validation_inputs, dtype=torch.long)
train_labels = torch.tensor(train_labels, dtype=torch.long)
validation_labels = torch.tensor(validation_labels.values, dtype=torch.long)
train_masks = torch.tensor(train_masks, dtype=torch.long)
validation_masks = torch.tensor(validation_masks, dtype=torch.long)

In [28]:
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SZ)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=BATCH_SZ)

In [29]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
model.cuda()

100%|██████████| 407873900/407873900 [00:11<00:00, 34326509.25B/s]


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
   

In [30]:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]

optimizer = BertAdam(optimizer_grouped_parameters,
                     lr=2e-5,
                     warmup=.1)

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

**Training**

In [31]:
train_loss_set = []
epochs = 4

for _ in range(epochs):  
  model.train()  
  tr_loss = 0
  nb_tr_examples, nb_tr_steps = 0, 0

  for step, batch in enumerate(train_dataloader):
    # Add batch to GPU
    batch = tuple(t.to(device) for t in batch)
    # Unpack the inputs from our dataloader
    b_input_ids, b_input_mask, b_labels = batch
    # Clear out the gradients (by default they accumulate)
    optimizer.zero_grad()
    # Forward pass
    loss = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
    train_loss_set.append(loss.item())    
    # Backward pass
    loss.backward()
    # Update parameters and take a step using the computed gradient
    optimizer.step()
    # Update tracking variables
    tr_loss += loss.item()
    nb_tr_examples += b_input_ids.size(0)
    nb_tr_steps += 1
  print("Train loss: {}".format(tr_loss/nb_tr_steps))

Train loss: 0.46157938663247994
Train loss: 0.2970094054703313
Train loss: 0.1845028461664135
Train loss: 0.11477488360041022


**Validation**

In [32]:
model.eval()

eval_accuracy = 0
nb_eval_steps = 0

for batch in validation_dataloader:
   # Add batch to GPU
   batch = tuple(t.to(device) for t in batch)
   # Unpack the inputs from our dataloader
   b_input_ids, b_input_mask, b_labels = batch
   # Telling the model not to compute or store gradients, saving memory and speeding up validation
   with torch.no_grad():
    # Forward pass, calculate logit predictions
    logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)    
    # Move logits and labels to CPU
    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    tmp_eval_accuracy = flat_accuracy(logits, label_ids)    
    eval_accuracy += tmp_eval_accuracy
    nb_eval_steps += 1
print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))

Validation Accuracy: 0.8159265350877193


**Prediction on test set to be evaluated**

In [33]:
test_df.drop(columns=['keyword', 'location'], inplace=True)
test_df['cleaned_text'] = np.vectorize(clean_text)(test_df['text'])
test_df['cleaned_text'] = np.vectorize(add_special_token)(test_df['cleaned_text'])
test_input_ids, test_attention_masks = generate_input_attention_mask(test_df['cleaned_text'])

In [34]:
test_inputs = torch.tensor(test_input_ids, dtype=torch.long)
test_attention = torch.tensor(test_attention_masks, dtype=torch.long)

test_data = TensorDataset(test_inputs, test_attention)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size = BATCH_SZ)

In [35]:
model.eval()

predictions = []
for batch in test_dataloader:
  batch = tuple(t.to(device) for t in batch)
  b_input_ids, b_mask = batch

  with torch.no_grad():
    logits = model(b_input_ids, token_type_ids=None, attention_mask = b_mask)

  logits = logits.detach().cpu().numpy()
  predictions.append(logits)

test_predictions = [item for sublist in predictions for item in sublist]
test_predictions = np.argmax(test_predictions, axis=1).flatten()

In [36]:
submission = pd.DataFrame({'id':test_df['id'], 'target':test_predictions})
submission.to_csv('./submission.csv', index=False)