In [2]:
import pandas as pd
import numpy as np
from torch import nn
import torch
from torch import tensor
 
import multiprocessing
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import RandomSampler, SequentialSampler
import yaml
import pickle
import torch
import torch.nn as nn


In [105]:
class clf_dfm(nn.Module):
    
    def __init__(
        self,
        num_domains,
        inp_emb_dimensions,
        fm_inp_dim = 64,
        dnn_layer_dimensions =[512,128]
    ):
        super(clf_dfm, self).__init__()
        self.dnn_fc = nn.ModuleList()
        dnn_num_layers = len(dnn_layer_dimensions)
        self.dnn_num_layers = dnn_num_layers
        for i in range(dnn_num_layers):    
            if i == 0: inp_dim = inp_emb_dimensions * num_domains
            if i == dnn_num_layers-1:
                op_dim = 1 
            else:
                op_dim = dnn_layer_dimensions[i] 
            self.dnn_fc.append( nn.Linear(inp_dim, op_dim ))
            inp_dim = op_dim
            
        self.fc2 = nn.Linear(2,2,bias=False)    
        self.activation_1 = torch.nn.Tanh()
        self.softmax = torch.nn.Softmax()
        self.num_domains = num_domains
        self.xform_fm_1 = nn.ModuleList([nn.Linear(inp_emb_dimensions,fm_inp_dim) for _ in range(num_domains)])
        print(self.xform_fm_1 )
        return
    
    
    def forward(self, input_x):
        '''
        input x should be of shape [Batch, num_domains, emb_vec]
        '''
        
        # ----- DNN ------- #
        x_dnn = input_x.view(-1,12)
        
        for i in range(self.dnn_num_layers):    
            x_dnn = self.dnn_fc[i](x_dnn)
        print(x_dnn.shape)
        # ----- FM --------- #
        
        fm_input = input_x
        # transform inner products with a MLP
        fm_input = torch.chunk(
            fm_input,
            self.num_domains,
            dim=1
        )
        print(fm_input[0].shape, len(fm_input))
        x_fm_input = []
        for i in range(self.num_domains):
            x_fm_input.append( self.xform_fm_1[i](fm_input[i]).squeeze(1) )
        x_fm_input = torch.stack(
             x_fm_input, dim=2   
        )    
        print(x_fm_input.shape)      
        
        square_of_sum = torch.pow(torch.sum(x_fm_input, dim=1, keepdim=True), 2)
        sum_of_square = torch.sum(x_fm_input * x_fm_input, dim=1, keepdim=True)
        cross_term = square_of_sum - sum_of_square
        cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False)
        print(cross_term.shape)
        x2 = torch.cat ([x_dnn, cross_term],dim = 1)
        op = self.fc2(x2)
        print(op.shape)
        op = self.activation_1(op)
        op = self.softmax (op)
        return op
    
        

In [106]:
model =  clf_dfm(
    num_domains = 3,
    inp_emb_dimensions = 4,
    fm_inp_dim= 8,
    dnn_layer_dimensions = [5,6]
)

ModuleList(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): Linear(in_features=4, out_features=8, bias=True)
  (2): Linear(in_features=4, out_features=8, bias=True)
)


In [107]:
x = torch.FloatTensor(np.random.random([10,3,4]))

In [108]:
model(x)

torch.Size([10, 1])
torch.Size([10, 1, 4]) 3
torch.Size([10, 8, 3])
torch.Size([10, 1])
torch.Size([10, 2])




tensor([[0.3853, 0.6147],
        [0.5029, 0.4971],
        [0.3913, 0.6087],
        [0.6236, 0.3764],
        [0.4327, 0.5673],
        [0.4662, 0.5338],
        [0.4908, 0.5092],
        [0.4709, 0.5291],
        [0.3612, 0.6388],
        [0.4063, 0.5937]], grad_fn=<SoftmaxBackward>)

In [63]:
x1 = np.random.random([1,3,4])
fm_input = x1

square_of_sum = np.power(
    np.sum(fm_input, axis=1, keepdims=True), 
    2
)
sum_of_square = np.sum(fm_input * fm_input, axis=1, keepdims=True)
cross_term = square_of_sum - sum_of_square

cross_term = 0.5 * np.sum(cross_term, axis=2, keepdims=False)


In [64]:
fm_input

array([[[0.19505893, 0.19231236, 0.94403774, 0.89043176],
        [0.43569211, 0.25410309, 0.2999439 , 0.69454437],
        [0.21133718, 0.69357648, 0.39821617, 0.74579648]]])

In [65]:
cross_term

array([[3.15582316]])

In [66]:
sum_of_square

array([[[0.27253901, 0.58260076, 1.13974973, 1.83147299]]])

In [71]:
x1[0][0]

array([0.19505893, 0.19231236, 0.94403774, 0.89043176])

In [72]:
r = 0

for i in range(3):
    for j in range(i+1,3):
        r += np.dot(x1[0][i],x1[0][j])
    

In [74]:
r

3.1558231612885437