Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 77 additions & 66 deletions cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,38 @@
import torch.nn.functional as F


class SmallCNNFeature(nn.Module):
class BaseCNN(nn.Module):
"""
Base class for CNN models providing reusable factory methods for common layers.
This class is intended to be inherited by specific CNN implementations to reduce redundancy.
"""

@staticmethod
def conv(in_ch, out_ch, kernel, stride=1, padding=0, bias=True, dim=2):
"""Return a conv layer (1-D or 2-D)."""
conv_cls = {1: nn.Conv1d, 2: nn.Conv2d}[dim]
return conv_cls(in_ch, out_ch, kernel, stride, padding, bias=bias)

@staticmethod
def batch_norm(num_features, dim=2):
"""Return a batch-norm layer (1-D or 2-D)."""
bn_cls = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d}[dim]
return bn_cls(num_features)

@staticmethod
def max_pool(kernel_size, stride=None, dim=2):
"""Return a max-pool layer (1-D or 2-D)."""
pool_cls = {1: nn.MaxPool1d, 2: nn.MaxPool2d}[dim]
return pool_cls(kernel_size, stride)

@staticmethod
def adaptive_avg_pool(output_size, dim=2):
"""Return an adaptive average-pool layer."""
pool_cls = {1: nn.AdaptiveAvgPool1d, 2: nn.AdaptiveAvgPool2d}[dim]
return pool_cls(output_size)


class SmallCNNFeature(BaseCNN):
"""
A feature extractor for small 32x32 images (e.g. CIFAR, MNIST) that outputs a feature vector of length 128.

Expand All @@ -15,34 +46,29 @@ class SmallCNNFeature(nn.Module):
"""

def __init__(self, num_channels=3, kernel_size=5):
super(SmallCNNFeature, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=kernel_size)
self.bn1 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(2)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(64, 64, kernel_size=kernel_size)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(64, 64 * 2, kernel_size=kernel_size)
self.bn3 = nn.BatchNorm2d(64 * 2)
super().__init__()
self.conv1 = self.conv(num_channels, 64, kernel_size)
self.bn1 = self.batch_norm(64)
self.pool1 = self.max_pool(2)
self.conv2 = self.conv(64, 64, kernel_size)
self.bn2 = self.batch_norm(64)
self.pool2 = self.max_pool(2)
self.conv3 = self.conv(64, 64 * 2, kernel_size)
self.bn3 = self.batch_norm(64 * 2)
self.sigmoid = nn.Sigmoid()
self._out_features = 128

def forward(self, input_):
x = self.bn1(self.conv1(input_))
x = self.relu1(self.pool1(x))
x = self.bn2(self.conv2(x))
x = self.relu2(self.pool2(x))
x = F.relu(self.pool1(self.bn1(self.conv1(input_))))
x = F.relu(self.pool2(self.bn2(self.conv2(x))))
x = self.sigmoid(self.bn3(self.conv3(x)))
x = x.view(x.size(0), -1)
return x
return x.view(x.size(0), -1)

def output_size(self):
return self._out_features


class SignalVAEEncoder(nn.Module):
class SignalVAEEncoder(BaseCNN):
"""
SignalVAEEncoder encodes 1D signals into a latent representation suitable for variational autoencoders (VAE).

Expand All @@ -68,9 +94,9 @@ class SignalVAEEncoder(nn.Module):

def __init__(self, input_dim=60000, latent_dim=256):
super().__init__()
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1)
self.conv1 = self.conv(1, 16, 3, stride=2, padding=1, dim=1)
self.conv2 = self.conv(16, 32, 3, stride=2, padding=1, dim=1)
self.conv3 = self.conv(32, 64, 3, stride=2, padding=1, dim=1)
self.flatten = nn.Flatten()
self.fc_mu = nn.Linear(64 * (input_dim // 8), latent_dim)
self.fc_log_var = nn.Linear(64 * (input_dim // 8), latent_dim)
Expand All @@ -81,12 +107,10 @@ def forward(self, x):
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.flatten(x)
mean = self.fc_mu(x)
log_var = self.fc_log_var(x)
return mean, log_var
return self.fc_mu(x), self.fc_log_var(x)


class ProteinCNN(nn.Module):
class ProteinCNN(BaseCNN):
"""
A protein feature extractor using Convolutional Neural Networks (CNNs).

Expand All @@ -102,32 +126,29 @@ class ProteinCNN(nn.Module):
"""

def __init__(self, embedding_dim, num_filters, kernel_size, padding=True):
super(ProteinCNN, self).__init__()
super().__init__()
if padding:
self.embedding = nn.Embedding(26, embedding_dim, padding_idx=0)
else:
self.embedding = nn.Embedding(26, embedding_dim)
in_ch = [embedding_dim] + num_filters
# self.in_ch = in_ch[-1]
kernels = kernel_size
self.conv1 = nn.Conv1d(in_channels=in_ch[0], out_channels=in_ch[1], kernel_size=kernels[0])
self.bn1 = nn.BatchNorm1d(in_ch[1])
self.conv2 = nn.Conv1d(in_channels=in_ch[1], out_channels=in_ch[2], kernel_size=kernels[1])
self.bn2 = nn.BatchNorm1d(in_ch[2])
self.conv3 = nn.Conv1d(in_channels=in_ch[2], out_channels=in_ch[3], kernel_size=kernels[2])
self.bn3 = nn.BatchNorm1d(in_ch[3])
k = kernel_size
self.conv1 = self.conv(in_ch[0], in_ch[1], k[0], dim=1)
self.bn1 = self.batch_norm(in_ch[1], dim=1)
self.conv2 = self.conv(in_ch[1], in_ch[2], k[1], dim=1)
self.bn2 = self.batch_norm(in_ch[2], dim=1)
self.conv3 = self.conv(in_ch[2], in_ch[3], k[2], dim=1)
self.bn3 = self.batch_norm(in_ch[3], dim=1)

def forward(self, v):
v = self.embedding(v.long())
v = v.transpose(2, 1)
v = self.embedding(v.long()).transpose(2, 1)
v = self.bn1(F.relu(self.conv1(v)))
v = self.bn2(F.relu(self.conv2(v)))
v = self.bn3(F.relu(self.conv3(v)))
v = v.view(v.size(0), v.size(2), -1)
return v
return v.view(v.size(0), v.size(2), -1)


class LeNet(nn.Module):
class LeNet(BaseCNN):
"""LeNet is a customizable Convolutional Neural Network (CNN) model based on the LeNet architecture, designed for
feature extraction from image and audio modalities.
LeNet supports several layers of 2D convolution, followed by batch normalization, max pooling, and adaptive
Expand Down Expand Up @@ -161,20 +182,18 @@ def __init__(
linear=None,
squeeze_output=True,
):
super(LeNet, self).__init__()
super().__init__()
self.output_each_layer = output_each_layer
self.conv_layers = [nn.Conv2d(input_channels, output_channels, kernel_size=5, padding=2, bias=False)]
self.batch_norms = [nn.BatchNorm2d(output_channels)]
self.global_pools = [nn.AdaptiveAvgPool2d(1)]
self.conv_layers = [self.conv(input_channels, output_channels, 5, padding=2, bias=False)]
self.batch_norms = [self.batch_norm(output_channels)]
self.global_pools = [self.adaptive_avg_pool(1)]

for i in range(additional_layers):
self.conv_layers.append(
nn.Conv2d(
(2**i) * output_channels, (2 ** (i + 1)) * output_channels, kernel_size=3, padding=1, bias=False
)
)
self.batch_norms.append(nn.BatchNorm2d(output_channels * (2 ** (i + 1))))
self.global_pools.append(nn.AdaptiveAvgPool2d(1))
in_ch = (2**i) * output_channels
out_ch = (2 ** (i + 1)) * output_channels
self.conv_layers.append(self.conv(in_ch, out_ch, 3, padding=1, bias=False))
self.batch_norms.append(self.batch_norm(out_ch))
self.global_pools.append(self.adaptive_avg_pool(1))

self.conv_layers = nn.ModuleList(self.conv_layers)
self.batch_norms = nn.ModuleList(self.batch_norms)
Expand Down Expand Up @@ -203,16 +222,11 @@ def forward(self, x):
intermediate_outputs.append(output)

if self.output_each_layer:
if self.squeeze_output:
return [t.squeeze() for t in intermediate_outputs]
return intermediate_outputs

if self.squeeze_output:
return output.squeeze()
return output
return [t.squeeze() for t in intermediate_outputs] if self.squeeze_output else intermediate_outputs
return output.squeeze() if self.squeeze_output else output


class ImageVAEEncoder(nn.Module):
class ImageVAEEncoder(BaseCNN):
"""
ImageVAEEncoder encodes 2D image data into a latent representation for use in a Variational Autoencoder (VAE).

Expand Down Expand Up @@ -243,10 +257,9 @@ class ImageVAEEncoder(nn.Module):

def __init__(self, input_channels=1, latent_dim=256):
super().__init__()
# Convolutional layers for 224x224 input
self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.conv1 = self.conv(input_channels, 16, 3, stride=2, padding=1)
self.conv2 = self.conv(16, 32, 3, stride=2, padding=1)
self.conv3 = self.conv(32, 64, 3, stride=2, padding=1)
self.flatten = nn.Flatten()
self.fc_mu = nn.Linear(64 * 28 * 28, latent_dim)
self.fc_log_var = nn.Linear(64 * 28 * 28, latent_dim)
Expand All @@ -267,6 +280,4 @@ def forward(self, x):
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.flatten(x)
mean = self.fc_mu(x)
log_var = self.fc_log_var(x)
return mean, log_var
return self.fc_mu(x), self.fc_log_var(x)