In [40]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, roc_curve
import matplotlib.pyplot as plt
import pickle
import zipfile

In [41]:
# Step 1: Unzip the ZIP files
zip_filenames = ['Dengue-4-sequences.zip', 'Ebola-sequences.zip', 'SARS-CoV-2-sequences.zip', 'hepatitis-C-3a-sequences.zip', 'influenza-A-sequences.zip', 'mers-sequences.zip']

for zip_filename in zip_filenames:
    with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
        zip_ref.extractall()

fasta_filenames = [filename[:-4] for filename in zip_filenames]

In [42]:
# Step 2: Apply k-mer encoding
def k_mer_enc(filename: str, k: int):
    with open(filename + '.fasta', 'r') as file:
        seqs = []
        seq = ""
        for line in file.readlines():
            if line.startswith('>'):
                if len(seq) > 0:
                    seqs.append(seq)
                    seq = ""
            else:
                seq += line.strip('\n')
        if len(seq) > 0:
            seqs.append(seq)

        print(f'The file {filename} has {len(seqs)} sequences.')

        encodings = []
        for seq in seqs:
            encoding = []
            code = 0
            for c in seq:
                code *= 4
                if c == 'A':
                    code += 0
                if c == 'C':
                    code += 1
                if c == 'G':
                    code += 2
                if c == 'T':
                    code += 3
                code %= 4 ** k
                encoding.append(code)
            assert len(seq) == len(encoding), 'Error: Unmatched number of characters!'
            encodings.append(encoding)
        assert len(seqs) == len(encodings), 'Error: Unmatched number of sequences!'

        with open(filename + '.pkl', 'wb') as pkfile:
            pickle.dump(encodings, pkfile)

filenames = ['Dengue-4-sequences', 'Ebola-sequences', 'hepatitis-C-3a-sequences',
             'influenza-A-sequences', 'mers-sequences', 'SARS-CoV-2-sequences']
for filename in filenames:
    k_mer_enc(filename, 4)


The file Dengue-4-sequences has 3587 sequences.
The file Ebola-sequences has 4001 sequences.
The file hepatitis-C-3a-sequences has 3331 sequences.
The file influenza-A-sequences has 3669 sequences.
The file mers-sequences has 1633 sequences.
The file SARS-CoV-2-sequences has 4752 sequences.


In [43]:
# Step 3: Build the ConvLSTM model
class ConvLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, kernel_size, num_layers, dropout):
        super(ConvLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout = dropout

        self.embed = nn.Embedding(input_size, hidden_size)
        self.convlstm = nn.LSTM(hidden_size, hidden_size, num_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.embed(x)
        x, _ = self.convlstm(x)
        x = self.fc(x[:, -1, :])
        x = self.sigmoid(x)
        return x

In [44]:
# Step 4: Model training and evaluation
class VirusDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings[idx], self.labels[idx]

In [45]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0

    for inputs, targets in dataloader:
        # Move tensors to the correct device
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(dataloader)


In [46]:
def evaluate(model, dataloader):
    model.eval()
    predictions = []
    truths = []

    with torch.no_grad():
        for batch in dataloader:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            predictions.extend(predicted.cpu().numpy())
            truths.extend(labels.cpu().numpy())

    return predictions, truths

In [47]:
def plot_roc_curve(y_true, y_pred_prob, num_classes):
    fpr = dict()
    tpr = dict()

    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true, y_pred_prob[:, i], pos_label=i)

    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], label=f"Class {i}")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    plt.show()

In [57]:
# Parameters
input_size = 4**4  # k-mer encoding with k=4 gives 256 possible combinations
hidden_size = 256
output_size = len(fasta_filenames)  # Number of virus classes
kernel_size = 3
num_layers = 3
dropout = 0.5
num_epochs = 25
batch_size = 32
learning_rate = 0.001

In [58]:
# Load data
encodings = []
labels = []
for label, fasta_filename in enumerate(fasta_filenames):
    with open(fasta_filename + '.pkl', 'rb') as pkfile:
        file_encodings = pickle.load(pkfile)
        encodings.extend([np.array(enc) for enc in file_encodings])  # Convert nested lists to NumPy arrays
        labels.extend([label] * len(file_encodings))

encodings = np.array(encodings, dtype=object)
labels = np.array(labels)

# Find the maximum sequence length
max_seq_len = max([enc.shape[0] for enc in encodings])

# Pad encodings with zeros to create tensors
encodings_padded = []
for enc in encodings:
    pad_rows = max_seq_len - enc.shape[0]
    enc_padded = np.pad(enc, (0, pad_rows), mode='constant', constant_values=0)
    encodings_padded.append(enc_padded)

encodings = np.stack(encodings_padded)

In [59]:
# Step 1: Split the data into train, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(encodings, labels, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

train_dataset = VirusDataset(X_train, y_train)
val_dataset = VirusDataset(X_val, y_val)
test_dataset = VirusDataset(X_test, y_test)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [60]:
# Step 2: Create an EarlyStopping class
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            print(f"Early stopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.stop = True

In [61]:
# Initialize model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvLSTM(input_size, hidden_size, output_size, kernel_size, num_layers, dropout).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [62]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True, threshold=0.00001)

In [63]:
# Step 3: # Train model Use the EarlyStopping class in the training loop
early_stopping = EarlyStopping(patience=8, min_delta=0.001)

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_dataloader, criterion, optimizer)
    _, y_val_pred = evaluate(model, val_dataloader)
    val_loss = criterion(torch.tensor(y_val_pred), torch.tensor(y_val)).item()
    
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    scheduler.step(train_loss)
    early_stopping(val_loss)
    if early_stopping.stop:
        print("Early stopping triggered")
        break

RuntimeError: CUDA out of memory. Tried to allocate 7.94 GiB (GPU 0; 23.69 GiB total capacity; 21.52 GiB already allocated; 420.69 MiB free; 22.01 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
torch.save(model.state_dict(), "virus_classification_model_0.001.pt")

In [None]:
# Evaluate model
y_pred, y_true = evaluate(model, test_dataloader)

# Metrics
test_loss = criterion(torch.tensor(y_pred), torch.tensor(y_true)).item()
test_acc = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
specificity = recall_score(y_true, y_pred, average='macro', pos_label=0)
f_score = f1_score(y_true, y_pred, average='macro')
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Acc: {test_acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"Specificity: {specificity:.4f}")
print(f"F-score: {f_score:.4f}")

In [None]:
# AUC scores
y_pred_prob = torch.softmax(torch.tensor(y_pred), dim=1).numpy()
for i in range(output_size):
    auc = roc_auc_score(np.array(y_true) == i, y_pred_prob[:, i])
    print(f"AUC score (Class {i} vs. Other Classes): {auc:.4f}")

In [None]:
# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix:\n", cm)

In [None]:
# ROC Curve
plot_roc_curve(y_true, y_pred_prob, output_size)

In [None]:
import seaborn as sns

def plot_confusion_matrix(cm, labels):
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

plot_confusion_matrix(cm, fasta_filenames)

In [None]:
for i in range(output_size):
    auc = roc_auc_score(np.array(y_true) == i, y_pred_prob[:, i])
    print(f"AUC score (Class {i} vs. Other Classes): {auc:.4f}")

In [55]:
import torch
torch.cuda.empty_cache()