In [33]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data import get_tokenizer
from tqdm.auto import tqdm

In [34]:
class Corpus(Dataset):
    def __init__(self, data_file_path, window_size=2, tokenizer=None):
        self.window_size = window_size
        self.context_target = []
        with open(data_file_path, "r", encoding="utf-8") as f:
            lines = f.readlines()
        if tokenizer:
            lines = [tokenizer(line.strip()) for line in lines]
        else:
            lines = [line.strip().split(" ") for line in lines]
        self.vocab = build_vocab_from_iterator(lines)
        self.word_to_idx = self.vocab.get_stoi()
        self.idx_to_word = self.vocab.get_itos()
        for line in lines:
            self.context_target.extend(
                [
                    (
                        [line[i - (j + 1)] for j in range(window_size)]
                        + [line[i + (j + 1)] for j in range(window_size)],
                        line[i],
                    )
                    for i in range(window_size, len(line) - window_size)
                ]
            )

    def __getitem__(self, idx):
        context = torch.tensor(self.vocab(self.context_target[idx][0]), dtype=torch.long)
        target = torch.tensor(self.vocab[self.context_target[idx][1]], dtype=torch.long)
        return context, target

    def __len__(self):
        return len(self.context_target)

In [35]:
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, inputs):
        embeds = torch.mean(self.embedding(inputs), dim=1)
        return self.linear(embeds)

In [36]:
WINDOWS_SIZE = 2
EMBEDDING_DIM = 30
BATCH_SIZE = 512
NUM_EPOCH = 300
NUM_WORKERS = 12
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [37]:
corpus = Corpus('zh.txt', window_size=WINDOWS_SIZE)
# corpus = Corpus('en.txt', window_size=WINDOWS_SIZE, tokenizer=get_tokenizer("basic_english"))
vocab = corpus.vocab

In [38]:
dataloader = DataLoader(corpus, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [39]:
model = CBOW(len(corpus.vocab), EMBEDDING_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [40]:
for epoch in range(NUM_EPOCH):
    total_loss = 0
    correct = 0
    total_num = 0
    for context, target in tqdm(dataloader, leave=False):
        batch_size = context.shape[0]
        total_num += batch_size
        optimizer.zero_grad()
        context, target = context.to(DEVICE), target.to(DEVICE)
        pred = model(context)
        correct += (torch.argmax(pred, dim=1) == target).sum().item()
        loss = criterion(pred, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch_size
    print(f"Epoch {epoch + 1} Loss: {total_loss / total_num:.4f} Acc: {correct / total_num:.4f}")

                                                  

Epoch 1 Loss: 7.2787 Acc: 0.0841


                                                  

Epoch 2 Loss: 6.1037 Acc: 0.1319


                                                  

Epoch 3 Loss: 5.5928 Acc: 0.1569


                                                  

Epoch 4 Loss: 5.2003 Acc: 0.1772


                                                  

Epoch 5 Loss: 4.8770 Acc: 0.1981


                                                  

Epoch 6 Loss: 4.6020 Acc: 0.2216


                                                  

Epoch 7 Loss: 4.3664 Acc: 0.2455


                                                  

Epoch 8 Loss: 4.1620 Acc: 0.2697


                                                  

Epoch 9 Loss: 3.9857 Acc: 0.2918


                                                  

Epoch 10 Loss: 3.8303 Acc: 0.3114


                                                  

Epoch 11 Loss: 3.6934 Acc: 0.3313


                                                  

Epoch 12 Loss: 3.5733 Acc: 0.3491


                                                  

Epoch 13 Loss: 3.4660 Acc: 0.3643


                                                  

Epoch 14 Loss: 3.3708 Acc: 0.3784


                                                  

Epoch 15 Loss: 3.2851 Acc: 0.3919


                                                  

Epoch 16 Loss: 3.2085 Acc: 0.4029


                                                  

Epoch 17 Loss: 3.1388 Acc: 0.4138


                                                  

Epoch 18 Loss: 3.0763 Acc: 0.4232


                                                  

Epoch 19 Loss: 3.0198 Acc: 0.4326


                                                  

Epoch 20 Loss: 2.9671 Acc: 0.4413


                                                  

Epoch 21 Loss: 2.9193 Acc: 0.4480


                                                  

Epoch 22 Loss: 2.8762 Acc: 0.4543


                                                  

Epoch 23 Loss: 2.8363 Acc: 0.4603


                                                  

Epoch 24 Loss: 2.7985 Acc: 0.4665


                                                  

Epoch 25 Loss: 2.7657 Acc: 0.4707


                                                  

Epoch 26 Loss: 2.7356 Acc: 0.4760


                                                  

Epoch 27 Loss: 2.7057 Acc: 0.4807


                                                  

Epoch 28 Loss: 2.6790 Acc: 0.4844


                                                  

Epoch 29 Loss: 2.6533 Acc: 0.4887


                                                  

Epoch 30 Loss: 2.6303 Acc: 0.4918


                                                  

Epoch 31 Loss: 2.6069 Acc: 0.4950


                                                  

Epoch 32 Loss: 2.5872 Acc: 0.4982


                                                  

Epoch 33 Loss: 2.5675 Acc: 0.5011


                                                  

Epoch 34 Loss: 2.5488 Acc: 0.5034


                                                  

Epoch 35 Loss: 2.5316 Acc: 0.5058


                                                  

Epoch 36 Loss: 2.5145 Acc: 0.5090


                                                  

Epoch 37 Loss: 2.4996 Acc: 0.5118


                                                  

Epoch 38 Loss: 2.4855 Acc: 0.5128


                                                  

Epoch 39 Loss: 2.4704 Acc: 0.5151


                                                  

Epoch 40 Loss: 2.4580 Acc: 0.5162


                                                  

Epoch 41 Loss: 2.4453 Acc: 0.5184


                                                  

Epoch 42 Loss: 2.4340 Acc: 0.5200


                                                  

Epoch 43 Loss: 2.4204 Acc: 0.5224


                                                  

Epoch 44 Loss: 2.4105 Acc: 0.5228


                                                  

Epoch 45 Loss: 2.3998 Acc: 0.5256


                                                  

Epoch 46 Loss: 2.3908 Acc: 0.5259


                                                  

Epoch 47 Loss: 2.3810 Acc: 0.5276


                                                  

Epoch 48 Loss: 2.3703 Acc: 0.5295


                                                  

Epoch 49 Loss: 2.3622 Acc: 0.5308


                                                  

Epoch 50 Loss: 2.3545 Acc: 0.5309


                                                  

Epoch 51 Loss: 2.3460 Acc: 0.5320


                                                  

Epoch 52 Loss: 2.3385 Acc: 0.5331


                                                  

Epoch 53 Loss: 2.3299 Acc: 0.5348


                                                  

Epoch 54 Loss: 2.3242 Acc: 0.5352


                                                  

Epoch 55 Loss: 2.3167 Acc: 0.5359


                                                  

Epoch 56 Loss: 2.3092 Acc: 0.5377


                                                  

Epoch 57 Loss: 2.3034 Acc: 0.5391


                                                  

Epoch 58 Loss: 2.2968 Acc: 0.5392


                                                  

Epoch 59 Loss: 2.2907 Acc: 0.5405


                                                  

Epoch 60 Loss: 2.2848 Acc: 0.5408


                                                  

Epoch 61 Loss: 2.2786 Acc: 0.5418


                                                  

Epoch 62 Loss: 2.2734 Acc: 0.5422


                                                  

Epoch 63 Loss: 2.2690 Acc: 0.5431


                                                  

Epoch 64 Loss: 2.2630 Acc: 0.5433


                                                  

Epoch 65 Loss: 2.2591 Acc: 0.5448


                                                  

Epoch 66 Loss: 2.2534 Acc: 0.5456


                                                  

Epoch 67 Loss: 2.2479 Acc: 0.5462


                                                  

Epoch 68 Loss: 2.2437 Acc: 0.5473


                                                  

Epoch 69 Loss: 2.2400 Acc: 0.5469


                                                  

Epoch 70 Loss: 2.2358 Acc: 0.5478


                                                  

Epoch 71 Loss: 2.2311 Acc: 0.5493


                                                  

Epoch 72 Loss: 2.2268 Acc: 0.5496


                                                  

Epoch 73 Loss: 2.2219 Acc: 0.5501


                                                  

Epoch 74 Loss: 2.2184 Acc: 0.5498


                                                  

Epoch 75 Loss: 2.2152 Acc: 0.5511


                                                  

Epoch 76 Loss: 2.2114 Acc: 0.5512


                                                  

Epoch 77 Loss: 2.2077 Acc: 0.5514


                                                  

Epoch 78 Loss: 2.2043 Acc: 0.5528


                                                  

Epoch 79 Loss: 2.2018 Acc: 0.5520


                                                  

Epoch 80 Loss: 2.1982 Acc: 0.5525


                                                  

Epoch 81 Loss: 2.1960 Acc: 0.5536


                                                  

Epoch 82 Loss: 2.1907 Acc: 0.5543


                                                  

Epoch 83 Loss: 2.1871 Acc: 0.5549


                                                  

Epoch 84 Loss: 2.1843 Acc: 0.5549


                                                  

Epoch 85 Loss: 2.1829 Acc: 0.5552


                                                  

Epoch 86 Loss: 2.1786 Acc: 0.5561


                                                  

Epoch 87 Loss: 2.1770 Acc: 0.5561


                                                  

Epoch 88 Loss: 2.1741 Acc: 0.5573


                                                  

Epoch 89 Loss: 2.1714 Acc: 0.5567


                                                  

Epoch 90 Loss: 2.1693 Acc: 0.5572


                                                  

Epoch 91 Loss: 2.1646 Acc: 0.5577


                                                  

Epoch 92 Loss: 2.1619 Acc: 0.5589


                                                  

Epoch 93 Loss: 2.1605 Acc: 0.5591


                                                  

Epoch 94 Loss: 2.1582 Acc: 0.5586


                                                  

Epoch 95 Loss: 2.1548 Acc: 0.5594


                                                  

Epoch 96 Loss: 2.1537 Acc: 0.5590


                                                  

Epoch 97 Loss: 2.1513 Acc: 0.5596


                                                  

Epoch 98 Loss: 2.1493 Acc: 0.5597


                                                  

Epoch 99 Loss: 2.1471 Acc: 0.5605


                                                  

Epoch 100 Loss: 2.1445 Acc: 0.5608


                                                  

Epoch 101 Loss: 2.1424 Acc: 0.5609


                                                  

Epoch 102 Loss: 2.1390 Acc: 0.5635


                                                  

Epoch 103 Loss: 2.1374 Acc: 0.5620


                                                  

Epoch 104 Loss: 2.1371 Acc: 0.5613


                                                  

Epoch 105 Loss: 2.1340 Acc: 0.5621


                                                  

Epoch 106 Loss: 2.1323 Acc: 0.5627


                                                  

Epoch 107 Loss: 2.1290 Acc: 0.5633


                                                  

Epoch 108 Loss: 2.1275 Acc: 0.5634


                                                  

Epoch 109 Loss: 2.1265 Acc: 0.5626


                                                  

Epoch 110 Loss: 2.1241 Acc: 0.5642


                                                  

Epoch 111 Loss: 2.1226 Acc: 0.5642


                                                  

Epoch 112 Loss: 2.1223 Acc: 0.5628


                                                  

Epoch 113 Loss: 2.1195 Acc: 0.5642


                                                  

Epoch 114 Loss: 2.1173 Acc: 0.5644


                                                  

Epoch 115 Loss: 2.1160 Acc: 0.5644


                                                  

Epoch 116 Loss: 2.1141 Acc: 0.5640


                                                  

Epoch 117 Loss: 2.1120 Acc: 0.5656


                                                  

Epoch 118 Loss: 2.1108 Acc: 0.5648


                                                  

Epoch 119 Loss: 2.1091 Acc: 0.5657


                                                  

Epoch 120 Loss: 2.1082 Acc: 0.5660


                                                  

Epoch 121 Loss: 2.1056 Acc: 0.5654


                                                  

Epoch 122 Loss: 2.1054 Acc: 0.5655


                                                  

Epoch 123 Loss: 2.1032 Acc: 0.5664


                                                  

Epoch 124 Loss: 2.1029 Acc: 0.5653


                                                  

Epoch 125 Loss: 2.1012 Acc: 0.5667


                                                  

Epoch 126 Loss: 2.0984 Acc: 0.5672


                                                  

Epoch 127 Loss: 2.0975 Acc: 0.5668


                                                  

Epoch 128 Loss: 2.0953 Acc: 0.5674


                                                  

Epoch 129 Loss: 2.0940 Acc: 0.5680


                                                  

Epoch 130 Loss: 2.0935 Acc: 0.5671


                                                  

Epoch 131 Loss: 2.0930 Acc: 0.5676


                                                  

Epoch 132 Loss: 2.0914 Acc: 0.5666


                                                  

Epoch 133 Loss: 2.0898 Acc: 0.5673


                                                  

Epoch 134 Loss: 2.0882 Acc: 0.5684


                                                  

Epoch 135 Loss: 2.0865 Acc: 0.5689


                                                  

Epoch 136 Loss: 2.0853 Acc: 0.5681


                                                  

Epoch 137 Loss: 2.0844 Acc: 0.5688


                                                  

Epoch 138 Loss: 2.0821 Acc: 0.5698


                                                  

Epoch 139 Loss: 2.0818 Acc: 0.5696


                                                  

Epoch 140 Loss: 2.0812 Acc: 0.5690


                                                  

Epoch 141 Loss: 2.0785 Acc: 0.5696


                                                  

Epoch 142 Loss: 2.0782 Acc: 0.5696


                                                  

Epoch 143 Loss: 2.0768 Acc: 0.5708


                                                  

Epoch 144 Loss: 2.0761 Acc: 0.5702


                                                  

Epoch 145 Loss: 2.0745 Acc: 0.5698


                                                  

Epoch 146 Loss: 2.0729 Acc: 0.5701


                                                  

Epoch 147 Loss: 2.0725 Acc: 0.5705


                                                  

Epoch 148 Loss: 2.0714 Acc: 0.5710


                                                  

Epoch 149 Loss: 2.0704 Acc: 0.5702


                                                  

Epoch 150 Loss: 2.0702 Acc: 0.5700


                                                  

Epoch 151 Loss: 2.0677 Acc: 0.5719


                                                  

Epoch 152 Loss: 2.0661 Acc: 0.5707


                                                  

Epoch 153 Loss: 2.0668 Acc: 0.5705


                                                  

Epoch 154 Loss: 2.0655 Acc: 0.5716


                                                  

Epoch 155 Loss: 2.0651 Acc: 0.5712


                                                  

Epoch 156 Loss: 2.0641 Acc: 0.5709


                                                  

Epoch 157 Loss: 2.0609 Acc: 0.5715


                                                  

Epoch 158 Loss: 2.0604 Acc: 0.5731


                                                  

Epoch 159 Loss: 2.0611 Acc: 0.5714


                                                  

Epoch 160 Loss: 2.0584 Acc: 0.5718


                                                  

Epoch 161 Loss: 2.0587 Acc: 0.5715


                                                  

Epoch 162 Loss: 2.0582 Acc: 0.5729


                                                  

Epoch 163 Loss: 2.0568 Acc: 0.5729


                                                  

Epoch 164 Loss: 2.0553 Acc: 0.5732


                                                  

Epoch 165 Loss: 2.0548 Acc: 0.5721


                                                  

Epoch 166 Loss: 2.0532 Acc: 0.5727


                                                  

Epoch 167 Loss: 2.0528 Acc: 0.5724


                                                  

Epoch 168 Loss: 2.0522 Acc: 0.5729


                                                  

Epoch 169 Loss: 2.0510 Acc: 0.5730


                                                  

Epoch 170 Loss: 2.0505 Acc: 0.5731


                                                  

Epoch 171 Loss: 2.0502 Acc: 0.5725


                                                  

Epoch 172 Loss: 2.0502 Acc: 0.5731


                                                  

Epoch 173 Loss: 2.0459 Acc: 0.5743


                                                  

Epoch 174 Loss: 2.0474 Acc: 0.5728


                                                  

Epoch 175 Loss: 2.0463 Acc: 0.5737


                                                  

Epoch 176 Loss: 2.0457 Acc: 0.5739


                                                  

Epoch 177 Loss: 2.0447 Acc: 0.5745


                                                  

Epoch 178 Loss: 2.0437 Acc: 0.5742


                                                  

Epoch 179 Loss: 2.0442 Acc: 0.5740


                                                  

Epoch 180 Loss: 2.0431 Acc: 0.5744


                                                  

Epoch 181 Loss: 2.0417 Acc: 0.5740


                                                  

Epoch 182 Loss: 2.0425 Acc: 0.5750


                                                  

Epoch 183 Loss: 2.0400 Acc: 0.5750


                                                  

Epoch 184 Loss: 2.0400 Acc: 0.5740


                                                  

Epoch 185 Loss: 2.0393 Acc: 0.5755


                                                  

Epoch 186 Loss: 2.0382 Acc: 0.5741


                                                  

Epoch 187 Loss: 2.0369 Acc: 0.5758


                                                  

Epoch 188 Loss: 2.0356 Acc: 0.5756


                                                  

Epoch 189 Loss: 2.0343 Acc: 0.5755


                                                  

Epoch 190 Loss: 2.0346 Acc: 0.5757


                                                  

Epoch 191 Loss: 2.0342 Acc: 0.5752


                                                  

Epoch 192 Loss: 2.0339 Acc: 0.5750


                                                  

Epoch 193 Loss: 2.0339 Acc: 0.5756


                                                  

Epoch 194 Loss: 2.0323 Acc: 0.5758


                                                  

Epoch 195 Loss: 2.0316 Acc: 0.5756


                                                  

Epoch 196 Loss: 2.0309 Acc: 0.5765


                                                  

Epoch 197 Loss: 2.0300 Acc: 0.5764


                                                  

Epoch 198 Loss: 2.0301 Acc: 0.5757


                                                  

Epoch 199 Loss: 2.0302 Acc: 0.5754


                                                  

Epoch 200 Loss: 2.0281 Acc: 0.5760


                                                  

Epoch 201 Loss: 2.0265 Acc: 0.5772


                                                  

Epoch 202 Loss: 2.0267 Acc: 0.5761


                                                  

Epoch 203 Loss: 2.0261 Acc: 0.5765


                                                  

Epoch 204 Loss: 2.0260 Acc: 0.5762


                                                  

Epoch 205 Loss: 2.0250 Acc: 0.5765


                                                  

Epoch 206 Loss: 2.0234 Acc: 0.5769


                                                  

Epoch 207 Loss: 2.0237 Acc: 0.5765


                                                  

Epoch 208 Loss: 2.0238 Acc: 0.5758


                                                  

Epoch 209 Loss: 2.0236 Acc: 0.5764


                                                  

Epoch 210 Loss: 2.0221 Acc: 0.5773


                                                  

Epoch 211 Loss: 2.0221 Acc: 0.5771


                                                  

Epoch 212 Loss: 2.0213 Acc: 0.5768


                                                  

Epoch 213 Loss: 2.0199 Acc: 0.5780


                                                  

Epoch 214 Loss: 2.0195 Acc: 0.5769


                                                  

Epoch 215 Loss: 2.0178 Acc: 0.5775


                                                  

Epoch 216 Loss: 2.0184 Acc: 0.5790


                                                  

Epoch 217 Loss: 2.0179 Acc: 0.5777


                                                  

Epoch 218 Loss: 2.0183 Acc: 0.5782


                                                  

Epoch 219 Loss: 2.0158 Acc: 0.5787


                                                  

Epoch 220 Loss: 2.0159 Acc: 0.5780


                                                  

Epoch 221 Loss: 2.0162 Acc: 0.5783


                                                  

Epoch 222 Loss: 2.0151 Acc: 0.5773


                                                  

Epoch 223 Loss: 2.0155 Acc: 0.5784


                                                  

Epoch 224 Loss: 2.0136 Acc: 0.5782


                                                  

Epoch 225 Loss: 2.0149 Acc: 0.5782


                                                  

Epoch 226 Loss: 2.0135 Acc: 0.5778


                                                  

Epoch 227 Loss: 2.0129 Acc: 0.5784


                                                  

Epoch 228 Loss: 2.0123 Acc: 0.5776


                                                  

Epoch 229 Loss: 2.0119 Acc: 0.5786


                                                  

Epoch 230 Loss: 2.0110 Acc: 0.5788


                                                  

Epoch 231 Loss: 2.0105 Acc: 0.5786


                                                  

Epoch 232 Loss: 2.0110 Acc: 0.5780


                                                  

Epoch 233 Loss: 2.0088 Acc: 0.5793


                                                  

Epoch 234 Loss: 2.0103 Acc: 0.5779


                                                  

Epoch 235 Loss: 2.0085 Acc: 0.5795


                                                  

Epoch 236 Loss: 2.0076 Acc: 0.5796


                                                  

Epoch 237 Loss: 2.0070 Acc: 0.5796


                                                  

Epoch 238 Loss: 2.0066 Acc: 0.5795


                                                  

Epoch 239 Loss: 2.0070 Acc: 0.5798


                                                  

Epoch 240 Loss: 2.0055 Acc: 0.5791


                                                  

Epoch 241 Loss: 2.0064 Acc: 0.5788


                                                  

Epoch 242 Loss: 2.0025 Acc: 0.5800


                                                  

Epoch 243 Loss: 2.0042 Acc: 0.5801


                                                  

Epoch 244 Loss: 2.0043 Acc: 0.5783


                                                  

Epoch 245 Loss: 2.0036 Acc: 0.5787


                                                  

Epoch 246 Loss: 2.0026 Acc: 0.5802


                                                  

Epoch 247 Loss: 2.0024 Acc: 0.5794


                                                  

Epoch 248 Loss: 2.0015 Acc: 0.5805


                                                  

Epoch 249 Loss: 2.0014 Acc: 0.5809


                                                  

Epoch 250 Loss: 2.0012 Acc: 0.5806


                                                  

Epoch 251 Loss: 2.0014 Acc: 0.5808


                                                  

Epoch 252 Loss: 2.0020 Acc: 0.5797


                                                  

Epoch 253 Loss: 1.9995 Acc: 0.5805


                                                  

Epoch 254 Loss: 1.9998 Acc: 0.5806


                                                  

Epoch 255 Loss: 1.9996 Acc: 0.5808


                                                  

Epoch 256 Loss: 1.9990 Acc: 0.5798


                                                  

Epoch 257 Loss: 1.9983 Acc: 0.5812


                                                  

Epoch 258 Loss: 1.9988 Acc: 0.5798


                                                  

Epoch 259 Loss: 1.9991 Acc: 0.5803


                                                  

Epoch 260 Loss: 1.9966 Acc: 0.5826


                                                  

Epoch 261 Loss: 1.9957 Acc: 0.5793


                                                  

Epoch 262 Loss: 1.9963 Acc: 0.5809


                                                  

Epoch 263 Loss: 1.9946 Acc: 0.5816


                                                  

Epoch 264 Loss: 1.9957 Acc: 0.5821


                                                  

Epoch 265 Loss: 1.9963 Acc: 0.5811


                                                  

Epoch 266 Loss: 1.9949 Acc: 0.5810


                                                  

Epoch 267 Loss: 1.9923 Acc: 0.5816


                                                  

Epoch 268 Loss: 1.9937 Acc: 0.5799


                                                  

Epoch 269 Loss: 1.9948 Acc: 0.5812


                                                  

Epoch 270 Loss: 1.9946 Acc: 0.5803


                                                  

Epoch 271 Loss: 1.9919 Acc: 0.5813


                                                  

Epoch 272 Loss: 1.9924 Acc: 0.5812


                                                  

Epoch 273 Loss: 1.9917 Acc: 0.5813


                                                  

Epoch 274 Loss: 1.9902 Acc: 0.5821


                                                  

Epoch 275 Loss: 1.9910 Acc: 0.5813


                                                  

Epoch 276 Loss: 1.9899 Acc: 0.5821


                                                  

Epoch 277 Loss: 1.9912 Acc: 0.5823


                                                  

Epoch 278 Loss: 1.9910 Acc: 0.5811


                                                  

Epoch 279 Loss: 1.9892 Acc: 0.5817


                                                  

Epoch 280 Loss: 1.9896 Acc: 0.5812


                                                  

Epoch 281 Loss: 1.9895 Acc: 0.5821


                                                  

Epoch 282 Loss: 1.9885 Acc: 0.5823


                                                  

Epoch 283 Loss: 1.9885 Acc: 0.5822


                                                  

Epoch 284 Loss: 1.9875 Acc: 0.5819


                                                  

Epoch 285 Loss: 1.9868 Acc: 0.5813


                                                  

Epoch 286 Loss: 1.9862 Acc: 0.5818


                                                  

Epoch 287 Loss: 1.9881 Acc: 0.5810


                                                  

Epoch 288 Loss: 1.9859 Acc: 0.5831


                                                  

Epoch 289 Loss: 1.9858 Acc: 0.5815


                                                  

Epoch 290 Loss: 1.9845 Acc: 0.5827


                                                  

Epoch 291 Loss: 1.9861 Acc: 0.5820


                                                  

Epoch 292 Loss: 1.9850 Acc: 0.5829


                                                  

Epoch 293 Loss: 1.9825 Acc: 0.5830


                                                  

Epoch 294 Loss: 1.9833 Acc: 0.5831


                                                  

Epoch 295 Loss: 1.9847 Acc: 0.5816


                                                  

Epoch 296 Loss: 1.9852 Acc: 0.5818


                                                  

Epoch 297 Loss: 1.9838 Acc: 0.5829


                                                  

Epoch 298 Loss: 1.9824 Acc: 0.5828


                                                  

Epoch 299 Loss: 1.9821 Acc: 0.5821


                                                  

Epoch 300 Loss: 1.9809 Acc: 0.5829


In [41]:
word_embeddings = {word: model.embedding.weight[idx].tolist() for idx, word in enumerate(corpus.idx_to_word)}

In [42]:
with open('zh_embeddings.txt', 'w') as f:
    for word, embedding in word_embeddings.items():
        f.write(f'{word}: {embedding}\n')

In [43]:
embedding = model.embedding.to('cpu')

In [44]:
vectors = torch.stack([embedding(torch.tensor(vocab[s])) for s in vocab.get_itos()], 0)

In [45]:
def close_words(x, n = 5):
  vec = embedding(torch.tensor(vocab[x]))
  top5 = np.linalg.norm(vectors.detach().numpy() - vec.detach().numpy(), axis = 1).argsort()[:n]
  return [ vocab.get_itos()[x] for x in top5 ]

In [48]:
close_words('中国')

['中国', '美国', '日本', '我国', '我们']