In [None]:
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

# Research Question 3

## Joint inference across both models

In [None]:
data = list(zip(diff_model_one_hot, reasons))
data

In [None]:
class ChoiceReasonTuplePredictor(nn.Module):  # inheriting from nn.Module!

    def __init__(self, num_terms, num_aspects):

        super(ChoiceReasonTuplePredictor, self).__init__()

        self.linear = nn.Linear(num_terms, num_aspects, bias=False)

    def forward(self, covariates):
        choice_prob = F.sigmoid(torch.sum(self.linear(covariates)))
        reason_log_prob = F.log_softmax(self.linear(covariates))
        return choice_prob, reason_log_prob

model = ChoiceReasonTuplePredictor(num_terms=2, num_aspects=6)

loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

losses_total = []
losses_choice = []
losses_reason = []
for epoch in range(100):
    running_loss_choice = 0
    running_loss_reason = 0
    running_loss_total = 0
    for instance_covariates, reason_idx in data:
        model.zero_grad()
        # forward pass
        choice_prob, reason_log_prob = model(torch.from_numpy(instance_covariates).float())

        choice_loss = F.binary_cross_entropy(choice_prob, torch.ones_like(choice_prob))
        reason_loss = F.nll_loss(reason_log_prob, torch.tensor(reason_idx))  
        loss_total =  (choice_loss + 10 * reason_loss) / 11
        loss_total.backward()
        running_loss_choice += choice_loss.item()
        running_loss_reason += reason_loss.item()
        running_loss_total += loss_total.item()
        optimizer.step()
    losses_total.append(running_loss_total / len(data))
    losses_choice.append(running_loss_choice / len(data))
    losses_reason.append(running_loss_reason / len(data))

plt.plot(losses_total, label='total')
plt.plot(losses_choice, label='choice')
plt.plot(losses_reason, label='reason')
plt.legend()

