In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/BioInf-Final/')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import numpy as np
import time
import random

from source.dataset import DataSet, prepare_interaction_pairs

def concordance_index(y_true, y_pred):
    pairs = [(i, j) for i in range(len(y_true)) for j in range(i+1, len(y_true)) if y_true[i] != y_true[j]]
    concordant = sum((y_true[i] > y_true[j]) == (y_pred[i] > y_pred[j]) for i, j in pairs)
    return concordant / len(pairs)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in loader:
            drugs, proteins, labels = batch['drug'].to(device), batch['target'].to(device), batch['affinity'].to(device)
            # drugs, proteins, labels = [b.to(device) for b in batch]
            outputs = model(drugs, proteins)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    mse = total_loss / len(loader)
    ci = concordance_index(all_labels, all_preds)
    return mse, ci

def unite_lists(list_of_lists):
    return [item for sublist in list_of_lists for item in sublist]

def set_seeds(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

from torch.utils.data import DataLoader

def create_reproducible_dataloader(dataset, batch_size, shuffle=True, seed=42):
    generator = torch.Generator()
    generator.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,  # Set to 0 for full reproducibility
        generator=generator,
        worker_init_fn=seed_worker
    )

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def run_experiment(model_class, config, train_dataset, test_dataset, dataset, num_epochs=100):
    set_seeds(seed=42)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Set up data loaders
    train_loader = create_reproducible_dataloader(train_dataset, batch_size=config['batch_size'])
    test_loader = create_reproducible_dataloader(test_dataset, batch_size=config['batch_size'], shuffle=False)

    # Initialize model
    model = model_class(config).to(device)

    # Define loss and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

    # Training loop
    best_test_mse = float('inf')
    best_test_ci = 0

    for epoch in tqdm(range(num_epochs), desc="Epochs"):
    # for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch in train_loader:
            drugs, proteins, labels = batch['drug'].to(device), batch['target'].to(device), batch['affinity'].to(device)
            optimizer.zero_grad()
            outputs = model(drugs, proteins)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        # Evaluation
        model.eval()
        test_loss = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for batch in test_loader:
                drugs, proteins, labels = batch['drug'].to(device), batch['target'].to(device), batch['affinity'].to(device)
                outputs = model(drugs, proteins)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                all_preds.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        test_loss /= len(test_loader)
        test_ci = concordance_index(all_labels, all_preds)

        if test_loss < best_test_mse:
            best_test_mse = test_loss
            best_test_ci = test_ci
            torch.save(model.state_dict(), f'{model_class.__name__}_{dataset}_best_model.pt')

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Test MSE: {test_loss:.4f}, Test CI: {test_ci:.4f}")

    print(f"Best Test MSE: {best_test_mse:.4f}, Best Test CI: {best_test_ci:.4f}")
    return best_test_mse, best_test_ci

In [None]:
class DTIDataset(Dataset):
    def __init__(self, drugs, targets, affinity):
        self.drugs = torch.LongTensor(drugs)
        self.targets = torch.LongTensor(targets)
        self.affinity = torch.FloatTensor(affinity)

    def __len__(self):
        return len(self.affinity)

    def __getitem__(self, idx):
        return {
            'drug': self.drugs[idx],
            'target': self.targets[idx],
            'affinity': self.affinity[idx]
        }

def get_datasets(dataset_type, config):
  if dataset_type == 'davis':
      dataset = DataSet(config['davis_path'], config['problem_type'], config['max_seq_len'],
                        config['max_smi_len'], 'davis', config['davis_convert_to_log'])
  else:
      dataset = DataSet(config['kiba_path'], config['problem_type'], config['max_seq_len'],
                        config['max_smi_len'], 'kiba', config['kiba_convert_to_log'])

  XD, XT, Y, label_row_inds, label_col_inds, test_fold, train_folds = dataset.get_data()

  train_fold = unite_lists(train_folds)

  train_drugs, train_targets, train_affinity = prepare_interaction_pairs(XD, XT, Y, label_row_inds[train_fold], label_col_inds[train_fold])
  test_drugs, test_targets, test_affinity = prepare_interaction_pairs(XD, XT, Y, label_row_inds[test_fold], label_col_inds[test_fold])

  train_dataset = DTIDataset(train_drugs, train_targets, train_affinity)
  test_dataset = DTIDataset(test_drugs, test_targets, test_affinity)
  return train_dataset, test_dataset

In [None]:
class SimpleLSTM(nn.Module):
    def __init__(self, config):
        super(SimpleLSTM, self).__init__()
        self.drug_embedding = nn.Embedding(config['charsmiset_size'] + 1, config['embed_dim'])
        self.protein_embedding = nn.Embedding(config['charseqset_size'] + 1, config['embed_dim'])

        # Single layer LSTM with reduced hidden size
        self.drug_lstm = nn.LSTM(config['embed_dim'], config['lstm_dim'], num_layers=config['lstm_layers'], batch_first=True)
        self.protein_lstm = nn.LSTM(config['embed_dim'], config['lstm_dim'], num_layers=config['lstm_layers'], batch_first=True)

        # Reduced number of fully connected layers
        self.fc1 = nn.Linear(config['lstm_dim'] * 2, 32)  # 64 because we concatenate two 32-dim vectors
        self.fc2 = nn.Linear(32, 1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(config['dropout_rate'])

    def forward(self, drug, protein):
        drug_embedded = self.drug_embedding(drug)
        protein_embedded = self.protein_embedding(protein)

        _, (drug_hidden, _) = self.drug_lstm(drug_embedded)
        _, (protein_hidden, _) = self.protein_lstm(protein_embedded)

        # Use only the last hidden state
        drug_hidden = drug_hidden[-1]
        protein_hidden = protein_hidden[-1]

        concat = torch.cat([drug_hidden, protein_hidden], dim=1)

        fc1 = self.relu(self.fc1(concat))
        fc1 = self.dropout(fc1)

        output = self.fc2(fc1).squeeze(-1)
        return output

config = {
  'charsmiset_size': 64,
  'charseqset_size': 26,
  'embed_dim': 128,
  'lstm_dim': 64,
  'lstm_layers': 2,
  'dropout_rate': 0.3,
  'learning_rate': 0.001,
  'batch_size': 128,

  # Dataset specific
  'davis_convert_to_log': True,
  'kiba_convert_to_log': False,

  # Input dimensions
  'max_seq_len': 1000,
  'max_smi_len': 100,
  'davis_path': '/content/drive/MyDrive/BioInf-Final/data/davis/',
  'kiba_path': '/content/drive/MyDrive/BioInf-Final/data/kiba/',
  'problem_type': 1

}

train_dataset, test_dataset = get_datasets('davis', config)

# Run experiment with SimpleLSTM
mse, ci = run_experiment(SimpleLSTM, config, train_dataset, test_dataset, num_epochs=100)
print(f"SimpleLSTM - Final MSE: {mse:.4f}, Final CI: {ci:.4f}")

Epoch 1/100, Train Loss: 3.2424, Test MSE: 0.7971, Test CI: 0.5373
Epoch 2/100, Train Loss: 1.4436, Test MSE: 0.7826, Test CI: 0.5644
Epoch 3/100, Train Loss: 1.4412, Test MSE: 0.7781, Test CI: 0.5782
Epoch 4/100, Train Loss: 1.4116, Test MSE: 0.8006, Test CI: 0.5729
Epoch 5/100, Train Loss: 1.4109, Test MSE: 0.7697, Test CI: 0.5857
Epoch 6/100, Train Loss: 1.4039, Test MSE: 0.7608, Test CI: 0.5931
Epoch 7/100, Train Loss: 1.3850, Test MSE: 0.7591, Test CI: 0.5878
Epoch 8/100, Train Loss: 1.3563, Test MSE: 0.7659, Test CI: 0.5821
Epoch 9/100, Train Loss: 1.3480, Test MSE: 0.7712, Test CI: 0.5911
Epoch 10/100, Train Loss: 1.3279, Test MSE: 0.7644, Test CI: 0.5901
Epoch 11/100, Train Loss: 1.3016, Test MSE: 0.7826, Test CI: 0.5932
Epoch 12/100, Train Loss: 1.2705, Test MSE: 0.7602, Test CI: 0.5986
Epoch 13/100, Train Loss: 1.2447, Test MSE: 0.8390, Test CI: 0.6018
Epoch 14/100, Train Loss: 1.2292, Test MSE: 0.7494, Test CI: 0.6000
Epoch 15/100, Train Loss: 1.2167, Test MSE: 0.7544, Test 

In [None]:
class BidirectionalLSTM(nn.Module):
    def __init__(self, config):
        super(BidirectionalLSTM, self).__init__()
        self.drug_embedding = nn.Embedding(config['charsmiset_size'] + 1, config['embed_dim'])
        self.protein_embedding = nn.Embedding(config['charseqset_size'] + 1, config['embed_dim'])

        self.drug_lstm = nn.LSTM(config['embed_dim'], config['lstm_dim'],
                                 num_layers=config['lstm_layers'], batch_first=True, bidirectional=True)
        self.protein_lstm = nn.LSTM(config['embed_dim'], config['lstm_dim'],
                                    num_layers=config['lstm_layers'], batch_first=True, bidirectional=True)

        self.fc1 = nn.Linear(config['lstm_dim'] * 4, 64)
        self.fc2 = nn.Linear(64, 1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(config['dropout_rate'])

    def forward(self, drug, protein):
        drug_embedded = self.drug_embedding(drug)
        protein_embedded = self.protein_embedding(protein)

        _, (drug_hidden, _) = self.drug_lstm(drug_embedded)
        _, (protein_hidden, _) = self.protein_lstm(protein_embedded)

        drug_hidden = torch.cat((drug_hidden[-2], drug_hidden[-1]), dim=1)
        protein_hidden = torch.cat((protein_hidden[-2], protein_hidden[-1]), dim=1)

        concat = torch.cat([drug_hidden, protein_hidden], dim=1)

        fc1 = self.relu(self.fc1(concat))
        fc1 = self.dropout(fc1)

        output = self.fc2(fc1).squeeze(-1)
        return output

config = {
  'charsmiset_size': 64,
  'charseqset_size': 26,
  'embed_dim': 128,
  'lstm_dim': 64,
  'lstm_layers': 2,
  'dropout_rate': 0.3,
  'learning_rate': 0.0005,
  'batch_size': 128,

  # Dataset specific
  'davis_convert_to_log': True,
  'kiba_convert_to_log': False,

  # Input dimensions
  'max_seq_len': 1000,
  'max_smi_len': 100,
  'davis_path': '/content/drive/MyDrive/BioInf-Final/data/davis/',
  'kiba_path': '/content/drive/MyDrive/BioInf-Final/data/kiba/',
  'problem_type': 1

}

train_dataset, test_dataset = get_datasets('davis', config)

# Run experiment with BidirectionalLSTM
mse, ci = run_experiment(BidirectionalLSTM, config, train_dataset, test_dataset, num_epochs=100)
print(f"BidirectionalLSTM - Final MSE: {mse:.4f}, Final CI: {ci:.4f}")

Epoch 1/100, Train Loss: 3.8459, Test MSE: 0.7826, Test CI: 0.5918
Epoch 2/100, Train Loss: 1.1685, Test MSE: 0.6755, Test CI: 0.7141
Epoch 3/100, Train Loss: 1.0384, Test MSE: 0.6323, Test CI: 0.7431
Epoch 4/100, Train Loss: 1.0268, Test MSE: 0.6983, Test CI: 0.7613
Epoch 5/100, Train Loss: 1.0011, Test MSE: 0.6050, Test CI: 0.7670
Epoch 6/100, Train Loss: 0.9823, Test MSE: 0.6106, Test CI: 0.7732
Epoch 7/100, Train Loss: 0.9913, Test MSE: 0.5937, Test CI: 0.7777
Epoch 8/100, Train Loss: 0.9818, Test MSE: 0.5914, Test CI: 0.7802
Epoch 9/100, Train Loss: 0.9552, Test MSE: 0.5918, Test CI: 0.7828
Epoch 10/100, Train Loss: 0.9618, Test MSE: 0.5818, Test CI: 0.7821
Epoch 11/100, Train Loss: 0.9467, Test MSE: 0.5954, Test CI: 0.7806
Epoch 12/100, Train Loss: 0.9537, Test MSE: 0.6222, Test CI: 0.7836
Epoch 13/100, Train Loss: 0.9395, Test MSE: 0.6300, Test CI: 0.7818
Epoch 14/100, Train Loss: 0.9346, Test MSE: 0.5784, Test CI: 0.7873
Epoch 15/100, Train Loss: 0.9202, Test MSE: 0.5649, Test 

In [None]:
import torch.nn.functional as F

class LSTMWithCrossAttention(nn.Module):
    def __init__(self, config):
        super(LSTMWithCrossAttention, self).__init__()
        self.drug_embedding = nn.Embedding(config['charsmiset_size'] + 1, config['embed_dim'])
        self.protein_embedding = nn.Embedding(config['charseqset_size'] + 1, config['embed_dim'])

        self.drug_lstm = nn.LSTM(config['embed_dim'], config['lstm_dim'],
                                 num_layers=config['lstm_layers'], batch_first=True, bidirectional=True)
        self.protein_lstm = nn.LSTM(config['embed_dim'], config['lstm_dim'],
                                    num_layers=config['lstm_layers'], batch_first=True, bidirectional=True)

        self.drug_attention = nn.Linear(config['lstm_dim'] * 2, config['lstm_dim'] * 2)
        self.protein_attention = nn.Linear(config['lstm_dim'] * 2, config['lstm_dim'] * 2)

        self.fc1 = nn.Linear(config['lstm_dim'] * 4, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 1)

        self.dropout = nn.Dropout(config['dropout_rate'])

    def cross_attention(self, query, key, value):
        # query: [batch_size, query_len, dim]
        # key: [batch_size, key_len, dim]
        # value: [batch_size, key_len, dim]

        attention_scores = torch.bmm(query, key.transpose(1, 2))
        attention_weights = F.softmax(attention_scores, dim=-1)
        context_vector = torch.bmm(attention_weights, value)
        return context_vector

    def forward(self, drug, protein):
        drug_embedded = self.drug_embedding(drug)
        protein_embedded = self.protein_embedding(protein)

        drug_output, _ = self.drug_lstm(drug_embedded)
        protein_output, _ = self.protein_lstm(protein_embedded)

        # Cross-attention
        drug_query = self.drug_attention(drug_output)
        protein_query = self.protein_attention(protein_output)

        drug_context = self.cross_attention(drug_query, protein_output, protein_output)
        protein_context = self.cross_attention(protein_query, drug_output, drug_output)

        # Global average pooling
        drug_repr = torch.mean(drug_context, dim=1)
        protein_repr = torch.mean(protein_context, dim=1)

        concat = torch.cat([drug_repr, protein_repr], dim=1)

        fc1 = F.relu(self.fc1(concat))
        fc1 = self.dropout(fc1)
        fc2 = F.relu(self.fc2(fc1))
        fc2 = self.dropout(fc2)
        output = self.fc3(fc2)

        return output.squeeze(-1)

config = {
        'charsmiset_size': 64,
        'charseqset_size': 26,
        'embed_dim': 128,
        'lstm_dim': 64,
        'lstm_layers': 2,
        'dropout_rate': 0.3,
        'learning_rate': 0.0005,
        'batch_size': 128,

        # Dataset specific
        'davis_convert_to_log': True,
        'kiba_convert_to_log': False,

        # Input dimensions
        'max_seq_len': 1000,
        'max_smi_len': 100,
        'davis_path': '/content/drive/MyDrive/BioInf-Final/data/davis/',
        'kiba_path': '/content/drive/MyDrive/BioInf-Final/data/kiba/',
        'problem_type': 1

}

train_dataset, test_dataset = get_datasets('davis', config)


mse, ci = run_experiment(LSTMWithCrossAttention, config, train_dataset, test_dataset)
print(f"LSTM with Cross-Attention - Final MSE: {mse:.4f}, Final CI: {ci:.4f}")

Epoch 1/100, Train Loss: 3.7172, Test MSE: 0.7958, Test CI: 0.5358
Epoch 2/100, Train Loss: 1.3072, Test MSE: 0.7662, Test CI: 0.6014
Epoch 3/100, Train Loss: 1.2361, Test MSE: 0.7402, Test CI: 0.6447
Epoch 4/100, Train Loss: 1.1697, Test MSE: 0.7410, Test CI: 0.6977
Epoch 5/100, Train Loss: 1.1333, Test MSE: 0.6272, Test CI: 0.7312
Epoch 6/100, Train Loss: 1.0788, Test MSE: 0.5924, Test CI: 0.7552
Epoch 7/100, Train Loss: 1.0467, Test MSE: 0.6160, Test CI: 0.7737
Epoch 8/100, Train Loss: 1.0219, Test MSE: 0.5764, Test CI: 0.7764
Epoch 9/100, Train Loss: 1.0198, Test MSE: 0.5644, Test CI: 0.7882
Epoch 10/100, Train Loss: 0.9859, Test MSE: 0.5471, Test CI: 0.7868
Epoch 11/100, Train Loss: 0.9643, Test MSE: 0.5394, Test CI: 0.7866
Epoch 12/100, Train Loss: 0.9513, Test MSE: 0.5298, Test CI: 0.7910
Epoch 13/100, Train Loss: 0.9332, Test MSE: 0.5496, Test CI: 0.7964
Epoch 14/100, Train Loss: 0.9268, Test MSE: 0.5153, Test CI: 0.8012
Epoch 15/100, Train Loss: 0.9037, Test MSE: 0.5282, Test 

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TransformerWithCrossAttention(nn.Module):
    def __init__(self, config):
        super(TransformerWithCrossAttention, self).__init__()
        self.drug_embedding = nn.Embedding(config['charsmiset_size'] + 1, config['d_model'])
        self.protein_embedding = nn.Embedding(config['charseqset_size'] + 1, config['d_model'])

        self.pos_encoder = PositionalEncoding(config['d_model'], config['dropout_rate'])

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['d_model'],
            nhead=config['n_head'],
            dim_feedforward=config['d_ff'],
            dropout=config['dropout_rate'],
            batch_first=True
        )

        self.drug_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config['n_layers'])
        self.protein_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config['n_layers'])

        self.cross_attention = nn.MultiheadAttention(
            embed_dim=config['d_model'],
            num_heads=config['n_head'],
            dropout=config['dropout_rate'],
            batch_first=True
        )

        self.fc1 = nn.Linear(config['d_model'] * 2, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 1)

        self.dropout = nn.Dropout(config['dropout_rate'])

    def forward(self, drug, protein):
        drug_embedded = self.drug_embedding(drug)
        protein_embedded = self.protein_embedding(protein)

        # Apply positional encoding
        drug_embedded = self.pos_encoder(drug_embedded.transpose(0, 1)).transpose(0, 1)
        protein_embedded = self.pos_encoder(protein_embedded.transpose(0, 1)).transpose(0, 1)

        drug_output = self.drug_encoder(drug_embedded)
        protein_output = self.protein_encoder(protein_embedded)

        # Cross-attention using MultiheadAttention
        drug_context, _ = self.cross_attention(drug_output, protein_output, protein_output)
        protein_context, _ = self.cross_attention(protein_output, drug_output, drug_output)

        # Global average pooling
        drug_repr = torch.mean(drug_context, dim=1)
        protein_repr = torch.mean(protein_context, dim=1)

        concat = torch.cat([drug_repr, protein_repr], dim=1)

        fc1 = F.relu(self.fc1(concat))
        fc1 = self.dropout(fc1)
        fc2 = F.relu(self.fc2(fc1))
        fc2 = self.dropout(fc2)
        output = self.fc3(fc2)

        return output.squeeze(-1)

config = {
    'charsmiset_size': 64,
    'charseqset_size': 26,
    'd_model': 128,  # Dimension of the model
    'n_head': 8,     # Number of attention heads
    'n_layers': 2,   # Number of Transformer layers
    'd_ff': 512,     # Dimension of the feedforward network in Transformer
    'dropout_rate': 0.1,
    'learning_rate': 0.0001,
    'batch_size': 64,

    # Dataset specific
    'davis_convert_to_log': True,
    'kiba_convert_to_log': False,

    # Input dimensions
    'max_seq_len': 1000,
    'max_smi_len': 100,
    'davis_path': '/content/drive/MyDrive/BioInf-Final/data/davis/',
    'kiba_path': '/content/drive/MyDrive/BioInf-Final/data/kiba/',
    'problem_type': 1

}

train_dataset, test_dataset = get_datasets('davis', config)

mse, ci = run_experiment(TransformerWithCrossAttention, config, train_dataset, test_dataset)
print(f"Transformer with Cross-Attention - Final MSE: {mse:.4f}, Final CI: {ci:.4f}")

Epoch 1/100, Train Loss: 3.0500, Test MSE: 0.8052, Test CI: 0.5697
Epoch 2/100, Train Loss: 0.9485, Test MSE: 0.7775, Test CI: 0.6392
Epoch 3/100, Train Loss: 0.8436, Test MSE: 0.6742, Test CI: 0.7166
Epoch 4/100, Train Loss: 0.7923, Test MSE: 0.6392, Test CI: 0.7362
Epoch 5/100, Train Loss: 0.7577, Test MSE: 0.6726, Test CI: 0.7464
Epoch 6/100, Train Loss: 0.7575, Test MSE: 0.6038, Test CI: 0.7516
Epoch 7/100, Train Loss: 0.7314, Test MSE: 0.5865, Test CI: 0.7575
Epoch 8/100, Train Loss: 0.6987, Test MSE: 0.5846, Test CI: 0.7689
Epoch 9/100, Train Loss: 0.6956, Test MSE: 0.5842, Test CI: 0.7769
Epoch 10/100, Train Loss: 0.6743, Test MSE: 0.6608, Test CI: 0.7737
Epoch 11/100, Train Loss: 0.6691, Test MSE: 0.5714, Test CI: 0.7823
Epoch 12/100, Train Loss: 0.6574, Test MSE: 0.5525, Test CI: 0.7869
Epoch 13/100, Train Loss: 0.6528, Test MSE: 0.5301, Test CI: 0.7951
Epoch 14/100, Train Loss: 0.6437, Test MSE: 0.5582, Test CI: 0.7846
Epoch 15/100, Train Loss: 0.6435, Test MSE: 0.5317, Test 

In [None]:
config = {
    'charsmiset_size': 64,
    'charseqset_size': 26,
    'd_model': 128,  # Dimension of the model
    'n_head': 8,     # Number of attention heads
    'n_layers': 2,   # Number of Transformer layers
    'd_ff': 512,     # Dimension of the feedforward network in Transformer
    'dropout_rate': 0.3,
    'learning_rate': 0.0005,
    'batch_size': 128,

    # Dataset specific
    'davis_convert_to_log': True,
    'kiba_convert_to_log': False,

    # Input dimensions
    'max_seq_len': 1000,
    'max_smi_len': 100,
    'davis_path': '/content/drive/MyDrive/BioInf-Final/data/davis/',
    'kiba_path': '/content/drive/MyDrive/BioInf-Final/data/kiba/',
    'problem_type': 1

}

train_dataset, test_dataset = get_datasets('davis', config)


mse, ci = run_experiment(TransformerWithCrossAttention, config, train_dataset, test_dataset)
print(f"Transformer with Cross-Attention - Final MSE: {mse:.4f}, Final CI: {ci:.4f}")

Epochs:   1%|          | 1/100 [00:46<1:16:06, 46.13s/it]

Epoch 1/100, Train Loss: 2.6403, Test MSE: 0.9272, Test CI: 0.5660


Epochs:   2%|▏         | 2/100 [01:32<1:15:16, 46.09s/it]

Epoch 2/100, Train Loss: 1.3347, Test MSE: 1.2470, Test CI: 0.5739


Epochs:   3%|▎         | 3/100 [02:18<1:14:28, 46.07s/it]

Epoch 3/100, Train Loss: 1.2784, Test MSE: 0.9216, Test CI: 0.6365


Epochs:   4%|▍         | 4/100 [03:04<1:13:42, 46.06s/it]

Epoch 4/100, Train Loss: 1.2307, Test MSE: 0.8570, Test CI: 0.6770


Epochs:   5%|▌         | 5/100 [03:50<1:12:55, 46.05s/it]

Epoch 5/100, Train Loss: 1.1630, Test MSE: 1.2556, Test CI: 0.7226


Epochs:   6%|▌         | 6/100 [04:36<1:12:12, 46.09s/it]

Epoch 6/100, Train Loss: 1.1633, Test MSE: 0.8630, Test CI: 0.7303


Epochs:   7%|▋         | 7/100 [05:22<1:11:24, 46.07s/it]

Epoch 7/100, Train Loss: 1.1367, Test MSE: 0.8925, Test CI: 0.7333


Epochs:   8%|▊         | 8/100 [06:08<1:10:36, 46.05s/it]

Epoch 8/100, Train Loss: 1.0784, Test MSE: 1.3011, Test CI: 0.7299


Epochs:   9%|▉         | 9/100 [06:54<1:09:50, 46.05s/it]

Epoch 9/100, Train Loss: 1.0864, Test MSE: 0.9285, Test CI: 0.7488


Epochs:  10%|█         | 10/100 [07:40<1:09:06, 46.07s/it]

Epoch 10/100, Train Loss: 1.0578, Test MSE: 0.7468, Test CI: 0.7499


Epochs:  11%|█         | 11/100 [08:26<1:08:19, 46.06s/it]

Epoch 11/100, Train Loss: 1.0388, Test MSE: 1.3652, Test CI: 0.7475


Epochs:  12%|█▏        | 12/100 [09:12<1:07:34, 46.07s/it]

Epoch 12/100, Train Loss: 1.0549, Test MSE: 0.6287, Test CI: 0.7518


Epochs:  13%|█▎        | 13/100 [09:58<1:06:49, 46.09s/it]

Epoch 13/100, Train Loss: 1.0439, Test MSE: 0.7098, Test CI: 0.7599


Epochs:  14%|█▍        | 14/100 [10:44<1:05:57, 46.02s/it]

Epoch 14/100, Train Loss: 1.0040, Test MSE: 0.7267, Test CI: 0.7602


Epochs:  15%|█▌        | 15/100 [11:31<1:05:18, 46.10s/it]

Epoch 15/100, Train Loss: 1.0048, Test MSE: 0.8344, Test CI: 0.7670


Epochs:  16%|█▌        | 16/100 [12:17<1:04:33, 46.12s/it]

Epoch 16/100, Train Loss: 0.9781, Test MSE: 0.9918, Test CI: 0.7634


Epochs:  17%|█▋        | 17/100 [13:03<1:03:48, 46.12s/it]

Epoch 17/100, Train Loss: 0.9617, Test MSE: 0.9416, Test CI: 0.7723


Epochs:  18%|█▊        | 18/100 [13:49<1:03:02, 46.12s/it]

Epoch 18/100, Train Loss: 0.9460, Test MSE: 0.7071, Test CI: 0.7688


Epochs:  19%|█▉        | 19/100 [14:35<1:02:14, 46.10s/it]

Epoch 19/100, Train Loss: 0.9349, Test MSE: 1.0196, Test CI: 0.7732


Epochs:  20%|██        | 20/100 [15:21<1:01:27, 46.09s/it]

Epoch 20/100, Train Loss: 0.9407, Test MSE: 0.6775, Test CI: 0.7775


Epochs:  21%|██        | 21/100 [16:07<1:00:40, 46.08s/it]

Epoch 21/100, Train Loss: 0.9310, Test MSE: 1.1269, Test CI: 0.7798


Epochs:  22%|██▏       | 22/100 [16:53<59:55, 46.10s/it]  

Epoch 22/100, Train Loss: 0.9077, Test MSE: 0.7816, Test CI: 0.7821


Epochs:  23%|██▎       | 23/100 [17:40<59:11, 46.13s/it]

Epoch 23/100, Train Loss: 0.9037, Test MSE: 1.1175, Test CI: 0.7806


Epochs:  24%|██▍       | 24/100 [18:26<58:23, 46.10s/it]

Epoch 24/100, Train Loss: 0.8899, Test MSE: 0.6526, Test CI: 0.7855


Epochs:  25%|██▌       | 25/100 [19:12<57:38, 46.11s/it]

Epoch 25/100, Train Loss: 0.8690, Test MSE: 0.6236, Test CI: 0.7855


Epochs:  26%|██▌       | 26/100 [19:58<56:52, 46.11s/it]

Epoch 26/100, Train Loss: 0.8593, Test MSE: 0.7266, Test CI: 0.7873


Epochs:  27%|██▋       | 27/100 [20:44<56:04, 46.09s/it]

Epoch 27/100, Train Loss: 0.8520, Test MSE: 0.6332, Test CI: 0.7883


Epochs:  28%|██▊       | 28/100 [21:30<55:17, 46.08s/it]

Epoch 28/100, Train Loss: 0.8504, Test MSE: 0.8797, Test CI: 0.7885


Epochs:  29%|██▉       | 29/100 [22:16<54:32, 46.09s/it]

Epoch 29/100, Train Loss: 0.8268, Test MSE: 0.6878, Test CI: 0.7866


Epochs:  30%|███       | 30/100 [23:02<53:46, 46.10s/it]

Epoch 30/100, Train Loss: 0.8299, Test MSE: 0.6275, Test CI: 0.7885


Epochs:  31%|███       | 31/100 [23:48<52:59, 46.08s/it]

Epoch 31/100, Train Loss: 0.8032, Test MSE: 0.9292, Test CI: 0.7939


Epochs:  32%|███▏      | 32/100 [24:34<52:12, 46.07s/it]

Epoch 32/100, Train Loss: 0.7978, Test MSE: 0.6613, Test CI: 0.7860


Epochs:  33%|███▎      | 33/100 [25:20<51:26, 46.06s/it]

Epoch 33/100, Train Loss: 0.7833, Test MSE: 0.5536, Test CI: 0.7907


Epochs:  34%|███▍      | 34/100 [26:06<50:40, 46.07s/it]

Epoch 34/100, Train Loss: 0.7847, Test MSE: 0.6040, Test CI: 0.7865


Epochs:  35%|███▌      | 35/100 [26:52<49:53, 46.06s/it]

Epoch 35/100, Train Loss: 0.7691, Test MSE: 0.5935, Test CI: 0.7954


Epochs:  36%|███▌      | 36/100 [27:38<49:08, 46.07s/it]

Epoch 36/100, Train Loss: 0.7484, Test MSE: 0.7056, Test CI: 0.7984


Epochs:  37%|███▋      | 37/100 [28:25<48:21, 46.06s/it]

Epoch 37/100, Train Loss: 0.7324, Test MSE: 0.5971, Test CI: 0.7901


Epochs:  38%|███▊      | 38/100 [29:11<47:36, 46.07s/it]

Epoch 38/100, Train Loss: 0.7135, Test MSE: 0.5147, Test CI: 0.7897


Epochs:  39%|███▉      | 39/100 [29:57<46:50, 46.08s/it]

Epoch 39/100, Train Loss: 0.6905, Test MSE: 0.5145, Test CI: 0.7868


Epochs:  40%|████      | 40/100 [30:43<46:04, 46.08s/it]

Epoch 40/100, Train Loss: 0.6979, Test MSE: 0.8321, Test CI: 0.7799


Epochs:  41%|████      | 41/100 [31:29<45:21, 46.12s/it]

Epoch 41/100, Train Loss: 0.6828, Test MSE: 0.5399, Test CI: 0.7963


Epochs:  42%|████▏     | 42/100 [32:15<44:34, 46.11s/it]

Epoch 42/100, Train Loss: 0.6707, Test MSE: 0.4668, Test CI: 0.8035


Epochs:  43%|████▎     | 43/100 [33:01<43:47, 46.09s/it]

Epoch 43/100, Train Loss: 0.6531, Test MSE: 0.6033, Test CI: 0.7962


Epochs:  44%|████▍     | 44/100 [33:47<43:00, 46.08s/it]

Epoch 44/100, Train Loss: 0.6445, Test MSE: 0.4937, Test CI: 0.8017


Epochs:  45%|████▌     | 45/100 [34:33<42:12, 46.05s/it]

Epoch 45/100, Train Loss: 0.6489, Test MSE: 0.5564, Test CI: 0.7992


Epochs:  46%|████▌     | 46/100 [35:19<41:27, 46.06s/it]

Epoch 46/100, Train Loss: 0.6379, Test MSE: 0.5499, Test CI: 0.8074


Epochs:  47%|████▋     | 47/100 [36:05<40:40, 46.05s/it]

Epoch 47/100, Train Loss: 0.6242, Test MSE: 0.4988, Test CI: 0.7962


Epochs:  48%|████▊     | 48/100 [36:51<39:55, 46.06s/it]

Epoch 48/100, Train Loss: 0.6174, Test MSE: 0.4630, Test CI: 0.8063


Epochs:  49%|████▉     | 49/100 [37:37<39:09, 46.07s/it]

Epoch 49/100, Train Loss: 0.6160, Test MSE: 0.5734, Test CI: 0.7935


Epochs:  50%|█████     | 50/100 [38:24<38:24, 46.08s/it]

Epoch 50/100, Train Loss: 0.5975, Test MSE: 0.4788, Test CI: 0.8057


Epochs:  51%|█████     | 51/100 [39:10<37:38, 46.09s/it]

Epoch 51/100, Train Loss: 0.6038, Test MSE: 0.4670, Test CI: 0.8074


Epochs:  52%|█████▏    | 52/100 [39:56<36:53, 46.12s/it]

Epoch 52/100, Train Loss: 0.5866, Test MSE: 0.4762, Test CI: 0.8012


Epochs:  53%|█████▎    | 53/100 [40:42<36:07, 46.11s/it]

Epoch 53/100, Train Loss: 0.5869, Test MSE: 0.4884, Test CI: 0.8075


Epochs:  54%|█████▍    | 54/100 [41:28<35:20, 46.10s/it]

Epoch 54/100, Train Loss: 0.5702, Test MSE: 0.4484, Test CI: 0.7971


Epochs:  55%|█████▌    | 55/100 [42:14<34:33, 46.09s/it]

Epoch 55/100, Train Loss: 0.5597, Test MSE: 0.4580, Test CI: 0.8029


Epochs:  56%|█████▌    | 56/100 [43:00<33:46, 46.06s/it]

Epoch 56/100, Train Loss: 0.5620, Test MSE: 0.4572, Test CI: 0.8048


Epochs:  57%|█████▋    | 57/100 [43:46<33:00, 46.07s/it]

Epoch 57/100, Train Loss: 0.5595, Test MSE: 0.4997, Test CI: 0.8033


Epochs:  58%|█████▊    | 58/100 [44:32<32:15, 46.08s/it]

Epoch 58/100, Train Loss: 0.5479, Test MSE: 0.4233, Test CI: 0.8157


Epochs:  59%|█████▉    | 59/100 [45:18<31:28, 46.07s/it]

Epoch 59/100, Train Loss: 0.5453, Test MSE: 0.4409, Test CI: 0.8156


Epochs:  60%|██████    | 60/100 [46:04<30:43, 46.09s/it]

Epoch 60/100, Train Loss: 0.5314, Test MSE: 0.5239, Test CI: 0.8043


Epochs:  61%|██████    | 61/100 [46:50<29:56, 46.07s/it]

Epoch 61/100, Train Loss: 0.5315, Test MSE: 0.4427, Test CI: 0.8188


Epochs:  62%|██████▏   | 62/100 [47:36<29:08, 46.02s/it]

Epoch 62/100, Train Loss: 0.5251, Test MSE: 0.4404, Test CI: 0.8155


Epochs:  63%|██████▎   | 63/100 [48:22<28:22, 46.02s/it]

Epoch 63/100, Train Loss: 0.5190, Test MSE: 0.4519, Test CI: 0.8110


Epochs:  64%|██████▍   | 64/100 [49:09<27:39, 46.10s/it]

Epoch 64/100, Train Loss: 0.5148, Test MSE: 0.5095, Test CI: 0.8048


Epochs:  65%|██████▌   | 65/100 [49:55<26:53, 46.10s/it]

Epoch 65/100, Train Loss: 0.5138, Test MSE: 0.4597, Test CI: 0.8206


Epochs:  66%|██████▌   | 66/100 [50:41<26:07, 46.11s/it]

Epoch 66/100, Train Loss: 0.5036, Test MSE: 0.4520, Test CI: 0.8129


Epochs:  67%|██████▋   | 67/100 [51:27<25:21, 46.11s/it]

Epoch 67/100, Train Loss: 0.5055, Test MSE: 0.4875, Test CI: 0.8085


Epochs:  68%|██████▊   | 68/100 [52:13<24:35, 46.11s/it]

Epoch 68/100, Train Loss: 0.4974, Test MSE: 0.4641, Test CI: 0.8137


Epochs:  69%|██████▉   | 69/100 [52:59<23:49, 46.10s/it]

Epoch 69/100, Train Loss: 0.4884, Test MSE: 0.4703, Test CI: 0.8126


Epochs:  70%|███████   | 70/100 [53:45<23:01, 46.04s/it]

Epoch 70/100, Train Loss: 0.4893, Test MSE: 0.4505, Test CI: 0.8121


Epochs:  71%|███████   | 71/100 [54:31<22:17, 46.13s/it]

Epoch 71/100, Train Loss: 0.4953, Test MSE: 0.4724, Test CI: 0.8130


Epochs:  72%|███████▏  | 72/100 [55:18<21:31, 46.13s/it]

Epoch 72/100, Train Loss: 0.4805, Test MSE: 0.4725, Test CI: 0.7993


Epochs:  73%|███████▎  | 73/100 [56:04<20:45, 46.14s/it]

Epoch 73/100, Train Loss: 0.4821, Test MSE: 0.4619, Test CI: 0.8182


Epochs:  74%|███████▍  | 74/100 [56:50<19:59, 46.14s/it]

Epoch 74/100, Train Loss: 0.4779, Test MSE: 0.4890, Test CI: 0.8182


Epochs:  75%|███████▌  | 75/100 [57:36<19:13, 46.12s/it]

Epoch 75/100, Train Loss: 0.4688, Test MSE: 0.4672, Test CI: 0.8182


Epochs:  76%|███████▌  | 76/100 [58:22<18:26, 46.11s/it]

Epoch 76/100, Train Loss: 0.4671, Test MSE: 0.4428, Test CI: 0.8155


Epochs:  77%|███████▋  | 77/100 [59:08<17:40, 46.09s/it]

Epoch 77/100, Train Loss: 0.4662, Test MSE: 0.4478, Test CI: 0.8246


Epochs:  78%|███████▊  | 78/100 [59:54<16:54, 46.10s/it]

Epoch 78/100, Train Loss: 0.4590, Test MSE: 0.4345, Test CI: 0.8186


Epochs:  79%|███████▉  | 79/100 [1:00:40<16:06, 46.05s/it]

Epoch 79/100, Train Loss: 0.4634, Test MSE: 0.4799, Test CI: 0.8147


Epochs:  80%|████████  | 80/100 [1:01:26<15:22, 46.12s/it]

Epoch 80/100, Train Loss: 0.4582, Test MSE: 0.4615, Test CI: 0.8166


Epochs:  81%|████████  | 81/100 [1:02:13<14:36, 46.13s/it]

Epoch 81/100, Train Loss: 0.4591, Test MSE: 0.4812, Test CI: 0.8102


Epochs:  82%|████████▏ | 82/100 [1:02:59<13:50, 46.13s/it]

Epoch 82/100, Train Loss: 0.4441, Test MSE: 0.5079, Test CI: 0.8107


Epochs:  83%|████████▎ | 83/100 [1:03:45<13:04, 46.13s/it]

Epoch 83/100, Train Loss: 0.4469, Test MSE: 0.4664, Test CI: 0.8077


Epochs:  84%|████████▍ | 84/100 [1:04:31<12:17, 46.11s/it]

Epoch 84/100, Train Loss: 0.4466, Test MSE: 0.4550, Test CI: 0.8108


Epochs:  85%|████████▌ | 85/100 [1:05:17<11:31, 46.09s/it]

Epoch 85/100, Train Loss: 0.4375, Test MSE: 0.4757, Test CI: 0.8123


Epochs:  86%|████████▌ | 86/100 [1:06:03<10:45, 46.10s/it]

Epoch 86/100, Train Loss: 0.4298, Test MSE: 0.4487, Test CI: 0.8187


Epochs:  87%|████████▋ | 87/100 [1:06:49<09:59, 46.09s/it]

Epoch 87/100, Train Loss: 0.4368, Test MSE: 0.4324, Test CI: 0.8211


Epochs:  88%|████████▊ | 88/100 [1:07:35<09:12, 46.08s/it]

Epoch 88/100, Train Loss: 0.4302, Test MSE: 0.4561, Test CI: 0.8161


Epochs:  89%|████████▉ | 89/100 [1:08:21<08:26, 46.09s/it]

Epoch 89/100, Train Loss: 0.4249, Test MSE: 0.4413, Test CI: 0.8266


Epochs:  90%|█████████ | 90/100 [1:09:07<07:40, 46.09s/it]

Epoch 90/100, Train Loss: 0.4219, Test MSE: 0.4484, Test CI: 0.8240


Epochs:  91%|█████████ | 91/100 [1:09:53<06:54, 46.04s/it]

Epoch 91/100, Train Loss: 0.4202, Test MSE: 0.5443, Test CI: 0.8159


Epochs:  92%|█████████▏| 92/100 [1:10:40<06:08, 46.11s/it]

Epoch 92/100, Train Loss: 0.4195, Test MSE: 0.4279, Test CI: 0.8277


Epochs:  93%|█████████▎| 93/100 [1:11:26<05:22, 46.13s/it]

Epoch 93/100, Train Loss: 0.4158, Test MSE: 0.4709, Test CI: 0.8113


Epochs:  94%|█████████▍| 94/100 [1:12:12<04:36, 46.13s/it]

Epoch 94/100, Train Loss: 0.4077, Test MSE: 0.4693, Test CI: 0.8214


Epochs:  95%|█████████▌| 95/100 [1:12:58<03:50, 46.14s/it]

Epoch 95/100, Train Loss: 0.4039, Test MSE: 0.4884, Test CI: 0.8232


Epochs:  96%|█████████▌| 96/100 [1:13:44<03:04, 46.10s/it]

Epoch 96/100, Train Loss: 0.4032, Test MSE: 0.4651, Test CI: 0.8265


Epochs:  97%|█████████▋| 97/100 [1:14:30<02:18, 46.10s/it]

Epoch 97/100, Train Loss: 0.4069, Test MSE: 0.4912, Test CI: 0.8141


Epochs:  98%|█████████▊| 98/100 [1:15:16<01:32, 46.11s/it]

Epoch 98/100, Train Loss: 0.4030, Test MSE: 0.4630, Test CI: 0.8289


Epochs:  99%|█████████▉| 99/100 [1:16:02<00:46, 46.05s/it]

Epoch 99/100, Train Loss: 0.3956, Test MSE: 0.5091, Test CI: 0.8112


Epochs: 100%|██████████| 100/100 [1:16:49<00:00, 46.09s/it]

Epoch 100/100, Train Loss: 0.3867, Test MSE: 0.4676, Test CI: 0.8262
Best Test MSE: 0.4233, Best Test CI: 0.8157
Transformer with Cross-Attention - Final MSE: 0.4233, Final CI: 0.8157





In [None]:
# Tested Different Architectures, Bi-Directional LSTM With Cross Attention Was Best

In [None]:
config = {
        # Model architecture
        'model_type': 'lstm',
        'lstm_layers': 1,
        'lstm_dim': 32,
        'dense_layers': 1,
        'dense_dim': 32,
        'dropout_rate': 0.2,
        'embed_dim': 64,

        # Training parameters
        'learning_rate': 0.001,
        'batch_size': 64,
        'num_epochs': 100,

        # Input dimensions
        'max_seq_len': 1000,
        'max_smi_len': 100,

        # Data handling
        'davis_path': '/content/drive/MyDrive/BioInf-Final/data/davis/',
        'kiba_path': '/content/drive/MyDrive/BioInf-Final/data/kiba/',
        'problem_type': 1,
        'binary_th': 0.0,

        # Output and logging
        'checkpoint_path': '',
        'log_dir': 'logs',

        # Dataset specific
        'davis_convert_to_log': True,
        'kiba_convert_to_log': False,

        # Alphabet sizes
        'charsmiset_size': 64,  # Size of SMILES alphabet
        'charseqset_size': 26,  # Size of protein sequence alphabet
    }

dataset = DataSet(config['kiba_path'], config['problem_type'], config['max_seq_len'],
                      config['max_smi_len'], 'kiba', config['kiba_convert_to_log'])

XD, XT, Y, label_row_inds, label_col_inds, test_fold, train_folds = dataset.get_data()

train_fold = unite_lists(train_folds)

train_drugs, train_targets, train_affinity = prepare_interaction_pairs(XD, XT, Y, label_row_inds[train_fold], label_col_inds[train_fold])
test_drugs, test_targets, test_affinity = prepare_interaction_pairs(XD, XT, Y, label_row_inds[test_fold], label_col_inds[test_fold])

train_dataset_kiba = DTIDataset(train_drugs, train_targets, train_affinity)
test_dataset_kiba = DTIDataset(test_drugs, test_targets, test_affinity)

Reading kiba dataset from /content/drive/MyDrive/BioInf-Final/data/kiba/
Parsing kiba dataset


In [None]:
config = {
        'charsmiset_size': 64,
        'charseqset_size': 26,
        'embed_dim': 128,
        'lstm_dim': 64,
        'lstm_layers': 2,
        'dropout_rate': 0.3,
        'learning_rate': 0.0005,
        'batch_size': 128,

        # Dataset specific
        'davis_convert_to_log': True,
        'kiba_convert_to_log': False,

        # Input dimensions
        'max_seq_len': 1000,
        'max_smi_len': 100,
        'davis_path': '/content/drive/MyDrive/BioInf-Final/data/davis/',
        'kiba_path': '/content/drive/MyDrive/BioInf-Final/data/kiba/',
        'problem_type': 1

}

train_dataset, test_dataset = get_datasets('kiba', config)


mse, ci = run_experiment(LSTMWithCrossAttention, config, train_dataset, test_dataset)
print(f"LSTM with Cross-Attention - Final MSE: {mse:.4f}, Final CI: {ci:.4f}")

Epochs:   1%|          | 1/100 [05:15<8:41:12, 315.88s/it]

Epoch 1/100, Train Loss: 6.3522, Test MSE: 0.6836, Test CI: 0.5935


Epochs:   2%|▏         | 2/100 [10:31<8:35:29, 315.61s/it]

Epoch 2/100, Train Loss: 2.8758, Test MSE: 0.6793, Test CI: 0.6408


Exception ignored in: <function _xla_gc_callback at 0x7854a480d510>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 98, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 
Epochs:   3%|▎         | 3/100 [15:45<8:28:49, 314.74s/it]

Epoch 3/100, Train Loss: 2.7781, Test MSE: 0.6370, Test CI: 0.6548


Epochs:   4%|▍         | 4/100 [20:57<8:22:17, 313.93s/it]

Epoch 4/100, Train Loss: 2.6347, Test MSE: 0.7229, Test CI: 0.6784


In [None]:
# Epoch 1/100, Train Loss: 2.9745, Test MSE: 0.7973, Test CI: 0.5264
# Epoch 2/100, Train Loss: 1.3261, Test MSE: 0.7863, Test CI: 0.5359