In [17]:
import torch
import torch.nn as nn
import numpy as np

In [42]:
class CharTextCNN(nn.Module):
    def __init__(self, config):
        super(CharTextCNN, self).__init__()
        in_features = [config.char_num] + config.features[0:-1]   # [70,256,256,256,256,256]
        out_features = config.features    # [256,256,256,256,256,256]
        kernel_size = config.kernel_sizes   # [7,7,3,3,3,3]
        self.convs = []
        
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_features[0], out_features[0], kernel_size=kernel_size[0], stride=1),
            nn.BatchNorm1d(out_features[0]),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(in_features[1], out_features[1], kernel_size=kernel_size[1], stride=1),
            nn.BatchNorm1d(out_features[1]),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3)
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(in_features[2], out_features[2], kernel_size=kernel_size[2], stride=1),
            nn.BatchNorm1d(out_features[2]),
            nn.ReLU()
        )
        self.conv4 = nn.Sequential(
            nn.Conv1d(in_features[3], out_features[3], kernel_size=kernel_size[3], stride=1),
            nn.BatchNorm1d(out_features[3]),
            nn.ReLU()
        )
        self.conv5 = nn.Sequential(
            nn.Conv1d(in_features[4], out_features[4], kernel_size=kernel_size[4], stride=1),
            nn.BatchNorm1d(out_features[4]),
            nn.ReLU()
        )
        self.conv6 = nn.Sequential(
            nn.Conv1d(in_features[5], out_features[5], kernel_size=kernel_size[5], stride=1),
            nn.BatchNorm1d(out_features[5]),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3)
        )
        
        self.fc1 = nn.Sequential(
            nn.Linear(8704, 1024),
            nn.ReLU(),
            nn.Dropout(p=config.dropout)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Dropout(p=config.dropout)
        )
        self.fc3 = nn.Linear(1024, config.num_classes)  
        
    
    def forward(self, x):
        x = self.conv1(x)  # 64, 256, 1014-7+1=1008-3/3+1=336
#         print(x.shape)
        x = self.conv2(x)  # 64, 256, 110
#         print(x.shape)
        x = self.conv3(x)  # 64, 256, 108
#         print(x.shape)
        x = self.conv4(x)  # 64, 256, 106
#         print(x.shape)
        x = self.conv5(x)  # 64, 256, 106
#         print(x.shape)
        x = self.conv6(x)  # 64, 256, 34
#         print(x.shape)

        x = x.view(x.size(0), -1) # 64, 8704
        x = self.fc1(x)  # 64, 1024
        x = self.fc2(x)  # 64, 1024
        x = self.fc3(x)  # 64, 4
        return x
        

In [43]:
class config:
    def __init__(self):
        self.char_num = 70  # 字符的个数
        self.features = [256,256,256,256,256,256] # 每一层特征个数
        self.kernel_sizes = [7,7,3,3,3,3] # 每一层的卷积核尺寸
        self.dropout = 0.5 # dropout大小
        self.num_classes = 4 # 数据的类别个数

In [44]:
config = config()
chartextcnn = CharTextCNN(config)
test = torch.zeros([64,70,1014])
out = chartextcnn(test)

torch.Size([64, 256, 336])
torch.Size([64, 256, 110])
torch.Size([64, 256, 108])
torch.Size([64, 256, 106])
torch.Size([64, 256, 104])
torch.Size([64, 256, 34])
torch.Size([64, 8704])
torch.Size([64, 1024])
torch.Size([64, 1024])
torch.Size([64, 4])


In [45]:
out.shape

torch.Size([64, 4])