<a href="https://colab.research.google.com/github/BBB-WU/NLP/blob/BBB-WU-patch-1/NLP_with_BERT_multi_class_text__classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [283]:
pip install transformers



In [284]:
import torch
from tqdm.notebook import tqdm

from transformers import BertTokenizer
from torch.utils.data import TensorDataset

from transformers import BertForSequenceClassification

In [285]:
import pandas as pd

In [286]:
import numpy as np


In [287]:
df = pd.read_csv('/content/train.tsv',sep='\t',header=None)

dt = pd.read_csv('/content/valid.tsv',sep='\t',header=None)

dtt = pd.read_csv('/content/test.tsv',sep='\t',header=None)

In [288]:
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13
0,2635.json,false,Says the Annies List political group supports ...,abortion,dwayne-bohac,State representative,Texas,republican,0.0,1.0,0.0,0.0,0.0,a mailer
1,10540.json,half-true,When did the decline of coal start? It started...,"energy,history,job-accomplishments",scott-surovell,State delegate,Virginia,democrat,0.0,0.0,1.0,1.0,0.0,a floor speech.
2,324.json,mostly-true,"Hillary Clinton agrees with John McCain ""by vo...",foreign-policy,barack-obama,President,Illinois,democrat,70.0,71.0,160.0,163.0,9.0,Denver
3,1123.json,false,Health care reform legislation is likely to ma...,health-care,blog-posting,,,none,7.0,19.0,3.0,5.0,44.0,a news release
4,9028.json,half-true,The economic turnaround started at the end of ...,"economy,jobs",charlie-crist,,Florida,democrat,15.0,9.0,20.0,19.0,2.0,an interview on CNN


In [289]:
#去掉原来资料库的其他资料，留下label和text
df=df[[1,2]]
df.columns=['Tvalue','Title']

dt=dt[[1,2]]
dt.columns=['Tvalue','Title']

dtt=dtt[[1,2]]
dtt.columns=['Tvalue','Title']

In [290]:
#将六类label分成两类
label_dict = {'barely-true': 0,
 'false': 0,
 'half-true': 0,
 'mostly-true': 1,
 'pants-fire': 0,
 'true': 1}

In [291]:
df['label'] = df.Tvalue.replace(label_dict)
dt['label'] = dt.Tvalue.replace(label_dict)
dtt['label'] = dtt.Tvalue.replace(label_dict)

In [292]:
dt['label'].value_counts()

0    864
1    420
Name: label, dtype: int64

In [293]:
#为了让数据比较平均，选择去掉了两类。

df=df[ ~ df['Tvalue'].str.contains('half-true') ]
df=df[ ~ df['Tvalue'].str.contains('barely-true') ]

dt=dt[ ~ dt['Tvalue'].str.contains('half-true') ]
dt=dt[ ~ dt['Tvalue'].str.contains('barely-true') ]

dtt=dtt[ ~ dtt['Tvalue'].str.contains('barely-true') ]
dtt=dtt[ ~ dtt['Tvalue'].str.contains('barely-true') ]

In [294]:
df.head(10)

Unnamed: 0,Tvalue,Title,label
0,false,Says the Annies List political group supports ...,0
2,mostly-true,"Hillary Clinton agrees with John McCain ""by vo...",1
3,false,Health care reform legislation is likely to ma...,0
5,true,The Chicago Bears have had more starting quart...,1
9,mostly-true,Says GOP primary opponents Glenn Grothman and ...,1
10,mostly-true,"For the first time in history, the share of th...",1
12,false,When Mitt Romney was governor of Massachusetts...,0
13,mostly-true,The economy bled $24 billion due to the govern...,1
16,true,McCain opposed a requirement that the governme...,1
19,mostly-true,"Almost 100,000 people left Puerto Rico last year.",1


In [295]:
df['label'].value_counts()

1    3638
0    2834
Name: label, dtype: int64

In [296]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', 
                                          do_lower_case=True)

In [297]:
#利用tokenizer，把句子变成Bert能处理的token形式。

encoded_data_train = tokenizer.batch_encode_plus(
    df.Title.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=32, 
    return_tensors='pt'
)

encoded_data_val = tokenizer.batch_encode_plus(
    dt.Title.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=32, 
    return_tensors='pt'
)

encoded_data_test = tokenizer.batch_encode_plus(
    dt.Title.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=32, 
    return_tensors='pt'
)

input_ids_train = encoded_data_train['input_ids']
attention_masks_train = encoded_data_train['attention_mask']
labels_train = torch.tensor(df.label.values)

input_ids_val = encoded_data_val['input_ids']
attention_masks_val = encoded_data_val['attention_mask']
labels_val = torch.tensor(dt.label.values)

input_ids_test = encoded_data_test['input_ids']
attention_masks_test = encoded_data_test['attention_mask']
labels_test = torch.tensor(dt.label.values)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [298]:
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)

dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

dataset_test = TensorDataset(input_ids_test, attention_masks_test, labels_test)

In [299]:
len(dataset_train), len(dataset_val), len(dataset_test)

(6472, 799, 799)

In [300]:
#应用已经写好的model做二元分类。这个model是在bert-base上在加一个nn.linear做分类。

model = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                      num_labels=2,
                                                      output_attentions=False,
                                                      output_hidden_states=False)

print("""
name            module
----------------------""")
for name, module in model.named_children():
    if name == "bert":
        for n, _ in module.named_children():
            print(f"{name}:{n}")
    else:
        print("{:15} {}".format(name, module))


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at


name            module
----------------------
bert:embeddings
bert:encoder
bert:pooler
dropout         Dropout(p=0.1, inplace=False)
classifier      Linear(in_features=768, out_features=2, bias=True)


In [301]:
#把数据分成一个个batch组成的dataloader。

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

batch_size = 32

dataloader_train = DataLoader(dataset_train, 
                              sampler=RandomSampler(dataset_train), 
                              batch_size=batch_size)

dataloader_validation = DataLoader(dataset_val, 
                                   sampler=SequentialSampler(dataset_val), 
                                   batch_size=batch_size)

dataloader_test = DataLoader(dataset_test, 
                                   sampler=SequentialSampler(dataset_test), 
                                   batch_size=batch_size)

In [302]:
#optimizer应用AdamW，还有一个让learning rate逐渐变小的scheduler。
from transformers import AdamW, get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(),
                  lr=1e-6, 
                  eps=1e-8)

In [303]:
epochs = 120

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0,
                                            num_training_steps=len(dataloader_train)*epochs)


In [304]:
#加入f1_score来做model预测效果评估。

from sklearn.metrics import f1_score

def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='weighted')

def accuracy_per_class(preds, labels):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

In [305]:
import random

seed_val = 17
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [306]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(device)

cuda


In [307]:
#用来评估test的结果，会记录下每次预测的结果和正确率。

def evaluate(dataloader_val):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_val:
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
            
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals

In [308]:
#这是model train 的部分，每个epoch后，都会把loss和f1_score显示。
#然后，每个epoch， train 完的model都会存下来。在test测试哪个更好。
for epoch in tqdm(range(1, epochs+1)):
    
    model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for batch in progress_bar:

        model.zero_grad()
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }       

        outputs = model(**inputs)
        
        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
         
    if epoch%5==0:
      torch.save(model.state_dict(), f'data_volume/finetuned_BERT_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')
    
    val_loss, predictions, true_vals = evaluate(dataloader_validation)
    val_f1 = f1_score_func(predictions, true_vals)
    tqdm.write(f'Validation loss: {val_loss}')
    tqdm.write(f'F1 Score (Weighted): {val_f1}')

HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=203.0, style=ProgressStyle(description_widt…


Epoch 1
Training loss: 0.6821269666032838
Validation loss: 0.6751371955871582
F1 Score (Weighted): 0.5185427755087415


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=203.0, style=ProgressStyle(description_widt…


Epoch 2
Training loss: 0.6620779841991481
Validation loss: 0.6589302968978882
F1 Score (Weighted): 0.6044375484114513


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=203.0, style=ProgressStyle(description_widt…


Epoch 3
Training loss: 0.6520691402440001
Validation loss: 0.6498772740364075
F1 Score (Weighted): 0.6179404968181841


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=203.0, style=ProgressStyle(description_widt…


Epoch 4
Training loss: 0.6435911015336737
Validation loss: 0.6465421938896179
F1 Score (Weighted): 0.620538116084019


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=203.0, style=ProgressStyle(description_widt…


Epoch 5
Training loss: 0.6357499866062785
Validation loss: 0.6457441234588623
F1 Score (Weighted): 0.6224308135900157


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=203.0, style=ProgressStyle(description_widt…


Epoch 6
Training loss: 0.6286133134306358
Validation loss: 0.6364796257019043
F1 Score (Weighted): 0.6443485129160089


HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=203.0, style=ProgressStyle(description_widt…


Epoch 7
Training loss: 0.6196202123400025
Validation loss: 0.6324529886245728
F1 Score (Weighted): 0.6526940585337989


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=203.0, style=ProgressStyle(description_widt…


Epoch 8
Training loss: 0.615317724168007
Validation loss: 0.6296440243721009
F1 Score (Weighted): 0.6520946362172477


HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=203.0, style=ProgressStyle(description_widt…


Epoch 9
Training loss: 0.6100422177702335
Validation loss: 0.6357916569709778
F1 Score (Weighted): 0.6526117405848917


HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=203.0, style=ProgressStyle(description_wid…


Epoch 10
Training loss: 0.5992758592654919
Validation loss: 0.6340423285961151
F1 Score (Weighted): 0.653938828244182


HBox(children=(FloatProgress(value=0.0, description='Epoch 11', max=203.0, style=ProgressStyle(description_wid…


Epoch 11
Training loss: 0.5938559898308345
Validation loss: 0.6354938960075378
F1 Score (Weighted): 0.6582712322186985


HBox(children=(FloatProgress(value=0.0, description='Epoch 12', max=203.0, style=ProgressStyle(description_wid…


Epoch 12
Training loss: 0.5845785283396396
Validation loss: 0.6288318061828613
F1 Score (Weighted): 0.6492808524024014


HBox(children=(FloatProgress(value=0.0, description='Epoch 13', max=203.0, style=ProgressStyle(description_wid…


Epoch 13
Training loss: 0.5743314910111169
Validation loss: 0.6286163532733917
F1 Score (Weighted): 0.6553128419894327


HBox(children=(FloatProgress(value=0.0, description='Epoch 14', max=203.0, style=ProgressStyle(description_wid…


Epoch 14
Training loss: 0.5698167176669454
Validation loss: 0.6292099583148957
F1 Score (Weighted): 0.6589200924171538


HBox(children=(FloatProgress(value=0.0, description='Epoch 15', max=203.0, style=ProgressStyle(description_wid…


Epoch 15
Training loss: 0.5616120838472995
Validation loss: 0.6337929463386536
F1 Score (Weighted): 0.6584529485752441


HBox(children=(FloatProgress(value=0.0, description='Epoch 16', max=203.0, style=ProgressStyle(description_wid…


Epoch 16
Training loss: 0.5524259951020697
Validation loss: 0.6318966293334961
F1 Score (Weighted): 0.6503517936061431


HBox(children=(FloatProgress(value=0.0, description='Epoch 17', max=203.0, style=ProgressStyle(description_wid…


Epoch 17
Training loss: 0.5514050353630423
Validation loss: 0.631215147972107
F1 Score (Weighted): 0.6503517936061431


HBox(children=(FloatProgress(value=0.0, description='Epoch 18', max=203.0, style=ProgressStyle(description_wid…


Epoch 18
Training loss: 0.5376480372962106
Validation loss: 0.6334270596504211
F1 Score (Weighted): 0.660821514537174


HBox(children=(FloatProgress(value=0.0, description='Epoch 19', max=203.0, style=ProgressStyle(description_wid…


Epoch 19
Training loss: 0.5311623907147958
Validation loss: 0.6339023399353028
F1 Score (Weighted): 0.665111177810876


HBox(children=(FloatProgress(value=0.0, description='Epoch 20', max=203.0, style=ProgressStyle(description_wid…


Epoch 20
Training loss: 0.5206427428816339
Validation loss: 0.6371698343753814
F1 Score (Weighted): 0.6551783437730256


HBox(children=(FloatProgress(value=0.0, description='Epoch 21', max=203.0, style=ProgressStyle(description_wid…


Epoch 21
Training loss: 0.5110589949955494
Validation loss: 0.6415237271785736
F1 Score (Weighted): 0.6573330930984888


HBox(children=(FloatProgress(value=0.0, description='Epoch 22', max=203.0, style=ProgressStyle(description_wid…


Epoch 22
Training loss: 0.5034295019551451
Validation loss: 0.6487340128421784
F1 Score (Weighted): 0.6626948200424759


HBox(children=(FloatProgress(value=0.0, description='Epoch 23', max=203.0, style=ProgressStyle(description_wid…


Epoch 23
Training loss: 0.4985780312216341
Validation loss: 0.6547680175304413
F1 Score (Weighted): 0.6642401276506104


HBox(children=(FloatProgress(value=0.0, description='Epoch 24', max=203.0, style=ProgressStyle(description_wid…


Epoch 24
Training loss: 0.4815123120845832
Validation loss: 0.6594659435749054
F1 Score (Weighted): 0.6633338806302733


HBox(children=(FloatProgress(value=0.0, description='Epoch 25', max=203.0, style=ProgressStyle(description_wid…


Epoch 25
Training loss: 0.47823708926515623
Validation loss: 0.6580501902103424
F1 Score (Weighted): 0.6593999068683711


HBox(children=(FloatProgress(value=0.0, description='Epoch 26', max=203.0, style=ProgressStyle(description_wid…


Epoch 26
Training loss: 0.47018687654598595
Validation loss: 0.6589371383190155
F1 Score (Weighted): 0.6599338163116556


HBox(children=(FloatProgress(value=0.0, description='Epoch 27', max=203.0, style=ProgressStyle(description_wid…


Epoch 27
Training loss: 0.46910258775274155
Validation loss: 0.6644612991809845
F1 Score (Weighted): 0.6609735910211832


HBox(children=(FloatProgress(value=0.0, description='Epoch 28', max=203.0, style=ProgressStyle(description_wid…


Epoch 28
Training loss: 0.4554373466322575
Validation loss: 0.6726744985580444
F1 Score (Weighted): 0.6609879561458658


HBox(children=(FloatProgress(value=0.0, description='Epoch 29', max=203.0, style=ProgressStyle(description_wid…


Epoch 29
Training loss: 0.44787942687866134
Validation loss: 0.6791918480396271
F1 Score (Weighted): 0.6625521447152973


HBox(children=(FloatProgress(value=0.0, description='Epoch 30', max=203.0, style=ProgressStyle(description_wid…


Epoch 30
Training loss: 0.44102759842802153
Validation loss: 0.6807375657558441
F1 Score (Weighted): 0.6723883537329923


HBox(children=(FloatProgress(value=0.0, description='Epoch 31', max=203.0, style=ProgressStyle(description_wid…


Epoch 31
Training loss: 0.4328007582079601
Validation loss: 0.6851294040679932
F1 Score (Weighted): 0.6643070933335958


HBox(children=(FloatProgress(value=0.0, description='Epoch 32', max=203.0, style=ProgressStyle(description_wid…


Epoch 32
Training loss: 0.429276391570204
Validation loss: 0.6866304779052734
F1 Score (Weighted): 0.6637828152616545


HBox(children=(FloatProgress(value=0.0, description='Epoch 33', max=203.0, style=ProgressStyle(description_wid…


Epoch 33
Training loss: 0.4107545826382238
Validation loss: 0.6983645343780518
F1 Score (Weighted): 0.6593999068683711


HBox(children=(FloatProgress(value=0.0, description='Epoch 34', max=203.0, style=ProgressStyle(description_wid…


Epoch 34
Training loss: 0.4081934796178282
Validation loss: 0.7063979518413543
F1 Score (Weighted): 0.6575572946036026


HBox(children=(FloatProgress(value=0.0, description='Epoch 35', max=203.0, style=ProgressStyle(description_wid…


Epoch 35
Training loss: 0.39549807859171787
Validation loss: 0.7081471979618073
F1 Score (Weighted): 0.6604513578824912


HBox(children=(FloatProgress(value=0.0, description='Epoch 36', max=203.0, style=ProgressStyle(description_wid…


Epoch 36
Training loss: 0.3939010042128305
Validation loss: 0.7202964055538178
F1 Score (Weighted): 0.6639204770640299


HBox(children=(FloatProgress(value=0.0, description='Epoch 37', max=203.0, style=ProgressStyle(description_wid…


Epoch 37
Training loss: 0.38817512475211047
Validation loss: 0.7317546558380127
F1 Score (Weighted): 0.6553495984470749


HBox(children=(FloatProgress(value=0.0, description='Epoch 38', max=203.0, style=ProgressStyle(description_wid…


Epoch 38
Training loss: 0.36897354914343417
Validation loss: 0.7344744491577149
F1 Score (Weighted): 0.6579578560773734


HBox(children=(FloatProgress(value=0.0, description='Epoch 39', max=203.0, style=ProgressStyle(description_wid…


Epoch 39
Training loss: 0.36606504573610615
Validation loss: 0.7343372344970703
F1 Score (Weighted): 0.6644272145160705


HBox(children=(FloatProgress(value=0.0, description='Epoch 40', max=203.0, style=ProgressStyle(description_wid…


Epoch 40
Training loss: 0.36097931340703826
Validation loss: 0.7493275284767151
F1 Score (Weighted): 0.6621587147077375


HBox(children=(FloatProgress(value=0.0, description='Epoch 41', max=203.0, style=ProgressStyle(description_wid…


Epoch 41
Training loss: 0.3579154166535204
Validation loss: 0.7491866493225098
F1 Score (Weighted): 0.6719681759876195


HBox(children=(FloatProgress(value=0.0, description='Epoch 42', max=203.0, style=ProgressStyle(description_wid…


Epoch 42
Training loss: 0.34596404616762266
Validation loss: 0.7610406005382537
F1 Score (Weighted): 0.6605038842638105


HBox(children=(FloatProgress(value=0.0, description='Epoch 43', max=203.0, style=ProgressStyle(description_wid…


Epoch 43
Training loss: 0.345190421113827
Validation loss: 0.758297609090805
F1 Score (Weighted): 0.6649626727160971


HBox(children=(FloatProgress(value=0.0, description='Epoch 44', max=203.0, style=ProgressStyle(description_wid…


Epoch 44
Training loss: 0.3307138866391675
Validation loss: 0.7772690522670745
F1 Score (Weighted): 0.6637828152616545


HBox(children=(FloatProgress(value=0.0, description='Epoch 45', max=203.0, style=ProgressStyle(description_wid…


Epoch 45
Training loss: 0.322610579137438
Validation loss: 0.7799732255935669
F1 Score (Weighted): 0.6653767006174787


HBox(children=(FloatProgress(value=0.0, description='Epoch 46', max=203.0, style=ProgressStyle(description_wid…


Epoch 46
Training loss: 0.3110563302862233
Validation loss: 0.7792389214038848
F1 Score (Weighted): 0.6636589885440446


HBox(children=(FloatProgress(value=0.0, description='Epoch 47', max=203.0, style=ProgressStyle(description_wid…


Epoch 47
Training loss: 0.30635724488297117
Validation loss: 0.8007002687454223
F1 Score (Weighted): 0.6650568040139802


HBox(children=(FloatProgress(value=0.0, description='Epoch 48', max=203.0, style=ProgressStyle(description_wid…


Epoch 48
Training loss: 0.3152364166866382
Validation loss: 0.805652985572815
F1 Score (Weighted): 0.6605965497702395


HBox(children=(FloatProgress(value=0.0, description='Epoch 49', max=203.0, style=ProgressStyle(description_wid…


Epoch 49
Training loss: 0.29200859962425796
Validation loss: 0.8284885978698731
F1 Score (Weighted): 0.6549376726592981


HBox(children=(FloatProgress(value=0.0, description='Epoch 50', max=203.0, style=ProgressStyle(description_wid…


Epoch 50
Training loss: 0.29544268254869677
Validation loss: 0.818501740694046
F1 Score (Weighted): 0.6612463467617903


HBox(children=(FloatProgress(value=0.0, description='Epoch 51', max=203.0, style=ProgressStyle(description_wid…


Epoch 51
Training loss: 0.2892878823753061
Validation loss: 0.8350990903377533
F1 Score (Weighted): 0.6620301920830355


HBox(children=(FloatProgress(value=0.0, description='Epoch 52', max=203.0, style=ProgressStyle(description_wid…


Epoch 52
Training loss: 0.274936620531411
Validation loss: 0.8452976655960083
F1 Score (Weighted): 0.6563221664905201


HBox(children=(FloatProgress(value=0.0, description='Epoch 53', max=203.0, style=ProgressStyle(description_wid…


Epoch 53
Training loss: 0.2865786156410654
Validation loss: 0.8609718453884124
F1 Score (Weighted): 0.6434686497888398


HBox(children=(FloatProgress(value=0.0, description='Epoch 54', max=203.0, style=ProgressStyle(description_wid…


Epoch 54
Training loss: 0.2603430162216055
Validation loss: 0.8662269723415374
F1 Score (Weighted): 0.6518589453968222


HBox(children=(FloatProgress(value=0.0, description='Epoch 55', max=203.0, style=ProgressStyle(description_wid…


Epoch 55
Training loss: 0.2606708713866807
Validation loss: 0.8794700586795807
F1 Score (Weighted): 0.6461814134908369


HBox(children=(FloatProgress(value=0.0, description='Epoch 56', max=203.0, style=ProgressStyle(description_wid…


Epoch 56
Training loss: 0.2477597962385915
Validation loss: 0.8818091082572938
F1 Score (Weighted): 0.6535625581779256


HBox(children=(FloatProgress(value=0.0, description='Epoch 57', max=203.0, style=ProgressStyle(description_wid…


Epoch 57
Training loss: 0.2500929684731467
Validation loss: 0.8851963293552398
F1 Score (Weighted): 0.655932306129305


HBox(children=(FloatProgress(value=0.0, description='Epoch 58', max=203.0, style=ProgressStyle(description_wid…


Epoch 58
Training loss: 0.24642001019029194
Validation loss: 0.9011220037937164
F1 Score (Weighted): 0.6554556148848338


HBox(children=(FloatProgress(value=0.0, description='Epoch 59', max=203.0, style=ProgressStyle(description_wid…


Epoch 59
Training loss: 0.24332910541243155
Validation loss: 0.9005171728134155
F1 Score (Weighted): 0.6525882977965576


HBox(children=(FloatProgress(value=0.0, description='Epoch 60', max=203.0, style=ProgressStyle(description_wid…


Epoch 60
Training loss: 0.23430099044674135
Validation loss: 0.9032139205932617
F1 Score (Weighted): 0.6574148075435938


HBox(children=(FloatProgress(value=0.0, description='Epoch 61', max=203.0, style=ProgressStyle(description_wid…


Epoch 61
Training loss: 0.21696887097408618
Validation loss: 0.9149790835380555
F1 Score (Weighted): 0.6597391784489816


HBox(children=(FloatProgress(value=0.0, description='Epoch 62', max=203.0, style=ProgressStyle(description_wid…


Epoch 62
Training loss: 0.2209328675431571
Validation loss: 0.929707715511322
F1 Score (Weighted): 0.6554556148848338


HBox(children=(FloatProgress(value=0.0, description='Epoch 63', max=203.0, style=ProgressStyle(description_wid…


Epoch 63
Training loss: 0.2139700526197262
Validation loss: 0.9351388728618621
F1 Score (Weighted): 0.6532935864272679


HBox(children=(FloatProgress(value=0.0, description='Epoch 64', max=203.0, style=ProgressStyle(description_wid…


Epoch 64
Training loss: 0.21298491844696366
Validation loss: 0.9532213163375854
F1 Score (Weighted): 0.6496972646291386


HBox(children=(FloatProgress(value=0.0, description='Epoch 65', max=203.0, style=ProgressStyle(description_wid…


Epoch 65
Training loss: 0.21443824003923115
Validation loss: 0.9472564339637757
F1 Score (Weighted): 0.6562273729079837


HBox(children=(FloatProgress(value=0.0, description='Epoch 66', max=203.0, style=ProgressStyle(description_wid…


Epoch 66
Training loss: 0.2074533836741753
Validation loss: 0.9647048199176789
F1 Score (Weighted): 0.6510544231975879


HBox(children=(FloatProgress(value=0.0, description='Epoch 67', max=203.0, style=ProgressStyle(description_wid…


Epoch 67
Training loss: 0.20113518395520782
Validation loss: 0.9694018077850342
F1 Score (Weighted): 0.650061667235078


HBox(children=(FloatProgress(value=0.0, description='Epoch 68', max=203.0, style=ProgressStyle(description_wid…


Epoch 68
Training loss: 0.2004718818871552
Validation loss: 0.9747402238845825
F1 Score (Weighted): 0.6522276370396378


HBox(children=(FloatProgress(value=0.0, description='Epoch 69', max=203.0, style=ProgressStyle(description_wid…


Epoch 69
Training loss: 0.19175045417961228
Validation loss: 0.9869231271743775
F1 Score (Weighted): 0.6519155658326894


HBox(children=(FloatProgress(value=0.0, description='Epoch 70', max=203.0, style=ProgressStyle(description_wid…


Epoch 70
Training loss: 0.19434226350170639
Validation loss: 1.0072290289402008
F1 Score (Weighted): 0.6481459353955488


HBox(children=(FloatProgress(value=0.0, description='Epoch 71', max=203.0, style=ProgressStyle(description_wid…


Epoch 71
Training loss: 0.1856099598917174
Validation loss: 1.0041364324092865
F1 Score (Weighted): 0.650061667235078


HBox(children=(FloatProgress(value=0.0, description='Epoch 72', max=203.0, style=ProgressStyle(description_wid…


Epoch 72
Training loss: 0.18580089269086644
Validation loss: 1.0104974508285522
F1 Score (Weighted): 0.6557779682849386


HBox(children=(FloatProgress(value=0.0, description='Epoch 73', max=203.0, style=ProgressStyle(description_wid…


Epoch 73
Training loss: 0.18170965150952925
Validation loss: 1.0216861522197724
F1 Score (Weighted): 0.6519155658326894


HBox(children=(FloatProgress(value=0.0, description='Epoch 74', max=203.0, style=ProgressStyle(description_wid…


Epoch 74
Training loss: 0.1838298047631245
Validation loss: 1.0281403863430023
F1 Score (Weighted): 0.6492290880302635


HBox(children=(FloatProgress(value=0.0, description='Epoch 75', max=203.0, style=ProgressStyle(description_wid…


Epoch 75
Training loss: 0.1745435918309712
Validation loss: 1.0392587566375733
F1 Score (Weighted): 0.6498818143832439


HBox(children=(FloatProgress(value=0.0, description='Epoch 76', max=203.0, style=ProgressStyle(description_wid…


Epoch 76
Training loss: 0.17649541226367058
Validation loss: 1.0399472534656524
F1 Score (Weighted): 0.652762204885172


HBox(children=(FloatProgress(value=0.0, description='Epoch 77', max=203.0, style=ProgressStyle(description_wid…


Epoch 77
Training loss: 0.1662990489029532
Validation loss: 1.0493358755111695
F1 Score (Weighted): 0.6574148075435938


HBox(children=(FloatProgress(value=0.0, description='Epoch 78', max=203.0, style=ProgressStyle(description_wid…


Epoch 78
Training loss: 0.16486885730782752
Validation loss: 1.0565361511707305
F1 Score (Weighted): 0.6520739002214013


HBox(children=(FloatProgress(value=0.0, description='Epoch 79', max=203.0, style=ProgressStyle(description_wid…


Epoch 79
Training loss: 0.1569986556532788
Validation loss: 1.0712515938282012
F1 Score (Weighted): 0.6502368391025186


HBox(children=(FloatProgress(value=0.0, description='Epoch 80', max=203.0, style=ProgressStyle(description_wid…


Epoch 80
Training loss: 0.16313850595225843
Validation loss: 1.0796304440498352
F1 Score (Weighted): 0.6477114733198654


HBox(children=(FloatProgress(value=0.0, description='Epoch 81', max=203.0, style=ProgressStyle(description_wid…


Epoch 81
Training loss: 0.15309129104825664
Validation loss: 1.0863329124450685
F1 Score (Weighted): 0.6520739002214013


HBox(children=(FloatProgress(value=0.0, description='Epoch 82', max=203.0, style=ProgressStyle(description_wid…


Epoch 82
Training loss: 0.15513952479243573
Validation loss: 1.084949107170105
F1 Score (Weighted): 0.6511904113388699


HBox(children=(FloatProgress(value=0.0, description='Epoch 83', max=203.0, style=ProgressStyle(description_wid…


Epoch 83
Training loss: 0.16246927362577668
Validation loss: 1.0850500869750976
F1 Score (Weighted): 0.6527968773324214


HBox(children=(FloatProgress(value=0.0, description='Epoch 84', max=203.0, style=ProgressStyle(description_wid…


Epoch 84
Training loss: 0.14901621344274488
Validation loss: 1.0991541707515717
F1 Score (Weighted): 0.651752619771053


HBox(children=(FloatProgress(value=0.0, description='Epoch 85', max=203.0, style=ProgressStyle(description_wid…


Epoch 85
Training loss: 0.14131008175298057
Validation loss: 1.105390567779541
F1 Score (Weighted): 0.6437808801839718


HBox(children=(FloatProgress(value=0.0, description='Epoch 86', max=203.0, style=ProgressStyle(description_wid…


Epoch 86
Training loss: 0.14053684928605856
Validation loss: 1.1228185033798217
F1 Score (Weighted): 0.6481459353955488


HBox(children=(FloatProgress(value=0.0, description='Epoch 87', max=203.0, style=ProgressStyle(description_wid…

KeyboardInterrupt: ignored

In [None]:
#后面就是拿train完的model在test做测试
model = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                      num_labels=2,
                                                      output_attentions=False,
                                                      output_hidden_states=False)

model.to(device)

In [338]:

model.load_state_dict(torch.load('data_volume/finetuned_BERT_epoch_80.model', map_location=torch.device('cuda')))

<All keys matched successfully>

In [339]:
_, predictions, true_vals = evaluate(dataloader_test)

In [340]:
accuracy_per_class(predictions, true_vals)

Class: pants-fire
Accuracy: 211/379

Class: true
Accuracy: 309/420



In [342]:

inputs = tokenizer("5G is the cause of covid-19", return_tensors="pt")
inputs.to(device)
  # Batch size 1
outputs = model(**inputs)
logits=outputs[0]
pred = logits.detach().cpu().numpy()
pred=np.argmax(pred, axis=1)
pred

array([0])