# intermediate_SARC
This notebook takes our custom XED binary datas_et and trains an intermediate model.

## Imports & Settings

First, update working directory to parent so that we may use our custom functions

In [1]:
import os
os.chdir('..')
# os.getcwd( )

In [3]:
import params
from utils import *
from trainer import *

import numpy as np
import pandas as pd
from datasets import load_from_disk

from transformers import RobertaTokenizer, RobertaForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

# suppress model warning
from transformers import logging
logging.set_verbosity_error()

# set logging level
import logging
logging.basicConfig(format='%(message)s', level=logging.INFO)

<torch._C.Generator at 0x2940aaad0>

In [None]:
# set general seeds
set_seeds(1)

# set dataloader generator seed
g = torch.Generator()
g.manual_seed(1)

# set params for this model
params.num_labels = 2
params.output_dir = "model_saves/intermediate_SARC_01"
params.dataset_path = "data/inter_SARC/itesd_sarc_balanced.hf"

# Ensure we're on an ARM environment if necessary.
platform_check()

We're Armed: macOS-13.1-arm64-i386-64bit


## Load Data

### SARC

In [4]:
datasets = load_from_disk(params.dataset_path)
datasets

Unnamed: 0,label,text,author,subreddit,score,ups,downs,date,created_utc,parent_comment,num_word_text
0,0,That's paid propaganda created by Big Law Enfo...,akronix10,news,15,15,0,2014-11,2014-11-26 20:50:39,"If TV tells me anything, most beat cops hate i...",14
1,0,I imagine they'll use it for regenerative brak...,f03nix,Futurology,1,1,0,2015-07,2015-07-20 07:30:02,"Ummm, powered by a wind turbine? Sorry...but t...",8
2,0,Distraction and diversion from the real proble...,Shinranshonin,politics,1,-1,-1,2016-12,2016-12-12 13:51:13,"The media call Trump a 'cyberbully,' even when...",28
3,0,Not at this point in the movie,spot35,funny,2,2,0,2015-11,2015-11-21 07:59:49,"But it's a protocol droid, fluent in over 6 mi...",7
4,0,I have a feeling you don't understand the Mand...,throwaway-11-20-2016,MandelaEffect,1,-1,-1,2016-12,2016-12-28 16:14:36,so should the side bar say that a large number...,10


In [5]:
# we will need to view and prep the datasets
# this is more easily done as dataframes
train_df = datasets['train'].to_pandas()
validate_df = datasets['validation'].to_panda

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100000 entries, 0 to 99999
Data columns (total 11 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   label           100000 non-null  int64 
 1   text            100000 non-null  object
 2   author          100000 non-null  object
 3   subreddit       100000 non-null  object
 4   score           100000 non-null  int64 
 5   ups             100000 non-null  int64 
 6   downs           100000 non-null  int64 
 7   date            100000 non-null  object
 8   created_utc     100000 non-null  object
 9   parent_comment  100000 non-null  object
 10  num_word_text   100000 non-null  int64 
dtypes: int64(5), object(6)
memory usage: 8.4+ MB


In [6]:
# view training dataset
print("train_df Info:")
print(train_df.info())
print("\ntrain_df Value Counts")
print(train_df['label'].value_counts())

0    50000
1    50000
Name: label, dtype: int64

In [None]:
# view validation dataset
print("validate_df Info:")
print(validate_df.info())
print("\n validate_df Value Counts")
print(validate_df['label'].value_counts())

## Preprocess

In [8]:
train_token_ids, train_attention_masks = encode_text(train_df.text.values)
validate_token_ids, validate_attention_masks = encode_text(validate_df.text.values)

In [None]:
train_features = []
for i in range(len(train_token_ids)):
    train_features.append({'label': train_df.label.values[i], 
                           'input_ids': train_token_ids[i], 
                           'attention_mask':train_attention_masks[i]})

validate_features = []
for i in range(len(validate_token_ids)):
    validate_features.append({'label': validate_df.label.values[i], 
                              'input_ids': validate_token_ids[i],
                              'attention_mask':validate_attention_masks[i]})

## Data Split
We split the dataset into train (80%) and validation (20%) sets, and wrap them around a torch.utils.data.DataLoader object.

In [9]:
# Prepare DataLoader
train_dataloader = DataLoader(
            train_features,
            sampler = RandomSampler(train_features),
            batch_size = params.batch_size,
            worker_init_fn=seed_worker,
            generator=g,
            collate_fn=seq_class_collate
        )

validation_dataloader = DataLoader(
            validate_features,
            sampler = RandomSampler(validate_features),
            batch_size = params.batch_size,
            worker_init_fn=seed_worker,
            generator=g,
            collate_fn=seq_class_collate
        )

## Train

Download transformers.RobertaForSequenceClassificatio, which is a RoBERTa model with a linear layer for sentence classification (or regression) on top of the pooled output:

In [10]:
# Load the RobertaForSequenceClassification model
model = RobertaForSequenceClassification.from_pretrained('roberta-base',
                                                         num_labels = params.num_labels,
                                                         output_attentions = False,
                                                         output_hidden_states = False,
                                                         )

from torchinfo import summary
summary(model, input_size=(1, 512), dtypes=['torch.IntTensor'])

Layer (type:depth-idx)                                       Output Shape              Param #
RobertaForSequenceClassification                             [1, 2]                    --
├─RobertaModel: 1-1                                          [1, 512, 768]             --
│    └─RobertaEmbeddings: 2-1                                [1, 512, 768]             --
│    │    └─Embedding: 3-1                                   [1, 512, 768]             38,603,520
│    │    └─Embedding: 3-2                                   [1, 512, 768]             768
│    │    └─Embedding: 3-3                                   [1, 512, 768]             394,752
│    │    └─LayerNorm: 3-4                                   [1, 512, 768]             1,536
│    │    └─Dropout: 3-5                                     [1, 512, 768]             --
│    └─RobertaEncoder: 2-2                                   [1, 512, 768]             --
│    │    └─ModuleList: 3-6                                  --               

Set model to device, initialize trainer

In [11]:
model.to(params.device)
print(f"Device: {params.device}")

optimizer = torch.optim.Adam(params=model.parameters(), 
                             lr=params.learning_rate,
                             weight_decay=params.weight_decay) #roberta

trainer = Trainer(model=model,
                  device=params.device,
                  tokenizer=params.tokenizer,
                  train_dataloader=train_dataloader,
                  validation_dataloader=validation_dataloader,
                  epochs=params.epochs,
                  optimizer=optimizer,
                  val_loss_fn=params.val_loss_fn,
                  num_labels=params.num_labels,
                  output_dir=params.output_dir,
                  save_freq=params.save_freq,
                  checkpoint_freq=params.checkpoint_freq)

output_parameters()

Trained Dataset: data/SARC/SARC_preped_sampled_train.csv
Device: mps


Fit the model to our training data.

In [12]:
trainer.fit()

  incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
Epoch 1: 100%|██████████| 5000/5000 [1:49:20<00:00,  1.31s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:19<00:00,  2.50batch/s]


	 - Train loss: 0.552180
	 - Validation Loss: 0.507575
	 - Validation Accuracy: 0.746200
	 - Validation F1: 0.711028
	 - Validation Recall: 0.671129
	 - Validation Precision: 0.789116
	 * Model @ epoch 1 saved to model_saves/intermediate_SARC_01/E01_A0.75_F0.71


Epoch 2: 100%|██████████| 5000/5000 [1:48:42<00:00,  1.30s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:20<00:00,  2.50batch/s]


	 - Train loss: 0.478418
	 - Validation Loss: 0.511777
	 - Validation Accuracy: 0.754750
	 - Validation F1: 0.725020
	 - Validation Recall: 0.696219
	 - Validation Precision: 0.788127
	 * Model @ epoch 2 saved to model_saves/intermediate_SARC_01/E02_A0.75_F0.73


Epoch 3: 100%|██████████| 5000/5000 [1:49:55<00:00,  1.32s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:23<00:00,  2.48batch/s]


	 - Train loss: 0.420092
	 - Validation Loss: 0.527644
	 - Validation Accuracy: 0.756500
	 - Validation F1: 0.731763
	 - Validation Recall: 0.712443
	 - Validation Precision: 0.780065
	 * Model @ epoch 3 saved to model_saves/intermediate_SARC_01/E03_A0.76_F0.73


Epoch 4: 100%|██████████| 5000/5000 [1:50:16<00:00,  1.32s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:21<00:00,  2.49batch/s]


	 - Train loss: 0.354929
	 - Validation Loss: 0.561886
	 - Validation Accuracy: 0.753000
	 - Validation F1: 0.740327
	 - Validation Recall: 0.755949
	 - Validation Precision: 0.749303
	 * Model @ epoch 4 saved to model_saves/intermediate_SARC_01/E04_A0.75_F0.74


Epoch 5: 100%|██████████| 5000/5000 [1:48:57<00:00,  1.31s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:19<00:00,  2.50batch/s]


	 - Train loss: 0.288942
	 - Validation Loss: 0.657629
	 - Validation Accuracy: 0.751000
	 - Validation F1: 0.734821
	 - Validation Recall: 0.742193
	 - Validation Precision: 0.754510
	 * Model @ epoch 5 saved to model_saves/intermediate_SARC_01/E05_A0.75_F0.73


Epoch 6: 100%|██████████| 5000/5000 [1:48:54<00:00,  1.31s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:21<00:00,  2.49batch/s]


	 - Train loss: 0.237295
	 - Validation Loss: 0.722420
	 - Validation Accuracy: 0.745850
	 - Validation F1: 0.725334
	 - Validation Recall: 0.720531
	 - Validation Precision: 0.758503
	 * Model @ epoch 6 saved to model_saves/intermediate_SARC_01/E06_A0.75_F0.73


Epoch 7: 100%|██████████| 5000/5000 [1:48:48<00:00,  1.31s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:19<00:00,  2.50batch/s]


	 - Train loss: 0.189573
	 - Validation Loss: 0.766260
	 - Validation Accuracy: 0.747050
	 - Validation F1: 0.732267
	 - Validation Recall: 0.740162
	 - Validation Precision: 0.750637
	 * Model @ epoch 7 saved to model_saves/intermediate_SARC_01/E07_A0.75_F0.73


Epoch 8: 100%|██████████| 5000/5000 [1:48:48<00:00,  1.31s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:21<00:00,  2.49batch/s]


	 - Train loss: 0.154791
	 - Validation Loss: 0.861222
	 - Validation Accuracy: 0.739600
	 - Validation F1: 0.729675
	 - Validation Recall: 0.750797
	 - Validation Precision: 0.733539
	 * Model @ epoch 8 saved to model_saves/intermediate_SARC_01/E08_A0.74_F0.73


Epoch 9: 100%|██████████| 5000/5000 [1:48:52<00:00,  1.31s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:21<00:00,  2.49batch/s]


	 - Train loss: 0.130389
	 - Validation Loss: 0.927679
	 - Validation Accuracy: 0.745100
	 - Validation F1: 0.737106
	 - Validation Recall: 0.762999
	 - Validation Precision: 0.737391
	 * Model @ epoch 9 saved to model_saves/intermediate_SARC_01/E09_A0.75_F0.74


Epoch 10: 100%|██████████| 5000/5000 [1:49:23<00:00,  1.31s/batch]
	 Validation 1249: 100%|██████████| 1250/1250 [08:30<00:00,  2.45batch/s]


	 - Train loss: 0.109669
	 - Validation Loss: 1.007827
	 - Validation Accuracy: 0.741000
	 - Validation F1: 0.727307
	 - Validation Recall: 0.740828
	 - Validation Precision: 0.739979
	 * Model @ epoch 10 saved to model_saves/intermediate_SARC_01/E10_A0.74_F0.73
