# TRANSCRIPTION FACTOR BINDING PREDICTION USING MULTILAYER PERCEPTRON

#### Load packages

In [56]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as du
import joblib

#### Compute dataset class

In [57]:
class JUND_Dataset(du.Dataset):
    def __init__(self, data_dir):
        super(JUND_Dataset, self).__init__()
        self.X = torch.tensor(joblib.load(f'{data_dir}/shard-0-X.joblib'), dtype=torch.float32)
        self.y = torch.tensor(joblib.load(f'{data_dir}/shard-0-y.joblib'), dtype=torch.float32)
        self.w = torch.tensor(joblib.load(f'{data_dir}/shard-0-w.joblib'), dtype=torch.float32)
        self.a = torch.tensor(joblib.load(f'{data_dir}/shard-0-a.joblib'), dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.w[idx], self.a[idx]


#### Loading Datasets

In [58]:
train_dataset = JUND_Dataset('/Users/christophertarkaa/Bioinformatics/Machine Learning/TF_data/train_dataset')
valid_dataset = JUND_Dataset('/Users/christophertarkaa/Bioinformatics/Machine Learning/TF_data/valid_dataset')
test_dataset = JUND_Dataset('/Users/christophertarkaa/Bioinformatics/Machine Learning/TF_data/test_dataset')

#### Define MLP model

In [59]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        '''in_dim: input layer dim
            hidden_dim: hidden layer dim
            out_dim: output layer dim'''
        
        super(MLP, self).__init__()
    
        # nucleotides are 101 x 4 so flatten them into 404d vector
        self.flatten = nn.Flatten()

        #two fully connected layers
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim + 1, out_dim) 
    
    def forward(self, x, a):
        # since x is 104 x 4, flatten it first
        x = self.flatten(x)

        # compute output of fc1, and apply relu activation
        x = nn.functional.relu(self.fc1(x))
        a = a.view(-1, 1)
        # compute output layer
        x = torch.cat((x, a), dim=1)
        x = self.fc2(x)
        return x

#### Set up training

In [60]:
# Hyperparameters
input_size = 101*4 
hidden_size = 128  
output_size = 1      

learning_rate = 0.001
batch_size = 20
epochs = 5

# Initialize model and optimizer
device = torch.device("cpu")  # Explicitly set to CPU
model = MLP(input_size, hidden_size, output_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# set model and optimizer
model = MLP(101*4, 128, 1)
optimizer =optim.Adam(model.parameters(), lr=learning_rate)

# Load training data in batches
train_loader = du.DataLoader(dataset=train_dataset, 
                             batch_size=batch_size, 
                             shuffle=True)

# Send model over to device
model = model.to(device)
model.train()

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=404, out_features=128, bias=True)
  (fc2): Linear(in_features=129, out_features=1, bias=True)
)

#### Training Loop over batches

In [61]:
for epoch in range(1, epochs + 1):
    sum_loss = 0.
    for batch_idx, (data, target, weight, accessibility) in enumerate(train_loader):
        # send batch over to device
        data, target, weight, accessibility = data.to(device), target.to(device), weight.to(device), accessibility.to(device)

        # zero out prev gradients
        optimizer.zero_grad()

        # run the forward pass
        output = model(data, accessibility)

        # compute loss/error
        loss = nn.BCEWithLogitsLoss(weight=weight)(output, target)

        # sum up batch losses
        sum_loss += loss.item()

        # compute gradients and take a step
        loss.backward()
        optimizer.step()

    # average loss per example
    sum_loss /= len(train_loader.dataset)
    print(f'Epoch: {epoch}, Loss: {sum_loss:.6f}')

    # average loss per batch
    avg_loss = sum_loss / len(train_loader)
    print(f"Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.6f}")



Epoch: 1, Loss: 0.038957
Epoch [1/5] - Train Loss: 0.000003
Epoch: 2, Loss: 0.039468
Epoch [2/5] - Train Loss: 0.000003
Epoch: 3, Loss: 0.036321
Epoch [3/5] - Train Loss: 0.000003
Epoch: 4, Loss: 0.034099
Epoch [4/5] - Train Loss: 0.000002
Epoch: 5, Loss: 0.033095
Epoch [5/5] - Train Loss: 0.000002


#### Evaluate on validation data

In [75]:
# Load validation data in batches
valid_loader = du.DataLoader(dataset=valid_dataset, 
                             batch_size=batch_size, 
                             shuffle=True)
# Initialize correct beffore the loop
correct = 0
# Set model to evaluation mode
model.eval()
with torch.no_grad():
    valid_loss = 0.0
    total_weight = 0.0
    for data, target, weight, accessibility in valid_loader:
        data, target, weight, accessibility = data.to(device), target.to(device), weight.to(device), accessibility.to(device)
        output = model(data, accessibility)
        valid_loss += nn.BCEWithLogitsLoss(weight=weight)(output, target).item()
            
        # Compute weighted prediction accuracy
        predicted = (output > 0).float()
        correct += (predicted == target).sum().item()
        total_weight += weight.sum().item()
    
    # Compute weighted accuracy 
    valid_accuracy = correct / total_weight
    print(f"Epoch [{epoch}/{epochs}] - Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}")

Epoch [5/5] - Valid Loss: 1188.5757, Valid Accuracy: 0.6923


#### Testing

In [77]:
# Load validation data in batches
test_loader = du.DataLoader(dataset=test_dataset, 
                             batch_size=batch_size, 
                             shuffle=True)
model.eval()
with torch.no_grad():
    test_correct = 0
    test_total_weight = 0.0

    for data, target, weight, accessibility in test_loader:
        data, target, weight, accessibility = data.to(device), target.to(device), weight.to(device), accessibility.to(device)
        output = model(data, accessibility)
        predicted = (output > 0).float()
        test_correct += (predicted == target).sum().item()
        test_total_weight += weight.sum().item()

    test_weighted_accuracy = test_correct / test_total_weight
    print(f"Test Weighted Accuracy: {test_weighted_accuracy:.4f}")

Test Weighted Accuracy: 0.6978
