# CNN News NLP

## By Edith Lee and Joleena Marshall

In [80]:
import pandas as pd
import torch.nn as nn
import numpy as np
import spacy
from tqdm.notebook import tqdm
import re
from collections import Counter
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import torch.optim as optim
from sklearn.feature_extraction.text import CountVectorizer
from sklearn import metrics
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict
from transformers import BertTokenizer, BertModel
from pytorch_transformers import BertConfig

In [2]:
articles = pd.read_csv('cnn_articles/CNN_Articels_clean_2/CNN_Articels_clean.csv')
articles

Unnamed: 0,Index,Author,Date published,Category,Section,Url,Headline,Description,Keywords,Second headline,Article text
0,0,"Jacopo Prisco, CNN",2021-07-15 02:46:59,news,world,https://www.cnn.com/2021/07/14/world/tusimple-...,"There's a shortage of truckers, but TuSimple t...",The e-commerce boom has exacerbated a global t...,"world, There's a shortage of truckers, but TuS...","There's a shortage of truckers, but TuSimple t...","(CNN)Right now, there's a shortage of truck d..."
1,1,"Stephanie Bailey, CNN",2021-05-12 07:52:09,news,world,https://www.cnn.com/2021/05/12/world/ironhand-...,Bioservo's robotic 'Ironhand' could protect fa...,Working in a factory can mean doing the same t...,"world, Bioservo's robotic 'Ironhand' could pro...",A robotic 'Ironhand' could protect factory wor...,(CNN)Working in a factory or warehouse can me...
2,2,"Words by Stephanie Bailey, video by Zahra Jamshed",2021-06-16 02:51:30,news,asia,https://www.cnn.com/2021/06/15/asia/swarm-robo...,This swarm of robots gets smarter the more it ...,"In a Hong Kong warehouse, a swarm of autonomou...","asia, This swarm of robots gets smarter the mo...",This swarm of robots gets smarter the more it ...,"(CNN)In a Hong Kong warehouse, a swarm of aut..."
3,3,Kathryn Vasel,2022-03-18 14:37:21,business,success,https://www.cnn.com/2022/03/18/success/pandemi...,"Two years later, remote work has changed milli...",Here's a look at how the pandemic reshaped peo...,"success, Two years later, remote work has chan...","Two years later, remote work has changed milli...",The pandemic thrust the working world into a n...
4,4,"Paul R. La Monica, CNN Business",2022-03-19 11:41:08,business,investing,https://www.cnn.com/2022/03/19/investing/march...,Why March is so volatile for stocks - CNN,March Madness isn't just for college basketbal...,"investing, Why March is so volatile for stocks...",Why March is so volatile for stocks,New York (CNN Business)March Madness isn't jus...
...,...,...,...,...,...,...,...,...,...,...,...
37944,44992,"Ben Church and Aleks Klosok, CNN",2022-03-01 10:59:10,sport,sport,https://www.cnn.com/2022/03/01/sport/vladimir-...,Russian President Vladimir Putin is being stri...,Russian President Vladimir Putin has been stri...,"sport, Russian President Vladimir Putin is bei...",Vladimir Putin is being stripped of his honora...,(CNN)Russian President Vladimir Putin has bee...
37945,44993,"Tamara Qiblawi, CNN",2022-03-01 12:55:37,news,europe,https://www.cnn.com/2022/03/01/europe/nato-ukr...,"On NATO's doorstep, a former tourist hotspot i...",A long line of men snakes out of an unassuming...,"europe, On NATO's doorstep, a former tourist h...","On NATO's doorstep, a former tourist hotspot i...","Lviv, Ukraine (CNN)A long line of men snakes o..."
37946,44994,"Wayne Sterling and Steve Almasy, CNN",2022-03-01 11:54:44,sport,sport,https://www.cnn.com/2022/03/01/sport/mlb-deadl...,MLB is postponing Opening Day after owners and...,Major League Baseball (MLB) is postponing its ...,"sport, MLB is postponing Opening Day after own...",MLB is postponing Opening Day after owners and...,(CNN)Major League Baseball (MLB) is postponin...
37947,44996,CNN Editorial Research,2013-01-12 01:42:49,news,europe,https://www.cnn.com/2013/01/11/world/europe/mi...,Mikhail Gorbachev Fast Facts - CNN,"Read CNN's Fast Facts on Mikhail Gorbachev, fo...","europe, Mikhail Gorbachev Fast Facts - CNN",Mikhail Gorbachev Fast Facts,Here's a look at the life of Mikhail Gorbachev...


# Pre-processing

In [3]:
# See types of categories
print(articles.Category.unique())

['news' 'business' 'health' 'entertainment' 'sport' 'politics' 'travel'
 'vr' 'style']


In [4]:
# Make sure no null values
articles.isna().sum()

Index              0
Author             0
Date published     0
Category           0
Section            0
Url                0
Headline           0
Description        0
Keywords           0
Second headline    0
Article text       9
dtype: int64

In [5]:
articles_nona = articles.dropna()
articles_nona.drop(['Index'],axis=1, inplace=True)
articles_nona

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return super().drop(


Unnamed: 0,Author,Date published,Category,Section,Url,Headline,Description,Keywords,Second headline,Article text
0,"Jacopo Prisco, CNN",2021-07-15 02:46:59,news,world,https://www.cnn.com/2021/07/14/world/tusimple-...,"There's a shortage of truckers, but TuSimple t...",The e-commerce boom has exacerbated a global t...,"world, There's a shortage of truckers, but TuS...","There's a shortage of truckers, but TuSimple t...","(CNN)Right now, there's a shortage of truck d..."
1,"Stephanie Bailey, CNN",2021-05-12 07:52:09,news,world,https://www.cnn.com/2021/05/12/world/ironhand-...,Bioservo's robotic 'Ironhand' could protect fa...,Working in a factory can mean doing the same t...,"world, Bioservo's robotic 'Ironhand' could pro...",A robotic 'Ironhand' could protect factory wor...,(CNN)Working in a factory or warehouse can me...
2,"Words by Stephanie Bailey, video by Zahra Jamshed",2021-06-16 02:51:30,news,asia,https://www.cnn.com/2021/06/15/asia/swarm-robo...,This swarm of robots gets smarter the more it ...,"In a Hong Kong warehouse, a swarm of autonomou...","asia, This swarm of robots gets smarter the mo...",This swarm of robots gets smarter the more it ...,"(CNN)In a Hong Kong warehouse, a swarm of aut..."
3,Kathryn Vasel,2022-03-18 14:37:21,business,success,https://www.cnn.com/2022/03/18/success/pandemi...,"Two years later, remote work has changed milli...",Here's a look at how the pandemic reshaped peo...,"success, Two years later, remote work has chan...","Two years later, remote work has changed milli...",The pandemic thrust the working world into a n...
4,"Paul R. La Monica, CNN Business",2022-03-19 11:41:08,business,investing,https://www.cnn.com/2022/03/19/investing/march...,Why March is so volatile for stocks - CNN,March Madness isn't just for college basketbal...,"investing, Why March is so volatile for stocks...",Why March is so volatile for stocks,New York (CNN Business)March Madness isn't jus...
...,...,...,...,...,...,...,...,...,...,...
37944,"Ben Church and Aleks Klosok, CNN",2022-03-01 10:59:10,sport,sport,https://www.cnn.com/2022/03/01/sport/vladimir-...,Russian President Vladimir Putin is being stri...,Russian President Vladimir Putin has been stri...,"sport, Russian President Vladimir Putin is bei...",Vladimir Putin is being stripped of his honora...,(CNN)Russian President Vladimir Putin has bee...
37945,"Tamara Qiblawi, CNN",2022-03-01 12:55:37,news,europe,https://www.cnn.com/2022/03/01/europe/nato-ukr...,"On NATO's doorstep, a former tourist hotspot i...",A long line of men snakes out of an unassuming...,"europe, On NATO's doorstep, a former tourist h...","On NATO's doorstep, a former tourist hotspot i...","Lviv, Ukraine (CNN)A long line of men snakes o..."
37946,"Wayne Sterling and Steve Almasy, CNN",2022-03-01 11:54:44,sport,sport,https://www.cnn.com/2022/03/01/sport/mlb-deadl...,MLB is postponing Opening Day after owners and...,Major League Baseball (MLB) is postponing its ...,"sport, MLB is postponing Opening Day after own...",MLB is postponing Opening Day after owners and...,(CNN)Major League Baseball (MLB) is postponin...
37947,CNN Editorial Research,2013-01-12 01:42:49,news,europe,https://www.cnn.com/2013/01/11/world/europe/mi...,Mikhail Gorbachev Fast Facts - CNN,"Read CNN's Fast Facts on Mikhail Gorbachev, fo...","europe, Mikhail Gorbachev Fast Facts - CNN",Mikhail Gorbachev Fast Facts,Here's a look at the life of Mikhail Gorbachev...


In [41]:
nlp = spacy.load('en_core_web_sm', disable = ['ner', 'parser'])

rows = []
for idx in tqdm(range(len(articles_nona))):
    row = articles_nona.iloc[idx].copy()
    
    # first we remove numeric characters and lowercase everything
    cleaned_review = re.sub("[^A-Za-z']+", ' ', row['Headline'].replace('<br />', ' ').replace('CNN', ' ')).lower()
    
    # we let spaCy tokenize and lemmatize the text for us
    tokenized_review = nlp(cleaned_review)
    cleaned_tokenized = [token.lemma_ for token in tokenized_review if ((not token.is_stop) or (' ' in token.text))]
    
    if len(cleaned_tokenized) > 1:
        row['cleaned_headline'] = ' '.join(cleaned_tokenized)
    rows.append(row)
df_clean = pd.DataFrame(rows)
df_clean.head()

  0%|          | 0/37940 [00:00<?, ?it/s]

Unnamed: 0,Author,Date published,Category,Section,Url,Headline,Description,Keywords,Second headline,Article text,cleaned_headline
0,"Jacopo Prisco, CNN",2021-07-15 02:46:59,news,world,https://www.cnn.com/2021/07/14/world/tusimple-...,"There's a shortage of truckers, but TuSimple t...",The e-commerce boom has exacerbated a global t...,"world, There's a shortage of truckers, but TuS...","There's a shortage of truckers, but TuSimple t...","(CNN)Right now, there's a shortage of truck d...",shortage trucker tusimple think solution drive...
1,"Stephanie Bailey, CNN",2021-05-12 07:52:09,news,world,https://www.cnn.com/2021/05/12/world/ironhand-...,Bioservo's robotic 'Ironhand' could protect fa...,Working in a factory can mean doing the same t...,"world, Bioservo's robotic 'Ironhand' could pro...",A robotic 'Ironhand' could protect factory wor...,(CNN)Working in a factory or warehouse can me...,bioservo robotic ' ironhand ' protect factory ...
2,"Words by Stephanie Bailey, video by Zahra Jamshed",2021-06-16 02:51:30,news,asia,https://www.cnn.com/2021/06/15/asia/swarm-robo...,This swarm of robots gets smarter the more it ...,"In a Hong Kong warehouse, a swarm of autonomou...","asia, This swarm of robots gets smarter the mo...",This swarm of robots gets smarter the more it ...,"(CNN)In a Hong Kong warehouse, a swarm of aut...",swarm robot get smart work
3,Kathryn Vasel,2022-03-18 14:37:21,business,success,https://www.cnn.com/2022/03/18/success/pandemi...,"Two years later, remote work has changed milli...",Here's a look at how the pandemic reshaped peo...,"success, Two years later, remote work has chan...","Two years later, remote work has changed milli...",The pandemic thrust the working world into a n...,year later remote work change million career
4,"Paul R. La Monica, CNN Business",2022-03-19 11:41:08,business,investing,https://www.cnn.com/2022/03/19/investing/march...,Why March is so volatile for stocks - CNN,March Madness isn't just for college basketbal...,"investing, Why March is so volatile for stocks...",Why March is so volatile for stocks,New York (CNN Business)March Madness isn't jus...,march volatile stock


In [42]:
clean_nona = df_clean.dropna()
clean_nona

Unnamed: 0,Author,Date published,Category,Section,Url,Headline,Description,Keywords,Second headline,Article text,cleaned_headline
0,"Jacopo Prisco, CNN",2021-07-15 02:46:59,news,world,https://www.cnn.com/2021/07/14/world/tusimple-...,"There's a shortage of truckers, but TuSimple t...",The e-commerce boom has exacerbated a global t...,"world, There's a shortage of truckers, but TuS...","There's a shortage of truckers, but TuSimple t...","(CNN)Right now, there's a shortage of truck d...",shortage trucker tusimple think solution drive...
1,"Stephanie Bailey, CNN",2021-05-12 07:52:09,news,world,https://www.cnn.com/2021/05/12/world/ironhand-...,Bioservo's robotic 'Ironhand' could protect fa...,Working in a factory can mean doing the same t...,"world, Bioservo's robotic 'Ironhand' could pro...",A robotic 'Ironhand' could protect factory wor...,(CNN)Working in a factory or warehouse can me...,bioservo robotic ' ironhand ' protect factory ...
2,"Words by Stephanie Bailey, video by Zahra Jamshed",2021-06-16 02:51:30,news,asia,https://www.cnn.com/2021/06/15/asia/swarm-robo...,This swarm of robots gets smarter the more it ...,"In a Hong Kong warehouse, a swarm of autonomou...","asia, This swarm of robots gets smarter the mo...",This swarm of robots gets smarter the more it ...,"(CNN)In a Hong Kong warehouse, a swarm of aut...",swarm robot get smart work
3,Kathryn Vasel,2022-03-18 14:37:21,business,success,https://www.cnn.com/2022/03/18/success/pandemi...,"Two years later, remote work has changed milli...",Here's a look at how the pandemic reshaped peo...,"success, Two years later, remote work has chan...","Two years later, remote work has changed milli...",The pandemic thrust the working world into a n...,year later remote work change million career
4,"Paul R. La Monica, CNN Business",2022-03-19 11:41:08,business,investing,https://www.cnn.com/2022/03/19/investing/march...,Why March is so volatile for stocks - CNN,March Madness isn't just for college basketbal...,"investing, Why March is so volatile for stocks...",Why March is so volatile for stocks,New York (CNN Business)March Madness isn't jus...,march volatile stock
...,...,...,...,...,...,...,...,...,...,...,...
37944,"Ben Church and Aleks Klosok, CNN",2022-03-01 10:59:10,sport,sport,https://www.cnn.com/2022/03/01/sport/vladimir-...,Russian President Vladimir Putin is being stri...,Russian President Vladimir Putin has been stri...,"sport, Russian President Vladimir Putin is bei...",Vladimir Putin is being stripped of his honora...,(CNN)Russian President Vladimir Putin has bee...,russian president vladimir putin strip honorar...
37945,"Tamara Qiblawi, CNN",2022-03-01 12:55:37,news,europe,https://www.cnn.com/2022/03/01/europe/nato-ukr...,"On NATO's doorstep, a former tourist hotspot i...",A long line of men snakes out of an unassuming...,"europe, On NATO's doorstep, a former tourist h...","On NATO's doorstep, a former tourist hotspot i...","Lviv, Ukraine (CNN)A long line of men snakes o...",nato doorstep tourist hotspot ukraine dig resi...
37946,"Wayne Sterling and Steve Almasy, CNN",2022-03-01 11:54:44,sport,sport,https://www.cnn.com/2022/03/01/sport/mlb-deadl...,MLB is postponing Opening Day after owners and...,Major League Baseball (MLB) is postponing its ...,"sport, MLB is postponing Opening Day after own...",MLB is postponing Opening Day after owners and...,(CNN)Major League Baseball (MLB) is postponin...,mlb postpone opening day owner player agree la...
37947,CNN Editorial Research,2013-01-12 01:42:49,news,europe,https://www.cnn.com/2013/01/11/world/europe/mi...,Mikhail Gorbachev Fast Facts - CNN,"Read CNN's Fast Facts on Mikhail Gorbachev, fo...","europe, Mikhail Gorbachev Fast Facts - CNN",Mikhail Gorbachev Fast Facts,Here's a look at the life of Mikhail Gorbachev...,mikhail gorbachev fast fact


In [44]:
clean_nona.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 37922 entries, 0 to 37948
Data columns (total 12 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   Author            37922 non-null  object
 1   Date published    37922 non-null  object
 2   Category          37922 non-null  object
 3   Section           37922 non-null  object
 4   Url               37922 non-null  object
 5   Headline          37922 non-null  object
 6   Description       37922 non-null  object
 7   Keywords          37922 non-null  object
 8   Second headline   37922 non-null  object
 9   Article text      37922 non-null  object
 10  cleaned_headline  37922 non-null  object
 11  Category_code     37922 non-null  object
dtypes: object(12)
memory usage: 3.8+ MB


In [45]:
# Category to number mapping
clean_nona['Category'] = clean_nona.Category.astype('category')
clean_nona['Category_code'] = clean_nona.Category
clean_nona['Category_code'] = clean_nona.Category.cat.codes
clean_nona.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  clean_nona['Category'] = clean_nona.Category.astype('category')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  clean_nona['Category_code'] = clean_nona.Category
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  clean_nona['Category_code'] = clean_nona.Category.cat.codes


Unnamed: 0,Author,Date published,Category,Section,Url,Headline,Description,Keywords,Second headline,Article text,cleaned_headline,Category_code
0,"Jacopo Prisco, CNN",2021-07-15 02:46:59,news,world,https://www.cnn.com/2021/07/14/world/tusimple-...,"There's a shortage of truckers, but TuSimple t...",The e-commerce boom has exacerbated a global t...,"world, There's a shortage of truckers, but TuS...","There's a shortage of truckers, but TuSimple t...","(CNN)Right now, there's a shortage of truck d...",shortage trucker tusimple think solution drive...,3
1,"Stephanie Bailey, CNN",2021-05-12 07:52:09,news,world,https://www.cnn.com/2021/05/12/world/ironhand-...,Bioservo's robotic 'Ironhand' could protect fa...,Working in a factory can mean doing the same t...,"world, Bioservo's robotic 'Ironhand' could pro...",A robotic 'Ironhand' could protect factory wor...,(CNN)Working in a factory or warehouse can me...,bioservo robotic ' ironhand ' protect factory ...,3
2,"Words by Stephanie Bailey, video by Zahra Jamshed",2021-06-16 02:51:30,news,asia,https://www.cnn.com/2021/06/15/asia/swarm-robo...,This swarm of robots gets smarter the more it ...,"In a Hong Kong warehouse, a swarm of autonomou...","asia, This swarm of robots gets smarter the mo...",This swarm of robots gets smarter the more it ...,"(CNN)In a Hong Kong warehouse, a swarm of aut...",swarm robot get smart work,3
3,Kathryn Vasel,2022-03-18 14:37:21,business,success,https://www.cnn.com/2022/03/18/success/pandemi...,"Two years later, remote work has changed milli...",Here's a look at how the pandemic reshaped peo...,"success, Two years later, remote work has chan...","Two years later, remote work has changed milli...",The pandemic thrust the working world into a n...,year later remote work change million career,0
4,"Paul R. La Monica, CNN Business",2022-03-19 11:41:08,business,investing,https://www.cnn.com/2022/03/19/investing/march...,Why March is so volatile for stocks - CNN,March Madness isn't just for college basketbal...,"investing, Why March is so volatile for stocks...",Why March is so volatile for stocks,New York (CNN Business)March Madness isn't jus...,march volatile stock,0


In [49]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(clean_nona['cleaned_headline'].values, clean_nona['Category_code'].values,
                                                      test_size=0.2, random_state=1)
X_train.shape, X_test.shape

((30337,), (7585,))

In [51]:
X_train_list = [t.split() for t in X_train]
X_test_list = [t.split() for t in X_test]

In [52]:
def vocab(headline):
    vocab = defaultdict(float)
    for line in headline:
        words = set(line)
        for word in words:
            vocab[word] += 1
    return vocab  

In [55]:
word_count = vocab(X_train_list)
word_count

defaultdict(float,
            {'dominant': 17.0,
             'further': 1.0,
             'sportswoman': 2.0,
             'shiffrin': 41.0,
             'claim': 294.0,
             'mikaela': 38.0,
             'german': 205.0,
             'love': 121.0,
             'climate': 230.0,
             'change': 337.0,
             'concern': 69.0,
             'car': 162.0,
             'sour': 8.0,
             'assange': 78.0,
             'julian': 58.0,
             'sweden': 77.0,
             'ecuador': 15.0,
             'allow': 76.0,
             'question': 151.0,
             'say': 1389.0,
             'democrat': 21.0,
             'california': 83.0,
             'politic': 1983.0,
             'house': 204.0,
             'seat': 38.0,
             'republican': 77.0,
             'lose': 221.0,
             'win': 1251.0,
             'concede': 7.0,
             'australian': 431.0,
             "'": 4594.0,
             'enemy': 10.0,
             'force': 199.0,
   

In [58]:
model = BertModel.from_pretrained('bert-base-uncased',
           output_hidden_states = True, return_dict=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [82]:
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, hidden_dropout_prob=0.25,
        num_hidden_layers=10, num_attention_heads=10, intermediate_size=3000)

In [83]:
def text_tokens(text, max_seq_length=50, padding_start = False):
    tok_text = tokenizer.tokenize(text)
    if len(tok_text) > max_seq_length:
        tok_text = tok_text[:max_seq_length]
    ids_text  = tokenizer.convert_tokens_to_ids(tok_text)
    padding = [0] * (max_seq_length - len(ids_text))
    if padding_start:
        out = padding + ids_text
    else:
        out = ids_text + padding
    return np.array(out)

In [84]:
class HeadlineDataset(Dataset):
    def __init__(self, X, y):
        self.x = X
        self.y = y
        
    def __getitem__(self, index):
        x = self.x[index]
        x = text_tokens(x, padding_start=False)
        return x, self.y[index]
    
    def __len__(self):
        return len(self.y)

In [85]:
train_ds = HeadlineDataset(X_train, y_train)
test_ds = HeadlineDataset(X_test, y_test)

In [86]:
batch_size = 10
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

In [87]:
class BertClassification(nn.Module):
    def __init__(self):
        super(BertClassification, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased',
                                               output_hidden_states = True, return_dict=True)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
        nn.init.xavier_normal_(self.classifier.weight)
        
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        outputs = self.bert(input_ids, token_type_ids, attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits
    
    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

In [106]:
def train_model(model, loss_func, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        model.train()
        sum_loss = 0
        total = 0
        weights = [0.1]
        class_weight = torch.FloatTensor(weights)
        for x, y in train_dl:
            x = x
            y = y.unsqueeze(1).float()
            optimizer.zero_grad()
            logits = model(x)
            loss = loss_func(logits, y)
            loss.backward()
            optimizer.step()
            sum_loss += loss.item()*y.shape[0]
            total += y.shape[0]
        epoch_loss = sum_loss/total
        val_loss, accuracy = eval_model(model, loss_func)
        print('train loss: {:.3f}, valid loss {:.3f} accuracy {:.3f}'.format(epoch_loss, val_loss, accuracy))

In [107]:
def eval_model(model, loss_func):
    model.eval()
    sum_loss = 0
    total = 0
    correct = 0
    weights = [0.1, 0.9]
    class_weight = torch.FloatTensor(weights)
    for x, y in valid_dl:
        x = x
        y = y.unsqueeze(1).float()
        y_hat = model(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y) 
        y_pred = y_hat > 0
        correct += (y_pred.float() == y).float().sum()
        sum_loss += loss.item()*y.shape[0]
        total += y.shape[0]
    accuracy = correct/total
    epoch_loss = sum_loss/total
    return epoch_loss, accuracy

In [108]:
model = BertClassification()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [109]:
optimizer = optim.Adam(model.parameters(), lr = 0.01)
loss_func = nn.BCEWithLogitsLoss()

In [110]:
train_model(model, loss_func, optimizer, num_epochs=1)

KeyboardInterrupt: 