In [2]:
import torch
import argparse
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

import Ipynb_importer
from ReadingData import DataReader, Metapath2vecDataset
from model import SkipGramModel

importing Jupyter notebook from ReadingData.ipynb
importing Jupyter notebook from model.ipynb


In [5]:
class Metapath2VecTrainer:
    def __init__(self, file, min_count, window_size, batch_size, output_file, dim, iterations, initial_lr):
        self.data = DataReader(file, min_count)
        dataset = Metapath2vecDataset(self.data, window_size)
        self.dataloader = DataLoader(dataset, batch_size=batch_size,
                                     shuffle=True, num_workers=4, collate_fn=dataset.collate)

        self.output_file_name = output_file
        self.emb_size = len(self.data.word2id)
        self.emb_dimension = dim
        self.batch_size = batch_size
        self.iterations = iterations
        self.initial_lr = initial_lr #learning rate
        self.skip_gram_model = SkipGramModel(self.emb_size, self.emb_dimension)

        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        if self.use_cuda:
            self.skip_gram_model.cuda()

    def train(self):

        for iteration in range(self.iterations):
            print("\n\n\nIteration: " + str(iteration + 1))
            optimizer = optim.SparseAdam(self.skip_gram_model.parameters(), lr=self.initial_lr)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader))

            running_loss = 0.0
            for i, sample_batched in enumerate(tqdm(self.dataloader)):

                if len(sample_batched[0]) > 1:
                    pos_u = sample_batched[0].to(self.device)
                    pos_v = sample_batched[1].to(self.device)
                    neg_v = sample_batched[2].to(self.device)

                    scheduler.step()
                    optimizer.zero_grad()
                    loss = self.skip_gram_model.forward(pos_u, pos_v, neg_v)
                    loss.backward()
                    optimizer.step()

                    running_loss = running_loss * 0.9 + loss.item() * 0.1
                    if i > 0 and i % 500 == 0:
                        print(" Loss: " + str(running_loss))

            self.skip_gram_model.save_embedding(self.data.id2word, self.output_file_name)


In [8]:
file = '../output/output_path.txt'
min_count = 5   #单词频率截断
window_size = 5  #窗口大小
batch_size = 128  #批大小
output_file = '../output/embeddings.txt' #embedding输出路径
dim = 128  # embedding维度
iterations = 10  # 循环次数
initial_lr = 0.01  #learning rate

In [9]:
m2v = Metapath2VecTrainer(file, min_count, window_size, batch_size, output_file, dim, iterations, initial_lr)
m2v.train()

Total embeddings: 20864


  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 1


 92%|█████████▏| 502/546 [01:13<00:06,  6.83it/s]

 Loss: 2.9129858167833333


100%|██████████| 546/546 [01:19<00:00,  6.89it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 2


 92%|█████████▏| 502/546 [01:13<00:06,  6.51it/s]

 Loss: 2.786793152517554


100%|██████████| 546/546 [01:19<00:00,  6.89it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 3


 92%|█████████▏| 502/546 [01:07<00:06,  7.11it/s]

 Loss: 2.7567128685322677


100%|██████████| 546/546 [01:13<00:00,  7.43it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 4


 92%|█████████▏| 502/546 [01:12<00:07,  5.50it/s]

 Loss: 2.788184064071736


100%|██████████| 546/546 [01:20<00:00,  6.80it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 5


 92%|█████████▏| 502/546 [01:01<00:04, 10.81it/s]

 Loss: 2.8321075558231787


100%|██████████| 546/546 [01:07<00:00,  8.14it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 6


 92%|█████████▏| 503/546 [00:58<00:03, 10.85it/s]

 Loss: 2.940539925790936


100%|██████████| 546/546 [01:02<00:00,  8.70it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 7


 92%|█████████▏| 502/546 [01:08<00:05,  7.88it/s]

 Loss: 3.0533957102930245


100%|██████████| 546/546 [01:14<00:00,  7.35it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 8


 92%|█████████▏| 502/546 [01:08<00:05,  7.52it/s]

 Loss: 3.2011887262040517


100%|██████████| 546/546 [01:13<00:00,  7.40it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 9


 92%|█████████▏| 502/546 [00:59<00:04, 10.47it/s]

 Loss: 3.310098170198761


100%|██████████| 546/546 [01:04<00:00,  8.49it/s]
  0%|          | 0/546 [00:00<?, ?it/s]




Iteration: 10


 92%|█████████▏| 502/546 [00:59<00:04, 10.60it/s]

 Loss: 3.4719487614076523


100%|██████████| 546/546 [01:03<00:00,  8.60it/s]
