In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random

import string
import re

import scipy
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

from sklearn import metrics
from sklearn.metrics import confusion_matrix,accuracy_score,roc_auc_score,roc_curve,auc,f1_score
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data

import torchsummary

import time
import os
import shutil

import transformers
from transformers import BertTokenizer, BertModel

from tqdm.notebook import tqdm
tqdm.pandas()

device="cuda"

# Load Datasets
1. Load the datasets
2. Verify no article bodies overlap

In [2]:
train_stances = pd.read_csv("./dataset/train_stances.csv")
train_bodies = pd.read_csv("./dataset/train_bodies.csv")

test_stances = pd.read_csv("./dataset/competition_test_stances.csv")
test_bodies = pd.read_csv("./dataset/competition_test_bodies.csv")

In [3]:
train_stances["Headline"].value_counts()

ISIL Beheads American Photojournalist in Iraq                                                      127
WHO says reports of suspected Ebola cases in Iraq are untrue                                       124
James Foley remembered as 'brave and tireless' journalist                                          121
Islamic Militants Post Video Claiming to Show Beheading of U.S. Journalist                         118
US officials: Video shows American's execution                                                     112
                                                                                                  ... 
Apple hopes to sell over 50 million watches in 2015                                                  2
Mom Calls 911 On Masturbating Teenage Son; Boy Arrested, Charged With New ‘Self-Rape’ State Law      2
Federal Judge: Enough With the Stupid Names                                                          2
Sushi lover's entire body left riddled with WORMS after eating contaminat

In [4]:
test_stances["Headline"].value_counts()

Source: Joan Rivers' doc did biopsy, selfie                                                                                                                  160
Joan Rivers Personal Doctor Allegedly Took A Selfie Before Her Biopsy And Cardiac Arrest                                                                     138
‘Three-boobed’ woman: They’re not fake                                                                                                                        92
Adopting Potential Werewolves Is Routine Business for Argentine Presidents                                                                                    92
Justin Bieber Basically Saves A Russian Man From A Bear                                                                                                       88
                                                                                                                                                            ... 
Madonna pledges oral sex for Clint

In [5]:
train_stances.loc[train_stances["Headline"].isin(test_stances["Headline"])]

Unnamed: 0,Headline,Body ID,Stance
173,Cheese addiction breaks Kim Jong-un's ankles,2210,unrelated
1415,WSJ: Apple cut watch health features due to er...,1917,discuss
1418,Cheese blamed for North Korean leader Kim Jong...,1689,unrelated
1479,Cheese blamed for North Korean leader Kim Jong...,186,unrelated
1503,Cheese addiction breaks Kim Jong-un's ankles,2329,unrelated
...,...,...,...
48270,Apple was forced to nix key health features fr...,407,discuss
48852,Cheese blamed for North Korean leader Kim Jong...,2042,unrelated
49279,Cheese addiction breaks Kim Jong-un's ankles,1854,discuss
49349,Cheese blamed for North Korean leader Kim Jong...,2344,unrelated


In [6]:
train_stances.loc[train_stances["Body ID"].isin(test_stances["Body ID"])]

Unnamed: 0,Headline,Body ID,Stance


So there is overlap of the headlines in both the training and the test data but there is no overlap with the bodies of the articles so we should split the training set into training and validation such that the bodies of the articles are disjoint.

In [7]:
test_df = test_stances.merge(test_bodies, on="Body ID")
test_df["Related"] = (test_df["Stance"] != "unrelated").astype(int)

In [8]:
test_df

Unnamed: 0,Headline,Body ID,Stance,articleBody,Related
0,Ferguson riots: Pregnant woman loses eye after...,2008,unrelated,A RESPECTED senior French police officer inves...,0
1,Apple Stores to install safes to secure gold A...,2008,unrelated,A RESPECTED senior French police officer inves...,0
2,Pregnant woman loses eye after police shoot be...,2008,unrelated,A RESPECTED senior French police officer inves...,0
3,We just found out the #Ferguson Protester who ...,2008,unrelated,A RESPECTED senior French police officer inves...,0
4,Police Chief In Charge of Paris Attacks Commit...,2008,discuss,A RESPECTED senior French police officer inves...,1
...,...,...,...,...,...
25408,A Sign That Obamacare Exchanges Are Failing,2586,disagree,Remember how much Republicans wanted to repeal...,1
25409,Republicans call Obamacare a 'failure.' These ...,2586,agree,Remember how much Republicans wanted to repeal...,1
25410,CBO’s Alternate Facts Show Obamacare is Unsust...,2586,disagree,Remember how much Republicans wanted to repeal...,1
25411,Why Obamacare failed,2586,disagree,Remember how much Republicans wanted to repeal...,1


In [9]:
train_and_val_df = train_stances.merge(train_bodies, on="Body ID")
train_and_val_df["Related"] = (train_and_val_df["Stance"] != "unrelated").astype(int)

# Exploratory Data Analysis
TODO: Data analysis

In [10]:
train_and_val_df

Unnamed: 0,Headline,Body ID,Stance,articleBody,Related
0,Police find mass graves with at least '15 bodi...,712,unrelated,Danny Boyle is directing the untitled film\r\n...,0
1,Seth Rogen to Play Apple’s Steve Wozniak,712,discuss,Danny Boyle is directing the untitled film\r\n...,1
2,Mexico police find mass grave near site 43 stu...,712,unrelated,Danny Boyle is directing the untitled film\r\n...,0
3,Mexico Says Missing Students Not Found In Firs...,712,unrelated,Danny Boyle is directing the untitled film\r\n...,0
4,New iOS 8 bug can delete all of your iCloud do...,712,unrelated,Danny Boyle is directing the untitled film\r\n...,0
...,...,...,...,...,...
49967,Amazon Is Opening a Brick-and-Mortar Store in ...,464,agree,"Amazon, the cyber store that sells everything,...",1
49968,Elon University has not banned the term ‘fresh...,362,agree,"ELON, N.C. – A recent rumor claims that Elon U...",1
49969,Fake BBC News website set up to carry Charlie ...,915,agree,A realistic-looking fake BBC News website has ...,1
49970,Apple was forced to nix key health features fr...,407,discuss,The health-focused smartwatch that Apple initi...,1


# Data Splitting
Now we need to split the training data into a training and validation set. We leave the test set untouched in this respect.

In [11]:
val_split_ratio = 0.2

In [12]:
def split_train_val(df, ratio):
    val_count = int(ratio * df["Body ID"].nunique())
    all_ids = list(df["Body ID"].unique())
    val_body_ids = random.sample(all_ids, val_count)
    train_body_ids = set(all_ids) - set(val_body_ids)
    
    assert len(set(val_body_ids) & train_body_ids) == 0
    
    val_df = df.loc[df["Body ID"].isin(val_body_ids)]
    train_df = df.loc[df["Body ID"].isin(train_body_ids)]
    
    return val_df, train_df

In [13]:
val_df, train_df = split_train_val(train_and_val_df, val_split_ratio)

In [14]:
val_df["Body ID"].nunique()

336

In [15]:
test_df["Body ID"].nunique()

904

In [16]:
train_df["Body ID"].nunique()

1347

In [17]:
def prepare_df(df):
    df = df.drop("Body ID", axis=1)
    df = df.reset_index()
    df = df.drop("index", axis=1)
    df["Related"] = df["Stance"] != "unrelated"
    return df

In [18]:
val_df = prepare_df(val_df)
train_df = prepare_df(train_df)
test_df = prepare_df(test_df)

In [19]:
val_df["Stance"].value_counts(normalize=True)#.plot(kind="bar")

unrelated    0.743208
discuss      0.163193
agree        0.077224
disagree     0.016375
Name: Stance, dtype: float64

In [20]:
train_df["Stance"].value_counts(normalize=True)#.plot(kind="bar")

unrelated    0.728049
discuss      0.182414
agree        0.072609
disagree     0.016928
Name: Stance, dtype: float64

In [21]:
test_df["Stance"].value_counts(normalize=True)#.plot(kind="bar")

unrelated    0.722032
discuss      0.175658
agree        0.074883
disagree     0.027427
Name: Stance, dtype: float64

# Clean the Data
Now we have the data, need to clean the data and extract the TF-IDF features. Basic things to consider doing:
* Remove punctuation
* Remove URLs
* Remove HTML
* Remove numbers
* Remove emojis
* Convert to lowercase

And we should also:
* Tokenise
* Remove stopwords
* Lemmatisation or Stemming

In [22]:
# Most of this from the first practical
additional_specials = ["—", "”", "“", "’", "‘"]

def remove_excess_whitespace(text):
    text = text.replace("\n", " ")
    text = text.replace("\t", " ")
    text = text.replace("\r", " ")
    text = text.strip()
    return text

def remove_punctuation(text):
    punc = str.maketrans('', '', string.punctuation)
    text = text.translate(punc)
    
    for special in additional_specials:
        text = text.replace(special, "")
    
    return text

def remove_urls(text):
    url = re.compile(r'https?://\S+|www\.\S+')
    return url.sub('', text)

def remove_html(text):
    html = re.compile(r'<.*?>')
    return html.sub('', text)

def remove_numbers(text):
    numbers = re.compile(r'\d+')
    return numbers.sub('', text)

def remove_emojis(text):
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F"  # emoticons
                           u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                           u"\U0001F680-\U0001F6FF"  # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    
    return emoji_pattern.sub(r'', text)

def apply_cleaning(text, excess=True, punc=True, urls=True, html=True, numbers=True, emojis=True, lower=True):
    if excess:
        text = " ".join(text.split())
        
    if punc:
        text = remove_punctuation(text)
    
    if urls:
        text = remove_urls(text)
    
    if html:
        text = remove_html(text)
    
    if numbers:
        text = remove_numbers(text)
        
    if emojis:
        text = remove_emojis(text)
        
    if lower:
        text = text.lower()
    
    return text

In [23]:
config_remove_excess_whitespace = True
config_remove_punctuation = False
config_remove_urls = True
config_remove_html = True
config_remove_numbers = False
config_remove_emojis = True
config_convert_to_lowercase = False

In [24]:
def process_text(text):
    text = apply_cleaning(
        text, 
        excess=config_remove_excess_whitespace, 
        punc=config_remove_punctuation, 
        urls=config_remove_urls, 
        html=config_remove_html, 
        numbers=config_remove_numbers, 
        emojis=config_remove_emojis, 
        lower=config_convert_to_lowercase
    )
    
    return text

In [25]:
test_text = train_df.iloc[812]["articleBody"]
print("Unprocessed:")
print(test_text)
print()
test_processed = process_text(test_text)
print("Processed:")
print(test_processed)

Unprocessed:
(CNN) -- Boko Haram laughed off Nigeria's announcement of a ceasefire agreement, saying there is no such deal and schoolgirls abducted in spring have been converted to Islam and married off.

Nigerian officials announced two weeks ago that they had struck a deal with the Islamist terror group.

The deal, the government said, included the release of more than 200 girls whose kidnapping in April at a boarding school in the nation's north stunned the world.

In a video released Saturday, the Islamist group's notorious leader fired off a series of denials.

"Don't you know the over 200 Chibok schoolgirls have converted to Islam?" Abubakar Shekau said. "They have now memorized two chapters of the Quran."

Shekau slammed reports of their planned release.

"We married them off. They are in their marital homes," he said, chuckling.

The group's leader also denied knowing the negotiator with whom the government claimed it worked out a deal, saying he does not represent Boko Haram.


In [26]:
train_df

Unnamed: 0,Headline,Stance,articleBody,Related
0,HBO and Apple in Talks for $15/Month Apple TV ...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False
1,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False
2,Elderly Woman Arrested for Kidnapping Neighbor...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False
3,Kim Jong-Un 'bans the name Kim Jong-un',unrelated,(Reuters) - A Canadian soldier was shot at the...,False
4,Two blokes dared to eat 20-year-old burger for...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False
...,...,...,...,...
39219,“Willie Nelson dead 2015” : Guitarist killed b...,agree,News of guitarist Willie Nelson’s death spread...,True
39220,Elon University has not banned the term ‘fresh...,agree,"ELON, N.C. – A recent rumor claims that Elon U...",True
39221,Fake BBC News website set up to carry Charlie ...,agree,A realistic-looking fake BBC News website has ...,True
39222,Apple was forced to nix key health features fr...,discuss,The health-focused smartwatch that Apple initi...,True


In [27]:
train_df["Processed Headline"] = train_df["Headline"].progress_apply(process_text)
train_df["Processed Body"] = train_df["articleBody"].progress_apply(process_text)

  0%|          | 0/39224 [00:00<?, ?it/s]

  0%|          | 0/39224 [00:00<?, ?it/s]

In [28]:
train_df

Unnamed: 0,Headline,Stance,articleBody,Related,Processed Headline,Processed Body
0,HBO and Apple in Talks for $15/Month Apple TV ...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,HBO and Apple in Talks for $15/Month Apple TV ...,(Reuters) - A Canadian soldier was shot at the...
1,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,(Reuters) - A Canadian soldier was shot at the...
2,Elderly Woman Arrested for Kidnapping Neighbor...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Elderly Woman Arrested for Kidnapping Neighbor...,(Reuters) - A Canadian soldier was shot at the...
3,Kim Jong-Un 'bans the name Kim Jong-un',unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Kim Jong-Un 'bans the name Kim Jong-un',(Reuters) - A Canadian soldier was shot at the...
4,Two blokes dared to eat 20-year-old burger for...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Two blokes dared to eat 20-year-old burger for...,(Reuters) - A Canadian soldier was shot at the...
...,...,...,...,...,...,...
39219,“Willie Nelson dead 2015” : Guitarist killed b...,agree,News of guitarist Willie Nelson’s death spread...,True,“Willie Nelson dead 2015” : Guitarist killed b...,News of guitarist Willie Nelson’s death spread...
39220,Elon University has not banned the term ‘fresh...,agree,"ELON, N.C. – A recent rumor claims that Elon U...",True,Elon University has not banned the term ‘fresh...,"ELON, N.C. – A recent rumor claims that Elon U..."
39221,Fake BBC News website set up to carry Charlie ...,agree,A realistic-looking fake BBC News website has ...,True,Fake BBC News website set up to carry Charlie ...,A realistic-looking fake BBC News website has ...
39222,Apple was forced to nix key health features fr...,discuss,The health-focused smartwatch that Apple initi...,True,Apple was forced to nix key health features fr...,The health-focused smartwatch that Apple initi...


In [29]:
val_df["Processed Headline"] = val_df["Headline"].progress_apply(process_text)
val_df["Processed Body"] = val_df["articleBody"].progress_apply(process_text)

  0%|          | 0/10748 [00:00<?, ?it/s]

  0%|          | 0/10748 [00:00<?, ?it/s]

In [30]:
val_df

Unnamed: 0,Headline,Stance,articleBody,Related,Processed Headline,Processed Body
0,Police find mass graves with at least '15 bodi...,unrelated,Danny Boyle is directing the untitled film\r\n...,False,Police find mass graves with at least '15 bodi...,Danny Boyle is directing the untitled film Set...
1,Seth Rogen to Play Apple’s Steve Wozniak,discuss,Danny Boyle is directing the untitled film\r\n...,True,Seth Rogen to Play Apple’s Steve Wozniak,Danny Boyle is directing the untitled film Set...
2,Mexico police find mass grave near site 43 stu...,unrelated,Danny Boyle is directing the untitled film\r\n...,False,Mexico police find mass grave near site 43 stu...,Danny Boyle is directing the untitled film Set...
3,Mexico Says Missing Students Not Found In Firs...,unrelated,Danny Boyle is directing the untitled film\r\n...,False,Mexico Says Missing Students Not Found In Firs...,Danny Boyle is directing the untitled film Set...
4,New iOS 8 bug can delete all of your iCloud do...,unrelated,Danny Boyle is directing the untitled film\r\n...,False,New iOS 8 bug can delete all of your iCloud do...,Danny Boyle is directing the untitled film Set...
...,...,...,...,...,...,...
10743,US probing claims ISIS fighters seized airdrop...,discuss,Syrian activists claim militants from the Isla...,True,US probing claims ISIS fighters seized airdrop...,Syrian activists claim militants from the Isla...
10744,ISIS fighters seize weapons airdrop meant for ...,discuss,Syrian activists claim militants from the Isla...,True,ISIS fighters seize weapons airdrop meant for ...,Syrian activists claim militants from the Isla...
10745,Heartbroken girl spends week in KFC after gett...,discuss,A Chinese woman spent an entire week inside a ...,True,Heartbroken girl spends week in KFC after gett...,A Chinese woman spent an entire week inside a ...
10746,"Comfort eating? Chinese woman, 26, spends an e...",discuss,A Chinese woman spent an entire week inside a ...,True,"Comfort eating? Chinese woman, 26, spends an e...",A Chinese woman spent an entire week inside a ...


In [31]:
test_df["Processed Headline"] = test_df["Headline"].progress_apply(process_text)
test_df["Processed Body"] = test_df["articleBody"].progress_apply(process_text)

  0%|          | 0/25413 [00:00<?, ?it/s]

  0%|          | 0/25413 [00:00<?, ?it/s]

In [32]:
test_df

Unnamed: 0,Headline,Stance,articleBody,Related,Processed Headline,Processed Body
0,Ferguson riots: Pregnant woman loses eye after...,unrelated,A RESPECTED senior French police officer inves...,False,Ferguson riots: Pregnant woman loses eye after...,A RESPECTED senior French police officer inves...
1,Apple Stores to install safes to secure gold A...,unrelated,A RESPECTED senior French police officer inves...,False,Apple Stores to install safes to secure gold A...,A RESPECTED senior French police officer inves...
2,Pregnant woman loses eye after police shoot be...,unrelated,A RESPECTED senior French police officer inves...,False,Pregnant woman loses eye after police shoot be...,A RESPECTED senior French police officer inves...
3,We just found out the #Ferguson Protester who ...,unrelated,A RESPECTED senior French police officer inves...,False,We just found out the #Ferguson Protester who ...,A RESPECTED senior French police officer inves...
4,Police Chief In Charge of Paris Attacks Commit...,discuss,A RESPECTED senior French police officer inves...,True,Police Chief In Charge of Paris Attacks Commit...,A RESPECTED senior French police officer inves...
...,...,...,...,...,...,...
25408,A Sign That Obamacare Exchanges Are Failing,disagree,Remember how much Republicans wanted to repeal...,True,A Sign That Obamacare Exchanges Are Failing,Remember how much Republicans wanted to repeal...
25409,Republicans call Obamacare a 'failure.' These ...,agree,Remember how much Republicans wanted to repeal...,True,Republicans call Obamacare a 'failure.' These ...,Remember how much Republicans wanted to repeal...
25410,CBO’s Alternate Facts Show Obamacare is Unsust...,disagree,Remember how much Republicans wanted to repeal...,True,CBO’s Alternate Facts Show Obamacare is Unsust...,Remember how much Republicans wanted to repeal...
25411,Why Obamacare failed,disagree,Remember how much Republicans wanted to repeal...,True,Why Obamacare failed,Remember how much Republicans wanted to repeal...


# BERT Tokeniser

In [33]:
selected_model = "bert-base-uncased"

In [34]:
tokeniser = BertTokenizer.from_pretrained(selected_model)

In [35]:
test_processed

'(CNN) -- Boko Haram laughed off Nigeria\'s announcement of a ceasefire agreement, saying there is no such deal and schoolgirls abducted in spring have been converted to Islam and married off. Nigerian officials announced two weeks ago that they had struck a deal with the Islamist terror group. The deal, the government said, included the release of more than 200 girls whose kidnapping in April at a boarding school in the nation\'s north stunned the world. In a video released Saturday, the Islamist group\'s notorious leader fired off a series of denials. "Don\'t you know the over 200 Chibok schoolgirls have converted to Islam?" Abubakar Shekau said. "They have now memorized two chapters of the Quran." Shekau slammed reports of their planned release. "We married them off. They are in their marital homes," he said, chuckling. The group\'s leader also denied knowing the negotiator with whom the government claimed it worked out a deal, saying he does not represent Boko Haram. "We will not s

In [36]:
test_processed_tokens = tokeniser.tokenize(test_processed)
print(test_processed_tokens)

['(', 'cnn', ')', '-', '-', 'bo', '##ko', 'hara', '##m', 'laughed', 'off', 'nigeria', "'", 's', 'announcement', 'of', 'a', 'ceasefire', 'agreement', ',', 'saying', 'there', 'is', 'no', 'such', 'deal', 'and', 'school', '##girl', '##s', 'abducted', 'in', 'spring', 'have', 'been', 'converted', 'to', 'islam', 'and', 'married', 'off', '.', 'nigerian', 'officials', 'announced', 'two', 'weeks', 'ago', 'that', 'they', 'had', 'struck', 'a', 'deal', 'with', 'the', 'islamist', 'terror', 'group', '.', 'the', 'deal', ',', 'the', 'government', 'said', ',', 'included', 'the', 'release', 'of', 'more', 'than', '200', 'girls', 'whose', 'kidnapping', 'in', 'april', 'at', 'a', 'boarding', 'school', 'in', 'the', 'nation', "'", 's', 'north', 'stunned', 'the', 'world', '.', 'in', 'a', 'video', 'released', 'saturday', ',', 'the', 'islamist', 'group', "'", 's', 'notorious', 'leader', 'fired', 'off', 'a', 'series', 'of', 'denial', '##s', '.', '"', 'don', "'", 't', 'you', 'know', 'the', 'over', '200', 'chi', '##

In [37]:
test_processed_indexes = tokeniser.convert_tokens_to_ids(test_processed_tokens)
print(test_processed_indexes)

[1006, 13229, 1007, 1011, 1011, 8945, 3683, 18820, 2213, 4191, 2125, 7387, 1005, 1055, 8874, 1997, 1037, 26277, 3820, 1010, 3038, 2045, 2003, 2053, 2107, 3066, 1998, 2082, 15239, 2015, 20361, 1999, 3500, 2031, 2042, 4991, 2000, 7025, 1998, 2496, 2125, 1012, 11884, 4584, 2623, 2048, 3134, 3283, 2008, 2027, 2018, 4930, 1037, 3066, 2007, 1996, 27256, 7404, 2177, 1012, 1996, 3066, 1010, 1996, 2231, 2056, 1010, 2443, 1996, 2713, 1997, 2062, 2084, 3263, 3057, 3005, 15071, 1999, 2258, 2012, 1037, 9405, 2082, 1999, 1996, 3842, 1005, 1055, 2167, 9860, 1996, 2088, 1012, 1999, 1037, 2678, 2207, 5095, 1010, 1996, 27256, 2177, 1005, 1055, 12536, 3003, 5045, 2125, 1037, 2186, 1997, 14920, 2015, 1012, 1000, 2123, 1005, 1056, 2017, 2113, 1996, 2058, 3263, 9610, 5092, 2243, 2082, 15239, 2015, 2031, 4991, 2000, 7025, 1029, 1000, 8273, 3676, 6673, 2016, 2912, 2226, 2056, 1012, 1000, 2027, 2031, 2085, 24443, 18425, 2048, 9159, 1997, 1996, 21288, 1012, 1000, 2016, 2912, 2226, 7549, 4311, 1997, 2037, 3740, 

In [38]:
max_input_length = tokeniser.max_model_input_sizes[selected_model]
max_input_length

512

In [39]:
test_processed_encoded = tokeniser.encode(test_processed[:max_input_length])
test_processed_encoded_rev = tokeniser.convert_ids_to_tokens(test_processed_encoded)
print(test_processed_encoded_rev)

['[CLS]', '(', 'cnn', ')', '-', '-', 'bo', '##ko', 'hara', '##m', 'laughed', 'off', 'nigeria', "'", 's', 'announcement', 'of', 'a', 'ceasefire', 'agreement', ',', 'saying', 'there', 'is', 'no', 'such', 'deal', 'and', 'school', '##girl', '##s', 'abducted', 'in', 'spring', 'have', 'been', 'converted', 'to', 'islam', 'and', 'married', 'off', '.', 'nigerian', 'officials', 'announced', 'two', 'weeks', 'ago', 'that', 'they', 'had', 'struck', 'a', 'deal', 'with', 'the', 'islamist', 'terror', 'group', '.', 'the', 'deal', ',', 'the', 'government', 'said', ',', 'included', 'the', 'release', 'of', 'more', 'than', '200', 'girls', 'whose', 'kidnapping', 'in', 'april', 'at', 'a', 'boarding', 'school', 'in', 'the', 'nation', "'", 's', 'north', 'stunned', 'the', 'world', '.', 'in', 'a', 'video', 'released', 'saturday', ',', 'the', 'islamist', 'group', "'", 's', 'not', '##o', '[SEP]']


In [40]:
tokeniser(test_processed, truncation=True)

{'input_ids': [101, 1006, 13229, 1007, 1011, 1011, 8945, 3683, 18820, 2213, 4191, 2125, 7387, 1005, 1055, 8874, 1997, 1037, 26277, 3820, 1010, 3038, 2045, 2003, 2053, 2107, 3066, 1998, 2082, 15239, 2015, 20361, 1999, 3500, 2031, 2042, 4991, 2000, 7025, 1998, 2496, 2125, 1012, 11884, 4584, 2623, 2048, 3134, 3283, 2008, 2027, 2018, 4930, 1037, 3066, 2007, 1996, 27256, 7404, 2177, 1012, 1996, 3066, 1010, 1996, 2231, 2056, 1010, 2443, 1996, 2713, 1997, 2062, 2084, 3263, 3057, 3005, 15071, 1999, 2258, 2012, 1037, 9405, 2082, 1999, 1996, 3842, 1005, 1055, 2167, 9860, 1996, 2088, 1012, 1999, 1037, 2678, 2207, 5095, 1010, 1996, 27256, 2177, 1005, 1055, 12536, 3003, 5045, 2125, 1037, 2186, 1997, 14920, 2015, 1012, 1000, 2123, 1005, 1056, 2017, 2113, 1996, 2058, 3263, 9610, 5092, 2243, 2082, 15239, 2015, 2031, 4991, 2000, 7025, 1029, 1000, 8273, 3676, 6673, 2016, 2912, 2226, 2056, 1012, 1000, 2027, 2031, 2085, 24443, 18425, 2048, 9159, 1997, 1996, 21288, 1012, 1000, 2016, 2912, 2226, 7549, 4311,

In [41]:
test_headline = train_df.iloc[812]["Processed Headline"]
test_body = train_df.iloc[812]["Processed Body"]

In [42]:
test_headline

'Reports: Jihadists Steal Commercial Jets, Raise 9/11 Fears'

In [43]:
test_body

'(CNN) -- Boko Haram laughed off Nigeria\'s announcement of a ceasefire agreement, saying there is no such deal and schoolgirls abducted in spring have been converted to Islam and married off. Nigerian officials announced two weeks ago that they had struck a deal with the Islamist terror group. The deal, the government said, included the release of more than 200 girls whose kidnapping in April at a boarding school in the nation\'s north stunned the world. In a video released Saturday, the Islamist group\'s notorious leader fired off a series of denials. "Don\'t you know the over 200 Chibok schoolgirls have converted to Islam?" Abubakar Shekau said. "They have now memorized two chapters of the Quran." Shekau slammed reports of their planned release. "We married them off. They are in their marital homes," he said, chuckling. The group\'s leader also denied knowing the negotiator with whom the government claimed it worked out a deal, saying he does not represent Boko Haram. "We will not s

In [44]:
test_concat_ids = tokeniser(test_headline, test_body, truncation="longest_first", padding="max_length")["input_ids"]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


In [45]:
test_concat_tokens = tokeniser.convert_ids_to_tokens(test_concat_ids)
print(test_concat_tokens)

['[CLS]', 'reports', ':', 'jihad', '##ists', 'steal', 'commercial', 'jets', ',', 'raise', '9', '/', '11', 'fears', '[SEP]', '(', 'cnn', ')', '-', '-', 'bo', '##ko', 'hara', '##m', 'laughed', 'off', 'nigeria', "'", 's', 'announcement', 'of', 'a', 'ceasefire', 'agreement', ',', 'saying', 'there', 'is', 'no', 'such', 'deal', 'and', 'school', '##girl', '##s', 'abducted', 'in', 'spring', 'have', 'been', 'converted', 'to', 'islam', 'and', 'married', 'off', '.', 'nigerian', 'officials', 'announced', 'two', 'weeks', 'ago', 'that', 'they', 'had', 'struck', 'a', 'deal', 'with', 'the', 'islamist', 'terror', 'group', '.', 'the', 'deal', ',', 'the', 'government', 'said', ',', 'included', 'the', 'release', 'of', 'more', 'than', '200', 'girls', 'whose', 'kidnapping', 'in', 'april', 'at', 'a', 'boarding', 'school', 'in', 'the', 'nation', "'", 's', 'north', 'stunned', 'the', 'world', '.', 'in', 'a', 'video', 'released', 'saturday', ',', 'the', 'islamist', 'group', "'", 's', 'notorious', 'leader', 'fire

In [46]:
def concated_headline_body_tokens(headline, body):
    concated_ids = tokeniser(headline, body, truncation="longest_first", padding="max_length", return_tensors="pt")["input_ids"]
    return concated_ids

In [47]:
train_df

Unnamed: 0,Headline,Stance,articleBody,Related,Processed Headline,Processed Body
0,HBO and Apple in Talks for $15/Month Apple TV ...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,HBO and Apple in Talks for $15/Month Apple TV ...,(Reuters) - A Canadian soldier was shot at the...
1,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,(Reuters) - A Canadian soldier was shot at the...
2,Elderly Woman Arrested for Kidnapping Neighbor...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Elderly Woman Arrested for Kidnapping Neighbor...,(Reuters) - A Canadian soldier was shot at the...
3,Kim Jong-Un 'bans the name Kim Jong-un',unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Kim Jong-Un 'bans the name Kim Jong-un',(Reuters) - A Canadian soldier was shot at the...
4,Two blokes dared to eat 20-year-old burger for...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Two blokes dared to eat 20-year-old burger for...,(Reuters) - A Canadian soldier was shot at the...
...,...,...,...,...,...,...
39219,“Willie Nelson dead 2015” : Guitarist killed b...,agree,News of guitarist Willie Nelson’s death spread...,True,“Willie Nelson dead 2015” : Guitarist killed b...,News of guitarist Willie Nelson’s death spread...
39220,Elon University has not banned the term ‘fresh...,agree,"ELON, N.C. – A recent rumor claims that Elon U...",True,Elon University has not banned the term ‘fresh...,"ELON, N.C. – A recent rumor claims that Elon U..."
39221,Fake BBC News website set up to carry Charlie ...,agree,A realistic-looking fake BBC News website has ...,True,Fake BBC News website set up to carry Charlie ...,A realistic-looking fake BBC News website has ...
39222,Apple was forced to nix key health features fr...,discuss,The health-focused smartwatch that Apple initi...,True,Apple was forced to nix key health features fr...,The health-focused smartwatch that Apple initi...


In [48]:
train_df

Unnamed: 0,Headline,Stance,articleBody,Related,Processed Headline,Processed Body
0,HBO and Apple in Talks for $15/Month Apple TV ...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,HBO and Apple in Talks for $15/Month Apple TV ...,(Reuters) - A Canadian soldier was shot at the...
1,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,(Reuters) - A Canadian soldier was shot at the...
2,Elderly Woman Arrested for Kidnapping Neighbor...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Elderly Woman Arrested for Kidnapping Neighbor...,(Reuters) - A Canadian soldier was shot at the...
3,Kim Jong-Un 'bans the name Kim Jong-un',unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Kim Jong-Un 'bans the name Kim Jong-un',(Reuters) - A Canadian soldier was shot at the...
4,Two blokes dared to eat 20-year-old burger for...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Two blokes dared to eat 20-year-old burger for...,(Reuters) - A Canadian soldier was shot at the...
...,...,...,...,...,...,...
39219,“Willie Nelson dead 2015” : Guitarist killed b...,agree,News of guitarist Willie Nelson’s death spread...,True,“Willie Nelson dead 2015” : Guitarist killed b...,News of guitarist Willie Nelson’s death spread...
39220,Elon University has not banned the term ‘fresh...,agree,"ELON, N.C. – A recent rumor claims that Elon U...",True,Elon University has not banned the term ‘fresh...,"ELON, N.C. – A recent rumor claims that Elon U..."
39221,Fake BBC News website set up to carry Charlie ...,agree,A realistic-looking fake BBC News website has ...,True,Fake BBC News website set up to carry Charlie ...,A realistic-looking fake BBC News website has ...
39222,Apple was forced to nix key health features fr...,discuss,The health-focused smartwatch that Apple initi...,True,Apple was forced to nix key health features fr...,The health-focused smartwatch that Apple initi...


In [49]:
transformers.logging.set_verbosity_error()
train_df["Transformer Input"] = train_df.progress_apply(lambda row: concated_headline_body_tokens(row["Processed Headline"], row["Processed Body"]), axis=1)
transformers.logging.set_verbosity_warning()

  0%|          | 0/39224 [00:00<?, ?it/s]

In [50]:
train_df

Unnamed: 0,Headline,Stance,articleBody,Related,Processed Headline,Processed Body,Transformer Input
0,HBO and Apple in Talks for $15/Month Apple TV ...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,HBO and Apple in Talks for $15/Month Apple TV ...,(Reuters) - A Canadian soldier was shot at the...,"[[tensor(101), tensor(14633), tensor(1998), te..."
1,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,ISLAMIC STATE BEHEADS MISSING AMERICAN JOURNAL...,(Reuters) - A Canadian soldier was shot at the...,"[[tensor(101), tensor(5499), tensor(2110), ten..."
2,Elderly Woman Arrested for Kidnapping Neighbor...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Elderly Woman Arrested for Kidnapping Neighbor...,(Reuters) - A Canadian soldier was shot at the...,"[[tensor(101), tensor(9750), tensor(2450), ten..."
3,Kim Jong-Un 'bans the name Kim Jong-un',unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Kim Jong-Un 'bans the name Kim Jong-un',(Reuters) - A Canadian soldier was shot at the...,"[[tensor(101), tensor(5035), tensor(18528), te..."
4,Two blokes dared to eat 20-year-old burger for...,unrelated,(Reuters) - A Canadian soldier was shot at the...,False,Two blokes dared to eat 20-year-old burger for...,(Reuters) - A Canadian soldier was shot at the...,"[[tensor(101), tensor(2048), tensor(1038), ten..."
...,...,...,...,...,...,...,...
39219,“Willie Nelson dead 2015” : Guitarist killed b...,agree,News of guitarist Willie Nelson’s death spread...,True,“Willie Nelson dead 2015” : Guitarist killed b...,News of guitarist Willie Nelson’s death spread...,"[[tensor(101), tensor(1523), tensor(9893), ten..."
39220,Elon University has not banned the term ‘fresh...,agree,"ELON, N.C. – A recent rumor claims that Elon U...",True,Elon University has not banned the term ‘fresh...,"ELON, N.C. – A recent rumor claims that Elon U...","[[tensor(101), tensor(3449), tensor(2239), ten..."
39221,Fake BBC News website set up to carry Charlie ...,agree,A realistic-looking fake BBC News website has ...,True,Fake BBC News website set up to carry Charlie ...,A realistic-looking fake BBC News website has ...,"[[tensor(101), tensor(8275), tensor(4035), ten..."
39222,Apple was forced to nix key health features fr...,discuss,The health-focused smartwatch that Apple initi...,True,Apple was forced to nix key health features fr...,The health-focused smartwatch that Apple initi...,"[[tensor(101), tensor(6207), tensor(2001), ten..."


In [51]:
transformers.logging.set_verbosity_error()
val_df["Transformer Input"] = val_df.progress_apply(lambda row: concated_headline_body_tokens(row["Processed Headline"], row["Processed Body"]), axis=1)
transformers.logging.set_verbosity_warning()

  0%|          | 0/10748 [00:00<?, ?it/s]

In [52]:
bert = BertModel.from_pretrained(selected_model)
bert = bert.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [53]:
test_row = train_df.iloc[812]
test_row_tokens = test_row["Transformer Input"].to(device)
test_row_tokens.shape

torch.Size([1, 512])

In [54]:
bert_embedded_test = bert(test_row_tokens)[0]

In [55]:
bert_embedded_test.shape
# So the CLS token output is at [i][0] for input i

torch.Size([1, 512, 768])

In [56]:
m = lambda x: x #-1 if x == 0 else 1
train_labels = np.array([m(x) for x in train_df["Related"].values.astype(int)])# [:, np.newaxis]
train_labels_tensor = torch.LongTensor(train_labels).unsqueeze(1)
train_labels_tensor.shape

torch.Size([39224, 1])

In [57]:
val_labels = np.array([m(x) for x in val_df["Related"].values.astype(int)])# [:, np.newaxis]
val_labels_tensor = torch.LongTensor(val_labels).unsqueeze(1)
val_labels_tensor.shape

torch.Size([10748, 1])

In [58]:
train_transformer_token_ids = torch.concat(list(train_df["Transformer Input"].values))

In [59]:
train_transformer_token_ids.shape

torch.Size([39224, 512])

In [60]:
val_transformer_token_ids = torch.concat(list(val_df["Transformer Input"].values))

In [61]:
batch_size = 32
apply_loss_weighting = False

In [62]:
train_dataset = torch.utils.data.TensorDataset(train_transformer_token_ids, train_labels_tensor)
train_dataloader_enc = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [63]:
unique_class_labels = np.unique(train_labels)
class_weights = compute_class_weight("balanced", classes=unique_class_labels, y=train_labels)
class_weights = torch.from_numpy(class_weights)
class_weights

tensor([0.6868, 1.8386], dtype=torch.float64)

In [64]:
val_dataset = torch.utils.data.TensorDataset(val_transformer_token_ids, val_labels_tensor)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Standard ML Models

In [74]:
# Create the confusion matrix - from Practical 1
def plot_confusion_matrix(y_test, y_pred):
    ''' Plot the confusion matrix for the target labels and predictions '''
    cm = confusion_matrix(y_test, y_pred)

    # Create a dataframe with the confusion matrix values
    df_cm = pd.DataFrame(cm, range(cm.shape[0]),
                  range(cm.shape[1]))

    # Plot the confusion matrix
    sns.set(font_scale=1.4) # for label size
    sns.heatmap(df_cm, annot=True,fmt='.0f',cmap="YlGnBu",annot_kws={"size": 10}) # font size
    plt.show()

In [75]:
# ROC Curve
# plot no skill
# Calculate the points in the ROC curve - from Practical 1
def plot_roc_curve(y_test, y_pred):
    ''' Plot the ROC curve for the target labels and predictions'''
    fpr, tpr, thresholds = roc_curve(y_test, y_pred, pos_label=1)
    roc_auc= auc(fpr,tpr)
    plt.figure(figsize=(12, 12))
    ax = plt.subplot(121)
    ax.set_aspect(1)
    
    plt.title('Receiver Operating Characteristic')
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

## Need to extract BERT CLS tokens for these

In [76]:
train_bert_df = pd.DataFrame({
    "CLS Output": pd.Series(dtype=object),
    "Related": pd.Series(dtype=bool)
})

with torch.no_grad():
    for tokens, labels in tqdm(train_dataloader_enc):
        tokens = tokens.to(device)
        cls_outputs = bert(tokens)[0].detach().cpu()
        rows = []
        
        for cls_output, label in zip(cls_outputs, labels):
            rows.append({"CLS Output": cls_output.numpy(), "Related": label.item() == 1})
        
        train_bert_df = pd.concat([train_bert_df, pd.DataFrame(rows)], axis=0, ignore_index=True)

train_bert_df = train_bert_df.sample(frac=1).reset_index(drop=True)

  0%|          | 0/1226 [00:00<?, ?it/s]

In [77]:
val_bert_df = pd.DataFrame({
    "CLS Output": pd.Series(dtype=object),
    "Related": pd.Series(dtype=bool)
})

with torch.no_grad():
    for tokens, labels in tqdm(val_dataloader):
        tokens = tokens.to(device)
        cls_outputs = bert(tokens)[0].detach().cpu()
        rows = []
        
        for cls_output, label in zip(cls_outputs, labels):
            rows.append({"CLS Output": cls_output.numpy(), "Related": label.item() == 1})
        
        val_bert_df = pd.concat([val_bert_df, pd.DataFrame(rows)], axis=0, ignore_index=True)

val_bert_df = val_bert_df.sample(frac=1).reset_index(drop=True)

  0%|          | 0/336 [00:00<?, ?it/s]

In [78]:
train_bert_df

Unnamed: 0,CLS Output,Related
0,"[[-0.34326658, 0.36517367, -0.5840294, 0.11459...",False
1,"[[0.016007526, -0.21476755, -0.048804443, -0.3...",False
2,"[[-0.53133136, -0.21920921, 0.2819307, -0.1196...",False
3,"[[-0.39362964, -0.29681164, 0.42485055, -0.071...",False
4,"[[-0.2369354, 0.10395724, -0.23445871, 0.06743...",False
...,...,...
39219,"[[-0.31434163, -1.0140011, -0.22602957, 0.2552...",False
39220,"[[-0.46506372, 0.23114684, -0.14275485, 0.2136...",False
39221,"[[-0.7953826, 0.26478225, -0.17036456, 0.31143...",False
39222,"[[-0.99910885, 0.3724223, -0.509763, 0.2427289...",False


In [79]:
val_bert_df

Unnamed: 0,CLS Output,Related
0,"[[-0.57361656, -0.18418948, 0.24991547, 0.3978...",False
1,"[[-0.13854364, -0.1367868, 0.4549456, 0.289825...",False
2,"[[-0.6371334, -0.58085626, -0.7009059, -0.0713...",True
3,"[[-0.5669046, 0.37112597, -0.27369532, 0.07890...",True
4,"[[-0.14785111, 0.34502283, -0.50307065, 0.4218...",False
...,...,...
10743,"[[-0.15648781, -0.23415855, -0.058399748, 0.09...",False
10744,"[[-0.60972524, -0.64969605, -0.46111962, 0.140...",False
10745,"[[-0.4086986, 0.27373657, -0.3493088, -0.01856...",False
10746,"[[-0.7675553, 0.13341556, -0.6063807, 0.110620...",True


In [80]:
train_bert_features = np.array(list(train_bert_df["CLS Output"].values))
train_bert_labels = np.array([m(x) for x in train_bert_df["Related"].values.astype(int)])

MemoryError: Unable to allocate 57.5 GiB for an array with shape (39224, 512, 768) and data type float32

In [None]:
val_bert_features = np.array(list(val_bert_df["CLS Output"].values))
val_bert_labels = np.array([m(x) for x in val_bert_df["Related"].values.astype(int)])

In [None]:
train_bert_features.shape

In [None]:
train_bert_labels.shape

## Gradient Boost

In [None]:
from lightgbm import LGBMClassifier
from sklearn.metrics import f1_score

def f1_metric(ytrue,preds):
    ''' Return the F1 Score value for the preds and true values, ytrue '''
    return 'f1_score', f1_score((preds>=0.5).astype('int'), ytrue, average='macro'), True

# set the model parameters
params = {
    'learning_rate': 0.06,
    'n_estimators': 1500,
    'colsample_bytree': 0.5,
    'metric': 'f1_score'
}

full_clf = LGBMClassifier(**params)



# Fit or train the xgboost model
full_clf.fit(train_bert_features.astype(np.float32), train_bert_labels, eval_set=[(train_bert_features.astype(np.float32), train_bert_labels), (val_bert_features.astype(np.float32), val_bert_labels)],
             verbose=400, eval_metric=f1_metric)

#Show the results
print("train score:", full_clf.score(train_bert_features.astype(np.float32), train_bert_labels))
print("val score:", full_clf.score(val_bert_features.astype(np.float32), val_bert_labels))

In [None]:
# Predicting the Test set results
y_pred = full_clf.predict(val_bert_features.astype(np.float32))

print(metrics.classification_report(val_bert_labels, y_pred))
plot_confusion_matrix(val_bert_labels, y_pred)
plot_roc_curve(val_bert_labels, y_pred)

# Everything here is from the Practical

In [66]:
class BERTGRUSentiment(nn.Module):
    def __init__(self,
                 bert,
                 hidden_dim,
                 output_dim,
                 n_layers,
                 bidirectional,
                 dropout):
        
        super().__init__()
        
        self.bert = bert
        
        embedding_dim = bert.config.to_dict()['hidden_size']
        
        self.rnn = nn.GRU(embedding_dim,
                          hidden_dim,
                          num_layers = n_layers,
                          bidirectional = bidirectional,
                          batch_first = True,
                          dropout = 0 if n_layers < 2 else dropout)
        
        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, ids):
        with torch.no_grad():
            embedded = self.bert(ids)[0]
        
        print(embedded.shape)
        
        _, hidden = self.rnn(embedded)
    
        
        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
        else:
            hidden = self.dropout(hidden[-1,:,:])
        
        output = self.out(hidden)        
        
        return output

In [67]:
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25

model = BERTGRUSentiment(bert,
                         HIDDEN_DIM,
                         OUTPUT_DIM,
                         N_LAYERS,
                         BIDIRECTIONAL,
                         DROPOUT)

In [68]:
for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False

In [69]:
for name, param in model.named_parameters():                
    if param.requires_grad:
        print(name)

rnn.weight_ih_l0
rnn.weight_hh_l0
rnn.bias_ih_l0
rnn.bias_hh_l0
rnn.weight_ih_l0_reverse
rnn.weight_hh_l0_reverse
rnn.bias_ih_l0_reverse
rnn.bias_hh_l0_reverse
rnn.weight_ih_l1
rnn.weight_hh_l1
rnn.bias_ih_l1
rnn.bias_hh_l1
rnn.weight_ih_l1_reverse
rnn.weight_hh_l1_reverse
rnn.bias_ih_l1_reverse
rnn.bias_hh_l1_reverse
out.weight
out.bias


In [70]:
opt = optim.Adam(model.parameters())

In [71]:
model = model.to(device)

In [72]:
def evaluate_model(model, dataloader, labels, apply_loss_weighting):
    if not apply_loss_weighting:
        loss_func = nn.BCEWithLogitsLoss()
    
    with torch.no_grad():
        model.eval()
        
        all_pred = np.array([])
        total_loss = 0
        
        for inp, labels in tqdm(dataloader):
            inp = inp.to(device)
            labels = labels.to(device)
            
            loss_weights = torch.zeros_like(labels)
            loss_weights[labels == 0] = class_weights[0]
            loss_weights[labels == 1] = class_weights[1]
            loss_func = nn.BCEWithLogitsLoss(weight=loss_weights)
        
            pred = model(inp)
            loss = loss_func(pred, labels.float())
            pred = torch.sigmoid(pred)
            
            all_pred = np.concatenate([all_pred, pred.cpu().round().numpy().squeeze()])
            total_loss += loss.item()
        
        model.train()
    
    return all_pred, total_loss / len(dataloader)

In [73]:
best_validation_loss = 1e9
save_path = f"./runs/{time.time()}_{model.__class__.__qualname__}_{'weighted' if apply_loss_weighting else 'no_weight'}"
epochs_since_best_validation = 0

os.makedirs(save_path, exist_ok=True)

if not apply_loss_weighting:
    loss_func = nn.BCEWithLogitsLoss()
    
for epoch in range(1, 101):
    model.train()
    epoch_st = time.time()
    epoch_loss = 0
    
    batch_acc_loss = 0
    
    for batch_no, (inp, labels) in enumerate(tqdm(train_dataloader)):
        inp = inp.to(device)
        labels = labels.to(device)

        if apply_loss_weighting:
            loss_weights = torch.zeros_like(labels)
            loss_weights[labels == 0] = class_weights[0]
            loss_weights[labels == 1] = class_weights[1]
            loss_func = nn.BCEWithLogitsLoss(weight=loss_weights)

        predictions = model(inp)
        loss = loss_func(predictions, labels.float())

        opt.zero_grad()
        loss.backward()
        opt.step()

        epoch_loss += loss.item()
        batch_acc_loss += loss.item()

        if batch_no != 0 and batch_no % 40 == 0:
            print(f"[{epoch}:{batch_no}] Loss: {(batch_acc_loss / 40):.3f}")
            batch_acc_loss = 0

    epoch_dt = time.time() - epoch_st
    print(f"[{epoch}:END] Took {epoch_dt:.3f}s")
    print(f"[{epoch}:END] Training Loss: {(epoch_loss / len(train_dataloader)):.3f}")
    
    validation_pred, avg_validation_loss = evaluate_model(model, val_dataloader, val_labels, apply_loss_weighting)
    print(f"[{epoch}:END] Validation Loss: {avg_validation_loss:.3f}")
    print(f"[{epoch}:END] Validation Accuracy: {(validation_pred == val_labels).mean() * 100:.3f}%")
    
    print(metrics.classification_report(val_labels, validation_pred))
    plot_confusion_matrix(val_labels, validation_pred)
    plot_roc_curve(val_labels, validation_pred)
    
    epochs_since_best_validation += 1
    
    if avg_validation_loss < best_validation_loss:
        epochs_since_best_validation = 0
        best_validation_loss = avg_validation_loss
        torch.save(model.state_dict(), f"{save_path}/model_{epoch}.pth")
        print(f"[{epoch}:END] Validation loss improved, saved model to {save_path}/model_{epoch}.pth")
    
    if epochs_since_best_validation >= 2:
        print(f"[END] Stopping early as validation loss hasn't improved in 2 epochs")
        break
        
print("Training complete")

  0%|          | 0/1226 [00:00<?, ?it/s]

torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
torch.Size([32, 512, 768])
t

KeyboardInterrupt: 