In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import Parameter
from lib.word_vectors import obj_edge_vectors
ftensor = torch.FloatTensor
ltensor = torch.LongTensor

import ipdb

class MarginRankingLoss(nn.Module):
    def __init__(self, margin):
        super(MarginRankingLoss, self).__init__()
        #ipdb.set_trace()
        self.margin = margin
        
    def forward(self, p_enrgs, l_enrgs, ln_enrgs, weights=None):
        #ipdb.set_trace()
        # output: pl_rel_term, lnl_rel_term
        # input: p_rel_enrgs, l_rel_enrgs, nl_rel_enrgs
        scores_pl = (self.margin + p_enrgs - ln_enrgs).clamp(min=0)
        #scores_lnl = (self.margin + l_enrgs - ln_enrgs).clamp(min=0)
        
        if weights is not None:
            scores_pl = scores_pl * weights / weights.mean()
            #scores_lnl = scores_lnl * weights / weights.mean()
            
        scores_lnl = None
        
        return scores_pl.mean(), scores_lnl
        #return scores_pl.mean(), scores_lnl.mean()
    
def corrupt_batch(batch, num_ent, _cb_var, _cb_mode='head-tail-cor'):
    # batch: ltensor type, contains positive triplets
    batch_size, _ = batch.size()
    
    corrupted = batch.clone()

    if len(_cb_var) == 0:
        _cb_var.append(ltensor(batch_size//2).cuda())
        #_cb_var.append(ltensor(batch_size//2))

    q_samples_l = _cb_var[0].random_(0, num_ent)
    q_samples_r = _cb_var[0].random_(0, num_ent)
    
    if _cb_mode == 'head-cor':
        #head-corrupted
        corrupted[:batch_size//2, 0] = q_samples_l
    elif _cb_mode == 'tail-cor':
        #tail-corrupted
        corrupted[batch_size//2:, 2] = q_samples_r
    elif _cb_mode == 'head-tail-cor':
        #head-tail-corrupted
        corrupted[:batch_size//2, 0] = q_samples_l
        corrupted[batch_size//2:, 2] = q_samples_r

    return corrupted.contiguous(), torch.cat([q_samples_l, q_samples_r])

class TransRelE(nn.Module):
    def __init__(self, num_ent, num_rel, embed_dim, p):
        super(TransRelE, self).__init__()
        self.num_ent = num_ent
        self.num_rel = num_rel
        self.embed_dim = embed_dim
        self.p = p

        r = 6 / np.sqrt(self.embed_dim)
        self.ent_embeds = nn.Embedding(self.num_ent, self.embed_dim, max_norm=1, norm_type=2, sparse=True)
        self.rel_embeds = nn.Embedding(self.num_rel, self.embed_dim, max_norm=1, norm_type=2, sparse=True)

        self.ent_embeds.weight.data.uniform_(-r, r)#.renorm_(p=2, dim=1, maxnorm=1)
        self.rel_embeds.weight.data.uniform_(-r, r)#.renorm_(p=2, dim=1, maxnorm=1)
        

    def forward(self, triplets):

        lhs_idxs = triplets[:, 0]
        rel_idxs = triplets[:, 1]
        rhs_idxs = triplets[:, 2]
        lhs_es = self.ent_embeds(lhs_idxs)
        rel_es = self.rel_embeds(rel_idxs)
        rhs_es = self.ent_embeds(rhs_idxs)

        enrgs = (lhs_es + rel_es - rhs_es).norm(p=self.p, dim=1)
        return enrgs, lhs_es, rhs_es, rel_es

    def save(self, fn):
        torch.save(self.state_dict(), fn)

    def load(self, fn):
        self.load_state_dict(torch.load(fn))


class RotatE(nn.Module):
    def __init__(self, classes, rel_classes, embed_dim, p, use_tran_vis, obj_embed, rel_embed, mode):
        super(RotatE, self).__init__()
        self.num_ent = len(classes)
        self.num_rel = len(rel_classes)
        self.embed_dim = embed_dim
        self.p = p
        
        self.use_tran_vis = use_tran_vis
        self._ent_embeds = obj_embed
        self.rel_embeds = rel_embed
        self.mode = mode
        
        self.ent_transfer = nn.Embedding(self.num_ent, self.embed_dim, max_norm=1.0)
        self.rel_transfer = nn.Embedding(self.num_rel, self.embed_dim, max_norm=1.0)

        r = 6/np.sqrt(self.embed_dim)
        self.ent_transfer.weight.data.uniform_(-r, r)
        self.rel_transfer.weight.data.uniform_(-r, r)
        
        self.gamma = 12.0
        self.epsilon = 2.0
        self.embedding_range = (self.gamma + self.epsilon) / self.embed_dim 
        
    def transfer(self, emb, e_transfer, r_transfer):
        return emb + (emb * e_transfer).sum(dim=1, keepdim=True) * r_transfer
    
    #@profile
    def ent_embeds(self, rep, idx, rel_idx):
            
        if self.use_tran_vis:
            es = rep
        else:
            es = self._ent_embeds(idx)
            
        ts = self.ent_transfer(idx)
        
        rel_ts = self.rel_transfer(rel_idx)
        proj_es = self.transfer(es, ts, rel_ts)
        return proj_es
    
    #def forward(self, triplets, return_ent_embed=True):
    def forward(self, triplets, subj_rep, rel_rep, obj_rep, return_ent_emb=True):
        pi = 3.14159265358979323846
        
        lhs_idxs = triplets[:, 0]
        rel_idxs = triplets[:, 1]
        rhs_idxs = triplets[:, 2]

        if self.use_tran_vis:
            rel_es = rel_rep
        else:
            rel_es = self.rel_embeds(rel_idxs)
            
        subj_es = self.ent_embeds(subj_rep, lhs_idxs, rel_idxs)
        obj_es = self.ent_embeds(obj_rep, rhs_idxs, rel_idxs)
        
        re_head, im_head = torch.chunk(subj_es, 2, dim=1) # [512, 1, 2000]
        re_tail, im_tail = torch.chunk(obj_es, 2, dim=1) # [512, 1, 2000]
        
        #Make phases of relations uniformly distributed in [-pi, pi]
        phase_relation = rel_es/(self.embedding_range/pi) # 0.0260/pi, relation [512, 1, 1000]

        re_phase, im_phase = torch.chunk(phase_relation, 2, dim=1)
        re_relation = torch.cos(re_phase) 
        im_relation = torch.sin(im_phase)
        
        if self.mode == 'head-batch':
            re_score = re_relation * re_tail + im_relation * im_tail
            im_score = re_relation * im_tail - im_relation * re_tail
            re_score = re_score - re_head
            im_score = im_score - im_head
        else: # 'tail-batch'
            re_score = re_head * re_relation - im_head * im_relation
            im_score = re_head * im_relation + im_head * re_relation
            re_score = re_score - re_tail
            im_score = im_score - im_tail

        score = torch.stack([re_score, im_score], dim = 0)
        score = score.norm(dim = 0)

        #import ipdb; ipdb.set_trace()
        enrgs = self.gamma - score.sum(dim=1)
        enrgs = -F.logsigmoid(enrgs)
        
        if not return_ent_emb:
            return enrgs
        else:
            return enrgs,subj_es,obj_es,rel_es
        
class TransRelD(nn.Module):
    def __init__(self, classes, rel_classes, embed_dim, p, use_tran_vis, obj_embed, rel_embed):
        super(TransRelD, self).__init__()
        self.num_ent = len(classes)
        self.num_rel = len(rel_classes)
        self.embed_dim = embed_dim
        self.p = p
        
        self.use_tran_vis = use_tran_vis
        self._ent_embeds = obj_embed
        self.rel_embeds = rel_embed
        
        self.ent_transfer = nn.Embedding(self.num_ent, self.embed_dim, max_norm=1.0)
        self.rel_transfer = nn.Embedding(self.num_rel, self.embed_dim, max_norm=1.0)

        r = 6/np.sqrt(self.embed_dim)
        self.ent_transfer.weight.data.uniform_(-r, r)
        self.rel_transfer.weight.data.uniform_(-r, r)
        
    def transfer(self, emb, e_transfer, r_transfer):
        return emb + (emb * e_transfer).sum(dim=1, keepdim=True) * r_transfer
    
    #@profile
    def ent_embeds(self, rep, idx, rel_idx):
            
        if self.use_tran_vis:
            es = rep
        else:
            es = self._ent_embeds(idx)
            
        ts = self.ent_transfer(idx)
        
        rel_ts = self.rel_transfer(rel_idx)
        proj_es = self.transfer(es, ts, rel_ts)
        return proj_es

    def forward(self, triplets, subj_rep, rel_rep, obj_rep, return_ent_emb=True):
        lhs_idxs = triplets[:, 0]
        rel_idxs = triplets[:, 1]
        rhs_idxs = triplets[:, 2]
        
        if self.use_tran_vis:
            rel_es = rel_rep
        else:
            rel_es = self.rel_embeds(rel_idxs)
            
        lhs = self.ent_embeds(subj_rep, lhs_idxs, rel_idxs)
        rhs = self.ent_embeds(obj_rep, rhs_idxs, rel_idxs)
        
        if not return_ent_emb:
            enrgs = (lhs + rel_es - rhs).norm(p=self.p, dim=1)
            return enrgs
        else:
            enrgs = (lhs + rel_es - rhs).norm(p=self.p, dim=1)
            return enrgs,lhs,rhs,rel_es
        
    def get_embed(self, ents, rel_idxs):
        ent_embed = self.ent_embeds(ents, rel_idxs)
        return ent_embed

    def save(self, fn):
        torch.save(self.state_dict(), fn)

    def load(self, fn):
        self.load_state_dict(torch.load(fn))
        
def rel_trans_rep_e(tranRelE, lossF, p_rels, l_rels, num_objs,subj_rep, rel_rep, obj_rep, _cb_mode, is_train=False):
    
    #ipdb.set_trace()
    _cb_var = []
    if is_train:
        #nl_rels, q_samples = corrupt_batch(l_rels, num_objs, _cb_var, _cb_mode)
        nl_rels, q_samples = corrupt_batch(p_rels, num_objs, _cb_var, _cb_mode)
        d_ins = torch.cat([p_rels, l_rels, nl_rels], dim=0).contiguous()
        
        # prediction samples
        if tranRelE.use_tran_vis:
            subj_rep = subj_rep
            rel_rep = rel_rep
            obj_rep = obj_rep
        else:
            subj_rep = tranRelE._ent_embeds(p_rels[:,0])
            #rel_rep = tranRelE.rel_embeds(p_rels[:,1])
            rel_rep = tranRelE.rel_embeds(p_rels[:,0] * 151 + p_rels[:,2])
            obj_rep = tranRelE._ent_embeds(p_rels[:,2])
            
        # positive samples
        l_subj_rep = tranRelE._ent_embeds(l_rels[:,0])
        #l_rel_rep = tranRelE.rel_embeds(l_rels[:,1])
        l_rel_rep = tranRelE.rel_embeds(l_rels[:,0] * 151 + l_rels[:,2])
        l_obj_rep = tranRelE._ent_embeds(l_rels[:,2])

        # negative sample
        nl_subj_rep = tranRelE._ent_embeds(nl_rels[:,0])
        #nl_rel_rep = tranRelE.rel_embeds(nl_rels[:,1])
        nl_rel_rep = tranRelE.rel_embeds(nl_rels[:,0] * 151 + nl_rels[:,2])
        nl_obj_rep = tranRelE._ent_embeds(nl_rels[:,2])

        # prediction, pos/negative samples
        subj_rep = torch.cat([subj_rep, l_subj_rep, nl_subj_rep], dim=0).contiguous()
        rel_rep = torch.cat([rel_rep, l_rel_rep, nl_rel_rep], dim=0).contiguous()
        obj_rep = torch.cat([obj_rep, l_obj_rep, nl_obj_rep], dim=0).contiguous()
        
    else:
        nl_rels, q_samples = corrupt_batch(p_rels, num_objs, _cb_var)
        d_ins = torch.cat([p_rels, nl_rels], dim=0).contiguous()

        # prediction samples 
        if tranRelE.use_tran_vis:
            subj_rep = subj_rep
            rel_rep = rel_rep
            obj_rep = obj_rep
        else:
            subj_rep = tranRelE._ent_embeds(p_rels[:,0])
            #rel_rep = tranRelE.rel_embeds(p_rels[:,1])
            rel_rep = tranRelE.rel_embeds(p_rels[:,0] * 151 + p_rels[:,2])
            obj_rep = tranRelE._ent_embeds(p_rels[:,2])

        # negative samples 
        nl_subj_rep = tranRelE._ent_embeds(nl_rels[:,0])
        #nl_rel_rep = tranRelE.rel_embeds(nl_rels[:,1])
        nl_rel_rep = tranRelE.rel_embeds(nl_rels[:,0] * 151 + nl_rels[:,2])
        nl_obj_rep = tranRelE._ent_embeds(nl_rels[:,2])
        
        # prediction/ negative samples
        subj_rep = torch.cat([subj_rep, nl_subj_rep], dim=0).contiguous()
        rel_rep = torch.cat([rel_rep, nl_rel_rep], dim=0).contiguous()
        obj_rep = torch.cat([obj_rep, nl_obj_rep], dim=0).contiguous()
        
    #ipdb.set_trace()
    #d_out, subj_out, rel_out, obj_out = tranRelE(d_ins)
    d_out, subj_out, rel_out, obj_out = tranRelE(d_ins, subj_rep, rel_rep, obj_rep)
    
    #ipdb.set_trace()
    if is_train:
        p_rel_enrgs = d_out[:len(p_rels)]
        l_rel_enrgs = d_out[len(p_rels):len(p_rels)*2]
        nl_rel_enrgs = d_out[len(p_rels)*2:]
        
        pl_rel_score, lnl_rel_score = lossF(p_rel_enrgs, l_rel_enrgs, nl_rel_enrgs)
        
    else:
        pl_rel_score = None
        lnl_rel_score = None
        
    return pl_rel_score, lnl_rel_score, subj_out[:len(p_rels)], rel_out[:len(p_rels)], obj_out[:len(p_rels)]


if __name__ == '__main__':
    p= 1
    num_ent = 151
    num_rel = 51
    embed_dim = 4096
    
    batch = 10
    
    margin = 3.0
    #transrele = TransRelE(num_ent, num_rel, embed_dim, p)
    transrele = TransRelD(num_ent, num_rel, embed_dim, p)
    
    loss_func = MarginRankingLoss(margin)
    
    print(transrele)
    lhs = torch.LongTensor(batch).random_(0,num_ent)
    rhs = torch.LongTensor(batch).random_(0,num_ent)
    rel = torch.LongTensor(batch).random_(0,num_rel)
    
    p_batch = torch.stack((lhs, rel, rhs),1)
    print(p_batch)
    
    nce_batch, q_samples = corrupt_batch(p_batch, num_ent)
    nce_np = nce_batch.cpu().numpy()
    #train_hash = bias
    #nce_falseNs = ftensor(np.array([int(x.tobytes() in train_hash) for x in nce_np], dtype=np.float32))
    nce_falseNs = None
    print(nce_batch)
    print(q_samples)
    
    d_ins = torch.cat([p_batch, nce_batch], dim=0).contiguous()
    
    d_out,lhs_out, rhs_out, rel_out = transrele(d_ins)
    
    #ipdb.set_trace()
    p_enrgs = d_out[:len(p_batch)]
    nce_enrgs = d_out[len(p_batch):]
    #nce_term, nce_term_scores = loss_func(p_enrgs, nce_enrgs, weights=(1.- nce_falseNs))
    nce_term, nce_term_scores = loss_func(p_enrgs, nce_enrgs)
    print(nce_term)
    print(nce_term_scores)

TypeError: __init__() missing 3 required positional arguments: 'use_tran_vis', 'obj_embed', and 'rel_embed'