In [1]:
from get_model.model.model import GETFinetuneAuto
import torch.nn.functional as F
from torch import nn
import torch
from timm import create_model

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [56]:
class motif2seqScanner(nn.Module):
    """
    A motif decoder based on Conv1D to transform motif embeddings to sequence representations. [BATCH_SIZE, SEQ_LEN, MOTIF] -> [BATCH_SIZE, SEQ_LEN, 4]
    """
    def __init__(self, motif_dim, hidden_dim, output_dim=4):
        super(motif2seqScanner, self).__init__()

        # Ensure that padding is set to maintain SEQ_LEN consistency across transformations
        self.conv1 = nn.Conv1d(motif_dim, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, output_dim, kernel_size=3, padding=1)

    def forward(self, x):
        # x shape: (BATCH_SIZE, SEQ_LEN, MOTIF_DIM), but Conv1d expects (BATCH_SIZE, MOTIF_DIM, SEQ_LEN)
        x = x.transpose(1, 2)  # Transpose to fit Conv1d input requirements

        # Apply first Conv1d to get to hidden dimension
        x = F.relu(self.conv1(x))

        # Apply second Conv1d to get to output dimension (4)
        x = self.conv2(x)  # No activation here, assuming the next step involves a softmax or similar

        # Revert to (BATCH_SIZE, SEQ_LEN, 4) for consistency with expected output
        x = x.transpose(1, 2)
        return x
dropout = 0.5

encoder = nn.Sequential(
    # Projection from [B, 1024, 256, 1] -> [B, 16, 4, 64]
    nn.Conv2d(1, 2, kernel_size=(2, 2), stride=2),  # Output: [100, 512, 320, 2]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.Conv2d(2, 4, kernel_size=(2, 2), stride=2),  # Output: [100, 256, 160, 4]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.Conv2d(4, 8, kernel_size=(2, 2), stride=2),  # Output: [100, 128, 80, 8]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.Conv2d(8, 16, kernel_size=(2, 2), stride=2),  # Output: [100, 64, 40, 16]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.Conv2d(16, 32, kernel_size=(2, 2), stride=2),  # Output: [100, 32, 20, 32]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.Conv2d(32, 64, kernel_size=(2, 2), stride=2),  # Output: [100, 16, 10, 64]
    nn.ReLU(),
    nn.Dropout2d(dropout)
)

decoder = nn.Sequential(
    nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=2),  # Output: [100, 32, 20, 32]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=2),  # Output: [100, 64, 40, 16]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.ConvTranspose2d(16, 8, kernel_size=(2, 2), stride=2),  # Output: [100, 128, 80, 8]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=2),  # Output: [100, 256, 160, 4]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.ConvTranspose2d(4, 2, kernel_size=(2, 2), stride=2),  # Output: [100, 512, 320, 2]
    nn.ReLU(),
    nn.Dropout2d(dropout),
    nn.ConvTranspose2d(2, 1, kernel_size=(2, 2), stride=2),  # Output: [100, 1024, 640, 1]
    nn.ReLU()
)
motif_proj = nn.Conv1d(639, 256, 1, bias=True)
motifdecoder = motif2seqScanner(639, 639//2, 4)


In [73]:
x = torch.randn([16, 11240, 639])
x = x.transpose(1, 2)  # Now x has shape [128, 639, 1024]
print(x.shape)
x = motif_proj(x)
print(x.shape)
x = x.unsqueeze(1)  # Add channel dimension: [16, 1, 256, 1024]
print(x.shape)

torch.Size([16, 639, 11240])
torch.Size([16, 256, 11240])
torch.Size([16, 1, 256, 11240])


In [72]:
latent = encoder(x)
print(latent.shape)
output = decoder(latent)
print(output.shape)
output = output.squeeze(1)
#print(output.shape)
motif_proj_back = nn.Conv1d(in_channels=256, out_channels=639, kernel_size=1,bias=True)
reformed_x = motif_proj_back(output)
print(reformed_x.shape)
output_x = motifdecoder(reformed_x.transpose(1,2))
#print(result.shape)
print(output_x.transpose(1,2).shape)

torch.Size([16, 64, 4, 171])
torch.Size([16, 1, 256, 10944])
torch.Size([16, 639, 10944])
torch.Size([16, 4, 10944])
