In [1]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
import pandas as pd

from text_classification.utils.train_evaluate import Trainer
from text_classification.models.TextLSTM import TextLSTM
from text_classification.data_ag_news.data.data_process_glove import DataProcess

%run ..\..\models\TextLSTM.py

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dp = DataProcess('../datasets/train.csv', '../datasets/test.csv', device)
train_loader, test_loader = dp.get_dataLoader(141)
pre_vector = dp.get_pre_trained("glove.6B.50d.txt", '../../extra/glove_vector/')
print(pre_vector.shape)

torch.Size([95812, 50])


In [3]:
model = TextLSTM(num_class=4, vocab_size=pre_vector.shape[0], embedding_size=pre_vector.shape[1],
                 hidden_size=256, num_layers=2, dropout_ratio=0.3, bidirectional=True)
# 使用模型预训练词向量矩阵
model.embed.weight.data.copy_(pre_vector)
model.embed.weight.requires_grad = False  # 冻结网络层,使之不参与训练
model = model.to(device)

epochs = 1
lr = 0.01
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

In [4]:
def compute_metrics_f1(predict_all, y_true):
    """计算模型f1 score"""
    predict = predict_all.argmax(-1)
    label = y_true
    acc = f1_score(label, predict, average='micro')
    return {"f1": acc}


t_and_v = Trainer(model=model, optimizer=optimizer, criterion=criterion, epochs=epochs)

metric_result = t_and_v.train(train_loader=train_loader, valid_loader=test_loader, compute_metrics=compute_metrics_f1)
metric_result_df = pd.DataFrame(metric_result)
metric_result_df.index.name = 'epoch'
metric_result_df

**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
**********************
torch.Size([4, 16, 256])
**********************
******

KeyboardInterrupt: 

In [None]:
plt.plot(metric_result['Training f1'], label='Training f1 score')
plt.plot(metric_result['Validation f1'], label='Validation f1 score')
plt.ylabel('f1')
plt.xlabel('epoch')
plt.legend()
plt.show()
