From 128a5d3c65f50494219cd1af72992f53c1ad1e11 Mon Sep 17 00:00:00 2001 From: haipinglu Date: Wed, 5 Nov 2025 22:49:30 +0000 Subject: [PATCH] Add T4 --- cnn.py | 143 +++++++++++++++++++++++++++++++-------------------------- 1 file changed, 77 insertions(+), 66 deletions(-) diff --git a/cnn.py b/cnn.py index 9bcea33..ee60f75 100644 --- a/cnn.py +++ b/cnn.py @@ -2,7 +2,38 @@ import torch.nn.functional as F -class SmallCNNFeature(nn.Module): +class BaseCNN(nn.Module): + """ + Base class for CNN models providing reusable factory methods for common layers. + This class is intended to be inherited by specific CNN implementations to reduce redundancy. + """ + + @staticmethod + def conv(in_ch, out_ch, kernel, stride=1, padding=0, bias=True, dim=2): + """Return a conv layer (1-D or 2-D).""" + conv_cls = {1: nn.Conv1d, 2: nn.Conv2d}[dim] + return conv_cls(in_ch, out_ch, kernel, stride, padding, bias=bias) + + @staticmethod + def batch_norm(num_features, dim=2): + """Return a batch-norm layer (1-D or 2-D).""" + bn_cls = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d}[dim] + return bn_cls(num_features) + + @staticmethod + def max_pool(kernel_size, stride=None, dim=2): + """Return a max-pool layer (1-D or 2-D).""" + pool_cls = {1: nn.MaxPool1d, 2: nn.MaxPool2d}[dim] + return pool_cls(kernel_size, stride) + + @staticmethod + def adaptive_avg_pool(output_size, dim=2): + """Return an adaptive average-pool layer.""" + pool_cls = {1: nn.AdaptiveAvgPool1d, 2: nn.AdaptiveAvgPool2d}[dim] + return pool_cls(output_size) + + +class SmallCNNFeature(BaseCNN): """ A feature extractor for small 32x32 images (e.g. CIFAR, MNIST) that outputs a feature vector of length 128. @@ -15,34 +46,29 @@ class SmallCNNFeature(nn.Module): """ def __init__(self, num_channels=3, kernel_size=5): - super(SmallCNNFeature, self).__init__() - self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=kernel_size) - self.bn1 = nn.BatchNorm2d(64) - self.pool1 = nn.MaxPool2d(2) - self.relu1 = nn.ReLU() - self.conv2 = nn.Conv2d(64, 64, kernel_size=kernel_size) - self.bn2 = nn.BatchNorm2d(64) - self.pool2 = nn.MaxPool2d(2) - self.relu2 = nn.ReLU() - self.conv3 = nn.Conv2d(64, 64 * 2, kernel_size=kernel_size) - self.bn3 = nn.BatchNorm2d(64 * 2) + super().__init__() + self.conv1 = self.conv(num_channels, 64, kernel_size) + self.bn1 = self.batch_norm(64) + self.pool1 = self.max_pool(2) + self.conv2 = self.conv(64, 64, kernel_size) + self.bn2 = self.batch_norm(64) + self.pool2 = self.max_pool(2) + self.conv3 = self.conv(64, 64 * 2, kernel_size) + self.bn3 = self.batch_norm(64 * 2) self.sigmoid = nn.Sigmoid() self._out_features = 128 def forward(self, input_): - x = self.bn1(self.conv1(input_)) - x = self.relu1(self.pool1(x)) - x = self.bn2(self.conv2(x)) - x = self.relu2(self.pool2(x)) + x = F.relu(self.pool1(self.bn1(self.conv1(input_)))) + x = F.relu(self.pool2(self.bn2(self.conv2(x)))) x = self.sigmoid(self.bn3(self.conv3(x))) - x = x.view(x.size(0), -1) - return x + return x.view(x.size(0), -1) def output_size(self): return self._out_features -class SignalVAEEncoder(nn.Module): +class SignalVAEEncoder(BaseCNN): """ SignalVAEEncoder encodes 1D signals into a latent representation suitable for variational autoencoders (VAE). @@ -68,9 +94,9 @@ class SignalVAEEncoder(nn.Module): def __init__(self, input_dim=60000, latent_dim=256): super().__init__() - self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=1) - self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1) - self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1) + self.conv1 = self.conv(1, 16, 3, stride=2, padding=1, dim=1) + self.conv2 = self.conv(16, 32, 3, stride=2, padding=1, dim=1) + self.conv3 = self.conv(32, 64, 3, stride=2, padding=1, dim=1) self.flatten = nn.Flatten() self.fc_mu = nn.Linear(64 * (input_dim // 8), latent_dim) self.fc_log_var = nn.Linear(64 * (input_dim // 8), latent_dim) @@ -81,12 +107,10 @@ def forward(self, x): x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.flatten(x) - mean = self.fc_mu(x) - log_var = self.fc_log_var(x) - return mean, log_var + return self.fc_mu(x), self.fc_log_var(x) -class ProteinCNN(nn.Module): +class ProteinCNN(BaseCNN): """ A protein feature extractor using Convolutional Neural Networks (CNNs). @@ -102,32 +126,29 @@ class ProteinCNN(nn.Module): """ def __init__(self, embedding_dim, num_filters, kernel_size, padding=True): - super(ProteinCNN, self).__init__() + super().__init__() if padding: self.embedding = nn.Embedding(26, embedding_dim, padding_idx=0) else: self.embedding = nn.Embedding(26, embedding_dim) in_ch = [embedding_dim] + num_filters - # self.in_ch = in_ch[-1] - kernels = kernel_size - self.conv1 = nn.Conv1d(in_channels=in_ch[0], out_channels=in_ch[1], kernel_size=kernels[0]) - self.bn1 = nn.BatchNorm1d(in_ch[1]) - self.conv2 = nn.Conv1d(in_channels=in_ch[1], out_channels=in_ch[2], kernel_size=kernels[1]) - self.bn2 = nn.BatchNorm1d(in_ch[2]) - self.conv3 = nn.Conv1d(in_channels=in_ch[2], out_channels=in_ch[3], kernel_size=kernels[2]) - self.bn3 = nn.BatchNorm1d(in_ch[3]) + k = kernel_size + self.conv1 = self.conv(in_ch[0], in_ch[1], k[0], dim=1) + self.bn1 = self.batch_norm(in_ch[1], dim=1) + self.conv2 = self.conv(in_ch[1], in_ch[2], k[1], dim=1) + self.bn2 = self.batch_norm(in_ch[2], dim=1) + self.conv3 = self.conv(in_ch[2], in_ch[3], k[2], dim=1) + self.bn3 = self.batch_norm(in_ch[3], dim=1) def forward(self, v): - v = self.embedding(v.long()) - v = v.transpose(2, 1) + v = self.embedding(v.long()).transpose(2, 1) v = self.bn1(F.relu(self.conv1(v))) v = self.bn2(F.relu(self.conv2(v))) v = self.bn3(F.relu(self.conv3(v))) - v = v.view(v.size(0), v.size(2), -1) - return v + return v.view(v.size(0), v.size(2), -1) -class LeNet(nn.Module): +class LeNet(BaseCNN): """LeNet is a customizable Convolutional Neural Network (CNN) model based on the LeNet architecture, designed for feature extraction from image and audio modalities. LeNet supports several layers of 2D convolution, followed by batch normalization, max pooling, and adaptive @@ -161,20 +182,18 @@ def __init__( linear=None, squeeze_output=True, ): - super(LeNet, self).__init__() + super().__init__() self.output_each_layer = output_each_layer - self.conv_layers = [nn.Conv2d(input_channels, output_channels, kernel_size=5, padding=2, bias=False)] - self.batch_norms = [nn.BatchNorm2d(output_channels)] - self.global_pools = [nn.AdaptiveAvgPool2d(1)] + self.conv_layers = [self.conv(input_channels, output_channels, 5, padding=2, bias=False)] + self.batch_norms = [self.batch_norm(output_channels)] + self.global_pools = [self.adaptive_avg_pool(1)] for i in range(additional_layers): - self.conv_layers.append( - nn.Conv2d( - (2**i) * output_channels, (2 ** (i + 1)) * output_channels, kernel_size=3, padding=1, bias=False - ) - ) - self.batch_norms.append(nn.BatchNorm2d(output_channels * (2 ** (i + 1)))) - self.global_pools.append(nn.AdaptiveAvgPool2d(1)) + in_ch = (2**i) * output_channels + out_ch = (2 ** (i + 1)) * output_channels + self.conv_layers.append(self.conv(in_ch, out_ch, 3, padding=1, bias=False)) + self.batch_norms.append(self.batch_norm(out_ch)) + self.global_pools.append(self.adaptive_avg_pool(1)) self.conv_layers = nn.ModuleList(self.conv_layers) self.batch_norms = nn.ModuleList(self.batch_norms) @@ -203,16 +222,11 @@ def forward(self, x): intermediate_outputs.append(output) if self.output_each_layer: - if self.squeeze_output: - return [t.squeeze() for t in intermediate_outputs] - return intermediate_outputs - - if self.squeeze_output: - return output.squeeze() - return output + return [t.squeeze() for t in intermediate_outputs] if self.squeeze_output else intermediate_outputs + return output.squeeze() if self.squeeze_output else output -class ImageVAEEncoder(nn.Module): +class ImageVAEEncoder(BaseCNN): """ ImageVAEEncoder encodes 2D image data into a latent representation for use in a Variational Autoencoder (VAE). @@ -243,10 +257,9 @@ class ImageVAEEncoder(nn.Module): def __init__(self, input_channels=1, latent_dim=256): super().__init__() - # Convolutional layers for 224x224 input - self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=2, padding=1) - self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1) - self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) + self.conv1 = self.conv(input_channels, 16, 3, stride=2, padding=1) + self.conv2 = self.conv(16, 32, 3, stride=2, padding=1) + self.conv3 = self.conv(32, 64, 3, stride=2, padding=1) self.flatten = nn.Flatten() self.fc_mu = nn.Linear(64 * 28 * 28, latent_dim) self.fc_log_var = nn.Linear(64 * 28 * 28, latent_dim) @@ -267,6 +280,4 @@ def forward(self, x): x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.flatten(x) - mean = self.fc_mu(x) - log_var = self.fc_log_var(x) - return mean, log_var + return self.fc_mu(x), self.fc_log_var(x)