# Transcription Factor Binding Prediction via LSTMs

In [1]:
# Load libraries
import os
import torch
import torch.nn as nn
import torch.utils.data as du
import joblib
import torch.optim as optim

In [2]:
# Compute dataset class
class JUND_Dataset(du.Dataset):
    def __init__(self, data_dir):
        super(JUND_Dataset, self).__init__()
        self.X = torch.tensor(joblib.load(f'{data_dir}/shard-0-X.joblib'), dtype=torch.float32)
        self.y = torch.tensor(joblib.load(f'{data_dir}/shard-0-y.joblib'), dtype=torch.float32)
        self.w = torch.tensor(joblib.load(f'{data_dir}/shard-0-w.joblib'), dtype=torch.float32)
        self.a = torch.tensor(joblib.load(f'{data_dir}/shard-0-a.joblib'), dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.w[idx], self.a[idx]

In [3]:
# Load dataset
train_dataset = JUND_Dataset('/Users/christophertarkaa/Bioinformatics/Machine Learning/TF_data/train_dataset')
valid_dataset = JUND_Dataset('/Users/christophertarkaa/Bioinformatics/Machine Learning/TF_data/valid_dataset')
test_dataset = JUND_Dataset('/Users/christophertarkaa/Bioinformatics/Machine Learning/TF_data/test_dataset')

In [4]:
# Define LSTM model
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim

        # LSTM layer
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)

        # Fully connected layers
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim + 1, output_dim)

    def forward(self, x, a):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).to(x.device)
        # Initialize cell state with zeros
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).to(x.device)

        # LSTM forward pass
        out, _ = self.lstm(x, (h0, c0))
        
        # We only want the last output of the sequence
        out = out[:, -1, :]
        
        # Pass the last output through the first fully connected layer
        out = self.fc1(out)
        
        # Concatenate the output with the additional feature 'a'
        out = torch.cat((out, a.view(-1, 1)), dim=1)
        
        # Pass the combined features through the second fully connected layer
        out = self.fc2(out)
        
        return out

In [5]:
# Hyperparameters
input_size = 4
hidden_size = 128    
lstm_layers = 1      
output_size = 1      
learning_rate = 0.001
batch_size = 20
epochs = 10

In [6]:
# Initialize model and optimizer
device = torch.device("cpu")  # Set device to CPU (can be changed to 'cuda' for GPU)
model = LSTMModel(input_size, hidden_size, lstm_layers, output_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()  # Assuming a regression problem


In [7]:
# Load training data in batches
train_loader = du.DataLoader(dataset=train_dataset, 
                             batch_size=batch_size, 
                             shuffle=True)

In [8]:
# Send model to device and set to training mode
model.train()


LSTMModel(
  (lstm): LSTM(4, 128, batch_first=True)
  (fc1): Linear(in_features=128, out_features=128, bias=True)
  (fc2): Linear(in_features=129, out_features=1, bias=True)
)

In [9]:
# Training loop
for epoch in range(epochs):
    for batch_idx, (data, additional_feature, target, loss_tensor) in enumerate(train_loader):
        data, additional_feature, target = data.to(device), additional_feature.to(device), target.to(device)
        
        optimizer.zero_grad()  # Zero the gradients
        output = model(data, additional_feature)  # Forward pass
        loss = criterion(output, target)  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        
        if batch_idx % 10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], Loss: {loss.item():.4f}')


Epoch [1/10], Batch [0], Loss: 0.2789
Epoch [1/10], Batch [10], Loss: 0.0275
Epoch [1/10], Batch [20], Loss: 0.0067
Epoch [1/10], Batch [30], Loss: 0.0422
Epoch [1/10], Batch [40], Loss: 0.1185
Epoch [1/10], Batch [50], Loss: 0.2996
Epoch [1/10], Batch [60], Loss: 0.1738
Epoch [1/10], Batch [70], Loss: 0.1528
Epoch [1/10], Batch [80], Loss: 0.0229
Epoch [1/10], Batch [90], Loss: 0.0024
Epoch [1/10], Batch [100], Loss: 0.0023
Epoch [1/10], Batch [110], Loss: 0.0010
Epoch [1/10], Batch [120], Loss: 0.1022
Epoch [1/10], Batch [130], Loss: 0.0821
Epoch [1/10], Batch [140], Loss: 0.3605
Epoch [1/10], Batch [150], Loss: 0.2456
Epoch [1/10], Batch [160], Loss: 0.3217
Epoch [1/10], Batch [170], Loss: 1.5837
Epoch [1/10], Batch [180], Loss: 0.4761
Epoch [1/10], Batch [190], Loss: 0.0707
Epoch [1/10], Batch [200], Loss: 0.0361
Epoch [1/10], Batch [210], Loss: 0.0240
Epoch [1/10], Batch [220], Loss: 0.0098
Epoch [1/10], Batch [230], Loss: 0.0117
Epoch [1/10], Batch [240], Loss: 0.1388
Epoch [1/10

In [10]:
# Save the model after training
torch.save(model.state_dict(), 'lstm_model.pth')

In [11]:
# Load validation data in batches
valid_loader = du.DataLoader(dataset=valid_dataset, 
                             batch_size=batch_size, 
                             shuffle=True)

In [13]:
# Evaluation loop
model.eval()
valid_loss = 0.0
total_samples = 0
total_weighted_correct = 0.0

with torch.no_grad():
    for data, target, weight, additional_feature in valid_loader:
        data, target, weight, additional_feature = data.to(device), target.to(device), weight.to(device), additional_feature.to(device)
        
        output = model(data, additional_feature)
        
        # Compute the weighted MSE loss manually
        mse_loss = (weight * (output - target) ** 2).mean()
        test_loss += mse_loss.item() * data.size(0)
        
        # Compute weighted prediction accuracy (assuming binary classification)
        predicted = (output > 0).float()
        correct = (predicted == target).float() * weight
        total_weighted_correct += correct.sum().item()
        total_samples += weight.sum().item()
    
    # Calculate average loss and accuracy
    valid_loss /= total_samples
    valid_weighted_accuracy = total_weighted_correct / total_samples

    print(f"Validation Weighted MSE Loss: {valid_loss:.4f}")
    print(f"Validation Weighted Accuracy: {valid_weighted_accuracy:.4f}") 

Validation Weighted MSE Loss: 0.0000
Validation Weighted Accuracy: 0.5000


In [14]:
# Load test data in batches
test_loader = du.DataLoader(dataset=test_dataset, 
                             batch_size=batch_size, 
                             shuffle=False)

In [15]:
# Load test data in batches
test_loader = du.DataLoader(dataset=test_dataset, 
                            batch_size=batch_size, 
                            shuffle=False)

# Set model to evaluation mode
model.eval()

# Disable gradient calculation
with torch.no_grad():
    test_loss = 0.0
    test_correct = 0
    test_total_weight = 0.0

    for data, target, weight, accessibility in test_loader:  # Unpacking four values
        data, target, weight, accessibility = data.to(device), target.to(device), weight.to(device), accessibility.to(device)
        
        # Forward pass
        output = model(data, accessibility)
        
        # Calculate loss
        loss = criterion(output, target)
        test_loss += loss.item()
        
        # Compute weighted prediction accuracy
        predicted = (output > 0).float()
        test_correct += (predicted == target).sum().item()
        test_total_weight += weight.sum().item()  # Use weight.sum() as before for total weight
        
    # Calculate and print the average loss and accuracy
    test_loss /= len(test_loader.dataset)
    test_accuracy = test_correct / test_total_weight
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")


Test Loss: 0.0508, Test Accuracy: 0.0042
