# Test dataset

In [42]:
# Test the ProteinDataset
import logging
from pathlib import Path
import sys
import os
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))

from src.utils.data import read_pickle, read_json
from src.utils.losses import contrastive_loss
from src.data.collators import collate_variable_sequence_length
from src.data.datasets import ProteinDataset
from src.models.ProTCL import ProTCL
import torch
import wandb
import os
import datetime
from torchmetrics.classification import BinaryPrecision, BinaryRecall

%load_ext autoreload
%autoreload 2

# Data paths
FULL_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/full_GO.fasta'
TRAIN_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/train_GO.fasta'
VAL_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/dev_GO.fasta'
TEST_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/test_GO.fasta'
AMINO_ACID_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/amino_acid_vocab.json'
GO_LABEL_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/GO_label_vocab.json'
SEQUENCE_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/sequence_vocab.json'
SEQUENCE_ID_MAP_PATH = '/home/ncorley/protein/ProteinFunctions/data/embeddings/sequence_id_map.pkl'

# Embedding paths
LABEL_EMBEDDING_PATH = "/home/ncorley/protein/ProteinFunctions/data/embeddings/frozen_PubMedBERT_label_embeddings.pt"
SEQUENCE_EMBEDDING_PATH = "/home/ncorley/protein/ProteinFunctions/data/embeddings/frozen_proteinfer_sequence_embeddings.pkl"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [43]:
import logging

logging.basicConfig(level=logging.INFO)

# Create a ProteinDataset
paths = {
    'data_path': FULL_DATA_PATH,
    'amino_acid_vocabulary_path': AMINO_ACID_VOCAB_PATH,
    'label_vocabulary_path': GO_LABEL_VOCAB_PATH,
    'sequence_id_vocabulary_path': '/home/ncorley/protein/ProteinFunctions/data/vocabularies/sequence_id_vocab.json',
    'sequence_id_map_path': SEQUENCE_ID_MAP_PATH,
}
full_datset = ProteinDataset(paths=paths)

INFO:root:Loaded 522607 sequences from /home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/full_GO.fasta.


In [32]:
# Get the length of the dataset
print(f"Length of the dataset: {len(full_datset)}")

Length of the dataset: 522607


In [33]:
# Make a data laoder
from torch.utils.data import DataLoader

loader = DataLoader(
            full_datset,
            batch_size=2,
            shuffle=True,
            collate_fn=collate_variable_sequence_length,
            num_workers=2,
            pin_memory=True
        )

In [35]:
# Load one batch from the loader
for batch in loader:
    print(batch)
    # Print shape for each item in batch
    for item in batch:
        print(item.shape)
    break
    

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/anaconda/envs/protein_functions/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/protein_functions/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ncorley/protein/ProteinFunctions/src/data/collators.py", line 46, in collate_variable_sequence_length
    return torch.stack(processed_sequence_ids), torch.stack(processed_sequence_onehots), torch.stack(processed_label_ids), torch.stack(processed_label_multihots), torch.stack(processed_sequence_lengths)
                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [50] at entry 0 and [38] at entry 1


In [9]:
# !export ROOT_PATH="/home/ncorley/protein/ProteinFunctions"

# ROOT_PATH = os.environ.get('ROOT_PATH', '.')
# print(f"Root path: {ROOT_PATH}")
from src.utils.data import read_pickle, read_json, read_yaml    

ROOT_PATH = "/home/ncorley/protein/ProteinFunctions"

# Load the configuration file
config = read_yaml(ROOT_PATH + '/config.yaml')
params = config['params']
paths = config['relative_paths']

In [38]:
from src.data.datasets import ProteinDataset, create_multiple_loaders
import os

common_paths = {
    'amino_acid_vocabulary_path': os.path.join(ROOT_PATH, paths['AMINO_ACID_VOCAB_PATH']),
    'label_vocabulary_path': os.path.join(ROOT_PATH, paths['GO_LABEL_VOCAB_PATH']),
    'sequence_id_vocabulary_path': os.path.join(ROOT_PATH, paths['SEQUENCE_ID_VOCAB_PATH']),
    'sequence_id_map_path': os.path.join(ROOT_PATH, paths['SEQUENCE_ID_MAP_PATH']),
}

paths_list = [
    {
        **common_paths,
        'data_path': os.path.join(ROOT_PATH, data_path)
    }
    for data_path in [
        paths['TRAIN_DATA_PATH'],
        paths['VAL_DATA_PATH'],
        paths['TEST_DATA_PATH']
    ]
]

print(paths_list)

# Load datasets from config file paths; the same vocabulary is used for all datasets
train_dataset, val_dataset, test_dataset = ProteinDataset.create_multiple_datasets(paths_list)


KeyError: 'AMINO_ACID_VOCAB_PATH'

In [45]:
# Define data loaders
train_loader, val_loader, test_loader = create_multiple_loaders(
    [train_dataset, val_dataset, test_dataset],
    [params['TRAIN_BATCH_SIZE'], params['VALIDATION_BATCH_SIZE'],
        params['TEST_BATCH_SIZE']],
    num_workers=params['NUM_WORKERS'],
)

In [50]:
#### EMBEDDINGS ####
sequence_embedding_path = '/home/ncorley/protein/ProteinFunctions/data/embeddings/frozen_proteinfer_sequence_embeddings.pkl'

sequence_embeddings = read_pickle(sequence_embedding_path)

sequence_id_map = read_pickle('/home/ncorley/protein/ProteinFunctions/data/embeddings/sequence_id_map.pkl')

In [7]:
# Create a map from unique numeric sequence ID to sequence embedding
numeric_id_embedding_map = {sequence_id_map[alphanumeric_id]: embedding for alphanumeric_id, embedding in sequence_embeddings.items()}


In [10]:
# Get the maximum numeric sequence ID (may be different than length)
max_id = max(numeric_id_embedding_map.keys())

# Get the embedding dimension for the sequence
embedding_dim = params['PROTEIN_EMBEDDING_DIM']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create an embedding matrix with zeros for all IDs
sequence_embedding_matrix = torch.zeros(max_id + 1, embedding_dim, device=device)

# Fill the embedding matrix with the embeddings from the map
for numeric_id, numpy_embedding in numeric_id_embedding_map.items():
    tensor_embedding = torch.tensor(numpy_embedding, device=device)
    sequence_embedding_matrix[numeric_id] = tensor_embedding



In [11]:
import torch.nn as nn
sequence_embedding_layer = nn.Embedding.from_pretrained(sequence_embedding_matrix, freeze=True).to(device)  # Assuming you want to freeze the embeddings and not train them


In [16]:
sequence_embedding_layer
# Create a tensor with numbers 1-10
x = torch.arange(1, 4).to(device)

sequence_embeddings = sequence_embedding_layer(x)
print(sequence_embeddings.shape)
print(sequence_embeddings)

torch.Size([3, 1100])
tensor([[ 1.0928, -0.0995, -2.6018,  ...,  3.2864,  0.9259,  0.6427],
        [-0.1883, -0.4159, -1.2945,  ..., -0.9814,  0.5722,  0.4089],
        [-0.8475,  1.3617,  0.5347,  ...,  2.5408, -0.4552, -0.5581]],
       device='cuda:0')


Bad pipe message: %s [b'\x1a\xf7&\xd7\xbb\x9a\xbbH\x1c\xc5\xcb\xdd\xe7t`lgU .\xf7\xb0\x9e\xecB`\x85*\xa2\n3s\xe96\x15ku*', b'\xaa5\x9c\x94sx\x97\x01\xe0\x95\x0c\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 \xcdI\xcd\xddVw\xa4m']
Bad pipe message: %s [b"y\xbf`C\xc1v-6\xab6\xc2\xf6\xd0\x11B\xe1\xfb\x83\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc

In [52]:
### Now do the same thing for the labels
label_embedding_path = '/home/ncorley/protein/ProteinFunctions/data/embeddings/frozen_PubMedBERT_label_embeddings.pkl'

# Load label embeddings
label_embeddings = read_pickle(label_embedding_path)

In [56]:
label_id_map = full_datset.label2int

In [60]:
label_numeric_id_embedding_map = {
    label_id_map[k]: v for k, v in label_embeddings.items() if k in label_id_map
}


In [64]:
 # Get the maximum numeric sequence ID (may be different than length)
max_id = max(label_numeric_id_embedding_map.keys())
print(f"Max ID: {max_id}")

Max ID: 32101


In [66]:
# Create an embedding matrix with zeros for all IDs
label_embedding_matrix = torch.zeros(
    max_id + 1, params['LABEL_EMBEDDING_DIM'], device=device)

In [67]:
# Fill the embedding matrix with the embeddings from the map
for numeric_id, embedding in label_numeric_id_embedding_map.items():
    tensor_embedding = torch.tensor(embedding, device=device)
    label_embedding_matrix[numeric_id] = tensor_embedding

In [68]:
label_embedding_matrix.shape

torch.Size([32102, 768])

In [69]:
label_embedding_layer = nn.Embedding.from_pretrained(
    label_embedding_matrix, freeze=True).to(device)  # Assuming you want to freeze the embeddings and not train them

In [25]:
embeddings = label_embeddings['embedding']
# Convert the embeddings to a tensor
label_embedding_matrix = torch.tensor(embeddings, device=device)

  label_embedding_matrix = torch.tensor(embeddings, device=device)


In [29]:
label_embeddings['embedding']

GO:0000001    [-0.18025249, -0.02392223, 0.33825943, 0.15572...
GO:0000002    [-0.2219825, -0.051336795, 0.33668807, 0.14193...
GO:0000003    [-0.2465049, -0.19750126, 0.16020584, 0.038120...
GO:0000005    [-0.15026829, 0.14948216, 0.203052, 0.06520696...
GO:0000006    [-0.07081436, -0.40411177, 0.14208032, 0.02184...
                                    ...                        
GO:2001313    [-0.079838775, -0.22622626, 0.097667456, 0.431...
GO:2001314    [-0.09654722, -0.2291644, 0.187317, 0.35629278...
GO:2001315    [-0.22437151, 0.0012431458, 0.16394448, 0.3430...
GO:2001316    [-0.11879057, -0.1450008, 0.21268089, 0.440057...
GO:2001317    [-0.12991917, -0.11843379, 0.18660213, 0.36462...
Name: embedding, Length: 47401, dtype: object

In [26]:
label_embedding_matrix.shape

torch.Size([47401, 768])

# Test model by hand

In [5]:
import logging
from pathlib import Path
import sys
import os
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))

from src.utils.data import read_pickle, read_json
from src.utils.losses import contrastive_loss
from src.data.collators import collate_variable_sequence_length
from src.data.datasets import ProteinDataset
from src.models.ProTCL import ProTCL
import torch
import wandb
import os
import datetime
from torchmetrics.classification import BinaryPrecision, BinaryRecall

  from .autonotebook import tqdm as notebook_tqdm


In [31]:
# Example sequences and labels
sequences = ["SEQ1", "SEQ2", "SEQ3"]
labels = torch.tensor([
    [1, 0, 0, 1, 0, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 0, 1, 0, 0]
], dtype=torch.float32)

# Example sequence embeddings
sequence_to_embeddings_dict = {
    "SEQ1": torch.tensor([0.1, 0.2, 0.3]),
    "SEQ2": torch.tensor([0.4, 0.5, 0.6]),
    "SEQ3": torch.tensor([0.7, 0.8, 0.9])
}

# Example label embeddings
label_embeddings = torch.tensor([
    [0.01, 0.02],
    [0.03, 0.04],
    [0.05, 0.06], # Embeddings for all-zero columns
    [0.07, 0.08],
    [0.09, 0.10],  # Embeddings for all-zero columns
    [0.11, 0.12]   # Embeddings for all-zero columns
])

print("Label embeddings:", label_embeddings.shape)
print("Labels:", labels.shape)

# Constants
PROTEIN_EMBEDDING_DIM = 3
LABEL_EMBEDDING_DIM = 2
LATENT_EMBEDDING_DIM = 2
TEMPERATURE = 1.0
LEARNING_RATE = 0.001

# Assuming you're using a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move tensors to GPU
label_embeddings = label_embeddings.to(device)
sequence_to_embeddings_dict = {seq: embedding.to(device) for seq, embedding in sequence_to_embeddings_dict.items()}

Label embeddings: torch.Size([6, 2])
Labels: torch.Size([3, 6])


In [32]:
# Initialize the model
model = ProTCL(protein_embedding_dim=PROTEIN_EMBEDDING_DIM, 
                  label_embedding_dim=LABEL_EMBEDDING_DIM, 
                  latent_dim=LATENT_EMBEDDING_DIM,
                  temperature=TEMPERATURE,
                  sequence_to_embeddings_dict=sequence_to_embeddings_dict,
                  ordered_label_embeddings=label_embeddings).to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [33]:
%load_ext autoreload
%autoreload 2

# Test the forward method
P_e, L_e, target = model(sequences, labels)
print(P_e)
print(L_e)
print(target)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
tensor([[-0.8432,  0.5377],
        [-0.8763,  0.4817],
        [-0.8845,  0.4665]], device='cuda:0', grad_fn=<DivBackward0>)
tensor([[-0.7129,  0.7013],
        [-0.7731,  0.6343],
        [-0.7986,  0.6019]], device='cuda:0', grad_fn=<DivBackward0>)
tensor([[1., 0., 1.],
        [0., 1., 1.],
        [0., 0., 1.]])


In [34]:
# Test the contrastive loss function
loss = contrastive_loss(P_e, L_e, model.t, target)
print(loss)

tensor(1.6116, grad_fn=<DivBackward0>)


# Testing loop

In [1]:
import logging
from pathlib import Path
import sys
import os
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))

from src.utils.data import read_pickle, read_json
from src.utils.losses import contrastive_loss
from src.data.collators import collate_variable_sequence_length
from src.data.datasets import ProteinDataset
from ProteinFunctions.src.models.ProTCL import ProTCL
import torch
import wandb
import os
import datetime
from torchmetrics.classification import BinaryPrecision, BinaryRecall

# Get the root path from the environment variable
ROOT_PATH = os.environ.get('ROOT_PATH', '.')  # Default to current directory if ROOT_PATH is not set

# Load the configuration file
config = read_json(os.path.join(ROOT_PATH, 'config.json'))
params = config['params']
relative_paths = config['relative_paths']

# Initialize logging
logging.basicConfig(filename='train.log', filemode='w',
                    format='%(asctime)s %(levelname)-4s %(message)s',
                    level=logging.INFO,
                    datefmt='%Y-%m-%d %H:%M:%S %Z')

# Initialize new run in W&B
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
logging.info(
    f"################## {timestamp} RUNNING test_ProConNet.py ##################")

config={
    "LEARNING_RATE": 0.001,
    "TEMPERATURE": 0.07,
    "PROTEIN_EMBEDDING_DIM": 1100,
    "LABEL_EMBEDDING_DIM": 768,
    "LATENT_EMBEDDING_DIM": 934,
    "NUM_EPOCHS": 10,
    "BATCH_SIZE": 1000,
    "DECISION_TH": 0.88
}

# Data paths
TRAIN_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/train_GO.fasta'
VAL_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/dev_GO.fasta'
TEST_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/test_GO.fasta'
AMINO_ACID_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/amino_acid_vocab.json'
GO_LABEL_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/GO_label_vocab.json'

# Embedding paths
LABEL_EMBEDDING_PATH = "/home/ncorley/protein/ProteinFunctions/data/embeddings/label_embeddings.pk1"
SEQUENCE_EMBEDDING_PATH = "/home/ncorley/protein/ProteinFunctions/data/embeddings/sequence_embeddings.pk1"

# Directory to save models
OUTPUT_MODEL_DIR = "/home/ncorley/protein/ProteinFunctions/models/ProConNet"

# Model to load
LOAD_MODEL_PATH = "/home/ncorley/protein/ProteinFunctions/models/ProConNet/2023-09-05_23-28-25_best_Pro_ConNet.pth"

# Test only
TRAIN_MODEL = False
TEST_MODEL = True

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load datasets
test_dataset = ProteinDataset(data_path=TEST_DATA_PATH,
                              sequence_vocabulary_path=AMINO_ACID_VOCAB_PATH,
                              label_vocabulary_path=GO_LABEL_VOCAB_PATH)

# Define data loaders
# train_loader = torch.utils.data.DataLoader(
#     train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, collate_fn=collate_variable_sequence_length, num_workers=2)
# val_loader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, collate_fn=collate_variable_sequence_length, num_workers=2)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=10, shuffle=True, collate_fn=collate_variable_sequence_length, num_workers=2)

# Load sequence embeddings
if os.path.exists(SEQUENCE_EMBEDDING_PATH):
    sequence_to_embeddings_dict_raw = read_pickle(SEQUENCE_EMBEDDING_PATH)[
        ['sequence', 'embedding']].set_index('sequence')['embedding'].to_dict()

    # Convert embeddings in the dictionary to tensors
    sequence_to_embeddings_dict = {seq: torch.tensor(
        embedding) for seq, embedding in sequence_to_embeddings_dict_raw.items()}
else:
    raise ValueError("Sequence embeddings not found.")

# Load label embeddings (in the same order as the label vocabulary)
if os.path.exists(LABEL_EMBEDDING_PATH):
    # Load a dictionary mapping GO IDs to embeddings
    label_embedding_dict = read_pickle(LABEL_EMBEDDING_PATH)[
        ['go_id', 'embedding']].set_index('go_id')['embedding'].to_dict()
    # Create a tensor of embeddings in the correct order (i.e., the order of the label vocabulary)
    label_embeddings = torch.stack(
        # All label vocabularies are the same, so we can use the train dataset
        [torch.tensor(label_embedding_dict[label]) for label in test_dataset.label_vocabulary])
else:
    raise ValueError("Label embeddings not found.")

# Move tensors to GPU
label_embeddings = label_embeddings.to(device)
sequence_to_embeddings_dict = {seq: embedding.to(
    device) for seq, embedding in sequence_to_embeddings_dict.items()}


# Initialize the model
model = ProTCL(protein_embedding_dim=config['PROTEIN_EMBEDDING_DIM'],
                  label_embedding_dim=config['LABEL_EMBEDDING_DIM'],
                  latent_dim=config['LATENT_EMBEDDING_DIM'],
                  temperature=config['TEMPERATURE'],
                  sequence_to_embeddings_dict=sequence_to_embeddings_dict,
                  ordered_label_embeddings=label_embeddings).to(device)

# Load the model weights if LOAD_MODEL_PATH is provided and exists
if LOAD_MODEL_PATH is not None and os.path.exists(LOAD_MODEL_PATH):
    logging.info(f"Loading model weights from {LOAD_MODEL_PATH}...")
    model.load_state_dict(torch.load(LOAD_MODEL_PATH))

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1000, shuffle=True, collate_fn=collate_variable_sequence_length, num_workers=2)

In [None]:
print(len(test_loader))

In [6]:
 ####### TESTING LOOP #######
logging.info("Starting testing...")
model.eval()

# Initialize metrics
total_test_loss = 0
at_least_one_positive_pred = torch.tensor(0, dtype=int).to(
    device)  # seqs with at least one positive label prediction
n = torch.tensor(0, dtype=int).to(device)
seqwise_precision = BinaryPrecision(threshold=config['DECISION_TH'],
                                    multidim_average='samplewise').to(device)
seqwise_recall = BinaryRecall(threshold=config['DECISION_TH'],
                              multidim_average='samplewise').to(device)


with torch.no_grad():
    for test_batch in test_loader:
        # Unpack the test batch
        sequences, sequence_lengths, labels = test_batch

        # Convert labels to floats and move to GPU, if available
        labels = labels.float().to(device)

        # Forward pass
        P_e, L_e, true_labels = model(sequences, labels)

        # Compute test loss for the batch
        test_loss = contrastive_loss(P_e, L_e, model.t, true_labels)

        # Accumulate the total test loss
        total_test_loss += test_loss.item()

        # Compute cosine similarities for zero-shot classification
        logits = torch.mm(P_e, L_e.t()) * torch.exp(model.t)

        # Apply sigmoid to get the probabilities for multi-label classification
        probabilities = torch.sigmoid(logits)

        # Throw error
        # throw_error = True
        # if throw_error:
        #     raise ValueError("Error!")

        # Update metrics
        at_least_one_positive_pred += (probabilities >
                                       config['DECISION_TH']).any(axis=1).sum()
        n += true_labels.size(0)
        seqwise_precision(probabilities, true_labels)
        seqwise_recall(probabilities, true_labels)

# Compute average test loss
avg_test_loss = total_test_loss / len(test_loader)

# Compute average precision, recall, coverage, and F1 score
average_precision = seqwise_precision.compute().sum()/at_least_one_positive_pred
average_recall = seqwise_recall.compute().mean()
average_f1_score = 2*average_precision * \
    average_recall/(average_precision+average_recall)
coverage = at_least_one_positive_pred/n

logging.info(
    f"Test Loss: {avg_test_loss}, F1 Score: {average_f1_score}, Precision: {average_precision}, Recall: {average_recall}, Coverage: {coverage}")

logging.info("Testing complete.")

print(f"Test Loss: {avg_test_loss}, F1 Score: {average_f1_score}, Precision: {average_precision}, Recall: {average_recall}, Coverage: {coverage}")

# Close the W&B run
# wandb.log({"test_loss": avg_test_loss, "f1_score": average_f1_score,
#           "precision": average_precision, "recall": average_recall, "coverage": coverage})
# wandb.finish()

Error: You must call wandb.init() before wandb.log()