In [1]:
from datasets import load_dataset
import fasttext

# 加载 AG News 数据集
dataset = load_dataset("ag_news")

# 将数据集保存为 FastText 格式
def save_to_fasttext_format(dataset, filename):
    with open(filename, 'w') as f:
        for item in dataset:
            label = f"__label__{item['label']}"
            text = item['text'].replace('\n', ' ')  # 确保文本在一行
            f.write(f"{label} {text}\n")

In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

In [3]:
# 保存训练和测试数据
save_to_fasttext_format(dataset['train'], 'train.txt')
save_to_fasttext_format(dataset['test'], 'test.txt')

In [4]:
# 训练模型
model = fasttext.train_supervised(input="train.txt", epoch=25, lr=1.0, wordNgrams=2, verbose=2, minCount=1)

# 保存模型
model.save_model("model_fasttext.bin")

Read 4M words
Number of words:  188111
Number of labels: 4
Progress: 100.0% words/sec/thread: 2751251 lr:  0.000000 avg.loss:  0.034338 ETA:   0h 0m 0s


In [5]:
# 测试模型
result = model.test("test.txt")

# 输出测试结果
print(f"Precision: {result[1]}")
print(f"Recall: {result[2]}")

Precision: 0.9163157894736842
Recall: 0.9163157894736842


In [6]:
# 预测新文本
texts = [
    "The president gave a speech at the UN.",
    "A breakthrough in quantum computing was announced.",
    "The basketball team celebrated their victory.",
    "Tech companies are investing in blockchain technology."
]
labels, probabilities = model.predict(texts)

for text, label, probability in zip(texts, labels, probabilities):
    print(f"Text: {text}\nPredicted Label: {label[0]}\nProbability: {probability[0]}\n")

Text: The president gave a speech at the UN.
Predicted Label: __label__0
Probability: 0.9992398023605347

Text: A breakthrough in quantum computing was announced.
Predicted Label: __label__3
Probability: 0.9750422835350037

Text: The basketball team celebrated their victory.
Predicted Label: __label__1
Probability: 1.0000100135803223

Text: Tech companies are investing in blockchain technology.
Predicted Label: __label__3
Probability: 0.9999887943267822

