In [1]:
import os
import kagglehub

os.environ["KAGGLEHUB_CACHE"] = "./data/"  # 指定下载路径
path = kagglehub.dataset_download("lakshmi25npathi/imdb-dataset-of-50k-movie-reviews")
print("Path to dataset files:", path)  

Path to dataset files: ./data/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews/versions/1


In [10]:
import pandas as pd

def csv_to_fasttext(input_csv, output_file):
    df = pd.read_csv(input_csv)
    
    with open(output_file, "w", encoding="utf-8") as f:
        for _, row in df.iterrows():
            label = f"__label__{row['sentiment']}"
            text = row['review'].replace("\n", " ").strip()
            f.write(f"{label} {text}\n")

output_file = "./data/imdb_fasttext.txt"
csv_to_fasttext(
    input_csv=f"{path}/IMDB Dataset.csv",  
    output_file=output_file 
)

In [12]:
import random

with open(output_file, "r", encoding="utf-8") as f:
    lines = f.readlines()

random.shuffle(lines)
split_idx = int(0.8 * len(lines))

train_file = "./data/imdb_train.txt"
valid_file = "./data/imdb_valid.txt"

with open(train_file, "w", encoding="utf-8") as f_train:
    f_train.writelines(lines[:split_idx])
    
with open(valid_file, "w", encoding="utf-8") as f_valid:
    f_valid.writelines(lines[split_idx:])

In [15]:
import fasttext

model = fasttext.train_supervised(
    input=train_file,
    epoch=25,          # 迭代次数
    lr=0.5,            # 学习率
    wordNgrams=2,      # 使用2-gram特征
    dim=100,           # 词向量维度
    loss='ova'         # 二分类推荐One-vs-All
)

Read 9M words
Number of words:  383007
Number of labels: 2
Progress: 100.0% words/sec/thread:  880067 lr:  0.000000 avg.loss:  0.097113 ETA:   0h 0m 0s  9.6% words/sec/thread:  991839 lr:  0.451998 avg.loss:  0.687160 ETA:   0h 0m19s


In [16]:
print("验证集性能:", model.test(valid_file))

验证集性能: (10000, 0.898, 0.898)


In [17]:
model_path = "./data/imdb_model.bin"
model.save_model(model_path)  

In [20]:
loaded_model = fasttext.load_model(model_path)
text = "This movie was a brilliant portrayal of human resilience"
labels, probs = loaded_model.predict(text, k=2)  # k=2返回前2个预测
print(f"预测: {labels[0]} (置信度: {probs[0]:.2f})")

预测: __label__positive (置信度: 1.00)


In [2]:
os.environ["KAGGLEHUB_CACHE"] = "./data/"
path = kagglehub.dataset_download("hughxusu/thucnewsyuletrainingwords")

print("Path to dataset files:", path)

Path to dataset files: ./data/datasets/hughxusu/thucnewsyuletrainingwords/versions/1


In [3]:
import fasttext

input_file = f"{path}/THUCNews_yule.txt"
model = fasttext.train_unsupervised(
    input=input_file,
    model='skipgram',  # 可以是'skipgram'或'cbow'
    dim=300,           # 词向量维度
    ws=5,              # 上下文窗口大小
    epoch=5,           # 训练轮数
    minCount=5,        # 最小词频
)

Read 35M words
Number of words:  135377
Number of labels: 0
Progress: 100.0% words/sec/thread:   26677 lr:  0.000000 avg.loss:  1.151860 ETA:   0h 0m 0s1.940292 ETA:   0h 8m23s 18.7% words/sec/thread:   27860 lr:  0.040638 avg.loss:  1.911608 ETA:   0h 7m55s 20.5% words/sec/thread:   27432 lr:  0.039745 avg.loss:  1.907373 ETA:   0h 7m51s 22.4% words/sec/thread:   26729 lr:  0.038784 avg.loss:  1.901894 ETA:   0h 7m52s 26.4% words/sec/thread:   26429 lr:  0.036801 avg.loss:  1.892128 ETA:   0h 7m33s 29.2% words/sec/thread:   26500 lr:  0.035383 avg.loss:  1.862262 ETA:   0h 7m14s 30.3% words/sec/thread:   26483 lr:  0.034846 avg.loss:  1.850049 ETA:   0h 7m 8s 32.3% words/sec/thread:   26393 lr:  0.033825 avg.loss:  1.828202 ETA:   0h 6m57s 32.6% words/sec/thread:   26402 lr:  0.033679 avg.loss:  1.825910 ETA:   0h 6m55s 36.3% words/sec/thread:   26346 lr:  0.031851 avg.loss:  1.793900 ETA:   0h 6m33s 43.9% words/sec/thread:   26186 lr:  0.028067 avg.loss:  1.745266 ETA:   0h 5m49s 47

In [4]:
output_model = "./data/THUCNews_yule.bin"
model.save_model(output_model)

In [5]:
loaded_word_model = fasttext.load_model("./data/THUCNews_yule.bin")

In [7]:
loaded_word_model.get_word_vector("演唱会")

array([-1.36941954e-01, -6.91107959e-02, -4.50979322e-01, -1.00790765e-02,
       -1.79097299e-02,  5.43925762e-02, -1.65300652e-01,  4.95170578e-02,
       -1.37026059e-02, -1.96755417e-02,  1.99646518e-01, -3.60187322e-01,
       -2.76700258e-01, -2.08970562e-01,  2.20449399e-02,  1.38366967e-01,
        1.94086999e-01,  1.01207010e-01,  2.71101326e-01, -5.47323301e-02,
       -2.85441786e-01,  3.47651780e-01, -1.98001012e-01, -9.02976543e-02,
       -4.94873673e-01, -4.05424744e-01,  4.88335937e-02, -3.07663679e-01,
        2.97929764e-01,  4.82776612e-02, -1.52842730e-01, -1.75745875e-01,
        3.39250594e-01, -1.07334882e-01, -2.02305630e-01, -5.56828417e-02,
       -1.82944462e-01, -1.00628540e-01,  5.50660118e-02,  2.58158088e-01,
       -2.21999250e-02, -9.45993885e-03,  2.82346398e-01, -1.29430637e-01,
        1.47254333e-01, -1.01202711e-01,  5.69567978e-02, -2.49726847e-01,
        3.04114491e-01,  5.59997149e-02, -3.19920659e-01, -3.45625430e-01,
       -2.95162834e-02, -

In [12]:
loaded_word_model.get_nearest_neighbors("演唱会")

[(0.7165132761001587, '巡回演唱'),
 (0.6898350715637207, '开唱'),
 (0.6895934343338013, '个唱会'),
 (0.6618766784667969, '安可场'),
 (0.6557812690734863, 'LIVEDVD'),
 (0.6470138430595398, '踏红馆'),
 (0.646662712097168, '独唱会'),
 (0.6439158320426941, '音乐会'),
 (0.6410515308380127, '一忆莲'),
 (0.6387121081352234, '进红馆')]

In [14]:
loaded_word_model.get_nearest_neighbors("粉丝")

[(0.6937841773033142, '歌迷'),
 (0.671504020690918, '影迷'),
 (0.6690942645072937, 'fans'),
 (0.6440088748931885, 'Fans'),
 (0.6379040479660034, '庚饭'),
 (0.6054631471633911, '步迷'),
 (0.5840657949447632, 'FANS'),
 (0.580436110496521, '庚饭们'),
 (0.5663548707962036, '笔亲们'),
 (0.5612307786941528, '逾千')]

In [22]:
loaded_word_model.get_nearest_neighbors("周杰伦")

[(0.8326612710952759, '周董'),
 (0.6159330010414124, '杰伦'),
 (0.5920110940933228, '蔡依林'),
 (0.5822582244873047, '退罗志祥'),
 (0.5650357604026794, '刺陵'),
 (0.5632564425468445, '王力宏'),
 (0.5565342307090759, '罗康妮'),
 (0.5519910454750061, '刺陵时'),
 (0.5510969161987305, 'Jay'),
 (0.5508455634117126, '江语晨')]

In [None]:
from fasttext import util
util.download_model('zh', if_exists='ignore')

In [None]:
model = fasttext.load_model("cc.zh.300.bin")
model.words[:100]
loaded_word_model.get_nearest_neighbors("演唱会")
loaded_word_model.get_nearest_neighbors("粉丝")
loaded_word_model.get_nearest_neighbors("周杰伦")