Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Prepare for metapath sampler This file is just for reviewing metapath sampling algorithm (Python version). * Delete metapath_sampler * Prepare for metapath sampler This file is just for reviewing metapath sampling algorithm (Python code). * Add files via upload * Create metapath2vec.md * Add files via upload * Delete data_handler.py * Delete word2vec.py * Delete word_train.py * Add files via upload Metapath2vec implementations. Metapath2vec++ needs negative sampler optimization. * Delete shuffle_training.py * Delete test.py * Add files via upload * Delete sampler.py * Delete metapath_sampler.md * Add files via upload * Update and rename shuffle_training.py to metapath2vec.py * Update reading_data.py * Update metapath2vec.md * Update metapath2vec.md * Update metapath2vec.md * Update metapath2vec.md * Update metapath2vec.md * Create label 2 * Delete label 2 * Create testing.md * Add files via upload * Create sample.md * Add files via upload * Delete sampler.py * Add files via upload * Delete googlescholar.8area.author.label.txt * Delete googlescholar.8area.venue.label.txt * Delete testing.md * Delete id_author.txt * Delete id_conf.txt * Delete paper.txt * Delete paper_author.txt * Delete paper_conf.txt * Delete sample.md * Delete sampler.py * Add files via upload * Add files via upload * Add files via upload * Delete reading_data.py * Add files via upload * Add files via upload * Delete metapath2vec.py * Add files via upload * Rename shuffle_training.py to metapath2vec.py * Update metapath2vec.md * Delete reading_data.py * add comments and remov e commented codes
- Loading branch information
1 parent
98954c5
commit ddeb86f
Showing
7 changed files
with
623 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
import torch as th | ||
import torch.nn as nn | ||
|
||
|
||
class AminerDataset: | ||
""" | ||
Download Aminer Dataset from Amazon S3 bucket. | ||
""" | ||
def __init__(self, path): | ||
|
||
self.url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/aminer.zip' | ||
|
||
if not os.path.exists(os.path.join(path, 'aminer')): | ||
print('File not found. Downloading from', self.url) | ||
self._download_and_extract(path, 'aminer.zip') | ||
|
||
def _download_and_extract(self, path, filename): | ||
import shutil, zipfile, zlib | ||
from tqdm import tqdm | ||
import requests | ||
|
||
fn = os.path.join(path, filename) | ||
|
||
if os.path.exists(path): | ||
shutil.rmtree(path, ignore_errors=True) | ||
os.makedirs(path) | ||
f_remote = requests.get(self.url, stream=True) | ||
assert f_remote.status_code == 200, 'fail to open {}'.format(self.url) | ||
with open(fn, 'wb') as writer: | ||
for chunk in tqdm(f_remote.iter_content(chunk_size=1024*1024*3)): | ||
writer.write(chunk) | ||
print('Download finished. Unzipping the file...') | ||
|
||
with zipfile.ZipFile(fn) as zf: | ||
zf.extractall(path) | ||
print('Unzip finished.') | ||
self.fn = fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
Metapath2vec | ||
============ | ||
|
||
- Paper link: [metapath2vec: Scalable Representation Learning for Heterogeneous Networks](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) | ||
- Author's code repo: [https://ericdongyx.github.io/metapath2vec/m2v.html](https://ericdongyx.github.io/metapath2vec/m2v.html). | ||
|
||
Dependencies | ||
------------ | ||
- PyTorch 1.0.1+ | ||
|
||
How to run the code | ||
----- | ||
Run with the following procedures: | ||
|
||
1, Run sampler.py on your graph dataset. Note that: the input text file should be list of mappings so you probably need to preprocess your graph dataset. Files with sample format are available in "net_dbis" file. Of course you could also use your own metapath sampler implementation. | ||
|
||
2, Run the following command: | ||
```bash | ||
python metapath2vec.py --download "where/you/want/to/download" --output_file "your_output_file_path" | ||
``` | ||
|
||
Tips: Change num_workers based on your GPU instances; Running 3 or 4 epochs is actually enough. | ||
|
||
Tricks included in the implementation: | ||
------- | ||
1, Sub-sampling; | ||
|
||
2, Negative Sampling without repeatedly calling numpy random choices; | ||
|
||
Performance and Explanations: | ||
------- | ||
Venue Classification Results for Metapath2vec: | ||
|
||
| Metric | 5% | 10% | 20% | 30% | 40% | 50% | 60% | 70% | 80% | 90% | | ||
| ------ | -- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | ||
| Macro-F1 | 0.3033 | 0.5247 | 0.8033 | 0.8971 | 0.9406 | 0.9532 | 0.9529 | 0.9701 | 0.9683 | 0.9670 | | ||
| Micro-F1 | 0.4173 | 0.5975 | 0.8327 | 0.9011 | 0.9400 | 0.9522 | 0.9537 | 0.9725 | 0.9815 | 0.9857 | | ||
|
||
Author Classfication Results for Metapath2vec: | ||
|
||
| Metric | 5% | 10% | 20% | 30% | 40% | 50% | 60% | 70% | 80% | 90% | | ||
| ------ | -- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | ||
| Macro-F1 | 0.9216 | 0.9262 | 0.9292 | 0.9303 | 0.9309 | 0.9314 | 0.9315 | 0.9316 | 0.9319 | 0.9320 | | ||
| Micro-F1 | 0.9279 | 0.9319 | 0.9346 | 0.9356 | 0.9361 | 0.9365 | 0.9365 | 0.9365 | 0.9367 | 0.9369 | | ||
|
||
Note that: | ||
|
||
Testing files are available in "label 2" file; | ||
|
||
The above are results listed in the paper, in real experiments, exact numbers might be slightly different: | ||
|
||
1, For venue node classification results, when the size of the training dataset is small (e.g. 5%), the variance of the performance is large since the number of available labeled venues is small. | ||
|
||
2, For author node classification results, the performance is stable since the number of available labeled authors is huge, so even 5% training data would be sufficient. | ||
|
||
3, In the test.py, you could change experiment times you want, especially it is very slow to test author classification so you could only do 1 or 2 times. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import torch | ||
import argparse | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader | ||
|
||
from tqdm import tqdm | ||
|
||
from reading_data import DataReader, Metapath2vecDataset | ||
from model import SkipGramModel | ||
|
||
|
||
class Metapath2VecTrainer: | ||
def __init__(self, args): | ||
self.data = DataReader(args.download, args.min_count, args.care_type) | ||
dataset = Metapath2vecDataset(self.data, args.window_size) | ||
self.dataloader = DataLoader(dataset, batch_size=args.batch_size, | ||
shuffle=True, num_workers=args.num_workers, collate_fn=dataset.collate) | ||
|
||
self.output_file_name = args.output_file | ||
self.emb_size = len(self.data.word2id) | ||
self.emb_dimension = args.dim | ||
self.batch_size = args.batch_size | ||
self.iterations = args.iterations | ||
self.initial_lr = args.initial_lr | ||
self.skip_gram_model = SkipGramModel(self.emb_size, self.emb_dimension) | ||
|
||
self.use_cuda = torch.cuda.is_available() | ||
self.device = torch.device("cuda" if self.use_cuda else "cpu") | ||
if self.use_cuda: | ||
self.skip_gram_model.cuda() | ||
|
||
def train(self): | ||
|
||
for iteration in range(self.iterations): | ||
print("\n\n\nIteration: " + str(iteration + 1)) | ||
optimizer = optim.SparseAdam(self.skip_gram_model.parameters(), lr=self.initial_lr) | ||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader)) | ||
|
||
running_loss = 0.0 | ||
for i, sample_batched in enumerate(tqdm(self.dataloader)): | ||
|
||
if len(sample_batched[0]) > 1: | ||
pos_u = sample_batched[0].to(self.device) | ||
pos_v = sample_batched[1].to(self.device) | ||
neg_v = sample_batched[2].to(self.device) | ||
|
||
scheduler.step() | ||
optimizer.zero_grad() | ||
loss = self.skip_gram_model.forward(pos_u, pos_v, neg_v) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
running_loss = running_loss * 0.9 + loss.item() * 0.1 | ||
if i > 0 and i % 500 == 0: | ||
print(" Loss: " + str(running_loss)) | ||
|
||
self.skip_gram_model.save_embedding(self.data.id2word, self.output_file_name) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description="Metapath2vec") | ||
#parser.add_argument('--input_file', type=str, help="input_file") | ||
parser.add_argument('--download', type=str, help="download_path") | ||
parser.add_argument('--output_file', type=str, help='output_file') | ||
parser.add_argument('--dim', default=128, type=int, help="embedding dimensions") | ||
parser.add_argument('--window_size', default=7, type=int, help="context window size") | ||
parser.add_argument('--iterations', default=5, type=int, help="iterations") | ||
parser.add_argument('--batch_size', default=50, type=int, help="batch size") | ||
parser.add_argument('--care_type', default=0, type=int, help="if 1, heterogeneous negative sampling, else normal negative sampling") | ||
parser.add_argument('--initial_lr', default=0.025, type=float, help="learning rate") | ||
parser.add_argument('--min_count', default=5, type=int, help="min count") | ||
parser.add_argument('--num_workers', default=16, type=int, help="number of workers") | ||
args = parser.parse_args() | ||
m2v = Metapath2VecTrainer(args) | ||
m2v.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn import init | ||
|
||
""" | ||
u_embedding: Embedding for center word. | ||
v_embedding: Embedding for neighbor words. | ||
""" | ||
|
||
|
||
class SkipGramModel(nn.Module): | ||
|
||
def __init__(self, emb_size, emb_dimension): | ||
super(SkipGramModel, self).__init__() | ||
self.emb_size = emb_size | ||
self.emb_dimension = emb_dimension | ||
self.u_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True) | ||
self.v_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True) | ||
|
||
initrange = 1.0 / self.emb_dimension | ||
init.uniform_(self.u_embeddings.weight.data, -initrange, initrange) | ||
init.constant_(self.v_embeddings.weight.data, 0) | ||
|
||
def forward(self, pos_u, pos_v, neg_v): | ||
emb_u = self.u_embeddings(pos_u) | ||
emb_v = self.v_embeddings(pos_v) | ||
emb_neg_v = self.v_embeddings(neg_v) | ||
|
||
score = torch.sum(torch.mul(emb_u, emb_v), dim=1) | ||
score = torch.clamp(score, max=10, min=-10) | ||
score = -F.logsigmoid(score) | ||
|
||
neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze() | ||
neg_score = torch.clamp(neg_score, max=10, min=-10) | ||
neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1) | ||
|
||
return torch.mean(score + neg_score) | ||
|
||
def save_embedding(self, id2word, file_name): | ||
embedding = self.u_embeddings.weight.cpu().data.numpy() | ||
with open(file_name, 'w') as f: | ||
f.write('%d %d\n' % (len(id2word), self.emb_dimension)) | ||
for wid, w in id2word.items(): | ||
e = ' '.join(map(lambda x: str(x), embedding[wid])) | ||
f.write('%s %s\n' % (w, e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
from download import AminerDataset | ||
np.random.seed(12345) | ||
|
||
class DataReader: | ||
NEGATIVE_TABLE_SIZE = 1e8 | ||
|
||
def __init__(self, download, min_count, care_type): | ||
|
||
self.negatives = [] | ||
self.discards = [] | ||
self.negpos = 0 | ||
self.care_type = care_type | ||
self.word2id = dict() | ||
self.id2word = dict() | ||
self.sentences_count = 0 | ||
self.token_count = 0 | ||
self.word_frequency = dict() | ||
self.download = download | ||
FB = AminerDataset(self.download) | ||
self.inputFileName = FB.fn | ||
self.read_words(min_count) | ||
self.initTableNegatives() | ||
self.initTableDiscards() | ||
|
||
def read_words(self, min_count): | ||
word_frequency = dict() | ||
for line in open(self.inputFileName, encoding="ISO-8859-1"): | ||
line = line.split() | ||
if len(line) > 1: | ||
self.sentences_count += 1 | ||
for word in line: | ||
if len(word) > 0: | ||
self.token_count += 1 | ||
word_frequency[word] = word_frequency.get(word, 0) + 1 | ||
|
||
if self.token_count % 1000000 == 0: | ||
print("Read " + str(int(self.token_count / 1000000)) + "M words.") | ||
|
||
wid = 0 | ||
for w, c in word_frequency.items(): | ||
if c < min_count: | ||
continue | ||
self.word2id[w] = wid | ||
self.id2word[wid] = w | ||
self.word_frequency[wid] = c | ||
wid += 1 | ||
|
||
self.word_count = len(self.word2id) | ||
print("Total embeddings: " + str(len(self.word2id))) | ||
|
||
def initTableDiscards(self): | ||
# get a frequency table for sub-sampling. Note that the frequency is adjusted by | ||
# sub-sampling tricks. | ||
t = 0.0001 | ||
f = np.array(list(self.word_frequency.values())) / self.token_count | ||
self.discards = np.sqrt(t / f) + (t / f) | ||
|
||
def initTableNegatives(self): | ||
# get a table for negative sampling, if word with index 2 appears twice, then 2 will be listed | ||
# in the table twice. | ||
pow_frequency = np.array(list(self.word_frequency.values())) ** 0.75 | ||
words_pow = sum(pow_frequency) | ||
ratio = pow_frequency / words_pow | ||
count = np.round(ratio * DataReader.NEGATIVE_TABLE_SIZE) | ||
for wid, c in enumerate(count): | ||
self.negatives += [wid] * int(c) | ||
self.negatives = np.array(self.negatives) | ||
np.random.shuffle(self.negatives) | ||
self.sampling_prob = ratio | ||
|
||
def getNegatives(self, target, size): # TODO check equality with target | ||
if self.care_type == 0: | ||
response = self.negatives[self.negpos:self.negpos + size] | ||
self.negpos = (self.negpos + size) % len(self.negatives) | ||
if len(response) != size: | ||
return np.concatenate((response, self.negatives[0:self.negpos])) | ||
return response | ||
|
||
|
||
# ----------------------------------------------------------------------------------------------------------------- | ||
|
||
class Metapath2vecDataset(Dataset): | ||
def __init__(self, data, window_size): | ||
# read in data, window_size and input filename | ||
self.data = data | ||
self.window_size = window_size | ||
self.input_file = open(data.inputFileName, encoding="ISO-8859-1") | ||
|
||
def __len__(self): | ||
# return the number of walks | ||
return self.data.sentences_count | ||
|
||
def __getitem__(self, idx): | ||
# return the list of pairs (center, context, 5 negatives) | ||
while True: | ||
line = self.input_file.readline() | ||
if not line: | ||
self.input_file.seek(0, 0) | ||
line = self.input_file.readline() | ||
|
||
if len(line) > 1: | ||
words = line.split() | ||
|
||
if len(words) > 1: | ||
word_ids = [self.data.word2id[w] for w in words if | ||
w in self.data.word2id and np.random.rand() < self.data.discards[self.data.word2id[w]]] | ||
|
||
pair_catch = [] | ||
for i, u in enumerate(word_ids): | ||
for j, v in enumerate( | ||
word_ids[max(i - self.window_size, 0):i + self.window_size]): | ||
assert u < self.data.word_count | ||
assert v < self.data.word_count | ||
if i == j: | ||
continue | ||
pair_catch.append((u, v, self.data.getNegatives(v,5))) | ||
return pair_catch | ||
|
||
|
||
@staticmethod | ||
def collate(batches): | ||
all_u = [u for batch in batches for u, _, _ in batch if len(batch) > 0] | ||
all_v = [v for batch in batches for _, v, _ in batch if len(batch) > 0] | ||
all_neg_v = [neg_v for batch in batches for _, _, neg_v in batch if len(batch) > 0] | ||
|
||
return torch.LongTensor(all_u), torch.LongTensor(all_v), torch.LongTensor(all_neg_v) |
Oops, something went wrong.