In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import Parameter
from lib.word_vectors import obj_edge_vectors
ftensor = torch.FloatTensor
ltensor = torch.LongTensor

In [2]:
def corrupt_batch(subj_dists, obj_dists, batch, num_ent, _cb_var, _cb_mode='head-tail-cor'):
    # batch: ltensor type, contains positive triplets
    batch_size, _ = batch.size()
    
    corrupted = batch.clone()

    if len(_cb_var) == 0:
        _cb_var.append(ltensor(batch_size//2).cuda())
        #_cb_var.append(ltensor(batch_size//2))
        
    q_samples_l = _cb_var[0].random_(0, num_ent)
    q_samples_r = _cb_var[0].random_(0, num_ent)

    if _cb_mode == 'head-cor':
        #head-corrupted
        corrupted[:batch_size//2, 0] = q_samples_l
    elif _cb_mode == 'tail-cor':
        #tail-corrupted
        corrupted[batch_size//2:, 2] = q_samples_r
    elif _cb_mode == 'head-tail-cor':
        #head-tail-corrupted
        corrupted[:batch_size//2, 0] = q_samples_l
        corrupted[batch_size//2:, 2] = q_samples_r

    return corrupted.contiguous(), torch.cat([q_samples_l, q_samples_r])

In [5]:
def noisy_batch(subj_dists, obj_dists):
    
    subj_samples = torch.multinomial(subj_dists, 1).squeeze(1)
    obj_samples = torch.multinomial(obj_dists, 1).squeeze(1)
    
    return subj_samples, obj_samples

In [6]:
if __name__ == '__main__':
    p= 1
    num_ent = 151
    num_rel = 51
    embed_dim = 4096
    
    batch = 10
    
    subj_dists = Variable(torch.randn(batch,num_ent))
    obj_dists = Variable(torch.randn(batch,num_ent))
    
    subj_dists = F.softmax(subj_dists,1)
    obj_dists = F.softmax(obj_dists,1)
    
    lhs = torch.LongTensor(batch).random_(0,num_ent)
    rhs = torch.LongTensor(batch).random_(0,num_ent)
    rel = torch.LongTensor(batch).random_(0,num_rel)
    
    p_batch = torch.stack((lhs, rel, rhs),1)
    print(p_batch)
    
    _cb_var = []
    _cb_mode=['head-tail-cor','tail-cor','head-tail-cor']
    
    #import ipdb; ipdb.set_trace()
    nce_batch, q_samples = corrupt_batch(
        subj_dists, obj_dists, p_batch, num_ent, _cb_var, _cb_mode[0]
    )
    
    print(nce_batch)
    print(q_samples)
    
    subj_labels, obj_labels = noisy_batch(subj_dists, obj_dists)
    print(subj_labels)
    print(obj_labels)


  126     4   110
  146    47    73
   68     9   115
   45     4    66
  108    43   103
  140    18    76
   56    20    22
   22    17    91
   37     2     1
   92    26    48
[torch.LongTensor of size 10x3]


   62     4   110
   92    47    73
  138     9   115
   58     4    66
   70    43   103
  140    18    62
   56    20    92
   22    17   138
   37     2    58
   92    26    70
[torch.LongTensor of size 10x3]


  62
  92
 138
  58
  70
  62
  92
 138
  58
  70
[torch.cuda.LongTensor of size 10 (GPU 0)]

Variable containing:
  87
  15
   9
 125
  14
  94
  54
   1
  49
  86
[torch.LongTensor of size 10]

Variable containing:
 107
  66
  23
  69
  69
 136
  26
  75
   9
  42
[torch.LongTensor of size 10]

