### Fine-tuning pre-trained BERT for Sequence Classification model
#### Emotion labels that were used initially-
['serenity', 'joy', 'disgust', 'anxious', 'optimism', 'vigilance', 'sad', 'fear']

### Observations from implementation in emotion_analysis.ipynb -
#### Model scores were bad mostly due to less annotated data (emotions were added only for subsets of social and media groups' repsonses).
#### Training loss was monitored at steps and it was fluctuating up and down.
#### Validation loss was decreasing in each eval_step but it was sometimes falling below the training loss.
#### Validation loss was not showing an upward trajectory even if the number of epochs was increased, potentially indicating that overfitting wasn't happening so far, which could have been the case as there were a lot of features with all bert layers unfrozen and small dataset size.

### Things tried for improving the training loss fluctuation-
##### 1] Learning rates - 2e-5, default LR as part of Training Class args (5e-5), 0.0001, 0.001, 0.1
##### 2] Training+Validation dataset size tried - 50, 100, 200
##### 3] Using BERT instead of DistilBERT
##### 4] Changing the max_length for tokens - 512, 250, 200, 100, 80
##### 5] Freezing all layers to improve memory requirements but due to training loss appearing to oscillate, unfreezing all layers
##### and also freezing only the embeddings and initial trasnformer layers (upto 3) was tried
##### 6] experimenting by adding/substracting the following pipeline components for this specific task - stop word removal, lemmatizer, punctuations
##### 7] changing the number of epochs - 3, 5, 1, 10, 15, 30
##### 8] with limited data, increasing the training set size and very small validation  set size was also tried

##### Finally, to check if less training examples given 8 labels was causing an issue,  we reduced the number of labels- 
##### optimism was mapped to serenity
##### responses tagged under vigilance, sadness and fear all involved some level of anxiety and so all 3 were mapped to anxiety

#### Emotion labels that were used finally
['serenity', 'joy', 'disgust', 'anxious']

In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, classification_report
from sklearn_crfsuite.metrics import flat_classification_report
import torch
from transformers import TrainingArguments, Trainer
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import EarlyStoppingCallback
from transformers import AdamW
import torch.optim as optim

c:\users\vhpld\anaconda3\lib\site-packages\numpy\.libs\libopenblas.4SP5SUA7CBGXUEOC35YP2ASOICYYEQZZ.gfortran-win_amd64.dll
c:\users\vhpld\anaconda3\lib\site-packages\numpy\.libs\libopenblas.PYQHXLVVQ7VESDPUVUADXEVJOBGHJPAY.gfortran-win_amd64.dll
c:\users\vhpld\anaconda3\lib\site-packages\numpy\.libs\libopenblas.wcdjnk7yvmpzq2me2zzhjjrj3jikndb7.gfortran-win_amd64.dll
  stacklevel=1)


In [2]:
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
#pip install contractions
#import contractions
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
nltk.download('wordnet')
from nltk.stem import WordNetLemmatizer
import regex as re

https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/index.xml


[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\vhpld\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\vhpld\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
def case_handling(df,columns=[]):
    
    for col in columns:
        df[col] = df[col].str.lower() 
        
    return df       

In [4]:
def remove_punctuations(df,columns=[]):
    
    for col in columns:
        df[col] = df[col].apply(lambda text: re.sub(r'[^\w\s]', '', text))
        
    return df

In [5]:
def remove_stopwords(df, columns=[]):
    
    stop_words = set(stopwords.words('english'))
    
    def remove_sw(text):
        txt_output = " ".join([word for word in str(text).split() if word not in stop_words])
        return txt_output
    
    for col in columns:
        df[col] = df[col].apply(lambda text: remove_sw(text))
    
    return df

In [6]:
def lemmatize_words(df, columns=[]):
    
    lemmatizer = WordNetLemmatizer()
    
    def lemmatize(text):
        text_output = " ".join([lemmatizer.lemmatize(word) for word in text.split()])
        return text_output
    
    for col in columns:
        df[col] = df[col].apply(lambda text: lemmatize(text))
        
    return df

In [7]:
def remove_extra_spaces(df,columns=[]):
    
    for col in columns:
        df[col] = df[col].apply(lambda text: re.sub(' +', ' ', text))
        
    return df 

In [15]:
def data_preprocessing(df, columns=[]):
    
    #df = expand_contraction(df,columns)
    df = case_handling(df,columns) 
    df = remove_punctuations(df,columns)
    #df = remove_words_dgits(df,columns)  
    df = remove_stopwords(df,columns) 
    df = lemmatize_words(df, columns)
    df = remove_extra_spaces(df,columns) 
    
    return df

In [16]:
data = pd.read_csv("focus_groups_convos_emotion_analysis.csv", encoding="ISO-8859-1")

In [84]:
# response before pre-processing of text
data.iloc[2]['parent_answer']

"I have a similar experience to Parent 3. My high schooler had a lot of his homework that he had to type and word process... I don't know if that's the right word, into Docs, but it wasn't a huge amount. It wasn't doing research, it wasn't hours and hours. My elementary school guy, he didn't have any homework on the computer at all. And my middle schooler had very little, like type this one thing or something like that. It was very little on the computer. And for myself, in terms of my family, I lock the video games so that there was very little access to video games during the week."

In [17]:
columns=['parent_answer']
data =  data_preprocessing(data, columns)

In [18]:
data.head()

Unnamed: 0,focus_group_subtype,focus_group_subtype_id,doc_no_within_subtype,question_id,question_text,parent_num,parent_answer,emotion
0,gaming_group,1,1,1,So just starting to think through some of the ...,2,oh okay well didnt use much mean teacher would...,serenity
1,gaming_group,1,1,1,So just starting to think through some of the ...,3,would say thing daughter fifth grade would cou...,serenity
2,gaming_group,1,1,1,So just starting to think through some of the ...,5,similar experience parent 3 high schooler lot ...,serenity
3,gaming_group,1,1,1,So just starting to think through some of the ...,5,weekend hour something like wasnt much youtube...,joy
4,gaming_group,1,1,1,So just starting to think through some of the ...,4,go kid probably sound like online far rest dau...,serenity


In [19]:
# response after pre-processing of text
data['parent_answer'].max()

'youtube bottomless hole cant even begin bad transition looking google classroom google meet here youtube easy always still job thing cant sit top accessibility technology type content like youtube thats beneficial kid thats problematic say least'

In [20]:
data['emotion'].value_counts()

anxious     101
serenity     71
disgust      44
joy          13
Name: emotion, dtype: int64

In [21]:
emotions = list(data['emotion'].unique())
label_to_index = {emotion : index for index, emotion in enumerate(emotions)}
label_to_index

{'serenity': 0, 'joy': 1, 'disgust': 2, 'anxious': 3}

In [22]:
data['emotion_label_index'] = data.apply(lambda x : label_to_index[x['emotion']], axis=1)

In [23]:
data.head()

Unnamed: 0,focus_group_subtype,focus_group_subtype_id,doc_no_within_subtype,question_id,question_text,parent_num,parent_answer,emotion,emotion_label_index
0,gaming_group,1,1,1,So just starting to think through some of the ...,2,oh okay well didnt use much mean teacher would...,serenity,0
1,gaming_group,1,1,1,So just starting to think through some of the ...,3,would say thing daughter fifth grade would cou...,serenity,0
2,gaming_group,1,1,1,So just starting to think through some of the ...,5,similar experience parent 3 high schooler lot ...,serenity,0
3,gaming_group,1,1,1,So just starting to think through some of the ...,5,weekend hour something like wasnt much youtube...,joy,1
4,gaming_group,1,1,1,So just starting to think through some of the ...,4,go kid probably sound like online far rest dau...,serenity,0


In [24]:
index_to_label = {index: label for label, index in label_to_index.items()}
index_to_label

{0: 'serenity', 1: 'joy', 2: 'disgust', 3: 'anxious'}

In [217]:
""" tried but not used finally
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=len(label_to_index))
"""

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.bias', 'pre_classi

In [25]:
learning_rate = 2e-5

In [128]:
### freezing all layers except the classification head when distilbert was tried
for param in model.distilbert.parameters():
    param.requires_grad = False

optimizer = optim.Adam([{'params' : model.classifier.parameters()}], lr=1)


"""
Following was not used -
changing the optimizer
differential learning rates

optimizer = AdamW(filter(lambda p : p.requires_grad, model.parameters()), lr = learning_rate)
optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)
"""

"\noptimizer = AdamW(filter(lambda p : p.requires_grad, model.parameters()), lr = learning_rate)\noptim.SGD([\n                {'params': model.base.parameters()},\n                {'params': model.classifier.parameters(), 'lr': 1e-3}\n            ], lr=1e-2, momentum=0.9)\n"

In [16]:
### freezing only the initial pre-trained distilbert model's layers
modules = [model.distilbert.embeddings, model.distilbert.transformer.layer[:2]]
for module in modules:
    for param in module.parameters():
        param.require_grad = False

In [286]:
optimizer??

In [14]:
model.distilbert??

In [64]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [26]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=len(label_to_index))


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [27]:
X = list(data["parent_answer"])
y = list(data["emotion_label_index"])

In [188]:
X

["oh okay. well, use much. mean, teacher would assign reading math assignment ready time time homework. yeah, it. would monitor much time roblox thing entertain themselves, easier control back go much. mean, that's was.",
 'would say thing. daughter fifth grade, would couple newsela projects, half hour homework. playground, neighborhood friends, school projects, school swimming. little bit tablet weekends, internet access youtube life, roblox, gaming. so, different world.',
 "similar experience parent 3. high schooler lot homework type word process... know that's right word, docs, huge amount. research, hour hours. elementary school guy, homework computer all. middle schooler little, like type one thing something like that. little computer. myself, term family, lock video game little access video game week.",
 "go. kid probably more... sound like online than, far, rest them. daughter high school definitely almost homework... assigned school, definitely homework, feel like, good amount.

In [189]:
y

[0,
 0,
 0,
 0,
 1,
 1,
 2,
 1,
 2,
 0,
 2,
 2,
 1,
 0,
 0,
 1,
 0,
 0,
 3,
 1,
 2,
 3,
 2,
 2,
 2,
 1,
 0,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 3,
 3,
 3,
 3,
 1,
 3,
 2,
 2,
 3,
 2,
 1,
 3,
 3,
 1,
 2,
 1,
 0,
 3,
 2,
 2,
 3,
 2,
 0,
 2,
 0,
 2,
 2,
 0,
 2,
 0,
 2,
 3,
 3,
 3,
 2,
 3,
 1,
 3,
 0,
 2,
 0,
 1,
 0,
 2,
 0,
 2]

In [35]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.15)

In [27]:
train_test_split??

In [36]:
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=80)

In [39]:
type(X_train_tokenized)

transformers.tokenization_utils_base.BatchEncoding

In [37]:
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=80)

In [38]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels
        #print(labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        #print(item)
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [39]:
train_dataset = Dataset(X_train_tokenized, y_train)

In [47]:
print(train_dataset[3])

{'input_ids': tensor([  101,  3398,  1012,  2714,  6687,  1019,  6687,  1017,  1010,  2812,
         1010,  4845,  2052, 17781,  3042,  1012, 18546,  8654, 25249,  3047,
         1012,  2288, 18546, 17470,  2082,  1012,  2812,  1010,  1045,  1005,
         1049,  2551,  2188,  1012,  2093,  4845, 26536,  9354,  1012,  3034,
         4455,  1010,  2562,  4251,  1012,  2288, 25249,  2015,  2346, 23467,
         3047,  1010,  2008,  1005,  1055, 10639,  2814,  1012,  2253, 19448,
         1010,  4121,  2518,  2064,  1005,  1056,  1012,  1012,  1012,  2027,
         1005,  2128, 23042,  1012,  2448,  2188,  2082,  1010,  1000,  2073,
         1005,  1055, 25249,  1029,  2215,  2377,  1012,  1000,  2009,  1005,
         1055,  1010,  1045,  1005,  1049,  2145,  2551,  1012,  2507,  2210,
         2978,  1010,  2215,  2131,  2125,  1012,  2812,  1010,  2428,  1012,
         1012,  1012,  2052,  2377,  2208,  2030,  1012,  1012,  1012,  2082,
         1010, 10124,  2707,  1012,  2059,  1010, 

In [40]:
val_dataset = Dataset(X_val_tokenized, y_val)

In [41]:
def compute_metrics(p):
    
    pred, labels = p
    #print(pred)
    #print(pred.shape)
    pred = np.argmax(pred, axis=1)
    """
    labels = p.label_ids
    pred = p.predictions.argmax(-1)
    print(labels)
    print(pred)
    """
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average = 'weighted')
    precision = precision_score(y_true=labels, y_pred=pred, average = 'weighted')
    f1 = f1_score(y_true=labels, y_pred=pred, average = 'weighted')
    #print(f"accuracy: {accuracy}, precision: {precision}, recall: {recall}, f1: {f1}")
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

In [271]:
accuracy_score??

In [50]:
TrainingArguments??

In [42]:

args = TrainingArguments(
    output_dir="focus_groups_emotion_output",
    evaluation_strategy="steps",
    eval_steps=5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=15,
    num_train_epochs=10,
    seed=0,
    load_best_model_at_end=True,
    logging_steps=1,
    learning_rate = learning_rate
)
#learning_rate = learning_rate

In [198]:
args

TrainingArguments(output_dir=focus_groups_emotion_output, overwrite_output_dir=False, do_train=False, do_eval=True, do_predict=False, evaluation_strategy=IntervalStrategy.STEPS, prediction_loss_only=False, per_device_train_batch_size=64, per_device_eval_batch_size=18, gradient_accumulation_steps=1, eval_accumulation_steps=None, learning_rate=5e-05, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=5, max_steps=-1, lr_scheduler_type=SchedulerType.LINEAR, warmup_ratio=0.0, warmup_steps=0, logging_dir=runs\Dec04_19-18-40_DESKTOP-RR0FD0F, logging_strategy=IntervalStrategy.STEPS, logging_first_step=False, logging_steps=1, save_strategy=IntervalStrategy.STEPS, save_steps=500, save_total_limit=None, no_cuda=False, seed=0, fp16=False, fp16_opt_level=O1, fp16_backend=auto, fp16_full_eval=False, local_rank=-1, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=10, dataloader_num_workers=0, past_i

In [43]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    
)

In [275]:
Trainer.__init__??

In [290]:
TrainingArguments??

In [44]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
5,1.2725,1.159274,0.571429,0.441964,0.571429,0.473046
10,1.2347,1.092833,0.571429,0.441964,0.571429,0.473046
15,1.1027,1.065651,0.542857,0.345455,0.542857,0.422222
20,0.9942,1.079866,0.6,0.495961,0.6,0.529469
25,1.1085,1.073334,0.571429,0.471429,0.571429,0.509388
30,0.9482,1.084212,0.571429,0.471429,0.571429,0.509388


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=30, training_loss=1.1574009736378987, metrics={'train_runtime': 2648.5589, 'train_samples_per_second': 0.015, 'total_flos': 0, 'epoch': 7.5, 'init_mem_cpu_alloc_delta': 147456, 'init_mem_cpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 1603510272, 'train_mem_cpu_peaked_delta': 3570049024})

In [None]:
## earlier run when training loss was fluctuating 
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.1078,2.059592,0.25,0.105114,0.25,0.147431
2,2.187,2.048275,0.318182,0.10124,0.318182,0.153605
3,2.077,2.036327,0.318182,0.10124,0.318182,0.153605
4,2.15,2.02973,0.318182,0.10124,0.318182,0.153605
5,2.0669,2.022511,0.318182,0.10124,0.318182,0.153605
6,1.9867,2.016541,0.318182,0.10124,0.318182,0.153605
7,2.0318,2.010084,0.318182,0.103594,0.318182,0.1563
8,2.0786,2.006244,0.318182,0.103594,0.318182,0.1563
9,2.023,2.002145,0.295455,0.100887,0.295455,0.150413
10,1.8523,1.996485,0.295455,0.108852,0.295455,0.159091


[[-2.55430222e-01  5.23206079e-03  7.16976151e-02  1.95126578e-01
   2.06720665e-01 -8.07537064e-02 -4.61732745e-02  8.76906961e-02]
 [-2.38123000e-01  8.29945784e-03 -7.06706196e-03  2.18617037e-01
   2.30133817e-01 -1.01340972e-01 -8.22402388e-02 -1.80257857e-03]
 [-2.10962832e-01  1.43473549e-02  4.01571281e-02  2.37562120e-01
   2.16211408e-01 -1.07282706e-01 -3.60744633e-02  6.73384368e-02]
 [-2.52977580e-01  2.57163122e-03  1.14885233e-02  2.29033142e-01
   2.17481434e-01 -1.01586983e-01 -7.61057362e-02 -2.07156166e-02]
 [-2.10639536e-01  9.72980633e-05  2.04135142e-02  2.32257470e-01
   2.08822310e-01 -1.03895500e-01 -6.85415268e-02  4.99580950e-02]
 [-2.28587016e-01 -1.58852562e-02 -9.00641084e-04  2.69581854e-01
   1.85750589e-01 -8.09553564e-02 -7.40135238e-02  7.46334344e-03]
 [-2.39813611e-01 -1.24258548e-02  2.57445127e-03  2.42303193e-01
   1.94305018e-01 -1.27500877e-01 -6.96335807e-02  1.44356638e-02]
 [-2.29257271e-01 -4.02236693e-02 -5.09312041e-02  2.22777903e-01
   

  _warn_prf(average, modifier, msg_start, len(result))


accuracy: 0.25, precision: 0.10511363636363637, recall: 0.25, f1: 0.1474308300395257
[[-2.31532976e-01 -1.65545736e-02  9.61091146e-02  1.94729924e-01
   1.82518870e-01 -9.84241217e-02 -3.80140543e-03  5.18569276e-02]
 [-2.14018732e-01 -1.72333065e-02  1.89514533e-02  2.16189221e-01
   2.04298645e-01 -1.20588452e-01 -3.70064117e-02 -3.97562161e-02]
 [-1.85813889e-01 -8.97854194e-03  6.37502968e-02  2.38005847e-01
   1.89622015e-01 -1.27683118e-01  8.23188201e-03  3.13110277e-02]
 [-2.27006495e-01 -2.28872467e-02  3.40416469e-02  2.25916103e-01
   1.89020172e-01 -1.24110982e-01 -3.14549990e-02 -5.86730987e-02]
 [-1.85245559e-01 -2.48228759e-02  4.60463911e-02  2.30051666e-01
   1.81226388e-01 -1.22390762e-01 -1.97695419e-02  1.32618621e-02]
 [-2.01604992e-01 -3.96379307e-02  2.38833539e-02  2.68069237e-01
   1.54532224e-01 -1.01031899e-01 -2.81405151e-02 -2.83430666e-02]
 [-2.15465575e-01 -3.66927385e-02  2.51892731e-02  2.40684330e-01
   1.68280110e-01 -1.47367477e-01 -2.75212638e-02 -

  _warn_prf(average, modifier, msg_start, len(result))


[[-0.20046136 -0.04029752  0.1313908   0.20267813  0.15619501 -0.11811877
   0.03000014  0.01493212]
 [-0.18251672 -0.04361192  0.05582533  0.22258204  0.17696346 -0.14305362
  -0.00198835 -0.07937073]
 [-0.15225056 -0.03510235  0.0964971   0.24798292  0.16152078 -0.1500975
   0.04282644 -0.00524816]
 [-0.19121112 -0.0492795   0.06731699  0.2300027   0.15946193 -0.1483896
   0.0027568  -0.09751599]
 [-0.15131326 -0.05156069  0.08307822  0.23797709  0.15220688 -0.14373408
   0.01882353 -0.02257092]
 [-0.16666752 -0.06532631  0.05962511  0.27628398  0.12396482 -0.12361541
   0.00813302 -0.06591161]
 [-0.18524757 -0.06446107  0.05910485  0.24696317  0.14056435 -0.16925415
   0.00478362 -0.05455588]
 [-0.17208159 -0.09149607  0.00180708  0.22598074  0.18268257 -0.19129908
   0.06063133 -0.05730386]
 [-0.17635307 -0.05582854  0.01610931  0.24423918  0.1639766  -0.17201558
   0.04920106 -0.05826944]
 [-0.17670847 -0.0439418   0.11599606  0.2447948   0.1601345  -0.13862614
   0.01198915 -0.00

  _warn_prf(average, modifier, msg_start, len(result))


[[-0.17651245 -0.02979245  0.15363133  0.2041961   0.13191743 -0.13609225
   0.05031872 -0.01849182]
 [-0.15842372 -0.03278044  0.08109947  0.22422636  0.15048644 -0.16248298
   0.01988445 -0.11426286]
 [-0.12863062 -0.023828    0.1187459   0.25089753  0.13441366 -0.16936818
   0.06389654 -0.03885755]
 [-0.16459751 -0.03741337  0.0902143   0.22772923  0.13113579 -0.16915476
   0.0226874  -0.13238907]
 [-0.12611598 -0.03931374  0.10699861  0.2392734   0.12449516 -0.16272879
   0.04301801 -0.05399216]
 [-0.1403263  -0.05342101  0.08274913  0.27924025  0.09652937 -0.1430092
   0.03075302 -0.09909432]
 [-0.16203809 -0.05413011  0.08204479  0.24605002  0.11454409 -0.18778531
   0.02393368 -0.08438265]
 [-0.14721917 -0.08053236  0.0221744   0.2251707   0.1551857  -0.20992506
   0.08350503 -0.08646558]
 [-0.15131894 -0.04422419  0.03833017  0.24222468  0.13481127 -0.19188796
   0.0690625  -0.0896951 ]
 [-0.15238653 -0.03254415  0.13976136  0.24520917  0.1338825  -0.15642738
   0.03381352 -0.0

  _warn_prf(average, modifier, msg_start, len(result))


[[-0.14636126 -0.02451794  0.1820192   0.20935652  0.10503259 -0.1555377
   0.06474657 -0.05353335]
 [-0.1276086  -0.02652723  0.11278939  0.2293274   0.12260367 -0.18351157
   0.03580535 -0.15114643]
 [-0.09761156 -0.01750119  0.14661853  0.25791785  0.10610223 -0.18974587
   0.07891134 -0.07374579]
 [-0.13097078 -0.03035257  0.11881123  0.22909313  0.10186119 -0.18961173
   0.03579194 -0.16893305]
 [-0.0951632  -0.03245007  0.1363912   0.2444869   0.09542984 -0.18239933
   0.05981086 -0.08652678]
 [-0.10785613 -0.04708196  0.11089748  0.286432    0.06787013 -0.1627046
   0.04754421 -0.13269818]
 [-0.13275844 -0.04807575  0.11001791  0.24863647  0.08789141 -0.20705202
   0.03732369 -0.11703897]
 [-0.11689956 -0.07536575  0.04794879  0.22987105  0.1253646  -0.22986776
   0.10058065 -0.11827499]
 [-0.12057878 -0.03775891  0.06639767  0.24460804  0.10415402 -0.21192667
   0.08212736 -0.12284081]
 [-0.12095085 -0.02597717  0.16962269  0.24919137  0.10580911 -0.17620283
   0.04897485 -0.07

  _warn_prf(average, modifier, msg_start, len(result))


accuracy: 0.3181818181818182, precision: 0.1012396694214876, recall: 0.3181818181818182, f1: 0.1536050156739812
[[-0.12369311 -0.02449699  0.20097251  0.2087731   0.08273332 -0.17317736
   0.09098746 -0.08643039]
 [-0.10445555 -0.02665102  0.13494495  0.22888282  0.09848558 -0.20164326
   0.06399238 -0.18400332]
 [-0.07387732 -0.01686471  0.16598386  0.25961328  0.08174402 -0.20856819
   0.10666734 -0.10615346]
 [-0.10579285 -0.02947158  0.13840178  0.22560528  0.07668648 -0.20698619
   0.0616389  -0.20216222]
 [-0.07169984 -0.03105826  0.15609208  0.24439196  0.07063077 -0.20026971
   0.08788753 -0.11698622]
 [-0.08370902 -0.04632241  0.13029434  0.287736    0.04290816 -0.1800862
   0.07555447 -0.16332746]
 [-0.11112543 -0.04754581  0.1293132   0.24672224  0.0653792  -0.22458503
   0.06265545 -0.14676717]
 [-0.09428233 -0.07558517  0.0649724   0.2287769   0.10101777 -0.24791332
   0.12857358 -0.14785182]
 [-0.09800752 -0.03686709  0.08548685  0.24119847  0.07864328 -0.22956896
   0.10

  _warn_prf(average, modifier, msg_start, len(result))



[[-0.0976889  -0.02828862  0.22488043  0.21234567  0.05818113 -0.19220242
   0.11330293 -0.12060716]
 [-0.07776388 -0.03062592  0.1619253   0.23338287  0.07156535 -0.22083217
   0.08789791 -0.21790998]
 [-0.04532571 -0.01951305  0.19059075  0.26570478  0.05510781 -0.22864828
   0.13021503 -0.13981916]
 [-0.07636467 -0.03311164  0.16265419  0.22670722  0.04955607 -0.22520018
   0.082767   -0.23650484]
 [-0.04397523 -0.03392124  0.18070601  0.24773929  0.04405929 -0.21896781
   0.11004881 -0.14822291]
 [-0.05655201 -0.04903451  0.15411186  0.2928506   0.01602512 -0.198323
   0.09932584 -0.19444567]
 [-0.08620304 -0.05046946  0.153393    0.249705    0.04098173 -0.24233586
   0.0832704  -0.17763762]
 [-0.06869102 -0.07981449  0.08653191  0.23194396  0.07424198 -0.26648927
   0.15212798 -0.17786247]
 [-0.07173284 -0.04006358  0.10941169  0.24177086  0.05110766 -0.24721669
   0.1280851  -0.18459192]
 [-0.06937295 -0.02719409  0.21482497  0.25049806  0.05745631 -0.21421714
   0.09730962 -0.1

  _warn_prf(average, modifier, msg_start, len(result))


[[-0.07937971 -0.03585978  0.24318078  0.21220684  0.05697636 -0.2129086
   0.12992628 -0.15071318]
 [-0.05919598 -0.03855769  0.18248595  0.2327441   0.0689761  -0.24293953
   0.10556726 -0.2482579 ]
 [-0.02604573 -0.02573624  0.20854789  0.26733598  0.0540566  -0.2508483
   0.14766586 -0.17012909]
 [-0.05434408 -0.04049346  0.18137865  0.22494963  0.04662548 -0.24654049
   0.09898834 -0.2674438 ]
 [-0.02450818 -0.04095389  0.19866809  0.24746597  0.04153187 -0.24013743
   0.12633607 -0.17720436]
 [-0.03750018 -0.0554045   0.17246822  0.2932757   0.01277485 -0.2184652
   0.11638694 -0.2225906 ]
 [-0.06819884 -0.05687578  0.1714259   0.24855341  0.03871613 -0.2619958
   0.09833326 -0.20528896]
 [-0.05012924 -0.0871568   0.10303994  0.23089942  0.07047917 -0.28722566
   0.17031492 -0.20525196]
 [-0.05304911 -0.04690537  0.12746873  0.23912573  0.04694141 -0.266917
   0.14417912 -0.21304822]
 [-0.050087   -0.03402544  0.23367059  0.2493954   0.05624505 -0.2366397
   0.11357842 -0.1786428

  _warn_prf(average, modifier, msg_start, len(result))


accuracy: 0.3181818181818182, precision: 0.10359408033826639, recall: 0.3181818181818182, f1: 0.15629984051036683
[[-0.05751476 -0.04532795  0.26499915  0.21299967  0.05236182 -0.2324005
   0.14309268 -0.18154304]
 [-0.0366569  -0.04890661  0.20610315  0.23274744  0.06297305 -0.26481187
   0.12031061 -0.28000367]
 [-0.0036205  -0.03400836  0.22907472  0.27027404  0.05023941 -0.27252984
   0.16214645 -0.2011704 ]
 [-0.02846737 -0.05020882  0.20461673  0.22452947  0.0400985  -0.26801446
   0.11266124 -0.30000034]
 [-0.00180903 -0.05012894  0.22090712  0.24877357  0.03599462 -0.26085
   0.140511   -0.20704412]
 [-0.01486046 -0.06403038  0.19433194  0.29602638  0.00600228 -0.2383791
   0.1309667  -0.25114748]
 [-0.04697328 -0.06569986  0.19299412  0.24829166  0.03313358 -0.2808124
   0.11108732 -0.23409961]
 [-0.02790643 -0.09675836  0.12346562  0.23078582  0.06366593 -0.30744994
   0.1855643  -0.23374079]
 [-0.03081846 -0.05619026  0.14974004  0.23766202  0.03948877 -0.28612542
   0.15786

  _warn_prf(average, modifier, msg_start, len(result))


[[-0.03983711 -0.0554316   0.28236657  0.22505817  0.04664636 -0.2513674
   0.15158825 -0.20857914]
 [-0.01944908 -0.06015187  0.22432789  0.24629949  0.05638549 -0.285923
   0.13080005 -0.30813402]
 [ 0.0141869  -0.04281423  0.24493141  0.2845133   0.04557498 -0.29342896
   0.17239156 -0.22866993]
 [-0.00775667 -0.06040499  0.22282869  0.23723935  0.03267398 -0.28865936
   0.12246887 -0.32863593]
 [ 0.01638476 -0.05977219  0.23834527  0.26243767  0.02989054 -0.2810686
   0.1505996  -0.23272051]
 [ 0.00344197 -0.0730467   0.21147676  0.3101366  -0.0011419  -0.25748366
   0.14131136 -0.2757765 ]
 [-0.02964091 -0.0752011   0.20993611  0.26145852  0.02594045 -0.29916027
   0.11953492 -0.25955516]
 [-0.01091454 -0.1067988   0.13921012  0.24479797  0.05566243 -0.32662594
   0.19587864 -0.2580491 ]
 [-0.01309313 -0.06621451  0.16718961  0.24934748  0.03198012 -0.3050418
   0.16770591 -0.2687082 ]
 [-0.00734175 -0.05336475  0.27448913  0.2625884   0.04669322 -0.27869165
   0.13739806 -0.24220

  _warn_prf(average, modifier, msg_start, len(result))


[[-2.24239938e-02 -6.68533444e-02  3.03585321e-01  2.37191632e-01
   3.92363295e-02 -2.69889325e-01  1.58708394e-01 -2.36236423e-01]
 [-1.83088705e-03 -7.21538439e-02  2.46399164e-01  2.59348482e-01
   4.85265292e-02 -3.06933820e-01  1.39342025e-01 -3.36851686e-01]
 [ 3.27942260e-02 -5.29932044e-02  2.64864773e-01  2.98125774e-01
   3.95940244e-02 -3.13463718e-01  1.81066141e-01 -2.56140292e-01]
 [ 1.37731116e-02 -7.21241310e-02  2.45215774e-01  2.49704435e-01
   2.40403935e-02 -3.08649004e-01  1.30921096e-01 -3.57323289e-01]
 [ 3.53209786e-02 -7.08807632e-02  2.59736121e-01  2.75716037e-01
   2.18361598e-02 -3.00567687e-01  1.58820301e-01 -2.58275449e-01]
 [ 2.26100069e-02 -8.26877207e-02  2.32007548e-01  3.23784471e-01
  -1.01894774e-02 -2.76681602e-01  1.49731591e-01 -3.01404595e-01]
 [-1.17147323e-02 -8.57637376e-02  2.30763555e-01  2.74026632e-01
   1.70125403e-02 -3.16648483e-01  1.26388505e-01 -2.85299599e-01]
 [ 7.50590861e-03 -1.17910877e-01  1.58916309e-01  2.59046674e-01
   

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.0037211  -0.07757786  0.31798574  0.24541089  0.03016576 -0.2889893
   0.16341618 -0.26020002]
 [ 0.02276257 -0.08350785  0.26206547  0.26839542  0.03990006 -0.3269115
   0.14518076 -0.36169153]
 [ 0.05945155 -0.06280505  0.27857035  0.30750218  0.03240065 -0.33292657
   0.18681604 -0.2800941 ]
 [ 0.04225165 -0.08347282  0.2602108   0.2581114   0.01482194 -0.3272885
   0.13674484 -0.38194948]
 [ 0.06232792 -0.08098568  0.27457297  0.28501672  0.01284668 -0.31946346
   0.16407238 -0.28014976]
 [ 0.04917042 -0.09191433  0.24639724  0.33382994 -0.01953563 -0.2956125
   0.15564868 -0.32342798]
 [ 0.0142736  -0.09540233  0.24543107  0.28264737  0.007096   -0.3343336
   0.13079818 -0.3077746 ]
 [ 0.0331294  -0.12826109  0.17243686  0.26931986  0.03648135 -0.36265892
   0.2104986  -0.3040551 ]
 [ 0.03094284 -0.08881564  0.20224273  0.2679977   0.01379902 -0.34196177
   0.18083902 -0.31776267]
 [ 0.04058674 -0.07541607  0.31153452  0.28258142  0.03298427 -0.31801233
   0.15188943 -0.29768

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.03004782 -0.08859837  0.3347422   0.2521309   0.01971946 -0.30674428
   0.1663629  -0.28399658]
 [ 0.04814867 -0.09577615  0.28056037  0.2753587   0.02990964 -0.34547085
   0.14939988 -0.38585794]
 [ 0.08636442 -0.07343441  0.29492247  0.31495315  0.02388813 -0.35128266
   0.19108376 -0.30356538]
 [ 0.07081744 -0.09539009  0.27830836  0.26457247  0.00448884 -0.3450395
   0.14042647 -0.40597704]
 [ 0.08969261 -0.09189506  0.29221544  0.29161462  0.00192129 -0.33751175
   0.16786095 -0.30183625]
 [ 0.07697675 -0.10193251  0.26372543  0.34176108 -0.03020959 -0.31350902
   0.15987591 -0.34463844]
 [ 0.04043044 -0.10615558  0.26369047  0.28940737 -0.00395952 -0.35042548
   0.13398352 -0.33029464]
 [ 0.06005338 -0.13915327  0.18899176  0.27741405  0.02583635 -0.37956268
   0.21491848 -0.32556838]
 [ 0.05712185 -0.10037515  0.21928203  0.27283484  0.00324626 -0.35884616
   0.18425086 -0.3403083 ]
 [ 0.06909467 -0.08687287  0.32977176  0.28838342  0.02442422 -0.3363965
   0.15617396 -0.32

  _warn_prf(average, modifier, msg_start, len(result))


[[ 5.20535670e-02 -9.93702114e-02  3.47956300e-01  2.67195612e-01
   9.57131386e-03 -3.24277461e-01  1.66540995e-01 -3.04242581e-01]
 [ 6.98672011e-02 -1.07372344e-01  2.94862866e-01  2.91760266e-01
   1.99082438e-02 -3.64357769e-01  1.51210576e-01 -4.06230003e-01]
 [ 1.08985946e-01 -8.38700533e-02  3.07773530e-01  3.30933452e-01
   1.50514189e-02 -3.69976401e-01  1.92898497e-01 -3.23287517e-01]
 [ 9.52155814e-02 -1.06712729e-01  2.92375982e-01  2.80757010e-01
  -5.43617271e-03 -3.62930745e-01  1.42012805e-01 -4.26475793e-01]
 [ 1.12215891e-01 -1.02489367e-01  3.05929124e-01  3.07483673e-01
  -8.62699375e-03 -3.55920374e-01  1.69035167e-01 -3.20876807e-01]
 [ 1.00391030e-01 -1.11893743e-01  2.77828217e-01  3.58826995e-01
  -4.04616185e-02 -3.32303464e-01  1.61761433e-01 -3.62676024e-01]
 [ 6.23559095e-02 -1.16646767e-01  2.78164536e-01  3.05037856e-01
  -1.49229802e-02 -3.66966993e-01  1.35009080e-01 -3.49806130e-01]
 [ 8.30643624e-02 -1.49696708e-01  2.01752320e-01  2.93176919e-01
   

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.07499629 -0.11166995  0.36409518  0.28153592 -0.00221539 -0.33997434
   0.1663169  -0.3244757 ]
 [ 0.09322125 -0.12070661  0.31158453  0.30701923  0.00882862 -0.38154066
   0.1519279  -0.42663854]
 [ 0.1325534  -0.09583221  0.32335347  0.34649727  0.00495267 -0.3874631
   0.19432226 -0.34285575]
 [ 0.12074802 -0.11968702  0.30928326  0.29666996 -0.01724277 -0.37977946
   0.14288364 -0.44666895]
 [ 0.13569875 -0.11531164  0.32235876  0.32299334 -0.02068552 -0.37289158
   0.16974598 -0.33974922]
 [ 0.12516487 -0.12388715  0.2946959   0.37479    -0.05231774 -0.3496982
   0.1630181  -0.38113514]
 [ 0.0853347  -0.12922643  0.29525393  0.31999278 -0.02693705 -0.38164866
   0.13532136 -0.3698151 ]
 [ 0.10718121 -0.16206363  0.21696466  0.30757397  0.0038668  -0.4113068
   0.21839413 -0.3628202 ]
 [ 0.10145164 -0.12505296  0.24813867  0.29879394 -0.01910781 -0.39150238
   0.1860824  -0.37799692]
 [ 0.11832237 -0.11043528  0.3604321   0.31860965  0.00571061 -0.372428
   0.15852925 -0.36529

  _warn_prf(average, modifier, msg_start, len(result))


[[ 9.44075659e-02 -1.11148663e-01  3.76468450e-01  2.92311519e-01
  -1.39776431e-02 -3.54752421e-01  1.63927138e-01 -3.43730599e-01]
 [ 1.12845987e-01 -1.20534413e-01  3.24195713e-01  3.18637729e-01
  -2.89307535e-03 -3.97455454e-01  1.49773225e-01 -4.46047217e-01]
 [ 1.52759641e-01 -9.49786156e-02  3.35168540e-01  3.58135134e-01
  -5.93894720e-03 -4.03750420e-01  1.92901030e-01 -3.61991227e-01]
 [ 1.42264530e-01 -1.19356133e-01  3.22095454e-01  3.08485806e-01
  -2.97156069e-02 -3.95547152e-01  1.40913919e-01 -4.65903312e-01]
 [ 1.55937850e-01 -1.14277877e-01  3.34805846e-01  3.34149718e-01
  -3.30209322e-02 -3.88836086e-01  1.68015972e-01 -3.58213454e-01]
 [ 1.46110862e-01 -1.22849181e-01  3.07590604e-01  3.86740863e-01
  -6.41849935e-02 -3.66414845e-01  1.61570102e-01 -3.98901165e-01]
 [ 1.04778908e-01 -1.27857879e-01  3.08316827e-01  3.30986619e-01
  -3.91276106e-02 -3.95816803e-01  1.32788762e-01 -3.89166355e-01]
 [ 1.27830684e-01 -1.61881566e-01  2.29058310e-01  3.18735272e-01
  -

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.1142029  -0.112532    0.39045805  0.30103123 -0.02610958 -0.3684249
   0.1615665  -0.36219805]
 [ 0.13340147 -0.12235773  0.33880803  0.32808405 -0.01503944 -0.41244173
   0.14754707 -0.46524456]
 [ 0.17401347 -0.09584987  0.3485597   0.36797583 -0.01751913 -0.41896325
   0.19148831 -0.38095838]
 [ 0.16469873 -0.12090849  0.33687687  0.31780043 -0.04225421 -0.4104855
   0.13862967 -0.4845584 ]
 [ 0.1771927  -0.11487778  0.3491007   0.34286553 -0.0461405  -0.40408057
   0.1667565  -0.37624165]
 [ 0.16784576 -0.12360421  0.32230547  0.39682454 -0.07677814 -0.38221115
   0.16012236 -0.41657916]
 [ 0.12506819 -0.12814683  0.32326296  0.34017915 -0.05135081 -0.40963092
   0.12970999 -0.4081995 ]
 [ 0.14912196 -0.1634126   0.24306735  0.32796538 -0.02005508 -0.43911114
   0.21650161 -0.39805472]
 [ 0.14169154 -0.12716815  0.27398396  0.31511712 -0.04434883 -0.41995168
   0.1824665  -0.41334182]
 [ 0.16274169 -0.10965482  0.38759506  0.33897102 -0.01673116 -0.40464813
   0.15480304 -0.40

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.12989926 -0.11455944  0.40196976  0.31764126 -0.03735325 -0.38203472
   0.15859608 -0.37830454]
 [ 0.15044    -0.1254076   0.35055935  0.34646955 -0.02684665 -0.4274203
   0.14476955 -0.4817218 ]
 [ 0.1912888  -0.09791174  0.35931075  0.38593188 -0.02837535 -0.4341695
   0.18900844 -0.3976578 ]
 [ 0.1827335  -0.12349804  0.3483393   0.3355044  -0.0542955  -0.425529
   0.13588084 -0.5006645 ]
 [ 0.1941638  -0.11661668  0.36079612  0.3594078  -0.05859518 -0.4192931
   0.16480716 -0.39188427]
 [ 0.18549168 -0.12562165  0.33431065  0.41486728 -0.0889423  -0.39811474
   0.15801138 -0.4318853 ]
 [ 0.14156328 -0.1295017   0.3354383   0.35637835 -0.06274357 -0.42341405
   0.12593858 -0.4246909 ]
 [ 0.16627294 -0.16589214  0.25448087  0.34387875 -0.03196765 -0.45206156
   0.2143681  -0.41321218]
 [ 0.15848444 -0.13000211  0.2850898   0.3302502  -0.05679934 -0.43318287
   0.17996529 -0.42882237]
 [ 0.1806633  -0.11148345  0.39908382  0.35647658 -0.02762647 -0.41963983
   0.15194559 -0.42517

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.14586958 -0.11805348  0.41531557  0.33215988 -0.04944798 -0.3937784
   0.15523098 -0.39451975]
 [ 0.16781434 -0.12944607  0.36451313  0.3628558  -0.03980876 -0.44073197
   0.14169772 -0.4985282 ]
 [ 0.20893814 -0.10131364  0.3723322   0.40230054 -0.04038761 -0.447434
   0.18651223 -0.4144318 ]
 [ 0.20096299 -0.12713276  0.36223087  0.35139918 -0.06756404 -0.43899328
   0.13309921 -0.51690197]
 [ 0.2115447  -0.12027229  0.37479135  0.37408718 -0.07272825 -0.43274385
   0.16271992 -0.40814894]
 [ 0.20356265 -0.1289644   0.3483945   0.43110538 -0.10266415 -0.4125116
   0.15554094 -0.4477072 ]
 [ 0.15790936 -0.13240254  0.34933704  0.37054113 -0.0753291  -0.4354167
   0.12229703 -0.4411427 ]
 [ 0.18331987 -0.16971198  0.26792687  0.35780764 -0.04520351 -0.46354604
   0.21201852 -0.42888987]
 [ 0.1756982  -0.13423625  0.29835868  0.34386593 -0.07055531 -0.4450639
   0.17718808 -0.444214  ]
 [ 0.19874531 -0.11454386  0.41318285  0.37176555 -0.03999527 -0.43275842
   0.14879756 -0.443806

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.1571669  -0.12309908  0.42519957  0.34303176 -0.05071744 -0.40513095
   0.15102682 -0.41021633]
 [ 0.18024707 -0.13489366  0.37482888  0.3753653  -0.04075402 -0.45375103
   0.13716446 -0.5155712 ]
 [ 0.22168106 -0.10642496  0.38160613  0.41458827 -0.04082709 -0.4599381
   0.18266498 -0.43138438]
 [ 0.21491241 -0.1326639   0.37245157  0.36333957 -0.06964923 -0.45210433
   0.12907058 -0.5328825 ]
 [ 0.2244261  -0.12520777  0.38534307  0.38527107 -0.07490876 -0.44551054
   0.15927437 -0.42454243]
 [ 0.21700561 -0.13401623  0.35895875  0.4434224  -0.10495431 -0.42624328
   0.15133038 -0.46349517]
 [ 0.16991447 -0.13652405  0.35951456  0.38097388 -0.0769722  -0.44702628
   0.117555   -0.45742658]
 [ 0.19621195 -0.17497252  0.27804583  0.36842638 -0.04783891 -0.4751957
   0.20885095 -0.4442712 ]
 [ 0.18914984 -0.13980253  0.30793783  0.35416615 -0.07376457 -0.45663518
   0.17363554 -0.45948768]
 [ 0.21207386 -0.11897133  0.42344922  0.38337767 -0.04154857 -0.44527516
   0.14437231 -0.46

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.1701029  -0.12869859  0.43581378  0.35237604 -0.05302569 -0.4149249
   0.14788772 -0.42574978]
 [ 0.19441018 -0.14089799  0.3857175   0.38613573 -0.04308029 -0.46517956
   0.13337392 -0.5323505 ]
 [ 0.23581836 -0.11218637  0.39112747  0.42505288 -0.04212994 -0.47075206
   0.17968155 -0.4478426 ]
 [ 0.23058856 -0.13873206  0.38316968  0.373416   -0.07268393 -0.46370676
   0.1253979  -0.54840755]
 [ 0.23895524 -0.13042323  0.39640582  0.3945656  -0.07841083 -0.4567799
   0.15680146 -0.44059208]
 [ 0.23197995 -0.13950586  0.37011623  0.45430902 -0.10868491 -0.4384696
   0.14788523 -0.47885942]
 [ 0.18331414 -0.14103708  0.37001136  0.3895989  -0.07966995 -0.45715445
   0.1137986  -0.47351056]
 [ 0.21049464 -0.18071051  0.2884139   0.37739837 -0.05168004 -0.48535728
   0.20643686 -0.45967025]
 [ 0.204211   -0.14582975  0.31778294  0.3626092  -0.07784801 -0.4667578
   0.17072462 -0.47451532]
 [ 0.22711477 -0.12419406  0.4341843   0.39301088 -0.04456954 -0.45620525
   0.14079604 -0.4784

  _warn_prf(average, modifier, msg_start, len(result))


[[ 0.17939964 -0.13435483  0.4521311   0.3596064  -0.05612822 -0.42465118
   0.14289804 -0.44083667]
 [ 0.20481053 -0.14675087  0.40184847  0.39417794 -0.04631893 -0.47638908
   0.12822188 -0.5487033 ]
 [ 0.24640788 -0.1175879   0.4061222   0.43298316 -0.04446376 -0.48119855
   0.17521471 -0.46378523]
 [ 0.24222031 -0.14464006  0.39957845  0.38096362 -0.07662984 -0.47493696
   0.12031317 -0.56339204]
 [ 0.24956134 -0.13554943  0.4124986   0.40156    -0.08264989 -0.46786514
   0.1529988  -0.45639044]
 [ 0.2431571  -0.14472097  0.38646173  0.46248955 -0.11310992 -0.45016676
   0.14293161 -0.4937638 ]
 [ 0.19282153 -0.1453598   0.3855518   0.39604542 -0.08310805 -0.46741825
   0.10908917 -0.4890769 ]
 [ 0.22098225 -0.18610261  0.3036906   0.38421708 -0.05627968 -0.49522737
   0.20266104 -0.47449052]
 [ 0.21543375 -0.1513539   0.3329238   0.3686035  -0.08258102 -0.4769563
   0.16659398 -0.48901173]
 [ 0.2382538  -0.12944642  0.4507034   0.39997303 -0.04865025 -0.4667049
   0.1356172  -0.49

  _warn_prf(average, modifier, msg_start, len(result))


In [115]:
trainer.train??

In [45]:
test_data = pd.read_csv("focus_groups_convos_emotion_analysis_test.csv")

In [46]:
test_data["emotion"].value_counts()

anxious     4
serenity    3
disgust     1
joy         1
Name: emotion, dtype: int64

In [47]:
emotions = list(test_data['emotion'].unique())
#TODO: new labels in test data handling
test_data['emotion_label_index'] = test_data.apply(lambda x : label_to_index[x['emotion']], axis=1)

In [48]:
columns=['parent_answer']
test_data =  data_preprocessing(test_data, columns)

In [49]:
X_test = list(test_data["parent_answer"])
y_test = list(test_data["emotion_label_index"])

In [34]:
X_test

['yeah similar experience eightyearold prepandemic still second grade use computer limited maybe week would type word week one option picked interaction computer course device play would go father house every weekend would exposed computer still busy busy week school getting school getting school thats trip',
 'hockey taekwondo busy even didnt seem like enough didnt seem like enough certainly lot busier get play mostly probably weekend weekend also pretty packed dont know time play lot course able control item earn didnt well dont get play tool could use incentive week thats window available right day long getting away computer project every day',
 'like bomb went figured figured workarounds like pandora box trying figure block phone doe phone parent coordinator calling principal like block youtube without putting timer thing driving crazy downhill finally figured son banging much he actually broken screen whatever control put could bang would whatever wanted month later exchange ipad 

In [206]:
y_test

[0, 2, 1, 0, 2, 2, 1, 0, 1, 2, 2, 1, 0, 3, 1, 0, 1, 1, 2, 2, 3]

In [50]:
X_test_tokenized = tokenizer(X_test, padding=True, truncation=True, max_length=80)

In [161]:
X_test_tokenized

{'input_ids': [[101, 3398, 1010, 2714, 3325, 2809, 1011, 2095, 1011, 2214, 1012, 3653, 1011, 6090, 3207, 7712, 2145, 2117, 3694, 2224, 3274, 3132, 1012, 2672, 2733, 2052, 2828, 1010, 2773, 2733, 2028, 5724, 3856, 2079, 1012, 8290, 3274, 1012, 2607, 5080, 2377, 2006, 1010, 2052, 2175, 2269, 1005, 1055, 2160, 2296, 5353, 2052, 6086, 7588, 1012, 2145, 5697, 1012, 5697, 2733, 1010, 2082, 1010, 2893, 2082, 1010, 2893, 2082, 1010, 2008, 1005, 1055, 4440, 2993, 1012, 102, 0, 0, 0, 0], [101, 2066, 5968, 2253, 2125, 1010, 6618, 2041, 1010, 6618, 2147, 24490, 2015, 2066, 19066, 1005, 1055, 3482, 1012, 2667, 3275, 3796, 1012, 1012, 1012, 3042, 18629, 1010, 3042, 6687, 10669, 1010, 4214, 4054, 1010, 2066, 1010, 1000, 2129, 3796, 7858, 2302, 1012, 1012, 1012, 1000, 5128, 25309, 2477, 1010, 4439, 4689, 1010, 19448, 2045, 1012, 2633, 6618, 2365, 1010, 22255, 2172, 2002, 1005, 1055, 2941, 3714, 3898, 3649, 2491, 2404, 1999, 1010, 2071, 9748, 2052, 3649, 2359, 1012, 102], [101, 2036, 6134, 1012, 2569, 

In [51]:
test_dataset = Dataset(X_test_tokenized, y_test)

In [209]:
test_dataset[0]

{'input_ids': tensor([ 101, 3398, 1010, 2714, 3325, 2809, 1011, 2095, 1011, 2214, 1012, 3653,
         1011, 6090, 3207, 7712, 2145, 2117, 3694, 2224, 3274, 3132, 1012, 2672,
         2733, 2052, 2828, 1010, 2773, 2733, 2028, 5724, 3856, 2079, 1012, 8290,
         3274, 1012, 2607, 5080, 2377, 2006, 1010, 2052, 2175, 2269, 1005, 1055,
         2160, 2296, 5353, 2052, 6086, 7588, 1012, 2145, 5697, 1012, 5697, 2733,
         1010, 2082, 1010, 2893, 2082, 1010, 2893, 2082, 1010, 2008, 1005, 1055,
         4440, 2993, 1012,  102,    0,    0,    0,    0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 0, 0, 0, 0]),
 'labels': tensor(0)}

In [52]:
test_trainer = Trainer(model)

In [132]:
test_trainer.predict??

In [53]:
raw_pred, label_ids, metrics = test_trainer.predict(test_dataset)

In [39]:
raw_pred

array([[ 0.06766115, -0.3253832 , -0.06493668,  0.35471714, -0.15618443,
        -0.11512066, -0.16246825, -0.06664449],
       [ 0.0916235 , -0.35716227, -0.06221877,  0.41920525, -0.1804009 ,
        -0.17619711, -0.06709153,  0.04494047],
       [ 0.04280779, -0.37059927, -0.04707661,  0.34497947, -0.17658278,
        -0.14099593, -0.0756123 ,  0.05449395],
       [-0.1484484 , -0.6329201 , -0.34090024,  0.02485409,  0.0903853 ,
        -0.28749904, -0.05994736,  0.20962328],
       [ 0.03568046, -0.19395709,  0.02585419,  0.43622717, -0.05606452,
        -0.24524978, -0.17073739,  0.12783825],
       [ 0.00079341, -0.51856947, -0.12783314,  0.2570194 , -0.02892824,
        -0.38981146,  0.03351265,  0.04641962],
       [-0.07587789, -0.75548315, -0.28536007,  0.22248712,  0.08766535,
        -0.57228875,  0.26338446,  0.34008944],
       [-0.01004304, -0.46937793, -0.19732882,  0.2263671 , -0.07728738,
        -0.17946137, -0.05250129,  0.01034988],
       [-0.053477  , -0.5714641 

In [54]:
y_pred = np.argmax(raw_pred, axis=1)

In [55]:
y_pred = list(y_pred)
y_pred

[3, 3, 3, 3, 3, 0, 3, 3, 0]

In [56]:
label_ids = list(label_ids)
label_ids

[3, 0, 2, 1, 3, 0, 0, 3, 3]

In [57]:
pred = [ index_to_label[p] for p in y_pred]
print(pred)
true = [ index_to_label[label_id] for label_id in label_ids]
print(true)
report = classification_report(y_true=true, y_pred=pred)
print(report)

['anxious', 'anxious', 'anxious', 'anxious', 'anxious', 'serenity', 'anxious', 'anxious', 'serenity']
['anxious', 'serenity', 'disgust', 'joy', 'anxious', 'serenity', 'serenity', 'anxious', 'anxious']
              precision    recall  f1-score   support

     anxious       0.43      0.75      0.55         4
     disgust       0.00      0.00      0.00         1
         joy       0.00      0.00      0.00         1
    serenity       0.50      0.33      0.40         3

    accuracy                           0.44         9
   macro avg       0.23      0.27      0.24         9
weighted avg       0.36      0.44      0.38         9



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [59]:
## saving model once satisfactory scores are obtained
model.save_pretrained('emotion_analysis_output')
### NOTE: unable to upload to GIT the pytorch_model.bin serialized file as it is 427 MB

In [73]:
# loading model fine-tuned for emotion analysis
finetuned_model = BertForSequenceClassification.from_pretrained(r'C:\Users\vhpld\Desktop\emotion_analysis_output')

In [74]:
#testing the results match with the above model
test_trainer1 = Trainer(finetuned_model)

In [75]:
raw_pred, label_ids, metrics = test_trainer1.predict(test_dataset)

In [76]:
y_pred1 = list(np.argmax(raw_pred, axis=1))

In [77]:
label_ids1 = list(label_ids)

In [78]:
pred=['A', 'B']
true=['A', 'B']
pred = [ index_to_label[p] for p in y_pred1]
print(pred)
true = [ index_to_label[label_id] for label_id in label_ids1]
print(true)
report = classification_report(y_true=true, y_pred=pred)
print(report)

['anxious', 'anxious', 'anxious', 'anxious', 'anxious', 'serenity', 'anxious', 'anxious', 'serenity']
['anxious', 'serenity', 'disgust', 'joy', 'anxious', 'serenity', 'serenity', 'anxious', 'anxious']
              precision    recall  f1-score   support

     anxious       0.43      0.75      0.55         4
     disgust       0.00      0.00      0.00         1
         joy       0.00      0.00      0.00         1
    serenity       0.50      0.33      0.40         3

    accuracy                           0.44         9
   macro avg       0.23      0.27      0.24         9
weighted avg       0.36      0.44      0.38         9



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Conclusion -
##### With limited data, hyerparameter tuning with 8 classes was not leading to good scores and training loss was fluctuating
##### Labels were reduced to 4 from 8
##### 4 labels were chosed in a such a way so as to capture varying degrees of positive (serenity, joy) and negative emotions (anxious, disgust)
##### Reduction in labels finally led to the training loss fluctuation pattern to disappear and more labels were detected in the output as opposed to only single label being detected when training loss was fluctuating and not declining even after running too many epochs (assuming the batch size was causing the training loss to appear to not converge)

#### Next step-
##### Annotating the lowPIU and media groups' parent level responses with emotions
##### Improvement observed on very small test dataset size and this needs to be increased