# Test model by hand

In [5]:
import logging
from pathlib import Path
import sys
import os
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))

from src.utils.data import read_pickle, read_json
from src.utils.losses import contrastive_loss
from src.data.collators import collate_variable_sequence_length
from src.data.datasets import ProteinDataset
from src.models.ProTCL import ProTCL
import torch
import wandb
import os
import datetime
from torchmetrics.classification import BinaryPrecision, BinaryRecall

  from .autonotebook import tqdm as notebook_tqdm


In [31]:
# Example sequences and labels
sequences = ["SEQ1", "SEQ2", "SEQ3"]
labels = torch.tensor([
    [1, 0, 0, 1, 0, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 0, 1, 0, 0]
], dtype=torch.float32)

# Example sequence embeddings
sequence_to_embeddings_dict = {
    "SEQ1": torch.tensor([0.1, 0.2, 0.3]),
    "SEQ2": torch.tensor([0.4, 0.5, 0.6]),
    "SEQ3": torch.tensor([0.7, 0.8, 0.9])
}

# Example label embeddings
label_embeddings = torch.tensor([
    [0.01, 0.02],
    [0.03, 0.04],
    [0.05, 0.06], # Embeddings for all-zero columns
    [0.07, 0.08],
    [0.09, 0.10],  # Embeddings for all-zero columns
    [0.11, 0.12]   # Embeddings for all-zero columns
])

print("Label embeddings:", label_embeddings.shape)
print("Labels:", labels.shape)

# Constants
PROTEIN_EMBEDDING_DIM = 3
LABEL_EMBEDDING_DIM = 2
LATENT_EMBEDDING_DIM = 2
TEMPERATURE = 1.0
LEARNING_RATE = 0.001

# Assuming you're using a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move tensors to GPU
label_embeddings = label_embeddings.to(device)
sequence_to_embeddings_dict = {seq: embedding.to(device) for seq, embedding in sequence_to_embeddings_dict.items()}

Label embeddings: torch.Size([6, 2])
Labels: torch.Size([3, 6])


In [32]:
# Initialize the model
model = ProTCL(protein_embedding_dim=PROTEIN_EMBEDDING_DIM, 
                  label_embedding_dim=LABEL_EMBEDDING_DIM, 
                  latent_dim=LATENT_EMBEDDING_DIM,
                  temperature=TEMPERATURE,
                  sequence_to_embeddings_dict=sequence_to_embeddings_dict,
                  ordered_label_embeddings=label_embeddings).to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [33]:
%load_ext autoreload
%autoreload 2

# Test the forward method
P_e, L_e, target = model(sequences, labels)
print(P_e)
print(L_e)
print(target)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
tensor([[-0.8432,  0.5377],
        [-0.8763,  0.4817],
        [-0.8845,  0.4665]], device='cuda:0', grad_fn=<DivBackward0>)
tensor([[-0.7129,  0.7013],
        [-0.7731,  0.6343],
        [-0.7986,  0.6019]], device='cuda:0', grad_fn=<DivBackward0>)
tensor([[1., 0., 1.],
        [0., 1., 1.],
        [0., 0., 1.]])


In [34]:
# Test the contrastive loss function
loss = contrastive_loss(P_e, L_e, model.t, target)
print(loss)

tensor(1.6116, grad_fn=<DivBackward0>)


# Testing loop

In [1]:
import logging
from pathlib import Path
import sys
import os
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))

from src.utils.data import read_pickle, read_json
from src.utils.losses import contrastive_loss
from src.data.collators import collate_variable_sequence_length
from src.data.datasets import ProteinDataset
from ProteinFunctions.src.models.ProTCL import ProTCL
import torch
import wandb
import os
import datetime
from torchmetrics.classification import BinaryPrecision, BinaryRecall

# Get the root path from the environment variable
ROOT_PATH = os.environ.get('ROOT_PATH', '.')  # Default to current directory if ROOT_PATH is not set

# Load the configuration file
config = read_json(os.path.join(ROOT_PATH, 'config.json'))
params = config['params']
relative_paths = config['relative_paths']

# Initialize logging
logging.basicConfig(filename='train.log', filemode='w',
                    format='%(asctime)s %(levelname)-4s %(message)s',
                    level=logging.INFO,
                    datefmt='%Y-%m-%d %H:%M:%S %Z')

# Initialize new run in W&B
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
logging.info(
    f"################## {timestamp} RUNNING test_ProConNet.py ##################")

config={
    "LEARNING_RATE": 0.001,
    "TEMPERATURE": 0.07,
    "PROTEIN_EMBEDDING_DIM": 1100,
    "LABEL_EMBEDDING_DIM": 768,
    "LATENT_EMBEDDING_DIM": 934,
    "NUM_EPOCHS": 10,
    "BATCH_SIZE": 1000,
    "DECISION_TH": 0.88
}

# Data paths
TRAIN_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/train_GO.fasta'
VAL_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/dev_GO.fasta'
TEST_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/test_GO.fasta'
AMINO_ACID_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/amino_acid_vocab.json'
GO_LABEL_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/GO_label_vocab.json'

# Embedding paths
LABEL_EMBEDDING_PATH = "/home/ncorley/protein/ProteinFunctions/data/embeddings/label_embeddings.pk1"
SEQUENCE_EMBEDDING_PATH = "/home/ncorley/protein/ProteinFunctions/data/embeddings/sequence_embeddings.pk1"

# Directory to save models
OUTPUT_MODEL_DIR = "/home/ncorley/protein/ProteinFunctions/models/ProConNet"

# Model to load
LOAD_MODEL_PATH = "/home/ncorley/protein/ProteinFunctions/models/ProConNet/2023-09-05_23-28-25_best_Pro_ConNet.pth"

# Test only
TRAIN_MODEL = False
TEST_MODEL = True

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load datasets
test_dataset = ProteinDataset(data_path=TEST_DATA_PATH,
                              sequence_vocabulary_path=AMINO_ACID_VOCAB_PATH,
                              label_vocabulary_path=GO_LABEL_VOCAB_PATH)

# Define data loaders
# train_loader = torch.utils.data.DataLoader(
#     train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, collate_fn=collate_variable_sequence_length, num_workers=2)
# val_loader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, collate_fn=collate_variable_sequence_length, num_workers=2)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=10, shuffle=True, collate_fn=collate_variable_sequence_length, num_workers=2)

# Load sequence embeddings
if os.path.exists(SEQUENCE_EMBEDDING_PATH):
    sequence_to_embeddings_dict_raw = read_pickle(SEQUENCE_EMBEDDING_PATH)[
        ['sequence', 'embedding']].set_index('sequence')['embedding'].to_dict()

    # Convert embeddings in the dictionary to tensors
    sequence_to_embeddings_dict = {seq: torch.tensor(
        embedding) for seq, embedding in sequence_to_embeddings_dict_raw.items()}
else:
    raise ValueError("Sequence embeddings not found.")

# Load label embeddings (in the same order as the label vocabulary)
if os.path.exists(LABEL_EMBEDDING_PATH):
    # Load a dictionary mapping GO IDs to embeddings
    label_embedding_dict = read_pickle(LABEL_EMBEDDING_PATH)[
        ['go_id', 'embedding']].set_index('go_id')['embedding'].to_dict()
    # Create a tensor of embeddings in the correct order (i.e., the order of the label vocabulary)
    label_embeddings = torch.stack(
        # All label vocabularies are the same, so we can use the train dataset
        [torch.tensor(label_embedding_dict[label]) for label in test_dataset.label_vocabulary])
else:
    raise ValueError("Label embeddings not found.")

# Move tensors to GPU
label_embeddings = label_embeddings.to(device)
sequence_to_embeddings_dict = {seq: embedding.to(
    device) for seq, embedding in sequence_to_embeddings_dict.items()}


# Initialize the model
model = ProTCL(protein_embedding_dim=config['PROTEIN_EMBEDDING_DIM'],
                  label_embedding_dim=config['LABEL_EMBEDDING_DIM'],
                  latent_dim=config['LATENT_EMBEDDING_DIM'],
                  temperature=config['TEMPERATURE'],
                  sequence_to_embeddings_dict=sequence_to_embeddings_dict,
                  ordered_label_embeddings=label_embeddings).to(device)

# Load the model weights if LOAD_MODEL_PATH is provided and exists
if LOAD_MODEL_PATH is not None and os.path.exists(LOAD_MODEL_PATH):
    logging.info(f"Loading model weights from {LOAD_MODEL_PATH}...")
    model.load_state_dict(torch.load(LOAD_MODEL_PATH))

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1000, shuffle=True, collate_fn=collate_variable_sequence_length, num_workers=2)

In [None]:
print(len(test_loader))

In [6]:
 ####### TESTING LOOP #######
logging.info("Starting testing...")
model.eval()

# Initialize metrics
total_test_loss = 0
at_least_one_positive_pred = torch.tensor(0, dtype=int).to(
    device)  # seqs with at least one positive label prediction
n = torch.tensor(0, dtype=int).to(device)
seqwise_precision = BinaryPrecision(threshold=config['DECISION_TH'],
                                    multidim_average='samplewise').to(device)
seqwise_recall = BinaryRecall(threshold=config['DECISION_TH'],
                              multidim_average='samplewise').to(device)


with torch.no_grad():
    for test_batch in test_loader:
        # Unpack the test batch
        sequences, sequence_lengths, labels = test_batch

        # Convert labels to floats and move to GPU, if available
        labels = labels.float().to(device)

        # Forward pass
        P_e, L_e, true_labels = model(sequences, labels)

        # Compute test loss for the batch
        test_loss = contrastive_loss(P_e, L_e, model.t, true_labels)

        # Accumulate the total test loss
        total_test_loss += test_loss.item()

        # Compute cosine similarities for zero-shot classification
        logits = torch.mm(P_e, L_e.t()) * torch.exp(model.t)

        # Apply sigmoid to get the probabilities for multi-label classification
        probabilities = torch.sigmoid(logits)

        # Throw error
        # throw_error = True
        # if throw_error:
        #     raise ValueError("Error!")

        # Update metrics
        at_least_one_positive_pred += (probabilities >
                                       config['DECISION_TH']).any(axis=1).sum()
        n += true_labels.size(0)
        seqwise_precision(probabilities, true_labels)
        seqwise_recall(probabilities, true_labels)

# Compute average test loss
avg_test_loss = total_test_loss / len(test_loader)

# Compute average precision, recall, coverage, and F1 score
average_precision = seqwise_precision.compute().sum()/at_least_one_positive_pred
average_recall = seqwise_recall.compute().mean()
average_f1_score = 2*average_precision * \
    average_recall/(average_precision+average_recall)
coverage = at_least_one_positive_pred/n

logging.info(
    f"Test Loss: {avg_test_loss}, F1 Score: {average_f1_score}, Precision: {average_precision}, Recall: {average_recall}, Coverage: {coverage}")

logging.info("Testing complete.")

print(f"Test Loss: {avg_test_loss}, F1 Score: {average_f1_score}, Precision: {average_precision}, Recall: {average_recall}, Coverage: {coverage}")

# Close the W&B run
# wandb.log({"test_loss": avg_test_loss, "f1_score": average_f1_score,
#           "precision": average_precision, "recall": average_recall, "coverage": coverage})
# wandb.finish()

Error: You must call wandb.init() before wandb.log()