# SM Model PyTorch Walkthrough

This purpose of this notebook is to explain how to use PyTorch to implement the SM Model for new PyTorch users. Here are the recommended prerequisites before reading this walkthrough:

* Have knowledge of Convolutional Neural Networks. If not these are helpful slides: https://cs.uwaterloo.ca/~mli/Deep-Learning-2017-Lecture5CNN.ppt.
* Read the SM Model paper: http://dl.acm.org/citation.cfm?id=2767738

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class QAModel(nn.Module):
    
        def __init__(self, input_n_dim, filter_width, \
            conv_filters=100, no_ext_feats=False, ext_feats_size=4, n_classes=2):
        super(QAModel, self).__init__()

        self.no_ext_feats = no_ext_feats

        self.conv_channels = conv_filters
        n_hidden = 2*self.conv_channels + (0 if no_ext_feats else ext_feats_size)

        self.conv_q = nn.Sequential(
            nn.Conv1d(input_n_dim, self.conv_channels, filter_width, padding=filter_width-1),
            nn.Tanh()
        )

        self.conv_a = nn.Sequential(
            nn.Conv1d(input_n_dim, self.conv_channels, filter_width, padding=filter_width-1),
            nn.Tanh()
        )

        self.combined_feature_vector = nn.Linear(2*self.conv_channels + \
            (0 if no_ext_feats else ext_feats_size), n_hidden)

        self.combined_features_activation = nn.Tanh()
        self.dropout = nn.Dropout(0.5)
        self.hidden = nn.Linear(n_hidden, n_classes)
        self.logsoftmax = nn.LogSoftmax()


    def forward(self, question, answer, ext_feats):
        q = self.conv_q.forward(question)
        q = F.max_pool1d(q, q.size()[2])
        q = q.view(-1, self.conv_channels)

        a = self.conv_a.forward(answer)
        a = F.max_pool1d(a, a.size()[2])
        a = a.view(-1, self.conv_channels)

        x = None
        if self.no_ext_feats:
            x = torch.cat([q, a], 1)
        else:
            x = torch.cat([q, a, ext_feats], 1)

        x = self.combined_feature_vector.forward(x)
        x = self.combined_features_activation.forward(x)
        x = self.dropout(x)
        x = self.hidden(x)
        x = self.logsoftmax(x)

        return x