In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import Parameter
from torch import optim
import ipdb

import matplotlib
import matplotlib.pyplot as plt

class obj_message(nn.Module):
    def __init__(self,num_objs, embed_dim, opt, debug):
        super(obj_message, self).__init__()
        
        self.embed_dim = embed_dim
        self.num_objs = num_objs
        
        """object parameters"""
        self.ow1 = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.ou1 = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.ol1 = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.ofc_l = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        
        """object classifier"""
        self.cls = nn.Linear(self.embed_dim, self.num_objs, bias=True)
        
        self.bias = 0.1
        self.debug = debug
        
    def graph(self, subj,obj):
        """ Create a graph """
        # subjects and objects
        adj_matrix = Variable(
            torch.ones([self.num_objs,self.num_objs])
        ) * self.bias
        
        adj_matrix[subj,obj] = 1.0
        adj_matrix[obj,subj] = 1.0
        
        l_adj_m = torch.tril(adj_matrix) * -1.0
        u_adj_m = torch.tril(adj_matrix).transpose(0,1)
        adj_matrix = l_adj_m + u_adj_m
        
        if self.debug:
            print(adj_matrix)

        return adj_matrix
        
    def forward(self,r_obj_rep, obj_rep, subj, obj, is_train, use_root):
        
        
        if use_root:
            norm_Q = Variable(
                torch.ones([self.num_objs+1,self.num_objs+1])) * self.bias
            obj_rep = torch.cat([r_obj_rep, obj_rep], 0)
        else:
            norm_Q = Variable(
                torch.ones([self.num_objs,self.num_objs])) * self.bias
        
        sub_Q = self.graph(subj,obj)
        
        #import ipdb; ipdb.set_trace()
        mask = (sub_Q.abs() > 0).float()
        deg_m = 1/mask.sum(1).unsqueeze(1).repeat(1,self.num_objs)
        sub_Q = sub_Q * deg_m
        
        if use_root:
            norm_Q[1:,1:] = sub_Q
            norm_Q[1:,0] = (1.0 - norm_Q[1:,1:].sum(1))
        else:
            norm_Q = sub_Q
            
        norm_Q[norm_Q != norm_Q] = 0.0
        
        if self.debug:
            print(norm_Q.sum(1))
        
        #norm_Q[range(self.num_objs), range(self.num_objs)] = 1.0- norm_Q.sum(1)
        
        if self.debug:
            print(norm_Q.sum(1))
            print(norm_Q)
            
        """relational object features"""
        obj_w1 = self.ow1(F.relu(obj_rep))
        obj_u1 = self.ou1(F.relu(obj_rep))
        obj_l1 = self.ol1(F.relu(obj_rep))
        
        #import ipdb; ipdb.set_trace()
        ofc_l = torch.matmul(norm_Q, obj_l1)
        out_obj_rep = self.ofc_l(F.relu(ofc_l)) + obj_w1
        
        if use_root:
            out_obj_rep = out_obj_rep[1:,:]
        
        obj_dists = self.cls(out_obj_rep)
        
        return obj_dists, out_obj_rep
    
print(obj_message)

<class '__main__.obj_message'>


In [25]:
#import ipdb
if __name__ == '__main__':
    p= 1
    num_obj_cls = 5
    embed_dim = 16
    rgb = ['c', 'y', 'b', 'r', 'm']
    opt = ['proj', 'rotatE', 'rotatEC']
    n_var = 0.5
    debug = False
    mps = 2
    num_objs = 5
    te_num_objs = 5
    
    # model and optimizer
    obj_mps = obj_message(num_objs, embed_dim, opt[0],debug) 
    params = [p for n,p in obj_mps.named_parameters() if p.requires_grad]
    optimizer = optim.Adam(params, weight_decay=0.0001, lr=0.001, eps=1e-3)
    
    # define object embedding
    obj_embed = nn.Embedding(num_obj_cls, embed_dim)
    
    if debug:
        steps = 1
        epochs = 1
        display = 1
    else:
        steps = 200
        epochs = 1
        display = 40
        
    rnd_objs = torch.LongTensor(te_num_objs).random_(0,num_obj_cls)
    te_obj_labels = Variable(rnd_objs)
    te_obj_rep = obj_embed(te_obj_labels)
    te_r_obj_rep = te_obj_rep.sum(0).unsqueeze(0)
            
    te_subj = torch.LongTensor(int(num_objs/2)).random_(0,num_obj_cls)
    te_obj = torch.LongTensor(int(num_objs/2)).random_(0,num_obj_cls)
        
    for epoch in range(epochs):
        for step in range(steps):
            batch = 100
            
            # training sets
            rnd_objs = torch.LongTensor(num_objs).random_(0,num_obj_cls)
            tr_obj_labels = Variable(rnd_objs)
            tr_obj_rep = obj_embed(tr_obj_labels)
            tr_r_obj_rep = tr_obj_rep.sum(0).unsqueeze(0)
            
            tr_subj = torch.LongTensor(int(num_objs/2)).random_(0,num_obj_cls)
            tr_obj = torch.LongTensor(int(num_objs/2)).random_(0,num_obj_cls)
                                
            obj_mps.train()
            for i in range(mps):
                tr_obj_dists, tr_obj_rep = obj_mps(
                    tr_r_obj_rep, tr_obj_rep, tr_subj, tr_obj, True, use_root=(i==0))

            losses = {}
            losses['obj_loss'] = F.cross_entropy(tr_obj_dists, tr_obj_labels)
            loss = sum(losses.values())

            ##############################
            optimizer.zero_grad()
            loss.backward()
            losses['total'] = loss
            optimizer.step()
            ##############################
            
            if step % display == 0:
                print('train:epoch:{}, step:{}, loss:{}'.format(epoch, step,losses['total'].data))
          
        obj_mps.eval()
        for i in range(mps):
            te_obj_dists, te_obj_rep = obj_mps(
                te_r_obj_rep, te_obj_rep, te_subj, te_obj, True, use_root=(i==0))

        te_p_objs = te_obj_dists.max(1)[1]
        te_acc = np.array(te_obj_labels == te_p_objs).sum()/te_num_objs
        losses = {}
        losses['total'] = F.cross_entropy(te_obj_dists, te_obj_labels)
        loss = sum(losses.values())
        print('test:epoch:{}, step:{}, loss:{}, acc:{}'.format(epoch, step,losses['total'].data, te_acc))
        
                

train:epoch:0, step:0, loss:
 1.5706
[torch.FloatTensor of size 1]

train:epoch:0, step:40, loss:
 1.5812
[torch.FloatTensor of size 1]

train:epoch:0, step:80, loss:
 1.3195
[torch.FloatTensor of size 1]

train:epoch:0, step:120, loss:
 1.0126
[torch.FloatTensor of size 1]

train:epoch:0, step:160, loss:
 0.7953
[torch.FloatTensor of size 1]

test:epoch:0, step:199, loss:
 0.5296
[torch.FloatTensor of size 1]
, acc:Variable containing:
 1
[torch.ByteTensor of size 1]



In [None]:
            te_obj_labels = Variable(
                torch.from_numpy(np.array([0,1,2,3]))).type(torch.LongTensor)
            
            # training sets
            te_subj = torch.from_numpy(
                np.array([0,0,1,1,2,3])).type(torch.LongTensor)
            ote_bj = torch.from_numpy( 
                np.array([2,3,0,2,3,0])).type(torch.LongTensor)