<a href="https://colab.research.google.com/github/day253/labs/blob/master/labs/text_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, num_classes, dropout_rate):
        super(TextCNN, self).__init__()

        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # 静态通道（不更新词嵌入）
        self.static_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.static_embedding.weight.requires_grad = False

        # 卷积层
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(fs, embedding_dim))
            for fs in filter_sizes
        ])

        # 全连接层
        self.fc = nn.Linear(len(filter_sizes) * num_filters, num_classes)

        # Dropout
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # 获取词嵌入
        embedded = self.embedding(x).unsqueeze(1)  # (batch_size, 1, seq_len, embedding_dim)
        static_embedded = self.static_embedding(x).unsqueeze(1)  # (batch_size, 1, seq_len, embedding_dim)

        # 合并静态和非静态通道
        combined_embedded = torch.cat([embedded, static_embedded], dim=1)  # (batch_size, 2, seq_len, embedding_dim)

        # 卷积操作
        conv_outputs = [F.relu(conv(combined_embedded)).squeeze(3) for conv in self.convs]  # [(batch_size, num_filters, seq_len - filter_size + 1)]

        # 最大池化
        pooled_outputs = [F.max_pool1d(conv_output, conv_output.size(2)).squeeze(2) for conv_output in conv_outputs]  # [(batch_size, num_filters)]

        # 合并所有池化结果
        pooled = torch.cat(pooled_outputs, 1)  # (batch_size, num_filters * len(filter_sizes))

        # Dropout
        pooled = self.dropout(pooled)

        # 全连接层
        logits = self.fc(pooled)

        return logits