In [47]:
import torch
import random
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [42]:

class CNNResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
    super(CNNResidualBlock, self).__init__()
    self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
    self.bn1 = nn.BatchNorm1d(out_channels)
    self.relu = nn.ReLU()
    self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
    self.bn2 = nn.BatchNorm1d(out_channels)
    self.shortcut = nn.Sequential()
    if stride != 1 or in_channels != out_channels:
      self.shortcut = nn.Sequential(
        nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride*2, bias=False),
        nn.BatchNorm1d(out_channels)
      )

  def forward(self, x):
    residual = self.shortcut(x)
    # print(f"Residual shape: {residual.shape}")
    out = self.conv1(x)
    # print(f"Conv1 shape: {out.shape}")
    out = self.bn1(out)
    # print(f"BN1 shape: {out.shape}")
    out = F.relu(out)
    # print(f"ReLU shape: {out.shape}")
    out = self.conv2(out)
    # print(f"Conv2 shape: {out.shape}")
    out = self.bn2(out)
    # print(f"BN2 shape: {out.shape}")
    out += residual
    # print(f"Residual added shape: {out.shape}")
    out = F.relu(out)
    # print(f"Final ReLU shape: {out.shape}")
    return out

class BacteriaNN(nn.Module):
    def __init__(self, embedding_dim, output_dim, kernel_size, padding, avg_poolsize, embedding_reduction_factor, resblock_num, stride, dropout_rate = 0.2):
        super(BacteriaNN, self).__init__()
        self.conv1 = nn.Conv1d(embedding_dim, embedding_dim//2, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn1 = nn.BatchNorm1d(embedding_dim//2)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(kernel_size=kernel_size, stride=stride, padding=padding)
        for i in range(resblock_num):
            setattr(self, f'resblock{i+1}', CNNResidualBlock(embedding_dim//2, embedding_dim//2, kernel_size=kernel_size, padding=padding, stride=stride))
        self.avgpool = nn.AdaptiveAvgPool1d(avg_poolsize)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(avg_poolsize*embedding_dim//2, output_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        # print(f"Input shape: {x.shape}")
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.maxpool(x)
        # print(f"After maxpool: {x.shape}")
        for i in range(resblock_num):
            x = getattr(self, f'resblock{i+1}')(x)
            # print(f"After resblock{i+1}: {x.shape}")
        # print(f"After resblock4: {x.shape}")
        x = self.avgpool(x)
        # print(f"After avgpool: {x.shape}")
        x = x.view(x.size(0), -1)
        # print(f"After view: {x.shape}")
        x = self.dropout(x)
        # print(f"After dropout: {x.shape}")
        x = self.fc(x)
        return x


In [43]:
class ViralNN(nn.Module):
    def __init__(self, embedding_dim, length, output_dim, hidden_dims=[128, 64], dropout_rate = 0.2):
        super(ViralNN, self).__init__()

        fc_input_dim = embedding_dim * length
        layers = []
        current_dim = fc_input_dim

        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            current_dim = hidden_dim

        layers.append(nn.Linear(current_dim, output_dim))
        self.fc_layers = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

In [53]:
data_size = 1234
MAX_BACTERIA_LENGTH = 50000000
bacteria_max_seq_len = int(MAX_BACTERIA_LENGTH/131000)
viral_max_seq_len = 2
embedding_dim = 64 # use small to test, need to change to 4096
embedding_reduction_factor = 2
kernel_size = 11 # set to odd number smaller than embedding size
padding = (kernel_size - 1) // 2
stride = 2
avg_poolsize = 2
resblock_num = 2
output_dim = 256
dropout_rate = 0.2
hidden_dims = [embedding_dim//2, embedding_dim//4]

class PhINN(nn.Module):
    def __init__(self, bacteria_nn, viral_nn, output_dim):
        super(PhINN, self).__init__()
        self.bacteria_nn = bacteria_nn
        self.viral_nn = viral_nn
        self.fc = nn.Linear(output_dim*2, output_dim)
        self.fc2 = nn.Linear(output_dim, 1)

    def forward(self, bacteria_input, viral_input):
        bacteria_output = self.bacteria_nn(bacteria_input)
        viral_output = self.viral_nn(viral_input)

        combined_output = torch.cat((bacteria_output, viral_output), dim=1)
        output = self.fc(combined_output)
        logits = self.fc2(output)
        probs = torch.sigmoid(logits)
        return probs


bacteria_nn = BacteriaNN(embedding_dim = embedding_dim, kernel_size = kernel_size, padding = padding, stride = stride, avg_poolsize = avg_poolsize, resblock_num = resblock_num, embedding_reduction_factor = embedding_reduction_factor, output_dim = output_dim)
viral_nn = ViralNN(embedding_dim = embedding_dim, length=viral_max_seq_len, output_dim=output_dim, hidden_dims=hidden_dims, dropout_rate = dropout_rate)

# Create an instance of PhINN
phinn_model = PhINN(bacteria_nn, viral_nn, output_dim=output_dim).to(device)

# Example input tensors
bacteria_input = torch.randn(data_size, bacteria_max_seq_len, embedding_dim)
viral_input = torch.randn(data_size, viral_max_seq_len, embedding_dim)
example_labels = torch.randint(0, 2, (data_size,))

In [54]:
batch_size = 12
criterion = nn.BCELoss()
optimizer = optim.Adam(phinn_model.parameters(), lr=0.001)

num_epochs = 10
data_size = len(bacteria_input)

for epoch in range(num_epochs):
    phinn_model.train()
    train_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    num_batches = (data_size + batch_size - 1) // batch_size

    for i in tqdm(range(0, data_size, batch_size)):
        # Get the batch
        bacteria_input_load = bacteria_input[i:i+batch_size]
        viral_input_load = viral_input[i:i+batch_size]
        labels_load = example_labels[i:i+batch_size]

        bacteria_input_load = bacteria_input_load.to(device)
        viral_input_load = viral_input_load.to(device)
        labels_load = labels_load.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = phinn_model(bacteria_input_load, viral_input_load)

        # Compute the loss
        loss = criterion(outputs.squeeze(), labels_load.float())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate loss
        train_loss += loss.item()

        # Calculate accuracy
        predicted_labels = (outputs.squeeze() > 0.5).float()  # Threshold at 0.5
        correct_predictions += (predicted_labels == labels_load).sum().item()
        total_samples += labels_load.size(0)

    # Calculate average loss and accuracy for the epoch
    avg_loss = train_loss / num_batches
    accuracy = correct_predictions / total_samples

    # Print the average loss and accuracy for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")


100%|██████████| 103/103 [00:01<00:00, 63.48it/s]


Epoch 1/10, Loss: 0.7175, Accuracy: 0.4903


100%|██████████| 103/103 [00:01<00:00, 61.73it/s]


Epoch 2/10, Loss: 0.5432, Accuracy: 0.7334


100%|██████████| 103/103 [00:01<00:00, 67.68it/s]


Epoch 3/10, Loss: 0.2806, Accuracy: 0.8849


100%|██████████| 103/103 [00:01<00:00, 64.78it/s]


Epoch 4/10, Loss: 0.2410, Accuracy: 0.9100


100%|██████████| 103/103 [00:01<00:00, 59.33it/s]


Epoch 5/10, Loss: 0.1048, Accuracy: 0.9643


100%|██████████| 103/103 [00:01<00:00, 66.32it/s]


Epoch 6/10, Loss: 0.0205, Accuracy: 0.9943


100%|██████████| 103/103 [00:01<00:00, 67.77it/s]


Epoch 7/10, Loss: 0.0146, Accuracy: 0.9943


100%|██████████| 103/103 [00:01<00:00, 68.36it/s]


Epoch 8/10, Loss: 0.0767, Accuracy: 0.9708


100%|██████████| 103/103 [00:01<00:00, 68.16it/s]


Epoch 9/10, Loss: 0.0813, Accuracy: 0.9668


100%|██████████| 103/103 [00:01<00:00, 69.62it/s]

Epoch 10/10, Loss: 0.0321, Accuracy: 0.9919



