In [51]:
import os
import sys
import torch
from torch import nn
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm
from torch.nn import Module
from torch.nn import functional as F
from torch import LongTensor as LT
from torch import FloatTensor as FT
from torch import nn
from sklearn.base import BaseEstimator, ClassifierMixin
from collections import OrderedDict
from pathlib import Path

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device( "cpu")

class proxy_clf_network(Module):
    def __init__(
        self,
        domain_dims: list = [], 
        layer_dims:list  = [32,256,256],
        dropout_prob = 0.1
    ):
        """
        A simple 3 layered neural network, as per the paper
        """
        super(proxy_clf_network, self).__init__()
        self.domain_dims = domain_dims
        self.num_domains = len(domain_dims)
        self.layer_dims = layer_dims
        self.dropout_prob = dropout_prob
        self.__build__()
        return
    
    def __build__(self):
        """
        Build the architecture
        """
        emb_layer_dim = self.layer_dims[0]
        # Create an embedding layer for each domain
        embModule_list = []
        for  dim in self.domain_dims:
            embModule_list.append(nn.Embedding(dim, emb_layer_dim))
        self.embModule_list = nn.ModuleList(embModule_list)
        
        # The outputs should be concatenated
        fcn_layers = []
        dropout_prob = self.dropout_prob
        num_layers = len(self.layer_dims)
        inp_dim = emb_layer_dim * len(self.domain_dims)
        for i in range(1, num_layers):
            op_dim =  self.layer_dims[i]
            fcn_layers.append(nn.Linear(inp_dim,op_dim))
            fcn_layers.append(nn.Dropout(dropout_prob))
            fcn_layers.append(nn.ReLU())
            inp_dim = op_dim
        
        # Last layer for binary output
        fcn_layers.append(nn.Linear(inp_dim, 1))
        fcn_layers.append(nn.Sigmoid())                 
        self.fcn = nn.Sequential(*fcn_layers)
        return 
    
    def forward(self,X):
        """ 
        Input X : has shape [batch, num_domains, 1]
        """
       
        emb = []
        for i in range(self.num_domains):
            r = self.embModule_list[i](X[:,i])
            emb.append(r)
        emb = torch.cat(emb, dim =-1)
        
        x1 = self.fcn(emb)
        return x1


class proxy_clf(ClassifierMixin, BaseEstimator):
    """
    Container for the proxy model 
    """
    def __init__(
        self, 
        model: proxy_clf_network,
        dataset :str = None,
        batch_size: int = 512,
        LR: float = 0.001,
        device = torch.device("cpu")
    ):
        self.model = model
        self.signature = 'proxy_{}'.format(dataset) 
        self.device = device
        self.batch_size = batch_size
        self.LR = LR 
        return
    
    def fit(
        self,
        X : np.array, 
        Y : np.array,
        num_epochs: int = 50,
        log_interval = 100
    ):
        self.model.train()
        self.model.to(self.device)
        bs = self.batch_size
        opt = torch.optim.Adam(list(self.model.parameters()), lr = self.LR)
        num_batches = X.shape[0] // bs + 1
        idx = np.arange(X.shape[0])
        loss_values = []
        clip_value = 5
        # train model 
        for epoch in tqdm(range(num_epochs)):
            np.random.shuffle(idx)
            epoch_loss = []
            for b in range(num_batches):
                opt.zero_grad() 
                b_idx = idx[b*bs:(b+1)*bs]
                x = LT(X[b_idx]).to(self.device) 
                pred_y = self.model(x)
                target_y = FT(Y[b_idx]).to(self.device)
                # Calculate loss
                loss = F.binary_cross_entropy(pred_y, target_y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_value)
                opt.step()
                
                if b % log_interval == 0 :
                    print('[Epoch] {}  | batch {} | Loss {:4f}'.format(epoch, b, loss.cpu().data.numpy()))
                epoch_loss.append(loss.cpu().data.numpy())
            epoch_loss = np.mean(epoch_loss)
            loss_values.append(epoch_loss)
        return  loss_values  
    
    def predict(
        self, 
        X
    ):
        self.model.eval()
        result = []
        with torch.no_grad():
            bs = self.batch_size
            num_batches = X.shape[0] // bs + 1
            idx = np.arange(train_x_pos.shape[0])
            for b in range(num_batches):
                b_idx = idx[b*bs:(b+1)*bs]
                x = LT(X[b_idx]).to(self.device)
                pred_y = self.model(x)
                pred_y = pred_y.cpu().data.numpy()
                result.extend(pred_y)
        return result
    

    def save_model(
        self, 
        loc: str =None
    ):
        """
        Save model 
        """
        if loc is None:
            loc = './saved_models'
        path_obj = Path(loc)
        path_obj.mkdir( parents=True, exist_ok=True )
        loc = os.path.join(loc, self.signature  + '.pth')
        self.save_path = loc
        torch.save(self.model, loc)
        return

    def load_model(
        self, 
        path: str = None
    ):
        """
        Load Model
        """
        if self.save_path is None and path is None:
            print('Error . Null path given to load model ')
            return None
        print('Device', self.device)
        if path is None:
            path = self.save_path 
        
        self.model = torch.load(path)
        self.model.eval()
        
        return
    

In [52]:
with open('./../../GeneratedData/us_import1/domain_dims.pkl','rb') as fh:
    domain_dims = OrderedDict(pickle.load(fh))
df = pd.read_csv('./../../GeneratedData/us_import1/train_data.csv', index_col=None)

try:
    del df['PanjivaRecordID']
except:
    pass

X = df.head(1000).values
Y = np.random.randint(0,2, size=[1000,1])
network = proxy_clf_network(list(domain_dims.values()))

In [53]:
network.num_domains

8

In [56]:
clf_obj = proxy_clf(
    model = network,
    batch_size=512,
    device = DEVICE, 
    
)

In [57]:
clf_obj.fit(X,Y)

 12%|█▏        | 6/50 [00:00<00:01, 27.55it/s]

[Epoch] 0  | batch 0 | Loss 0.080982
[Epoch] 1  | batch 0 | Loss 0.096085
[Epoch] 2  | batch 0 | Loss 0.106306
[Epoch] 3  | batch 0 | Loss 0.087081
[Epoch] 4  | batch 0 | Loss 0.087011
[Epoch] 5  | batch 0 | Loss 0.089132


 24%|██▍       | 12/50 [00:00<00:01, 27.62it/s]

[Epoch] 6  | batch 0 | Loss 0.100844
[Epoch] 7  | batch 0 | Loss 0.095563
[Epoch] 8  | batch 0 | Loss 0.084265
[Epoch] 9  | batch 0 | Loss 0.091712
[Epoch] 10  | batch 0 | Loss 0.079760
[Epoch] 11  | batch 0 | Loss 0.098931


 36%|███▌      | 18/50 [00:00<00:01, 28.08it/s]

[Epoch] 12  | batch 0 | Loss 0.098052
[Epoch] 13  | batch 0 | Loss 0.083551
[Epoch] 14  | batch 0 | Loss 0.099131
[Epoch] 15  | batch 0 | Loss 0.090963
[Epoch] 16  | batch 0 | Loss 0.101459
[Epoch] 17  | batch 0 | Loss 0.082238


 48%|████▊     | 24/50 [00:00<00:00, 27.75it/s]

[Epoch] 18  | batch 0 | Loss 0.086479
[Epoch] 19  | batch 0 | Loss 0.089221
[Epoch] 20  | batch 0 | Loss 0.088265
[Epoch] 21  | batch 0 | Loss 0.094852
[Epoch] 22  | batch 0 | Loss 0.092860
[Epoch] 23  | batch 0 | Loss 0.099797


 60%|██████    | 30/50 [00:01<00:00, 28.32it/s]

[Epoch] 24  | batch 0 | Loss 0.089205
[Epoch] 25  | batch 0 | Loss 0.085608
[Epoch] 26  | batch 0 | Loss 0.100779
[Epoch] 27  | batch 0 | Loss 0.087697
[Epoch] 28  | batch 0 | Loss 0.076064
[Epoch] 29  | batch 0 | Loss 0.083351


 72%|███████▏  | 36/50 [00:01<00:00, 28.52it/s]

[Epoch] 30  | batch 0 | Loss 0.090127
[Epoch] 31  | batch 0 | Loss 0.085253
[Epoch] 32  | batch 0 | Loss 0.084880
[Epoch] 33  | batch 0 | Loss 0.075837
[Epoch] 34  | batch 0 | Loss 0.073307
[Epoch] 35  | batch 0 | Loss 0.073080


 78%|███████▊  | 39/50 [00:01<00:00, 27.19it/s]

[Epoch] 36  | batch 0 | Loss 0.094942
[Epoch] 37  | batch 0 | Loss 0.084815
[Epoch] 38  | batch 0 | Loss 0.098647
[Epoch] 39  | batch 0 | Loss 0.085531
[Epoch] 40  | batch 0 | Loss 0.080960
[Epoch] 41  | batch 0 | Loss 0.087804


 96%|█████████▌| 48/50 [00:01<00:00, 28.04it/s]

[Epoch] 42  | batch 0 | Loss 0.099845
[Epoch] 43  | batch 0 | Loss 0.101130
[Epoch] 44  | batch 0 | Loss 0.090926
[Epoch] 45  | batch 0 | Loss 0.109143
[Epoch] 46  | batch 0 | Loss 0.099153
[Epoch] 47  | batch 0 | Loss 0.094416


100%|██████████| 50/50 [00:01<00:00, 27.86it/s]

[Epoch] 48  | batch 0 | Loss 0.089763
[Epoch] 49  | batch 0 | Loss 0.079244





[0.118137464,
 0.10159801,
 0.09743577,
 0.096901715,
 0.10228263,
 0.093336955,
 0.09660366,
 0.094229355,
 0.09220202,
 0.09332855,
 0.08989188,
 0.09619245,
 0.09572525,
 0.08973353,
 0.094560504,
 0.08983621,
 0.093169905,
 0.08906898,
 0.08936441,
 0.08913475,
 0.09161727,
 0.089209974,
 0.09191942,
 0.08674375,
 0.09162404,
 0.08950473,
 0.09042431,
 0.08846916,
 0.088125736,
 0.08903943,
 0.08640144,
 0.087355465,
 0.08899748,
 0.08839207,
 0.08808977,
 0.08728139,
 0.08798537,
 0.08679101,
 0.088233605,
 0.090156645,
 0.08692327,
 0.08809258,
 0.08893137,
 0.08816779,
 0.086825594,
 0.09150988,
 0.08748473,
 0.08837591,
 0.08923798,
 0.0842368]