In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embed_size, num_filters, filter_sizes, num_classes, embedding_matrix=None):
        super(TextCNN, self).__init__()

        # Embedding Layer
        if embedding_matrix is None:
            self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)
        else:
            self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)

        # Convolutional Layers
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=embed_size, out_channels=num_filters, kernel_size=fs)  #选取不同的卷积核大小提取特证
            for fs in filter_sizes
        ])

        # Fully Connected Layer
        self.fc = nn.Linear(len(filter_sizes) * num_filters, num_classes)

    def forward(self, x):
        # x shape: (batch_size, seq_len)

        # Embedding Layer output: (batch_size, seq_len, embed_size)
        x = self.embedding(x)

        # Reshape input for Conv1d: (batch_size, embed_size, seq_len)
        x = x.permute(0, 2, 1)

        # Apply convolution and pooling
        pooled_outputs = []
        for conv in self.convs:
            conv_out = F.relu(conv(x))  # Conv1d output: (batch_size, num_filters, conv_output_length)
            pool_out = F.max_pool1d(conv_out, kernel_size=conv_out.shape[2])  # Max-pooling output: (batch_size, num_filters, 1)
            pooled_outputs.append(pool_out.squeeze(2))  # Squeeze to remove last dimension: (batch_size, num_filters)

        # Concatenate pooled features
        x = torch.cat(pooled_outputs, dim=1)  # Concatenate along num_filters dimension

        # Fully connected layer
        output = self.fc(x)

        return output
