<a href="https://colab.research.google.com/github/kryuchkovdm/Distillation/blob/master/models/CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
class CNN(nn.Module):

    def __init__(self,
                 n_vocab,
                 n_labels,
                 embedding_dim=50,
                 n_filters=100,
                 filter_sizes=[3, 4, 5],
                 dropout=0.5,
                 special_chars=[],
                 pretrained_embeddings=None):  
        super(CNN, self).__init__()
        self.n_vocab = n_vocab
        self.n_labels = n_labels
        self.embedding_dim = embedding_dim
        self.n_filters = n_filters
        self.filter_sizes = filter_sizes
        self.dropout_p = dropout
        self.width = len(filter_sizes) * n_filters

        if pretrained_embeddings is not None:
            assert n_vocab == pretrained_embeddings.shape[0]
            assert embedding_dim == pretrained_embeddings.shape[1]
            self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings)
        else:
            self.embedding = nn.Embedding(n_vocab, embedding_dim)
        
        self.conv0 = nn.Conv2d(in_channels=1,
                               out_channels=n_filters,
                               kernel_size=(filter_sizes[0], embedding_dim))
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=n_filters,
                               kernel_size=(filter_sizes[1], embedding_dim))
        self.conv2 = nn.Conv2d(in_channels=1,
                               out_channels=n_filters,
                               kernel_size=(filter_sizes[2], embedding_dim))
        self.dropout = nn.Dropout(dropout)

        self.fc = nn.Linear(in_features=self.width, out_features=n_labels)

        for special in special_chars:
            self.embedding.weight.data[special] = torch.zeros(embedding_dim)

    def forward(self, input_ids):
        """Only input ids are required - kwargs are for API compat with BERT."""
        X = self.embedding(input_ids)
        X = X.unsqueeze(1)  
        X0 = F.relu(self.conv0(X).squeeze(3))
        X1 = F.relu(self.conv1(X).squeeze(3))
        X2 = F.relu(self.conv2(X).squeeze(3))
        X0 = F.max_pool1d(X0, X0.shape[2]).squeeze(2)
        X1 = F.max_pool1d(X1, X1.shape[2]).squeeze(2)
        X2 = F.max_pool1d(X2, X2.shape[2]).squeeze(2)
        X = torch.cat([X0, X1, X2], dim=1)
        X = self.dropout(X)
        X = self.fc(X)
        return X