# Defining and Training Models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from numpy import log, sqrt, log2, ceil, exp

## Load Data

In [None]:
with open("train_sequences_padded_dataset.pkl", "rb") as file:
    train_sequences_padded_dataset = pickle.load(file)

with open("vendors_tensor.pkl", "rb") as file:
    vendors_tensor = pickle.load(file)

In [None]:
# Define DataLoaders

batch_size = 512
num_workers = 0
train_loader = torch.utils.data.DataLoader(dataset=train_sequences_padded_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

## Define Model

### Model 1
<p align="center">
  <img src="Recommender1.png" width="1000"/>
</p>

In [None]:
# Define column indices to split columns into chunks (hard-coded)

cont_idx_lo = 0
cont_idx_hi = 8     # Up to avg_sale_log
misc_idx_hi = 12    # Up to rank
ptag_idx_hi = 55    # Up to primary_tags_is_42
vtag_idx_hi = 123   # Up to vendor_tag_is_67

In [None]:
print(ceil(log2(67)))   # vtag embed size

In [None]:
print(ceil(log2(42)))   # ptag embed size

In [None]:
print(2 * ceil(log2(12+6+7)))   # final embed size

In [None]:
class Model1(nn.Module):
    def __init__(self, vendors, cont_idx_hi, misc_idx_hi, ptag_idx_hi, vtag_idx_hi, d_fc):
        super(Model1, self).__init__()

        # for vendor lookup 
        self.vendor_lookup = nn.Embedding.from_pretrained(vendors)
        self.vendor_lookup.weight.requires_grad = False

        # indices for slicing inputs
        self.cont_idx_hi = cont_idx_hi
        self.misc_idx_hi = misc_idx_hi
        self.ptag_idx_hi = ptag_idx_hi
        self.vtag_idx_hi = vtag_idx_hi
        
        # dimensions of slices
        d_cont = cont_idx_hi
        d_misc = misc_idx_hi - cont_idx_hi
        d_ptag = ptag_idx_hi - misc_idx_hi
        d_vtag = vtag_idx_hi - ptag_idx_hi

        # primary_tags embeddings
        d_emb_ptag = int(ceil(log2(d_ptag)))
        self.c_emb_ptag = nn.Linear(d_ptag, d_emb_ptag)
        self.v_emb_ptag = nn.Linear(d_ptag, d_emb_ptag)

        # vendor_tag embeddings
        d_emb_vtag = int(ceil(log2(d_vtag)))
        self.c_emb_vtag = nn.Linear(d_vtag, d_emb_vtag)
        self.v_emb_vtag = nn.Linear(d_vtag, d_emb_vtag)

        # customer and vendor embeddings
        d_emb = ceil(log2(d_cont+d_misc+d_emb_ptag+d_emb_vtag))
        self.c_emb = nn.Linear(d_cont+d_misc+d_emb_ptag+d_emb_vtag, d_emb)
        self.v_emb = nn.Linear(d_cont+d_misc+d_emb_ptag+d_emb_vtag, d_emb)

        # dense layers
        self.fc1 = nn.Linear(2 * d_emb, d_fc)
        self.fc2 = nn.Linear(d_fc, d_fc // 2)
        self.fc3 = nn.Linear(d_fc // 2, d_fc // 4)
        self.fc4 = nn.Linear(d_fc // 4, 1)


    def forward(self, c_seq, v_id):
        # lookup customer and vendor representations
        vendor = self.vendor_lookup(v_id)
        customer = torch.sum(self.vendor_lookup(c_seq), axis=1)     # correct axis?

        # split customer
        c_cont = customer[:, : self.cont_idx_hi]
        c_misc = customer[:, self.cont_idx_hi : self.misc_idx_hi]
        c_ptag = customer[:, self.misc_idx_hi : self.ptag_idx_hi]
        c_vtag = customer[:, self.ptag_idx_hi :]

        # split vendor
        v_cont = vendor[:, : self.cont_idx_hi]
        v_misc = vendor[:, self.cont_idx_hi : self.misc_idx_hi]
        v_ptag = vendor[:, self.misc_idx_hi : self.ptag_idx_hi]
        v_vtag = vendor[:, self.ptag_idx_hi :]

        # embed ptags
        c_ptag = self.c_emb_ptag(c_ptag)
        c_ptag = F.elu(c_ptag)

        v_ptag = self.v_emb_ptag(v_ptag)
        v_ptag = F.elu(v_ptag)

        # embed vtags
        c_vtag = self.c_emb_vtag(c_vtag)
        c_vtag = F.elu(c_vtag)

        v_vtag = self.v_emb_vtag(v_vtag)
        v_vtag = F.elu(v_vtag)

        # embed customer
        customer = torch.cat((c_cont, c_misc, c_ptag, c_vtag), axis=1)
        customer = self.c_emb(customer)
        customer = F.elu(customer)

        # embed vendor
        vendor = torch.cat((v_cont, v_misc, v_ptag, v_vtag), axis=1)
        vendor = self.v_emb(vendor)
        vendor = F.elu(vendor)

        # feed through classifier
        out = torch.cat((customer, vendor), axis=1)
        out = self.fc1(out)
        out = F.elu(out)

        out = self.fc2(out)
        out = F.elu(out)

        out = self.fc3(out)
        out = F.elu(out)

        out = self.fc4(out)     # output is raw
        return out

model1 = Recommender1(vendors=vendors_tensor, cont_idx_hi=cont_idx_hi, misc_idx_hi=misc_idx_hi, ptag_idx_hi=ptag_idx_hi, vtag_idx_hi=vtag_idx_hi, d_fc=64)

### Loss Function

In [None]:
# Define ranking loss

def sigmoid(x):
    return 1/(1+exp(-x))

def ranking_loss(pos_pred, neg_pred):
    return 1 - sigmoid(pos_pred - neg_pred)

### Optimizer

In [None]:
optimizer = torch.optim.Adam(model1.parameters())

## Training

In [None]:
# Define the training process for Model 1 TODO

epochs=100
print_every=10

for epoch in range(epochs):
    running_loss = 0.0
    for i, X in enumerate(train_loader):
        pos_pred = model1.forward(X)
        neg_sample = 

def train(model, dataloader, loss, optimizer, epochs:int=100, print_every:int=10):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (X, y) in enumerate(dataloader):
            
            y_score = model.forward(X)
            pred_loss = loss(y_score, y)
            running_loss += pred_loss.item()
            pred_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        if epoch % print_every == 0:
            print(f'Epoch [{epoch}/{epochs}]: sum(batch_losses) = {running_loss:.4f}')
    print(f'Epoch [{epochs}/{epochs}]: sum(batch_losses) = {running_loss:.4f}')
    print('Done!')