In [None]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from sklearn.metrics import mean_absolute_error
from tokenizer.tokenizer import MolTranBertTokenizer

# Load the tokenizer
tokenizer = MolTranBertTokenizer('bert_vocab.txt')

# Load model from checkpoint
checkpoint_path = "./checkpoints_u0"  # Update with the actual path
model = LightningModule.load_from_checkpoint(checkpoint_path, config=margs, tokenizer=tokenizer)
model.eval()  # Set model to evaluation mode

# Load test dataset
test_filename = "path/to/test.csv"  # Update with actual test set path
test_dataset = get_dataset(margs.data_root, test_filename, None, aug=False, measure_name=margs.measure_name)

# Create test DataLoader
test_dataloader = DataLoader(
    test_dataset,
    batch_size=margs.batch_size,
    num_workers=margs.num_workers,
    shuffle=False,
    collate_fn=PropertyPredictionDataModule(margs).collate
)
import numpy as np

all_preds = []
all_actuals = []

# Disable gradient computation for inference
with torch.no_grad():
    for batch in test_dataloader:
        idx, mask, targets = batch  # Extract batch components

        idx = idx.to(model.device)
        mask = mask.to(model.device)
        targets = targets.to(model.device)

        # Forward pass
        token_embeddings = model.tok_emb(idx)  # Token embeddings
        x = model.drop(token_embeddings)
        x = model.blocks(x, length_mask=LM(mask.sum(-1)))
        token_embeddings = x

        # Compute mean embedding for the sequence
        input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        loss_input = sum_embeddings / sum_mask

        # Get predictions
        preds = model.net(loss_input).squeeze()

        # Store predictions and actual values
        all_preds.append(preds.cpu().numpy())
        all_actuals.append(targets.cpu().numpy())

# Convert lists to numpy arrays
all_preds = np.concatenate(all_preds)
all_actuals = np.concatenate(all_actuals)

# Compute MAE
mae = mean_absolute_error(all_actuals, all_preds)
print(f"Test MAE: {mae:.4f}")
