In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
from Bio import SeqIO


if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f"Using device: {device}")

Using device: mps


In [2]:
# Function to read sequences from a FASTA file
def read_fasta(file_path):
    try:
        sequences = []
        for record in SeqIO.parse(file_path, "fasta"):
            sequences.append(str(record.seq))
        return sequences
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return []

# Load positive and negative sequences
positive_pairs = read_fasta('data/positive_pairs.fasta')
negative_pairs = read_fasta('data/negative_pairs.fasta')

print(f"Number of positive sequences: {len(positive_pairs)}")
print(f"Number of negative sequences: {len(negative_pairs)}")

Number of positive sequences: 1047308
Number of negative sequences: 6679


In [3]:
# Function to check if a sequence contains only standard amino acids
def is_valid_sequence(sequence):
    valid_vocab = set('ABCDEFGHIKLMNPQRSTUVWYXZ_')
    return all(char in valid_vocab for char in sequence), set(sequence) - valid_vocab

# Filter the sequence dictionary
filtered_positive = []
filtered_negative = []

for seq in positive_pairs:
    is_valid, invalid_chars = is_valid_sequence(seq)
    if is_valid:
        filtered_positive.append(seq)
    else:
        print(f"Sequence {seq} contains unknown amino acids: {', '.join(invalid_chars)} and will be excluded.")

for seq in negative_pairs:
    is_valid, invalid_chars = is_valid_sequence(seq)
    if is_valid:
        filtered_negative.append(seq)
    else:
        print(f"Sequence {seq} contains unknown amino acids: {', '.join(invalid_chars)} and will be excluded.")

print(f"Number of filtered positive sequences: {len(filtered_positive)}")
print(f"Number of filtered negative sequences: {len(filtered_negative)}")

Number of filtered positive sequences: 1047308
Number of filtered negative sequences: 6679


In [4]:
# Create labels
positive_labels = [1] * len(filtered_positive)
negative_labels = [0] * len(filtered_negative)

# Combine sequences and labels
all_sequences = filtered_positive + filtered_negative
all_labels = positive_labels + negative_labels

In [5]:
# Collect all unique amino acid letters
all_letters = set(''.join(all_sequences))
print(f'Unique amino acids: {all_letters}')

# Create a mapping from letters to integers, reserving 0 for padding
letter_to_int = {letter: idx + 1 for idx, letter in enumerate(sorted(all_letters))}
letter_to_int['<PAD>'] = 0  # Padding token

print(f'Letter to integer mapping: {letter_to_int}')

Unique amino acids: {'E', 'P', 'Z', 'C', 'N', 'Q', '_', 'G', 'W', 'A', 'S', 'T', 'U', 'D', 'Y', 'I', 'K', 'L', 'V', 'X', 'M', 'B', 'H', 'F', 'R'}
Letter to integer mapping: {'A': 1, 'B': 2, 'C': 3, 'D': 4, 'E': 5, 'F': 6, 'G': 7, 'H': 8, 'I': 9, 'K': 10, 'L': 11, 'M': 12, 'N': 13, 'P': 14, 'Q': 15, 'R': 16, 'S': 17, 'T': 18, 'U': 19, 'V': 20, 'W': 21, 'X': 22, 'Y': 23, 'Z': 24, '_': 25, '<PAD>': 0}


In [6]:
# Encode the sequences and convert to tensors (skip padding for now)
encoded_sequences = []
for seq in all_sequences:
    try:
        encoded_seq = [letter_to_int[aa] for aa in seq]
        encoded_sequences.append(torch.tensor(encoded_seq, dtype=torch.long))
    except KeyError as e:
        print(f"Error encoding sequence: {seq}. Unmapped character {e}.")

all_labels_tensor = torch.tensor(all_labels, dtype=torch.long)

In [7]:
# Set maximum sequence length
max_seq_len = 3000

# Pair sequences with their corresponding labels
filtered_sequences_with_labels = [
    (seq, label) for seq, label in zip(encoded_sequences, all_labels_tensor) if len(seq) <= max_seq_len
]

print(f"Number of sequences after filtering: {len(filtered_sequences_with_labels)}")

# Define a function to pad a batch of sequences and keep labels
def batch_pad_with_labels(batch, max_length, padding_value):
    padded_sequences = []
    labels = []
    for seq, label in batch:
        padding_needed = max_length - len(seq)
        padded_seq = torch.cat([seq, torch.full((padding_needed,), padding_value, dtype=torch.long)])
        padded_sequences.append(padded_seq)
        labels.append(label)
    return torch.stack(padded_sequences), torch.tensor(labels, dtype=torch.long)

# Process sequences and labels in batches
batch_size = 64
padded_batches = []
label_batches = []
for i in range(0, len(filtered_sequences_with_labels), batch_size):
    batch = filtered_sequences_with_labels[i:i + batch_size]
    padded_batch, label_batch = batch_pad_with_labels(batch, max_seq_len, letter_to_int['<PAD>'])
    padded_batches.append(padded_batch)
    label_batches.append(label_batch)

# Combine all padded batches into a single tensor
padded_sequences = torch.cat(padded_batches, dim=0)
labels = torch.cat(label_batches, dim=0)

print(f"Padded sequences shape: {padded_sequences.shape}")
print(f"Labels shape: {labels.shape}")

Number of sequences after filtering: 1020444
Padded sequences shape: torch.Size([1020444, 3000])
Labels shape: torch.Size([1020444])


In [8]:
# Create a simple Dataset class
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]

# Create indices for splitting
indices = np.arange(len(labels))
labels_np = labels.cpu().numpy()  # Convert labels to numpy for stratification

# Split indices instead of actual data
train_idx, temp_idx = train_test_split(
    indices,
    test_size=0.3,
    random_state=42,
    stratify=labels_np
)

val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.5,
    random_state=42,
    stratify=labels_np[temp_idx]
)

# Create datasets using indices
train_dataset = SequenceDataset(padded_sequences[train_idx], labels[train_idx])
val_dataset = SequenceDataset(padded_sequences[val_idx], labels[val_idx])
test_dataset = SequenceDataset(padded_sequences[test_idx], labels[test_idx])

# Create data loaders
batch_size = 32  # Adjust this based on your memory constraints
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Print summary
print("\nData split summary:")
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

# Calculate class distribution for each set
def get_class_distribution(dataset):
    pos = sum(label.item() == 1 for _, label in dataset)
    total = len(dataset)
    return pos, total - pos

train_pos, train_neg = get_class_distribution(train_dataset)
val_pos, val_neg = get_class_distribution(val_dataset)
test_pos, test_neg = get_class_distribution(test_dataset)

print(f"\nTraining set distribution:")
print(f"  Positive: {train_pos} ({train_pos/len(train_dataset)*100:.1f}%)")
print(f"  Negative: {train_neg} ({train_neg/len(train_dataset)*100:.1f}%)")

print(f"\nValidation set distribution:")
print(f"  Positive: {val_pos} ({val_pos/len(val_dataset)*100:.1f}%)")
print(f"  Negative: {val_neg} ({val_neg/len(val_dataset)*100:.1f}%)")

print(f"\nTest set distribution:")
print(f"  Positive: {test_pos} ({test_pos/len(test_dataset)*100:.1f}%)")
print(f"  Negative: {test_neg} ({test_neg/len(test_dataset)*100:.1f}%)")

: 

In [None]:
# Create DataLoaders
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

print("DataLoaders created:")
print(f"- Training batches: {len(train_loader)}")
print(f"- Validation batches: {len(val_loader)}")
print(f"- Test batches: {len(test_loader)}")

In [None]:
class_counts = np.bincount(y_train.numpy())  # Count occurrences of each class in y_train
print(f"Class counts: {class_counts}")

# Calculate class weights
class_weights = 1. / class_counts  # Inverse of the count to give more weight to minority classes
print(f"Class weights: {class_weights}")

# Convert class weights to tensor
class_weights_tensor = torch.FloatTensor(class_weights).to(device)
print(f"Class weights tensor: {class_weights_tensor}")

# Apply weights in the loss function
criterion = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)