In [1]:
from nltk.corpus import wordnet


def load_node():
    #全体词表,取出主词
    node = list({i.name().split('.')[0] for i in wordnet.all_synsets()})

    #全体可能的同义词
    #all_lemma_names是无重复的
    lemma = list(wordnet.all_lemma_names())

    #合并node和lemma,其实node基本上是lemma的子集,只多了4个词
    node = list(set(node + lemma))

    #编码
    node = {name: i for i, name in enumerate(node)}

    return node


node = load_node()

len(node), node['girl']

(147311, 73123)

In [2]:
def load_edge():
    #所有的关系
    edge_fs = [
        'synsets', 'hypernyms', 'hyponyms', 'instance_hyponyms',
        'member_meronyms', 'part_meronyms', 'topic_domains', 'usage_domains',
        'region_domains', 'attributes', 'entailments', 'causes', 'also_sees',
        'verb_groups', 'similar_tos', 'lemma_names'
    ]

    edge = []

    #全体词表
    for n1 in node.keys():

        #遍历同义词
        for synset in wordnet.synsets(n1):
            n2 = synset.name().split('.')[0]

            #添加同义词关系
            edge.append((n1, n2))

            #遍历所有可能的关系
            for f in edge_fs:
                #同义词刚刚已经处理过了,这里不再重复
                if f == 'synsets':
                    continue

                #调用同义词上的edge函数
                for n2 in getattr(synset, f)():
                    if f == 'lemma_names':
                        n2 = n2.lower()
                    else:
                        n2 = n2.name().split('.')[0]
                    edge.append((n1, n2))

    #删除所有自己到自己的关联
    #编码
    edge = [(node[i[0]], node[i[1]]) for i in edge if i[0] != i[1]]

    return edge


edge = load_edge()

len(edge), edge[:15]

(1022595,
 [(2, 117854),
  (2, 121790),
  (2, 8137),
  (2, 54195),
  (3, 98401),
  (3, 4697),
  (3, 98401),
  (3, 4697),
  (4, 15540),
  (4, 115369),
  (5, 121100),
  (6, 140980),
  (6, 63977),
  (6, 36451),
  (6, 123276)])

In [3]:
import torch
import random

edge_set = set(edge)


def get_batch():
    sample = random.sample(edge, 128)

    pos = []
    neg = []
    for i in sample:
        n1, n2 = i
        while (n1, n2) in edge_set:
            n1 = random.randint(0, len(node) - 1)
        neg.append((n1, n2))

        n1, n2 = i
        while (n1, n2) in edge_set:
            n2 = random.randint(0, len(node) - 1)
        neg.append((n1, n2))

        pos.append(i)
        pos.append(i)

    pos = torch.LongTensor(pos)
    neg = torch.LongTensor(neg)

    return pos, neg


get_batch()

(tensor([[ 25540,  28902],
         [ 25540,  28902],
         [116118,  25345],
         [116118,  25345],
         [ 81782,  49763],
         [ 81782,  49763],
         [ 78927,  61715],
         [ 78927,  61715],
         [ 95352,  95272],
         [ 95352,  95272],
         [ 29257,  63276],
         [ 29257,  63276],
         [ 91624, 146374],
         [ 91624, 146374],
         [146374,  77592],
         [146374,  77592],
         [ 79753,  80034],
         [ 79753,  80034],
         [120815,  59600],
         [120815,  59600],
         [126600,  34256],
         [126600,  34256],
         [  4613,  28360],
         [  4613,  28360],
         [ 40318,   8057],
         [ 40318,   8057],
         [ 34769, 131205],
         [ 34769, 131205],
         [ 66247,  85162],
         [ 66247,  85162],
         [ 41153,  57616],
         [ 41153,  57616],
         [137484,  73803],
         [137484,  73803],
         [144460,  43499],
         [144460,  43499],
         [ 12201, 105067],
 

In [4]:
def get_cos_loss(pos, neg):
    #pos -> [8, 2, 2]
    #neg -> [8, 2, 2]

    #[8, 1, 2]
    pos_n1 = pos[:, 0].unsqueeze(dim=1)
    #[8, 2, 1]
    pos_n2 = pos[:, 1].unsqueeze(dim=2)

    #[8, 1, 2],[8, 2, 1] -> [8]
    loss_pos = torch.bmm(pos_n1, pos_n2).squeeze()

    #[8, 1, 2]
    neg_n1 = neg[:, 0].unsqueeze(dim=1)
    #[8, 2, 1]
    neg_n2 = neg[:, 1].unsqueeze(dim=2)

    #[8, 1, 2],[8, 2, 1] -> [8]
    loss_neg = torch.bmm(neg_n1, -neg_n2).squeeze()

    #[8, 2] -> [8]
    loss_pos = loss_pos.sigmoid().clip(min=1e-8).log()

    #[8, 2] -> [8]
    loss_neg = loss_neg.sigmoid().clip(min=1e-8).log()

    return -(loss_pos + loss_neg).mean()


a, b = torch.randn(8, 2, 2), torch.randn(8, 2, 2)
get_cos_loss(a, b)

tensor(2.3421)

In [5]:
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

        #V是词汇量
        self.embed = torch.nn.Embedding(num_embeddings=len(node),
                                        embedding_dim=150)

        #初始化参数
        self.embed.weight.data.uniform_(-0.01, 0.01)

    def forward(self, pos, neg):
        #编码
        #[8, 6] -> [8, 6, 2]
        pos = self.embed(pos)

        #[8, 12] -> [8, 12, 2]
        neg = self.embed(neg)

        return get_cos_loss(pos, neg)


model = Model()

model(*get_batch())

tensor(1.3863, grad_fn=<NegBackward0>)

In [6]:
def test(test_words):
    embed = model.embed.weight.data.clone()
    node_keys = list(node.keys())

    for word in test_words:
        x = embed[node[word]]
        score = torch.nn.functional.cosine_similarity(x, embed)
        topk = score.topk(k=5).indices
        topk = [node_keys[k] for k in topk]
        print(word, topk)


test(['girl', 'bus', 'green', 'doctor', 'dog', 'queen', 'italy'])

girl ['girl', 'aleksandr_nikolayevich_scriabin', 'gobiesox_strumosus', 'diltiazem', 'patriotic']
bus ['bus', 'ineluctable', 'begum', 'lilac-blue', 'quick_time']
green ['green', 'golden_fern', 'keep_to_oneself', 'exterminable', 'movie']
doctor ['doctor', 'nonpasserine', 'redstart', 'uncombined', 'announce']
dog ['dog', 'virtuous', 'rock_sandwort', 'cattle', 'bestialise']
queen ['queen', 'crossbench', 'dhal', 'hydrologist', 'wild_sarsaparilla']
italy ['italy', 'rear_admiral', 'hokan', 'shipmate', 'family_tupaiidae']


In [7]:
def train():
    global model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_sum = 0
    for epoch in range(2000001):
        batch = get_batch()
        batch = [i.to(device) for i in batch]

        loss = model(*batch)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        loss_sum += loss.item()

        if epoch % 10000 == 0:
            print(epoch, loss_sum)
            test(['girl', 'bus', 'green', 'doctor', 'dog', 'queen', 'italy'])
            loss_sum = 0

        if epoch % 100000 == 0:
            torch.save(model.cpu(), 'models/wordnet_%d.model' % epoch)
            model = model.to(device)

    model = model.cpu()


train()

0 1.3862968683242798
girl ['girl', 'aleksandr_nikolayevich_scriabin', 'gobiesox_strumosus', 'diltiazem', 'patriotic']
bus ['bus', 'ineluctable', 'begum', 'lilac-blue', 'quick_time']
green ['green', 'golden_fern', 'keep_to_oneself', 'exterminable', 'movie']
doctor ['doctor', 'nonpasserine', 'redstart', 'uncombined', 'announce']
dog ['dog', 'virtuous', 'rock_sandwort', 'cattle', 'bestialise']
queen ['queen', 'crossbench', 'dhal', 'hydrologist', 'wild_sarsaparilla']
italy ['italy', 'rear_admiral', 'hokan', 'shipmate', 'family_tupaiidae']
10000 10527.33760380745
girl ['girl', 'missy', 'fille', 'young_lady', 'sweater_girl']
bus ['bus', 'jitney', 'buss', 'motorbus', 'charabanc']
green ['green', 'greens', 'viridity', 'light-green', 'greenishness']
doctor ['doctor', 'physician', 'doc', 'extern', 'dr.']
dog ['dog', 'domestic_dog', 'canis_familiaris', 'leonberg', 'perisher']
queen ['queen', 'queens', 'female_monarch', 'mary_queen_of_scots', 'queen_regnant']
italy ['italy', 'italian_republic', 'i

170000 8123.10838162899
girl ['girl', 'young_woman', 'young_lady', 'woman', 'miss']
bus ['bus', 'charabanc', 'buss', 'double-decker', 'motorbus']
green ['green', 'greens', 'chromatic', 'park', 'greenness']
doctor ['doctor', 'medico', 'md', 'doc', 'fix']
dog ['dog', 'domestic_dog', 'track', 'hound', 'chase']
queen ['queen', 'queens', 'female_monarch', 'drift', 'king']
italy ['italy', 'italian_republic', 'italia', 'italian_region', 'piedmont']
180000 8155.00337523222
girl ['girl', 'young_woman', 'young_lady', 'miss', 'woman']
bus ['bus', 'buss', 'car', 'double-decker', 'charabanc']
green ['green', 'greens', 'chromatic', 'park', 'greenness']
doctor ['doctor', 'doc', 'md', 'medico', 'repair']
dog ['dog', 'domestic_dog', 'chase', 'track', 'canis_familiaris']
queen ['queen', 'queens', 'picture_card', 'female_monarch', 'female_aristocrat']
italy ['italy', 'italian_republic', 'italia', 'italian_region', 'italian']
190000 8187.046461999416
girl ['girl', 'young_woman', 'miss', 'young_lady', 'wom

360000 8558.566155672073
girl ['girl', 'miss', 'young_lady', 'fille', 'young_woman']
bus ['bus', 'car', 'automobile', 'coach', 'buss']
green ['green', 'greens', 'k', 'site', 'greenness']
doctor ['doctor', 'medico', 'dr.', 'physician', 'doc']
dog ['dog', 'run_down', 'quest', 'bang', 'tracked']
queen ['queen', 'queens', 'device', 'grounds', 'driven']
italy ['italy', 'italia', 'italian_republic', 'italian_region', 'metropolis']
370000 8564.151451826096
girl ['girl', 'miss', 'young_lady', 'woman', 'fille']
bus ['bus', 'car', 'buss', 'coach', 'auto']
green ['green', 'greens', 'k', 'site', 'dig']
doctor ['doctor', 'medico', 'physician', 'dr.', 'doc']
dog ['dog', 'tracked', 'run_down', 'search', 'tail']
queen ['queen', 'queens', 'driven', 'soul', 'grounds']
italy ['italy', 'italia', 'italian_republic', 'italian_region', 'germany']
380000 8581.40454441309
girl ['girl', 'miss', 'young_lady', 'fille', 'woman']
bus ['bus', 'car', 'drive', 'coach', 'smack']
green ['green', 'greens', 'k', 'chromati

560000 8747.873451769352
girl ['girl', 'miss', 'young_woman', 'fille', 'woman']
bus ['bus', 'coach', 'heaps', 'buss', 'auto']
green ['green', 'greens', 'colour', 'chromatic_color', 'colors']
doctor ['doctor', 'dr.', 'doc', 'medico', 'physician']
dog ['dog', 'tug', 'make_out', 'admit', 'use']
queen ['queen', 'queens', 'promote', 'fagot', 'find']
italy ['italy', 'italia', 'italian_republic', 'metropolis', 'urban_center']
570000 8742.292772591114
girl ['girl', 'fille', 'miss', 'missy', 'young_woman']
bus ['bus', 'buss', 'auto', 'coach', 'blue']
green ['green', 'greens', 'colour', 'colors', 'color']
doctor ['doctor', 'dr.', 'physician', 'doc', 'medico']
dog ['dog', 'make_out', 'pursue', 'tug', 'admit']
queen ['queen', 'queens', 'promote', 'fagot', 'frustrate']
italy ['italy', 'italia', 'italian_republic', 'city', 'russian_federation']
580000 8758.05969953537
girl ['girl', 'fille', 'miss', 'young_woman', 'missy']
bus ['bus', 'articulate', 'heaps', 'buss', 'coach']
green ['green', 'greens', 

760000 8816.75190281868
girl ['girl', 'miss', 'young_woman', 'young_lady', 'missy']
bus ['bus', 'buss', 'slew', 'spatter', 'babble']
green ['green', 'greens', 'k', 'color', 'colour']
doctor ['doctor', 'physician', 'dr.', 'medico', 'md']
dog ['dog', 'run_down', 'domestic_dog', 'search', 'chase']
queen ['queen', 'queens', 'fag', 'connected', 'binding']
italy ['italy', 'italian_region', 'italian_republic', 'italia', 'mexico']
770000 8807.41715836525
girl ['girl', 'miss', 'young_woman', 'missy', 'young_lady']
bus ['bus', 'buss', 'coach', 'posed', 'draw']
green ['green', 'greens', 'k', 'colors', 'color']
doctor ['doctor', 'physician', 'medico', 'dr.', 'md']
dog ['dog', 'run_down', 'domestic_dog', 'search', 'chase']
queen ['queen', 'queens', 'fag', 'raise', 'queer']
italy ['italy', 'italian_republic', 'italian_region', 'italia', 'portugal']
780000 8835.600167930126
girl ['girl', 'miss', 'young_woman', 'young_lady', 'missy']
bus ['bus', 'buss', 'posed', 'draw', 'taken']
green ['green', 'green

960000 8856.743384599686
girl ['girl', 'young_lady', 'miss', 'young_woman', 'missy']
bus ['bus', 'coach', 'car', 'passenger_vehicle', 'motorcoach']
green ['green', 'greens', 'color', 'chromatic', 'harmonize']
doctor ['doctor', 'dr.', 'medico', 'physician', 'repair']
dog ['dog', 'domestic_dog', 'pursue', 'tagged', 'chase']
queen ['queen', 'queens', 'drawing', 'playing_card', 'faggot']
italy ['italy', 'italia', 'italian_republic', 'spain', 'pitched_battle']
970000 8857.102682292461
girl ['girl', 'young_woman', 'young_lady', 'miss', 'missy']
bus ['bus', 'coach', 'ride', 'car', 'removed']
green ['green', 'greens', 'chromatic', 'greenness', 'raise']
doctor ['doctor', 'dr.', 'medico', 'physician', 'repair']
dog ['dog', 'domestic_dog', 'tag', 'tagged', 'chase_after']
queen ['queen', 'queens', 'playing_card', 'drawing', 'faggot']
italy ['italy', 'italia', 'italian_republic', 'pitched_battle', 'germany']
980000 8852.512466013432
girl ['girl', 'miss', 'young_woman', 'young_lady', 'fille']
bus ['

1160000 8870.240259706974
girl ['girl', 'miss', 'fille', 'woman', 'young_woman']
bus ['bus', 'buss', 'coach', 'displace', 'car']
green ['green', 'greens', 'draw', 'take_in', 'colours']
doctor ['doctor', 'md', 'medico', 'doc', 'physician']
dog ['dog', 'risk', 'domestic_dog', 'chap', 'chase']
queen ['queen', 'queens', 'king', 'flies', 'writings']
italy ['italy', 'italia', 'italian_republic', 'germany', 'pitched_battle']
1170000 8863.238327622414
girl ['girl', 'miss', 'young_woman', 'fille', 'woman']
bus ['bus', 'buss', 'coach', 'car', 'ram']
green ['green', 'greens', 'colors', 'color', 'colour']
doctor ['doctor', 'doc', 'medico', 'md', 'physician']
dog ['dog', 'chase', 'occupy', 'incline', 'risk']
queen ['queen', 'queens', 'king', 'someone', 'male_monarch']
italy ['italy', 'italia', 'italian_republic', 'germany', 'pitched_battle']
1180000 8867.57926082611
girl ['girl', 'young_woman', 'miss', 'fille', 'woman']
bus ['bus', 'buss', 'coach', 'hitting', 'synchronize']
green ['green', 'greens'

1360000 8879.9641264081
girl ['girl', 'miss', 'woman', 'young_woman', 'travelling']
bus ['bus', 'buss', 'heap', 'automobile', 'coach']
green ['green', 'greens', 'park', 'pick', 'fool']
doctor ['doctor', 'doc', 'md', 'raising', 'excite']
dog ['dog', 'trail', 'hound', 'track', 'choke']
queen ['queen', 'queens', 'king', 'pouf', 'visualize']
italy ['italy', 'italian_republic', 'italia', 'spain', 'ireland']
1370000 8885.337143599987
girl ['girl', 'miss', 'woman', 'feel', 'takings']
bus ['bus', 'buss', 'motorbus', 'heap', 'coach']
green ['green', 'greens', 'park', 'stern', 'colour']
doctor ['doctor', 'doc', 'have', 'fixings', 'workings']
dog ['dog', 'tracking', 'clog', 'hound', 'tag']
queen ['queen', 'queens', 'poove', 'king', 'pouf']
italy ['italy', 'italian_republic', 'italia', 'spain', 'chains']
1380000 8887.206500470638
girl ['girl', 'miss', 'woman', 'young_woman', 'young_lady']
bus ['bus', 'coach', 'heap', 'buss', 'windows']
green ['green', 'greens', 'parcel', 'chromatic_color', 'chroma

1560000 8890.718221724033
girl ['girl', 'miss', 'missy', 'woman', 'young_woman']
bus ['bus', 'heap', 'coach', 'time_period', 'muckle']
green ['green', 'greens', 'chromatic', 'park', 'color']
doctor ['doctor', 'physician', 'md', 'doc', 'medico']
dog ['dog', 'quest', 'heel', 'assuring', 'assure']
queen ['queen', 'queens', 'faggot', 'king', 'fagot']
italy ['italy', 'italian_republic', 'italia', 'spain', 'italian_region']
1570000 8895.757581174374
girl ['girl', 'miss', 'missy', 'adult_female', 'woman']
bus ['bus', 'heap', 'buss', 'peck', 'muckle']
green ['green', 'greens', 'chromatic', 'color', 'park']
doctor ['doctor', 'physician', 'medico', 'doc', 'md']
dog ['dog', 'quest', 'cuss', 'chase', 'heel']
queen ['queen', 'queens', 'fagot', 'king', 'add']
italy ['italy', 'italian_republic', 'italia', 'italian_region', 'spain']
1580000 8897.728899359703
girl ['girl', 'miss', 'missy', 'fille', 'young_lady']
bus ['bus', 'push', 'buss', 'heap', 'auto']
green ['green', 'greens', 'tan', 'color', 'chro

1760000 8904.765096187592
girl ['girl', 'miss', 'woman', 'young_woman', 'missy']
bus ['bus', 'buss', 'coach', 'machine', 'car']
green ['green', 'greens', 'colouring', 'chromatic', 'colored']
doctor ['doctor', 'md', 'dr.', 'physician', 'fixings']
dog ['dog', 'domestic_dog', 'chase', 'run_down', 'dropping']
queen ['queen', 'queens', 'fag', 'faggot', 'fagot']
italy ['italy', 'italia', 'italian_republic', 'pitched_battle', 'italian_region']
1770000 8894.346966207027
girl ['girl', 'miss', 'woman', 'missy', 'female']
bus ['bus', 'buss', 'autobus', 'car', 'motorcoach']
green ['green', 'greens', 'chromatic', 'park', 'color']
doctor ['doctor', 'md', 'dr.', 'repair', 'doc']
dog ['dog', 'domestic_dog', 'chase', 'hook', 'support']
queen ['queen', 'queens', 'fag', 'rise', 'bind']
italy ['italy', 'italian_republic', 'italia', 'pitched_battle', 'river']
1780000 8908.387778818607
girl ['girl', 'miss', 'woman', 'missy', 'young_woman']
bus ['bus', 'motorbus', 'autobus', 'buss', 'passenger_vehicle']
gree

1960000 8918.72321152687
girl ['girl', 'miss', 'fille', 'dame', 'missy']
bus ['bus', 'buss', 'coach', 'prepare', 'train']
green ['green', 'greens', 'k', 'chromatic', 'colored']
doctor ['doctor', 'doc', 'md', 'dr.', 'fix']
dog ['dog', 'tag', 'run_down', 'tail', 'domestic_dog']
queen ['queen', 'queens', 'chess', 'increase', 'faggot']
italy ['italy', 'italia', 'italian_republic', 'spain', 'pitched_battle']
1970000 8903.006738603115
girl ['girl', 'miss', 'woman', 'fille', 'missy']
bus ['bus', 'prepare', 'train', 'coach', 'making']
green ['green', 'greens', 'k', 'chromatic', 'color']
doctor ['doctor', 'doc', 'down', 'md', 'gilbert']
dog ['dog', 'hound', 'tail', 'tag', 'run_down']
queen ['queen', 'queens', 'compel', 'fagot', 'fly']
italy ['italy', 'italia', 'italian_republic', 'spain', 'ireland']
1980000 8919.666934072971
girl ['girl', 'miss', 'fille', 'missy', 'woman']
bus ['bus', 'coach', 'train', 'buss', 'slews']
green ['green', 'greens', 'k', 'chromatic', 'color']
doctor ['doctor', 'doc'

In [8]:
model = torch.load('models/wordnet_2000000.model')

test(['girl', 'bus', 'green', 'doctor', 'dog', 'queen', 'italy'])

girl ['girl', 'miss', 'fille', 'missy', 'boast']
bus ['bus', 'coach', 'heap', 'buss', 'jitney']
green ['green', 'greens', 'k', 'chromatic', 'colour']
doctor ['doctor', 'doc', 'dr.', 'medico', 'physician']
dog ['dog', 'tag', 'quest', 'plunder', 'tail']
queen ['queen', 'queens', 'king', 'faggot', 'effect']
italy ['italy', 'italian_republic', 'italia', 'spain', 'france']
