In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.checkpoint import checkpoint_sequential
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from collections import OrderedDict

# ----------------------------
# Data Loading Utility Functions
# ----------------------------
def print_to_file(msg):
    print(msg)  # Replace with your logging function if needed

def fast_load_dfs(train_path='train.feather', val_path='val.feather', frac=0.01):
    """
    Load train and validation DataFrames from Feather files,
    and return a fraction of the rows (default 1%).

    Parameters:
        train_path (str): Path to the training DataFrame feather file.
        val_path (str): Path to the validation DataFrame feather file.
        frac (float): Fraction of rows to return from each DataFrame (0 < frac <= 1).
                      Default is 0.01 (i.e. 1% of the data).

    Returns:
        tuple: (train_df, val_df) as sampled Pandas DataFrames.
    """
    train_df = pd.read_feather(train_path)
    val_df = pd.read_feather(val_path)
    print_to_file(f"Loaded train DataFrame from {train_path} and validation DataFrame from {val_path}.")

    if frac < 1.0:
        train_df = train_df.sample(frac=frac, random_state=42)
        val_df = val_df.sample(frac=frac, random_state=42)

    return train_df, val_df

def get_global_vocab_and_cog2idx_from_df(df):
    """
    Extract columns from the DataFrame that correspond to COG or arCOG entries.
    """
    global_vocab = [col for col in df.columns if col.startswith('COG') or col.startswith('arCOG')]
    global_vocab = sorted(global_vocab)
    cog2idx = {cog: idx for idx, cog in enumerate(global_vocab)}
    return global_vocab, cog2idx

class GenomeDataset(Dataset):
    """
    A PyTorch Dataset for genome data that simulates noise in gene copy number observations.
    
    For each sample (a row in a DataFrame), the dataset:
      1. Extracts gene counts for a given vocabulary.
      2. Constructs a binary ground truth vector (1 if count > 0, else 0).
      3. Simulates false negatives by randomly dropping some present genes.
      4. Simulates false positives by injecting spurious signals in genes that are absent.
      5. Returns a dictionary containing:
            - 'tokens': a 2D NumPy array of shape (num_observed_tokens, 2) where the first
              column holds the gene (COG) index and the second column the (possibly noisy) count.
            - 'target': the full ground truth binary vector.
            
    Parameters:
      - df: pandas DataFrame with rows as genomes and columns corresponding to gene families.
      - global_vocab: list of column names that define the gene families.
      - cog2idx: dictionary mapping gene family names to integer indices.
      - false_negative_rate: probability of dropping a gene that is truly present.
      - false_positive_rate: rate (as a fraction of vocab_size) for spurious false positive injections.
      - count_noise_std: standard deviation for multiplicative noise on the count (currently not used,
                         but can be incorporated for more realistic simulation).
      - random_state: seed for reproducibility.
    """
    def __init__(self, df, global_vocab, cog2idx,
                 false_negative_rate=0.3, false_positive_rate=0.005,
                 count_noise_std=0.0, random_state=None):
        self.df = df.reset_index(drop=True)
        self.global_vocab = global_vocab
        self.cog2idx = cog2idx
        self.vocab_size = len(global_vocab)
        self.false_negative_rate = false_negative_rate
        self.false_positive_rate = false_positive_rate
        self.count_noise_std = count_noise_std
        # Use a new numpy random generator for reproducibility.
        self.rng = np.random.default_rng(random_state)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Extract the row corresponding to the genome.
        row = self.df.iloc[idx]
        # Get gene counts for the defined vocabulary, as floats.
        counts = row[self.global_vocab].values.astype(np.float32)
        # Create the ground truth binary vector: 1 if count > 0, else 0.
        target = (counts > 0).astype(np.float32)
        
        observed_indices = []
        observed_counts = []
        
        # Loop over each gene family (by its index in target)
        for cog_idx, present in enumerate(target):
            if present:
                # With probability equal to false_negative_rate, simulate a false negative by dropping the gene.
                if self.rng.random() < self.false_negative_rate:
                    continue
                original_count = counts[cog_idx]
                # If you want to simulate count noise, you can modify the following line.
                # For example: noise_factor = self.rng.normal(1, self.count_noise_std) if self.count_noise_std > 0 else 1
                noise_factor = 1  # Currently no noise is applied.
                noisy_count = max(original_count * noise_factor, 0.0)
                observed_indices.append(cog_idx)
                observed_counts.append(noisy_count)
        
        # Simulate false positives: determine how many FP events to inject.
        num_false_positives = self.rng.poisson(lam=self.false_positive_rate * self.vocab_size)
        # Get indices of gene families that are absent.
        absent_indices = np.where(target == 0)[0]
        if len(absent_indices) > 0 and num_false_positives > 0:
            # Choose a subset of absent indices to be false positives.
            false_pos = self.rng.choice(absent_indices,
                                        size=min(num_false_positives, len(absent_indices)),
                                        replace=False)
            for fp in false_pos:
                # Here we assign a count of 1 for false positives.
                noisy_count = 1
                observed_indices.append(fp)
                observed_counts.append(noisy_count)
        
        # Create the tokens array: each row is [gene_index, observed_count].
        if len(observed_indices) == 0:
            tokens = np.empty((0, 2), dtype=np.float32)
        else:
            tokens = np.stack([np.array(observed_indices, dtype=np.int64),
                               np.array(observed_counts, dtype=np.float32)], axis=-1)
        
        # Return a dictionary with tokens and the full ground truth target vector.
        sample = {
            'tokens': tokens,
            'target': target
        }
        return sample
def collate_genomes(batch, pad_idx):
    batch_tokens = [sample['tokens'] for sample in batch]
    targets = [sample['target'] for sample in batch]
    batch_size = len(batch_tokens)
    max_len = max(tokens.shape[0] for tokens in batch_tokens)
    tokens_padded = np.full((batch_size, max_len, 2), fill_value=0, dtype=np.float32)
    mask = np.ones((batch_size, max_len), dtype=bool)
    
    for i, tokens in enumerate(batch_tokens):
        length = tokens.shape[0]
        if length > 0:
            tokens_padded[i, :length, :] = tokens
            mask[i, :length] = False
        if length < max_len:
            tokens_padded[i, length:, 0] = pad_idx
    tokens_padded = torch.tensor(tokens_padded)
    mask = torch.tensor(mask)
    targets = torch.tensor(np.stack(targets, axis=0), dtype=torch.float32)
    return tokens_padded, mask, targets

# ----------------------------
# Model Definitions (SetTransformer)
# ----------------------------
class SAB(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.0):
        super(SAB, self).__init__()
        self.mha = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout)
        self.ln1 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
        self.ln2 = nn.LayerNorm(dim)
    
    def forward(self, X, mask=None):
        attn_out, _ = self.mha(X, X, X, key_padding_mask=mask)
        X = self.ln1(X + attn_out)
        ff_out = self.ff(X)
        out = self.ln2(X + ff_out)
        return out

class PMA(nn.Module):
    def __init__(self, dim, num_seeds, num_heads, dropout=0.0):
        super(PMA, self).__init__()
        self.num_seeds = num_seeds
        self.seed = nn.Parameter(torch.randn(num_seeds, dim))
        self.mha = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout)
        self.ln = nn.LayerNorm(dim)
    
    def forward(self, X, mask=None):
        batch_size = X.shape[1]
        S = self.seed.unsqueeze(1).expand(-1, batch_size, -1)
        pooled, _ = self.mha(S, X, X, key_padding_mask=mask)
        pooled = pooled + S
        pooled = self.ln(pooled)
        return pooled

class GenomeSetTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=4, num_sab=2, dropout=0.1):
        super(GenomeSetTransformer, self).__init__()
        self.vocab_size = vocab_size
        self.pad_idx = vocab_size  # extra token for padding
        self.cog_embedding = nn.Embedding(num_embeddings=vocab_size + 1, embedding_dim=d_model, padding_idx=self.pad_idx)
        self.count_linear = nn.Linear(1, d_model)
        self.sab_blocks = nn.ModuleList([SAB(dim=d_model, num_heads=num_heads, dropout=dropout) for _ in range(num_sab)])
        self.pma = PMA(dim=d_model, num_seeds=1, num_heads=num_heads, dropout=dropout)
        self.decoder = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.ReLU(),
            nn.Linear(d_model, vocab_size)
        )
    
    def forward(self, tokens, mask):
        B, N, _ = tokens.size()
        cog_ids = tokens[:, :, 0].long()
        binary_counts = (tokens[:, :, 1].float() > 0).float().unsqueeze(-1)
        emb_cog = self.cog_embedding(cog_ids)
        emb_count = self.count_linear(binary_counts)
        X = emb_cog + emb_count
        
        local_features = X.mean(dim=1)
        X = X.transpose(0, 1)
        for sab in self.sab_blocks:
            X = sab(X, mask=mask)
        pooled = self.pma(X, mask=mask)
        pooled = pooled.squeeze(0)
        combined = torch.cat([pooled, local_features], dim=1)
        logits = self.decoder(combined)
        probs = torch.sigmoid(logits)
        return probs

# ----------------------------
# Model Definitions (BinaryMLM)
# ----------------------------
class BinaryMLMModel(nn.Module):
    def __init__(self, 
                 vocab_size=3,
                 embed_dim=512, 
                 num_layers=6, 
                 num_heads=8, 
                 dropout=0.1, 
                 max_seq_len=5000):
        super(BinaryMLMModel, self).__init__()
        self.embed_dim = embed_dim
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dropout=dropout,
            activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.mlm_head = nn.Linear(embed_dim, 2)
        self.init_weights()
        
    def init_weights(self):
        nn.init.xavier_uniform_(self.token_embedding.weight)
        nn.init.xavier_uniform_(self.position_embedding.weight)
        for name, param in self.named_parameters():
            if 'weight' in name and param.dim() > 1:
                nn.init.xavier_uniform_(param)
                
    def forward(self, input_ids, labels=None):
        batch_size, seq_length = input_ids.shape
        positions = torch.arange(0, seq_length, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_length)
        token_embeds = self.token_embedding(input_ids)
        pos_embeds = self.position_embedding(positions)
        x = token_embeds + pos_embeds
        x = self.dropout(x)
        x = x.transpose(0, 1)
        num_layers = len(self.transformer.layers)
        x = checkpoint_sequential(self.transformer.layers, num_layers, x)
        x = x.transpose(0, 1)
        logits = self.mlm_head(x)
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1), ignore_index=-100)
        return logits, loss





In [2]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from collections import OrderedDict
from tqdm import tqdm

# --- Assume helper functions and classes are defined:
# fast_load_dfs, get_global_vocab_and_cog2idx_from_df,
# GenomeDataset, collate_genomes,
# GenomeSetTransformer, BinaryMLMModel
# --------------------------------------------------

def validate_per_sample_setT(model, dataloader, device, threshold=0.5, lower_thresh=0.2, upper_thresh=0.8):
    """
    Evaluates the SetTransformer on the given dataloader.
    
    For each sample, this function computes two sets of metrics:
      1. "Do Nothing" (noisy) metrics computed directly on the observed input
         (i.e. the binary vector reconstructed from the input tokens).
      2. Processed metrics computed from the SetTransformerâ€™s output probabilities,
         including standard metrics (accuracy, precision, recall, f1, genome_size_diff)
         and confidence fractions (top, bottom, and masked).
    
    Returns a list of dictionaries (one per sample) with all computed metrics.
    """
    model.eval()
    sample_metrics = []
    sample_index = 0
    for tokens, mask, targets in tqdm(dataloader, desc="Evaluating (extended)", leave=False):
        tokens = tokens.to(device)
        mask = mask.to(device)
        targets = targets.to(device)
        
        # Compute model probabilities from SetTransformer
        with torch.no_grad():
            probs = model(tokens, mask)  # shape: (B, vocab_size)
        # For SetTransformer output, binary predictions using fixed threshold
        preds_np = (probs.cpu().detach().numpy() >= threshold).astype(int)
        # Ground truth (binary) target
        targets_np = targets.cpu().detach().numpy().astype(int)
        # Also keep the raw probabilities (for computing confidence fractions)
        st_probs_np = probs.cpu().detach().numpy()
        
        # Reconstruct the "observed" (noisy) input from the raw tokens
        tokens_np = tokens.cpu().detach().numpy()
        mask_np = mask.cpu().detach().numpy()
        
        batch_size = targets_np.shape[0]
        V = targets_np.shape[1]
        
        for i in range(batch_size):
            # -------------------------
            # 1. "Do Nothing" / Noisy Input Metrics
            # -------------------------
            # Reconstruct observed binary vector from tokens (ignoring padded positions)
            observed = np.zeros(V, dtype=int)
            valid_tokens = tokens_np[i][mask_np[i] == False]
            for token in valid_tokens:
                cog_id = int(token[0])
                if cog_id < V:
                    observed[cog_id] = 1
            # Compute "noisy" metrics
            TP_noisy = np.sum((targets_np[i] == 1) & (observed == 1))
            TN_noisy = np.sum((targets_np[i] == 0) & (observed == 0))
            FP_noisy = np.sum((targets_np[i] == 0) & (observed == 1))
            FN_noisy = np.sum((targets_np[i] == 1) & (observed == 0))
            noisy_acc = (TP_noisy + TN_noisy) / V
            noisy_prec = TP_noisy / (TP_noisy + FP_noisy) if (TP_noisy + FP_noisy) > 0 else 0.0
            noisy_rec = TP_noisy / (TP_noisy + FN_noisy) if (TP_noisy + FN_noisy) > 0 else 0.0
            noisy_f1 = 2 * noisy_prec * noisy_rec / (noisy_prec + noisy_rec) if (noisy_prec + noisy_rec) > 0 else 0.0
            observed_size = np.sum(observed)
            true_size = np.sum(targets_np[i])
            noisy_genome_diff = np.abs(observed_size / true_size) if true_size > 0 else np.nan
            
            # -------------------------
            # 2. SetTransformer (Pre-MLM) Metrics from Processed Output
            # -------------------------
            # Binary predictions computed earlier:
            proc_preds = preds_np[i]
            TP = np.sum((targets_np[i] == 1) & (proc_preds == 1))
            TN = np.sum((targets_np[i] == 0) & (proc_preds == 0))
            FP = np.sum((targets_np[i] == 0) & (proc_preds == 1))
            FN = np.sum((targets_np[i] == 1) & (proc_preds == 0))
            acc = (TP + TN) / V
            prec = TP / (TP + FP) if (TP + FP) > 0 else 0.0
            rec = TP / (TP + FN) if (TP + FN) > 0 else 0.0
            f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
            proc_pred_size = np.sum(proc_preds)
            genome_diff = np.abs(proc_pred_size / true_size) if true_size > 0 else np.nan
            
            # Confidence fractions from the SetTransformer probabilities:
            sample_probs = st_probs_np[i]
            top_conf_fraction = np.mean(sample_probs > upper_thresh)
            bottom_conf_fraction = np.mean(sample_probs < lower_thresh)
            masked_fraction = np.mean((sample_probs >= lower_thresh) & (sample_probs <= upper_thresh))
            
            # FP removed & FN recovered are computed relative to the observed input:
            fp_noise = (observed == 1) & (targets_np[i] == 0)
            fn_noise = (observed == 0) & (targets_np[i] == 1)
            fp_removed = (np.sum((proc_preds == 0) & fp_noise) / np.sum(fp_noise)
                          if np.sum(fp_noise) > 0 else np.nan)
            fn_recovered = (np.sum((proc_preds == 1) & fn_noise) / np.sum(fn_noise)
                            if np.sum(fn_noise) > 0 else np.nan)
            
            # -------------------------
            # Combine metrics into one dictionary for this sample.
            # -------------------------
            sample_metrics.append({
                "sample_id": sample_index,
                # Noisy (do nothing) metrics:
                "noisy_accuracy": noisy_acc,
                "noisy_precision": noisy_prec,
                "noisy_recall": noisy_rec,
                "noisy_f1": noisy_f1,
                "noisy_genome_diff": noisy_genome_diff,
                # Processed (SetTransformer output) metrics:
                "accuracy": acc,
                "precision": prec,
                "recall": rec,
                "f1": f1,
                "genome_size_diff": genome_diff,
                "fp_removed": fp_removed,
                "fn_recovered": fn_recovered
            })
            sample_index += 1
    return sample_metrics


In [3]:
def validate_per_sample_extended(set_transformer, binary_mlm, dataloader, device,global_vocab,
                                 threshold=0.5, apply_mlm=True, block_frac=0.15, dummy_mode=False,
                                 label_string="default"):
    """
    Evaluates the SetTransformer model on the dataloader and computes per-sample metrics,
    while also accumulating TP, TN, FP, FN counts per COG (token) over all samples.
    
    For each sample:
      1. Compute SetTransformer probabilities and a discrete assignment by thresholding at 0.5.
      2. Reconstruct the observed (noisy) input from the tokens.
      3. Compute pre-MLM metrics using the discrete assignment.
      4. Accumulate per-COG counts from these predictions.
      5. Optionally (if apply_mlm is True), refine the combined input:
           - Form the combined base by OR-ing the observed input and the discrete assignment.
           - Partition the entire vocabulary [0, V) into blocks of size ceil(V*block_frac) (at least 1 token per block).
           - For each block, mask the tokens (set to 2), run binaryMLM, and replace the tokens in that block
             with the binary MLM predictions.
           - Compute MLM metrics from the final refined prediction and also accumulate per-COG counts.
      6. Append per-sample metrics to a list.
    
    After processing all samples, write out the per-COG aggregated counts to
         "COG_metrics_"+label_string+".csv".
    
    Returns:
      A list of dictionaries (one per sample) with the computed metrics.
    """
    sample_metrics = []
    global_sample_index = 0

    # These will accumulate counts per COG. They will be initialized on the first batch.
    global_pre_TP = None
    global_pre_TN = None
    global_pre_FP = None
    global_pre_FN = None
    if apply_mlm:
        global_mlm_TP = None
        global_mlm_TN = None
        global_mlm_FP = None
        global_mlm_FN = None

    for tokens, mask, targets in tqdm(dataloader, desc="Evaluating (extended)", leave=False):
        tokens = tokens.to(device)
        mask = mask.to(device)
        targets = targets.to(device)
        
        with torch.no_grad():
            probs = set_transformer(tokens, mask)  # shape: (B, vocab_size)
        st_probs_np = probs.cpu().detach().numpy()  # shape: (B, V)
        # Pre-MLM discrete assignment: threshold at 0.5.
        discrete_assignment = (st_probs_np >= threshold).astype(int)
        
        targets_np = targets.cpu().detach().numpy().astype(int)
        tokens_np = tokens.cpu().detach().numpy()
        mask_np = mask.cpu().detach().numpy()
        batch_size = targets_np.shape[0]
        V = targets_np.shape[1]
        
        # Initialize per-COG accumulators on the first batch.
        if global_pre_TP is None:
            global_pre_TP = np.zeros(V, dtype=int)
            global_pre_TN = np.zeros(V, dtype=int)
            global_pre_FP = np.zeros(V, dtype=int)
            global_pre_FN = np.zeros(V, dtype=int)
            if apply_mlm:
                global_mlm_TP = np.zeros(V, dtype=int)
                global_mlm_TN = np.zeros(V, dtype=int)
                global_mlm_FP = np.zeros(V, dtype=int)
                global_mlm_FN = np.zeros(V, dtype=int)
        
        for i in range(batch_size):
            # Reconstruct observed (noisy) input.
            observed = np.zeros(V, dtype=int)
            valid_tokens = tokens_np[i][mask_np[i] == False]
            for token in valid_tokens:
                cog_id = int(token[0])
                if cog_id < V:
                    observed[cog_id] = 1
            
            # Pre-MLM predictions from SetTransformer.
            pre_preds = discrete_assignment[i]
            
            # Accumulate per-COG counts for pre-MLM predictions.
            global_pre_TP += ((targets_np[i] == 1) & (pre_preds == 1)).astype(int)
            global_pre_TN += ((targets_np[i] == 0) & (pre_preds == 0)).astype(int)
            global_pre_FP += ((targets_np[i] == 0) & (pre_preds == 1)).astype(int)
            global_pre_FN += ((targets_np[i] == 1) & (pre_preds == 0)).astype(int)
            
            # Compute "noisy" metrics
            TP_noisy = np.sum((targets_np[i] == 1) & (observed == 1))
            TN_noisy = np.sum((targets_np[i] == 0) & (observed == 0))
            FP_noisy = np.sum((targets_np[i] == 0) & (observed == 1))
            FN_noisy = np.sum((targets_np[i] == 1) & (observed == 0))
            noisy_acc = (TP_noisy + TN_noisy) / V
            prec_noisy = TP_noisy / (TP_noisy + FP_noisy) if (TP_noisy + FP_noisy) > 0 else 0.0
            rec_noisy = TP_noisy / (TP_noisy + FN_noisy) if (TP_noisy + FN_noisy) > 0 else 0.0

            # Compute pre-MLM metrics.
            TP = np.sum((targets_np[i] == 1) & (pre_preds == 1))
            TN = np.sum((targets_np[i] == 0) & (pre_preds == 0))
            FP = np.sum((targets_np[i] == 0) & (pre_preds == 1))
            FN = np.sum((targets_np[i] == 1) & (pre_preds == 0))
            acc = (TP + TN) / V
            prec = TP / (TP + FP) if (TP + FP) > 0 else 0.0
            rec = TP / (TP + FN) if (TP + FN) > 0 else 0.0
            f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
            true_size = np.sum(targets_np[i])
            pre_pred_size = np.sum(pre_preds)
            genome_diff = np.abs(pre_pred_size / true_size) if true_size > 0 else np.nan
            fp_noise = (observed == 1) & (targets_np[i] == 0)
            fn_noise = (observed == 0) & (targets_np[i] == 1)
            fp_removed = (np.sum((pre_preds == 0) & fp_noise) / np.sum(fp_noise)
                          if np.sum(fp_noise) > 0 else np.nan)
            fn_recovered = (np.sum((pre_preds == 1) & fn_noise) / np.sum(fn_noise)
                            if np.sum(fn_noise) > 0 else np.nan)
            
            sample_dict = {
                "sample_id": global_sample_index,
                "noisy_accuracy": (np.sum((targets_np[i] == 1) & (observed == 1)) +
                                    np.sum((targets_np[i] == 0) & (observed == 0))) / V,
                "noisy_precision": TP_noisy / (TP_noisy + FP_noisy) if (TP_noisy + FP_noisy) > 0 else 0.0,
                "noisy_recall": TP_noisy / (TP_noisy + FN_noisy) if (TP_noisy + FN_noisy) > 0 else 0.0,
                "noisy_f1": 2 * prec_noisy * rec_noisy / (prec_noisy + rec_noisy) if (prec_noisy + rec_noisy) > 0 else 0.0,
                "noisy_genome_diff": np.abs(np.sum(observed) / true_size) if true_size > 0 else np.nan,
                "accuracy": acc,
                "precision": prec,
                "recall": rec,
                "f1": f1,
                "genome_size_diff": genome_diff,
                "fp_removed": fp_removed,
                "fn_recovered": fn_recovered
            }
            
            # Optional: MLM refinement step.
            if apply_mlm:
                true_size = np.sum(targets_np[i])
                # Combined base: OR of observed and pre_preds.
                combined_base = np.maximum(observed, pre_preds)
#                combined_base = np.maximum(pre_preds, pre_preds)
                current_output = combined_base.copy()
                # Partition the vocabulary [0, V) into blocks.
                block_size = max(1, int(np.ceil(V * block_frac)))
                all_indices = np.arange(V)
                # Process each block sequentially.
                for start in range(0, V, block_size):
                    current_block = all_indices[start:start+block_size]
                    modified_input = current_output.copy()
                    modified_input[current_block] = 2  # mask these tokens.
                    if dummy_mode:
                        refined_preds = current_output.copy()
                    else:
                        discrete_tensor = torch.tensor(modified_input).unsqueeze(0).to(device)
                        with torch.no_grad():
                            logits, _ = binary_mlm(discrete_tensor)
                            mlm_probs = torch.softmax(logits, dim=-1)[:, :, 1]
                        refined_preds = (mlm_probs.cpu().detach().numpy() >= threshold).astype(int).squeeze(0)
                    # Replace the tokens in the current block with the MLM predictions.
                    current_output[current_block] = refined_preds[current_block]
                
                final_mlm_preds = current_output
                # Accumulate per-COG counts for MLM predictions.

                global_mlm_TP += ((targets_np[i] == 1) & (final_mlm_preds == 1)).astype(int)
                global_mlm_TN += ((targets_np[i] == 0) & (final_mlm_preds == 0)).astype(int)
                global_mlm_FP += ((targets_np[i] == 0) & (final_mlm_preds == 1)).astype(int)
                global_mlm_FN += ((targets_np[i] == 1) & (final_mlm_preds == 0)).astype(int)
                    
                TP_mlm = np.sum((targets_np[i] == 1) & (final_mlm_preds == 1))
                TN_mlm = np.sum((targets_np[i] == 0) & (final_mlm_preds == 0))
                FP_mlm = np.sum((targets_np[i] == 0) & (final_mlm_preds == 1))
                FN_mlm = np.sum((targets_np[i] == 1) & (final_mlm_preds == 0))
                MLM_acc = (TP_mlm + TN_mlm) / V
                MLM_prec = TP_mlm / (TP_mlm + FP_mlm) if (TP_mlm + FP_mlm) > 0 else 0.0
                MLM_rec = TP_mlm / (TP_mlm + FN_mlm) if (TP_mlm + FN_mlm) > 0 else 0.0
                MLM_f1 = 2 * MLM_prec * MLM_rec / (MLM_prec + MLM_rec) if (MLM_prec + MLM_rec) > 0 else 0.0                
                MLM_pred_size = np.sum(final_mlm_preds)
                MLM_genome_diff = MLM_pred_size / true_size if true_size > 0 else np.nan
                #print(MLM_pred_size,true_size,MLM_genome_diff)
                fp_removed_mlm = (np.sum((final_mlm_preds == 0) & fp_noise) / np.sum(fp_noise)
                                  if np.sum(fp_noise) > 0 else np.nan)
                fn_recovered_mlm = (np.sum((final_mlm_preds == 1) & fn_noise) / np.sum(fn_noise)
                                    if np.sum(fn_noise) > 0 else np.nan)
                sample_dict.update({
                    "MLM_accuracy": MLM_acc,
                    "MLM_precision": MLM_prec,
                    "MLM_recall": MLM_rec,
                    "MLM_f1": MLM_f1,
                    "MLM_genome_diff": MLM_genome_diff,
                    "MLM_fp_removed": fp_removed_mlm,
                    "MLM_fn_recovered": fn_recovered_mlm
                })
            
            sample_metrics.append(sample_dict)
            global_sample_index += 1

    # After processing all samples, write per-COG metrics.
    # Create a dictionary with rows per COG index.
    per_cog_data = {
        "COG": global_vocab,
        "pre_TP": global_pre_TP,
        "pre_TN": global_pre_TN,
        "pre_FP": global_pre_FP,
        "pre_FN": global_pre_FN
    }
    if apply_mlm:
        per_cog_data.update({
            "MLM_TP": global_mlm_TP,
            "MLM_TN": global_mlm_TN,
            "MLM_FP": global_mlm_FP,
            "MLM_FN": global_mlm_FN
        })
    df_cog_metrics = pd.DataFrame(per_cog_data)
    df_cog_metrics.to_csv("COG_metrics_" + label_string + ".csv", index=False)
    
    return sample_metrics

In [27]:
global_vocab

['COG0001',
 'COG0002',
 'COG0003',
 'COG0004',
 'COG0005',
 'COG0006',
 'COG0007',
 'COG0008',
 'COG0009',
 'COG0010',
 'COG0011',
 'COG0012',
 'COG0013',
 'COG0014',
 'COG0015',
 'COG0016',
 'COG0017',
 'COG0018',
 'COG0019',
 'COG0020',
 'COG0021',
 'COG0022',
 'COG0023',
 'COG0024',
 'COG0025',
 'COG0026',
 'COG0027',
 'COG0028',
 'COG0029',
 'COG0030',
 'COG0031',
 'COG0033',
 'COG0034',
 'COG0035',
 'COG0036',
 'COG0037',
 'COG0038',
 'COG0039',
 'COG0040',
 'COG0041',
 'COG0042',
 'COG0043',
 'COG0044',
 'COG0045',
 'COG0046',
 'COG0047',
 'COG0048',
 'COG0049',
 'COG0050',
 'COG0051',
 'COG0052',
 'COG0053',
 'COG0054',
 'COG0055',
 'COG0056',
 'COG0057',
 'COG0058',
 'COG0059',
 'COG0060',
 'COG0061',
 'COG0062',
 'COG0063',
 'COG0064',
 'COG0065',
 'COG0066',
 'COG0067',
 'COG0068',
 'COG0069',
 'COG0070',
 'COG0071',
 'COG0072',
 'COG0073',
 'COG0074',
 'COG0075',
 'COG0076',
 'COG0077',
 'COG0078',
 'COG0079',
 'COG0080',
 'COG0081',
 'COG0082',
 'COG0083',
 'COG0084',
 'CO

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load DataFrames and vocabulary
train_df, val_df = fast_load_dfs('COG_train1.feather', 'COG_val1.feather', frac=0.1)
global_vocab, cog2idx = get_global_vocab_and_cog2idx_from_df(train_df)

# Load pretrained SetTransformer
set_transformer = GenomeSetTransformer(vocab_size=len(global_vocab), d_model=1024,
                                         num_heads=8, num_sab=4, dropout=0.1)
state_dict = torch.load('COG_high_1024_8_4_BCE_60.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
set_transformer.load_state_dict(new_state_dict)
#set_transformer.load_state_dict(torch.load('model_checkpoint_full.pth'))
set_transformer = set_transformer.to(device)

# Load pretrained BinaryMLM
binary_mlm = BinaryMLMModel(vocab_size=3, embed_dim=512, num_layers=6,
                            num_heads=8, dropout=0.1, max_seq_len=len(global_vocab))
state_dict = torch.load('binMLM_512_6_8_01_e50.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
binary_mlm.load_state_dict(new_state_dict)
binary_mlm = binary_mlm.to(device)

# Define FN rates to test and fixed FP rate
fn_rates = [0.0, 0.1, 0.25, 0.5, 0.75, 0.85,0.95,1.0]
fp_rate = 0.0

all_sample_records = []  # Will collect records from both stages
# We'll print summary statistics for Pre-MLM separately for "noisy" (do nothing) metrics and processed ones.

for fn in fn_rates:
    print(f"\nEvaluating for FN rate = {fn} and FP rate = {fp_rate}...")
    val_dataset = GenomeDataset(val_df, global_vocab=global_vocab, cog2idx=cog2idx,
                                false_negative_rate=fn, false_positive_rate=fp_rate,
                                count_noise_std=0.0, random_state=42)
    val_dataloader = DataLoader(val_dataset, batch_size=16,
                                collate_fn=lambda batch: collate_genomes(batch, pad_idx=len(global_vocab)))

    # Evaluate Pre-MLM metrics (using SetTransformer output)
    #pre_metrics = validate_per_sample_extended(set_transformer, val_dataloader, device,
    #                                             threshold=0.5, lower_thresh=-0.2, upper_thresh=0.8)
    pre_metrics = validate_per_sample_extended(set_transformer, binary_mlm, val_dataloader, device,global_vocab,label_string="FN"+repr(fn)+"_FP"+repr(fp_rate))

    for rec in pre_metrics:
        rec["FN_rate"] = fn
        rec["Stage"] = "Pre-MLM"
    # Print summary statistics for Pre-MLM stage (both "noisy" and processed metrics)
    print(f"\nSummary for FN rate {fn}, Stage Pre-MLM (Noisy Input):")
    for metric in ["noisy_accuracy", "noisy_precision", "noisy_recall", "noisy_f1", "noisy_genome_diff"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")

    print(f"\nSummary for FN rate {fn}, Stage Post-SetTransformer (Processed):")
    for metric in ["accuracy", "precision", "recall", "f1", "genome_size_diff", "fp_removed", "fn_recovered"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")

    print(f"\nSummary for FN rate {fn}, Stage Post-MLM (Processed):")
    for metric in ["MLM_accuracy", "MLM_precision", "MLM_recall", "MLM_f1", "MLM_genome_diff", "MLM_fp_removed", "MLM_fn_recovered"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")


    print("-" * 60)
    all_sample_records.extend(pre_metrics)

# Combine all records into a DataFrame and save to CSV.
df_metrics = pd.DataFrame(all_sample_records)
df_metrics.to_csv("per_sample_combined_metrics_FP"+repr(fp_rate)+".csv", index=False)
print("\nSaved per-sample combined metrics to per_sample_combined_metrics_FP"+repr(fp_rate)+".csv")



Loaded train DataFrame from COG_train1.feather and validation DataFrame from COG_val1.feather.

Evaluating for FN rate = 0.0 and FP rate = 0.0...


                                                                                                                                                                         


Summary for FN rate 0.0, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Precision     : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Recall        : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_F1            : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Genome_Diff   : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000

Summary for FN rate 0.0, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9348, 25th=0.9229, Median=0.9351, 75th=0.9474
  Precision           : Mean=0.8815, 25th=0.8560, Median=0.8955, 75th=0.9216
  Recall              : Mean=0.8353, 25th=0.8085, Median=0.8426, 75th=0.8744
  F1                  : Mean=0.8572, 25th=0.8302, Median=0.8661, 75th=0.8948
  Genome_Size_Diff    : Mean=0.9490, 25th=0.9259, Median=0.9467, 75th=0.9671
  Fp_Removed          : No valid values
  Fn_Recovered        : No valid values

Summary for FN rate 0.0, Stag

                                                                                                                                                                         


Summary for FN rate 0.1, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.9746, 25th=0.9683, Median=0.9747, 75th=0.9804
  Noisy_Precision     : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Recall        : Mean=0.9002, 25th=0.8944, Median=0.9003, 75th=0.9061
  Noisy_F1            : Mean=0.9475, 25th=0.9442, Median=0.9475, 75th=0.9507
  Noisy_Genome_Diff   : Mean=0.9002, 25th=0.8944, Median=0.9003, 75th=0.9061

Summary for FN rate 0.1, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9340, 25th=0.9221, Median=0.9340, 75th=0.9463
  Precision           : Mean=0.8817, 25th=0.8566, Median=0.8953, 75th=0.9224
  Recall              : Mean=0.8314, 25th=0.8025, Median=0.8396, 75th=0.8708
  F1                  : Mean=0.8552, 25th=0.8276, Median=0.8642, 75th=0.8933
  Genome_Size_Diff    : Mean=0.9444, 25th=0.9196, Median=0.9426, 75th=0.9659
  Fp_Removed          : No valid values
  Fn_Recovered        : Mean=0.8091, 25th=0.7714, Median=0.8184, 75th=

                                                                                                                                                                         


Summary for FN rate 0.25, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.9363, 25th=0.9207, Median=0.9361, 75th=0.9503
  Noisy_Precision     : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Recall        : Mean=0.7500, 25th=0.7421, Median=0.7498, 75th=0.7581
  Noisy_F1            : Mean=0.8571, 25th=0.8520, Median=0.8570, 75th=0.8624
  Noisy_Genome_Diff   : Mean=0.7500, 25th=0.7421, Median=0.7498, 75th=0.7581

Summary for FN rate 0.25, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9323, 25th=0.9198, Median=0.9324, 75th=0.9449
  Precision           : Mean=0.8821, 25th=0.8562, Median=0.8956, 75th=0.9229
  Recall              : Mean=0.8231, 25th=0.7941, Median=0.8318, 75th=0.8631
  F1                  : Mean=0.8509, 25th=0.8214, Median=0.8598, 75th=0.8894
  Genome_Size_Diff    : Mean=0.9346, 25th=0.9070, Median=0.9329, 75th=0.9582
  Fp_Removed          : No valid values
  Fn_Recovered        : Mean=0.8016, 25th=0.7641, Median=0.8095, 75t

                                                                                                                                                                         


Summary for FN rate 0.5, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.8725, 25th=0.8419, Median=0.8716, 75th=0.9004
  Noisy_Precision     : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Recall        : Mean=0.4998, 25th=0.4899, Median=0.4999, 75th=0.5090
  Noisy_F1            : Mean=0.6663, 25th=0.6576, Median=0.6666, 75th=0.6746
  Noisy_Genome_Diff   : Mean=0.4998, 25th=0.4899, Median=0.4999, 75th=0.5090

Summary for FN rate 0.5, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9273, 25th=0.9140, Median=0.9270, 75th=0.9409
  Precision           : Mean=0.8827, 25th=0.8551, Median=0.8973, 75th=0.9260
  Recall              : Mean=0.7996, 25th=0.7655, Median=0.8083, 75th=0.8417
  F1                  : Mean=0.8382, 25th=0.8080, Median=0.8474, 75th=0.8788
  Genome_Size_Diff    : Mean=0.9076, 25th=0.8761, Median=0.9060, 75th=0.9365
  Fp_Removed          : No valid values
  Fn_Recovered        : Mean=0.7783, 25th=0.7407, Median=0.7863, 75th=

                                                                                                                                                                         


Summary for FN rate 0.75, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.8088, 25th=0.7637, Median=0.8083, 75th=0.8508
  Noisy_Precision     : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Recall        : Mean=0.2500, 25th=0.2413, Median=0.2500, 75th=0.2586
  Noisy_F1            : Mean=0.3999, 25th=0.3888, Median=0.4000, 75th=0.4109
  Noisy_Genome_Diff   : Mean=0.2500, 25th=0.2413, Median=0.2500, 75th=0.2586

Summary for FN rate 0.75, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9122, 25th=0.8960, Median=0.9109, 75th=0.9271
  Precision           : Mean=0.8802, 25th=0.8484, Median=0.8970, 75th=0.9304
  Recall              : Mean=0.7335, 25th=0.6955, Median=0.7406, 75th=0.7807
  F1                  : Mean=0.7985, 25th=0.7634, Median=0.8075, 75th=0.8432
  Genome_Size_Diff    : Mean=0.8362, 25th=0.7890, Median=0.8346, 75th=0.8775
  Fp_Removed          : No valid values
  Fn_Recovered        : Mean=0.7134, 25th=0.6726, Median=0.7206, 75t

                                                                                                                                                                         


Summary for FN rate 0.85, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7833, 25th=0.7321, Median=0.7828, 75th=0.8313
  Noisy_Precision     : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Recall        : Mean=0.1500, 25th=0.1429, Median=0.1499, 75th=0.1567
  Noisy_F1            : Mean=0.2607, 25th=0.2500, Median=0.2607, 75th=0.2709
  Noisy_Genome_Diff   : Mean=0.1500, 25th=0.1429, Median=0.1499, 75th=0.1567

Summary for FN rate 0.85, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.8949, 25th=0.8737, Median=0.8933, 75th=0.9136
  Precision           : Mean=0.8707, 25th=0.8333, Median=0.8927, 75th=0.9310
  Recall              : Mean=0.6678, 25th=0.6271, Median=0.6749, 75th=0.7147
  F1                  : Mean=0.7529, 25th=0.7182, Median=0.7601, 75th=0.7975
  Genome_Size_Diff    : Mean=0.7731, 25th=0.7114, Median=0.7625, 75th=0.8180
  Fp_Removed          : No valid values
  Fn_Recovered        : Mean=0.6503, 25th=0.6069, Median=0.6571, 75t

                                                                                                                                                                         


Summary for FN rate 0.95, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7578, 25th=0.7004, Median=0.7569, 75th=0.8112
  Noisy_Precision     : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_Recall        : Mean=0.0500, 25th=0.0455, Median=0.0499, 75th=0.0542
  Noisy_F1            : Mean=0.0951, 25th=0.0871, Median=0.0950, 75th=0.1028
  Noisy_Genome_Diff   : Mean=0.0500, 25th=0.0455, Median=0.0499, 75th=0.0542

Summary for FN rate 0.95, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.8577, 25th=0.8309, Median=0.8584, 75th=0.8855
  Precision           : Mean=0.8323, 25th=0.7801, Median=0.8644, 75th=0.9201
  Recall              : Mean=0.5456, 25th=0.4985, Median=0.5450, 75th=0.5941
  F1                  : Mean=0.6505, 25th=0.6203, Median=0.6590, 75th=0.6913
  Genome_Size_Diff    : Mean=0.6783, 25th=0.5637, Median=0.6327, 75th=0.7277
  Fp_Removed          : No valid values
  Fn_Recovered        : Mean=0.5370, 25th=0.4898, Median=0.5369, 75t

                                                                                                                                                                         


Summary for FN rate 1.0, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7451, 25th=0.6837, Median=0.7447, 75th=0.8010
  Noisy_Precision     : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_Recall        : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_F1            : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_Genome_Diff   : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000

Summary for FN rate 1.0, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.7451, 25th=0.6837, Median=0.7447, 75th=0.8010
  Precision           : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Recall              : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  F1                  : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Genome_Size_Diff    : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Fp_Removed          : No valid values
  Fn_Recovered        : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=

SyntaxError: 'return' outside function (1054510047.py, line 98)

In [44]:

global_vocab

['COG0001',
 'COG0002',
 'COG0003',
 'COG0004',
 'COG0005',
 'COG0006',
 'COG0007',
 'COG0008',
 'COG0009',
 'COG0010',
 'COG0011',
 'COG0012',
 'COG0013',
 'COG0014',
 'COG0015',
 'COG0016',
 'COG0017',
 'COG0018',
 'COG0019',
 'COG0020',
 'COG0021',
 'COG0022',
 'COG0023',
 'COG0024',
 'COG0025',
 'COG0026',
 'COG0027',
 'COG0028',
 'COG0029',
 'COG0030',
 'COG0031',
 'COG0033',
 'COG0034',
 'COG0035',
 'COG0036',
 'COG0037',
 'COG0038',
 'COG0039',
 'COG0040',
 'COG0041',
 'COG0042',
 'COG0043',
 'COG0044',
 'COG0045',
 'COG0046',
 'COG0047',
 'COG0048',
 'COG0049',
 'COG0050',
 'COG0051',
 'COG0052',
 'COG0053',
 'COG0054',
 'COG0055',
 'COG0056',
 'COG0057',
 'COG0058',
 'COG0059',
 'COG0060',
 'COG0061',
 'COG0062',
 'COG0063',
 'COG0064',
 'COG0065',
 'COG0066',
 'COG0067',
 'COG0068',
 'COG0069',
 'COG0070',
 'COG0071',
 'COG0072',
 'COG0073',
 'COG0074',
 'COG0075',
 'COG0076',
 'COG0077',
 'COG0078',
 'COG0079',
 'COG0080',
 'COG0081',
 'COG0082',
 'COG0083',
 'COG0084',
 'CO

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load DataFrames and vocabulary
train_df, val_df = fast_load_dfs('COG_train1.feather', 'COG_val1.feather', frac=0.1)
global_vocab, cog2idx = get_global_vocab_and_cog2idx_from_df(train_df)

# Load pretrained SetTransformer
set_transformer = GenomeSetTransformer(vocab_size=len(global_vocab), d_model=1024,
                                         num_heads=8, num_sab=4, dropout=0.1)
state_dict = torch.load('COG_high_1024_8_4_BCE_60.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
set_transformer.load_state_dict(new_state_dict)
#set_transformer.load_state_dict(torch.load('model_checkpoint_full.pth'))
set_transformer = set_transformer.to(device)

# Load pretrained BinaryMLM
binary_mlm = BinaryMLMModel(vocab_size=3, embed_dim=512, num_layers=6,
                            num_heads=8, dropout=0.1, max_seq_len=len(global_vocab))
state_dict = torch.load('binMLM_512_6_8_01_e50.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
binary_mlm.load_state_dict(new_state_dict)
binary_mlm = binary_mlm.to(device)

# Define FN rates to test and fixed FP rate
fn_rates = [0.0, 0.1, 0.25, 0.5, 0.75, 0.85,0.95,1.0]
fp_rate = 0.01

all_sample_records = []  # Will collect records from both stages
# We'll print summary statistics for Pre-MLM separately for "noisy" (do nothing) metrics and processed ones.

for fn in fn_rates:
    print(f"\nEvaluating for FN rate = {fn} and FP rate = {fp_rate}...")
    val_dataset = GenomeDataset(val_df, global_vocab=global_vocab, cog2idx=cog2idx,
                                false_negative_rate=fn, false_positive_rate=fp_rate,
                                count_noise_std=0.0, random_state=42)
    val_dataloader = DataLoader(val_dataset, batch_size=16,
                                collate_fn=lambda batch: collate_genomes(batch, pad_idx=len(global_vocab)))

    # Evaluate Pre-MLM metrics (using SetTransformer output)
    #pre_metrics = validate_per_sample_extended(set_transformer, val_dataloader, device,
    #                                             threshold=0.5, lower_thresh=-0.2, upper_thresh=0.8)
    pre_metrics = validate_per_sample_extended(set_transformer, binary_mlm, val_dataloader, device,global_vocab,label_string="FN"+repr(fn)+"_FP"+repr(fp_rate))

    for rec in pre_metrics:
        rec["FN_rate"] = fn
        rec["Stage"] = "Pre-MLM"
    # Print summary statistics for Pre-MLM stage (both "noisy" and processed metrics)
    print(f"\nSummary for FN rate {fn}, Stage Pre-MLM (Noisy Input):")
    for metric in ["noisy_accuracy", "noisy_precision", "noisy_recall", "noisy_f1", "noisy_genome_diff"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")

    print(f"\nSummary for FN rate {fn}, Stage Post-SetTransformer (Processed):")
    for metric in ["accuracy", "precision", "recall", "f1", "genome_size_diff", "fp_removed", "fn_recovered"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")

    print(f"\nSummary for FN rate {fn}, Stage Post-MLM (Processed):")
    for metric in ["MLM_accuracy", "MLM_precision", "MLM_recall", "MLM_f1", "MLM_genome_diff", "MLM_fp_removed", "MLM_fn_recovered"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")


    print("-" * 60)
    all_sample_records.extend(pre_metrics)

# Combine all records into a DataFrame and save to CSV.
df_metrics = pd.DataFrame(all_sample_records)
df_metrics.to_csv("per_sample_combined_metrics_FP"+repr(fp_rate)+".csv", index=False)
print("\nSaved per-sample combined metrics to per_sample_combined_metrics_FP"+repr(fp_rate)+".csv")



Loaded train DataFrame from COG_train1.feather and validation DataFrame from COG_val1.feather.

Evaluating for FN rate = 0.0 and FP rate = 0.01...


                                                                                                                                                                                                       


Summary for FN rate 0.0, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.9901, 25th=0.9891, Median=0.9902, 75th=0.9910
  Noisy_Precision     : Mean=0.9564, 25th=0.9514, Median=0.9626, 75th=0.9701
  Noisy_Recall        : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_F1            : Mean=0.9776, 25th=0.9751, Median=0.9809, 75th=0.9848
  Noisy_Genome_Diff   : Mean=1.0462, 25th=1.0308, Median=1.0389, 75th=1.0511

Summary for FN rate 0.0, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9274, 25th=0.9148, Median=0.9275, 75th=0.9401
  Precision           : Mean=0.8712, 25th=0.8359, Median=0.8906, 75th=0.9265
  Recall              : Mean=0.8138, 25th=0.7880, Median=0.8183, 75th=0.8461
  F1                  : Mean=0.8402, 25th=0.8116, Median=0.8489, 75th=0.8790
  Genome_Size_Diff    : Mean=0.9395, 25th=0.8896, Median=0.9192, 75th=0.9673
  Fp_Removed          : Mean=0.9607, 25th=0.9423, Median=0.9643, 75th=0.9811
  Fn_Recovered        : No valid 

                                                                                                                                                                                                       


Summary for FN rate 0.1, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.9646, 25th=0.9580, Median=0.9647, 75th=0.9708
  Noisy_Precision     : Mean=0.9518, 25th=0.9462, Median=0.9586, 75th=0.9670
  Noisy_Recall        : Mean=0.9001, 25th=0.8945, Median=0.9004, 75th=0.9063
  Noisy_F1            : Mean=0.9251, 25th=0.9213, Median=0.9282, 75th=0.9332
  Noisy_Genome_Diff   : Mean=0.9463, 25th=0.9297, Median=0.9398, 75th=0.9532

Summary for FN rate 0.1, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9270, 25th=0.9144, Median=0.9271, 75th=0.9401
  Precision           : Mean=0.8712, 25th=0.8367, Median=0.8905, 75th=0.9264
  Recall              : Mean=0.8122, 25th=0.7845, Median=0.8171, 75th=0.8452
  F1                  : Mean=0.8393, 25th=0.8113, Median=0.8478, 75th=0.8782
  Genome_Size_Diff    : Mean=0.9376, 25th=0.8878, Median=0.9179, 75th=0.9663
  Fp_Removed          : Mean=0.9606, 25th=0.9412, Median=0.9643, 75th=0.9811
  Fn_Recovered        : Mean=0.80

                                                                                                                                                                                                       


Summary for FN rate 0.25, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.9263, 25th=0.9108, Median=0.9259, 75th=0.9403
  Noisy_Precision     : Mean=0.9428, 25th=0.9363, Median=0.9506, 75th=0.9607
  Noisy_Recall        : Mean=0.7499, 25th=0.7418, Median=0.7498, 75th=0.7583
  Noisy_F1            : Mean=0.8351, 25th=0.8293, Median=0.8379, 75th=0.8443
  Noisy_Genome_Diff   : Mean=0.7961, 25th=0.7781, Median=0.7897, 75th=0.8054

Summary for FN rate 0.25, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9265, 25th=0.9134, Median=0.9265, 75th=0.9397
  Precision           : Mean=0.8715, 25th=0.8360, Median=0.8907, 75th=0.9265
  Recall              : Mean=0.8090, 25th=0.7804, Median=0.8140, 75th=0.8424
  F1                  : Mean=0.8378, 25th=0.8091, Median=0.8460, 75th=0.8773
  Genome_Size_Diff    : Mean=0.9335, 25th=0.8838, Median=0.9156, 75th=0.9627
  Fp_Removed          : Mean=0.9596, 25th=0.9400, Median=0.9630, 75th=0.9808
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.5, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.8625, 25th=0.8317, Median=0.8618, 75th=0.8910
  Noisy_Precision     : Mean=0.9171, 25th=0.9076, Median=0.9276, 75th=0.9420
  Noisy_Recall        : Mean=0.4996, 25th=0.4899, Median=0.4996, 75th=0.5093
  Noisy_F1            : Mean=0.6464, 25th=0.6384, Median=0.6482, 75th=0.6573
  Noisy_Genome_Diff   : Mean=0.5458, 25th=0.5268, Median=0.5393, 75th=0.5562

Summary for FN rate 0.5, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9245, 25th=0.9113, Median=0.9246, 75th=0.9380
  Precision           : Mean=0.8717, 25th=0.8344, Median=0.8918, 75th=0.9271
  Recall              : Mean=0.7993, 25th=0.7697, Median=0.8045, 75th=0.8337
  F1                  : Mean=0.8326, 25th=0.8029, Median=0.8414, 75th=0.8732
  Genome_Size_Diff    : Mean=0.9219, 25th=0.8728, Median=0.9055, 75th=0.9517
  Fp_Removed          : Mean=0.9577, 25th=0.9375, Median=0.9608, 75th=0.9804
  Fn_Recovered        : Mean=0.78

                                                                                                                                                                                                       


Summary for FN rate 0.75, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7989, 25th=0.7526, Median=0.7978, 75th=0.8415
  Noisy_Precision     : Mean=0.8489, 25th=0.8311, Median=0.8651, 75th=0.8906
  Noisy_Recall        : Mean=0.2500, 25th=0.2413, Median=0.2498, 75th=0.2585
  Noisy_F1            : Mean=0.3857, 25th=0.3750, Median=0.3868, 75th=0.3975
  Noisy_Genome_Diff   : Mean=0.2962, 25th=0.2783, Median=0.2895, 75th=0.3050

Summary for FN rate 0.75, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9162, 25th=0.9014, Median=0.9158, 75th=0.9309
  Precision           : Mean=0.8703, 25th=0.8341, Median=0.8904, 75th=0.9271
  Recall              : Mean=0.7607, 25th=0.7258, Median=0.7663, 75th=0.8031
  F1                  : Mean=0.8102, 25th=0.7765, Median=0.8180, 75th=0.8550
  Genome_Size_Diff    : Mean=0.8788, 25th=0.8278, Median=0.8673, 75th=0.9168
  Fp_Removed          : Mean=0.9515, 25th=0.9318, Median=0.9556, 75th=0.9787
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.85, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7733, 25th=0.7207, Median=0.7726, 75th=0.8213
  Noisy_Precision     : Mean=0.7738, 25th=0.7462, Median=0.7937, 75th=0.8301
  Noisy_Recall        : Mean=0.1498, 25th=0.1429, Median=0.1499, 75th=0.1568
  Noisy_F1            : Mean=0.2505, 25th=0.2404, Median=0.2518, 75th=0.2616
  Noisy_Genome_Diff   : Mean=0.1960, 25th=0.1789, Median=0.1889, 75th=0.2039

Summary for FN rate 0.85, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9006, 25th=0.8829, Median=0.9000, 75th=0.9186
  Precision           : Mean=0.8597, 25th=0.8226, Median=0.8825, 75th=0.9244
  Recall              : Mean=0.6996, 25th=0.6559, Median=0.7074, 75th=0.7493
  F1                  : Mean=0.7687, 25th=0.7308, Median=0.7779, 75th=0.8177
  Genome_Size_Diff    : Mean=0.8207, 25th=0.7544, Median=0.8086, 75th=0.8679
  Fp_Removed          : Mean=0.9418, 25th=0.9167, Median=0.9474, 75th=0.9762
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.95, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7479, 25th=0.6898, Median=0.7467, 75th=0.8014
  Noisy_Precision     : Mean=0.5434, 25th=0.4873, Median=0.5603, 75th=0.6209
  Noisy_Recall        : Mean=0.0500, 25th=0.0454, Median=0.0498, 75th=0.0543
  Noisy_F1            : Mean=0.0912, 25th=0.0835, Median=0.0911, 75th=0.0990
  Noisy_Genome_Diff   : Mean=0.0962, 25th=0.0800, Median=0.0890, 75th=0.1023

Summary for FN rate 0.95, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.8471, 25th=0.8215, Median=0.8496, 75th=0.8745
  Precision           : Mean=0.7885, 25th=0.7264, Median=0.8254, 75th=0.8854
  Recall              : Mean=0.5390, 25th=0.4826, Median=0.5398, 75th=0.5995
  F1                  : Mean=0.6281, 25th=0.5904, Median=0.6387, 75th=0.6774
  Genome_Size_Diff    : Mean=0.7212, 25th=0.5694, Median=0.6599, 75th=0.7824
  Fp_Removed          : Mean=0.9060, 25th=0.8718, Median=0.9111, 75th=0.9474
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 1.0, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7351, 25th=0.6740, Median=0.7347, 75th=0.7918
  Noisy_Precision     : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_Recall        : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_F1            : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_Genome_Diff   : Mean=0.0462, 25th=0.0308, Median=0.0389, 75th=0.0511

Summary for FN rate 1.0, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.8116, 25th=0.7847, Median=0.8223, 75th=0.8478
  Precision           : Mean=0.7311, 25th=0.6366, Median=0.7691, 75th=0.8685
  Recall              : Mean=0.4615, 25th=0.3789, Median=0.4607, 75th=0.5364
  F1                  : Mean=0.5363, 25th=0.4964, Median=0.5487, 75th=0.5932
  Genome_Size_Diff    : Mean=0.7244, 25th=0.4495, Median=0.6143, 75th=0.8198
  Fp_Removed          : Mean=0.8942, 25th=0.8537, Median=0.9048, 75th=0.9474
  Fn_Recovered        : Mean=0.46

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load DataFrames and vocabulary
train_df, val_df = fast_load_dfs('COG_train1.feather', 'COG_val1.feather', frac=0.1)
global_vocab, cog2idx = get_global_vocab_and_cog2idx_from_df(train_df)

# Load pretrained SetTransformer
set_transformer = GenomeSetTransformer(vocab_size=len(global_vocab), d_model=1024,
                                         num_heads=8, num_sab=4, dropout=0.1)
state_dict = torch.load('COG_high_1024_8_4_BCE_60.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
set_transformer.load_state_dict(new_state_dict)
#set_transformer.load_state_dict(torch.load('model_checkpoint_full.pth'))
set_transformer = set_transformer.to(device)

# Load pretrained BinaryMLM
binary_mlm = BinaryMLMModel(vocab_size=3, embed_dim=512, num_layers=6,
                            num_heads=8, dropout=0.1, max_seq_len=len(global_vocab))
state_dict = torch.load('binMLM_512_6_8_01_e50.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
binary_mlm.load_state_dict(new_state_dict)
binary_mlm = binary_mlm.to(device)

# Define FN rates to test and fixed FP rate
fn_rates = [0.0, 0.1, 0.25, 0.5, 0.75, 0.85,0.95,1.0]
fp_rate = 0.05

all_sample_records = []  # Will collect records from both stages
# We'll print summary statistics for Pre-MLM separately for "noisy" (do nothing) metrics and processed ones.

for fn in fn_rates:
    print(f"\nEvaluating for FN rate = {fn} and FP rate = {fp_rate}...")
    val_dataset = GenomeDataset(val_df, global_vocab=global_vocab, cog2idx=cog2idx,
                                false_negative_rate=fn, false_positive_rate=fp_rate,
                                count_noise_std=0.0, random_state=42)
    val_dataloader = DataLoader(val_dataset, batch_size=16,
                                collate_fn=lambda batch: collate_genomes(batch, pad_idx=len(global_vocab)))

    # Evaluate Pre-MLM metrics (using SetTransformer output)
    #pre_metrics = validate_per_sample_extended(set_transformer, val_dataloader, device,
    #                                             threshold=0.5, lower_thresh=-0.2, upper_thresh=0.8)
    pre_metrics = validate_per_sample_extended(set_transformer, binary_mlm, val_dataloader, device,global_vocab,label_string="FN"+repr(fn)+"_FP"+repr(fp_rate))

    for rec in pre_metrics:
        rec["FN_rate"] = fn
        rec["Stage"] = "Pre-MLM"
    # Print summary statistics for Pre-MLM stage (both "noisy" and processed metrics)
    print(f"\nSummary for FN rate {fn}, Stage Pre-MLM (Noisy Input):")
    for metric in ["noisy_accuracy", "noisy_precision", "noisy_recall", "noisy_f1", "noisy_genome_diff"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")

    print(f"\nSummary for FN rate {fn}, Stage Post-SetTransformer (Processed):")
    for metric in ["accuracy", "precision", "recall", "f1", "genome_size_diff", "fp_removed", "fn_recovered"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")

    print(f"\nSummary for FN rate {fn}, Stage Post-MLM (Processed):")
    for metric in ["MLM_accuracy", "MLM_precision", "MLM_recall", "MLM_f1", "MLM_genome_diff", "MLM_fp_removed", "MLM_fn_recovered"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")


    print("-" * 60)
    all_sample_records.extend(pre_metrics)

# Combine all records into a DataFrame and save to CSV.
df_metrics = pd.DataFrame(all_sample_records)
df_metrics.to_csv("per_sample_combined_metrics_FP"+repr(fp_rate)+".csv", index=False)
print("\nSaved per-sample combined metrics to per_sample_combined_metrics_FP"+repr(fp_rate)+".csv")



Loaded train DataFrame from COG_train1.feather and validation DataFrame from COG_val1.feather.

Evaluating for FN rate = 0.0 and FP rate = 0.05...


                                                                                                                                                                                                       


Summary for FN rate 0.0, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.9501, 25th=0.9480, Median=0.9503, 75th=0.9524
  Noisy_Precision     : Mean=0.8186, 25th=0.7998, Median=0.8363, 75th=0.8639
  Noisy_Recall        : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_F1            : Mean=0.8986, 25th=0.8888, Median=0.9109, 75th=0.9270
  Noisy_Genome_Diff   : Mean=1.2316, 25th=1.1575, Median=1.1957, 75th=1.2503

Summary for FN rate 0.0, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9301, 25th=0.9173, Median=0.9303, 75th=0.9424
  Precision           : Mean=0.8679, 25th=0.8347, Median=0.8828, 75th=0.9161
  Recall              : Mean=0.8315, 25th=0.8070, Median=0.8381, 75th=0.8646
  F1                  : Mean=0.8484, 25th=0.8196, Median=0.8563, 75th=0.8854
  Genome_Size_Diff    : Mean=0.9616, 25th=0.9218, Median=0.9475, 75th=0.9870
  Fp_Removed          : Mean=0.9568, 25th=0.9440, Median=0.9593, 75th=0.9724
  Fn_Recovered        : No valid 

                                                                                                                                                                                                       


Summary for FN rate 0.1, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.9247, 25th=0.9179, Median=0.9246, 75th=0.9313
  Noisy_Precision     : Mean=0.8029, 25th=0.7816, Median=0.8215, 75th=0.8511
  Noisy_Recall        : Mean=0.9001, 25th=0.8943, Median=0.9001, 75th=0.9059
  Noisy_F1            : Mean=0.8469, 25th=0.8368, Median=0.8586, 75th=0.8746
  Noisy_Genome_Diff   : Mean=1.1316, 25th=1.0571, Median=1.0955, 75th=1.1509

Summary for FN rate 0.1, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9298, 25th=0.9171, Median=0.9298, 75th=0.9426
  Precision           : Mean=0.8673, 25th=0.8362, Median=0.8818, 75th=0.9140
  Recall              : Mean=0.8310, 25th=0.8054, Median=0.8377, 75th=0.8651
  F1                  : Mean=0.8478, 25th=0.8197, Median=0.8557, 75th=0.8858
  Genome_Size_Diff    : Mean=0.9614, 25th=0.9233, Median=0.9493, 75th=0.9859
  Fp_Removed          : Mean=0.9558, 25th=0.9425, Median=0.9582, 75th=0.9720
  Fn_Recovered        : Mean=0.81

                                                                                                                                                                                                       


Summary for FN rate 0.25, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.8864, 25th=0.8708, Median=0.8858, 75th=0.9010
  Noisy_Precision     : Mean=0.7735, 25th=0.7501, Median=0.7925, 75th=0.8267
  Noisy_Recall        : Mean=0.7500, 25th=0.7416, Median=0.7500, 75th=0.7583
  Noisy_F1            : Mean=0.7594, 25th=0.7489, Median=0.7709, 75th=0.7872
  Noisy_Genome_Diff   : Mean=0.9815, 25th=0.9075, Median=0.9447, 75th=1.0023

Summary for FN rate 0.25, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9294, 25th=0.9165, Median=0.9292, 75th=0.9420
  Precision           : Mean=0.8662, 25th=0.8352, Median=0.8800, 75th=0.9119
  Recall              : Mean=0.8298, 25th=0.8027, Median=0.8369, 75th=0.8667
  F1                  : Mean=0.8467, 25th=0.8177, Median=0.8550, 75th=0.8852
  Genome_Size_Diff    : Mean=0.9607, 25th=0.9235, Median=0.9525, 75th=0.9871
  Fp_Removed          : Mean=0.9541, 25th=0.9412, Median=0.9562, 75th=0.9702
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.5, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.8226, 25th=0.7920, Median=0.8221, 75th=0.8509
  Noisy_Precision     : Mean=0.6977, 25th=0.6647, Median=0.7184, 75th=0.7608
  Noisy_Recall        : Mean=0.4996, 25th=0.4900, Median=0.5000, 75th=0.5095
  Noisy_F1            : Mean=0.5795, 25th=0.5674, Median=0.5893, 75th=0.6047
  Noisy_Genome_Diff   : Mean=0.7311, 25th=0.6575, Median=0.6959, 75th=0.7497

Summary for FN rate 0.5, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9260, 25th=0.9127, Median=0.9259, 75th=0.9392
  Precision           : Mean=0.8631, 25th=0.8348, Median=0.8755, 75th=0.9046
  Recall              : Mean=0.8161, 25th=0.7841, Median=0.8284, 75th=0.8656
  F1                  : Mean=0.8376, 25th=0.8095, Median=0.8481, 75th=0.8799
  Genome_Size_Diff    : Mean=0.9471, 25th=0.9122, Median=0.9512, 75th=0.9897
  Fp_Removed          : Mean=0.9478, 25th=0.9309, Median=0.9502, 75th=0.9680
  Fn_Recovered        : Mean=0.80

                                                                                                                                                                                                       


Summary for FN rate 0.75, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7590, 25th=0.7137, Median=0.7585, 75th=0.8012
  Noisy_Precision     : Mean=0.5429, 25th=0.4967, Median=0.5596, 75th=0.6143
  Noisy_Recall        : Mean=0.2503, 25th=0.2417, Median=0.2500, 75th=0.2588
  Noisy_F1            : Mean=0.3396, 25th=0.3273, Median=0.3458, 75th=0.3584
  Noisy_Genome_Diff   : Mean=0.4818, 25th=0.4085, Median=0.4456, 75th=0.5008

Summary for FN rate 0.75, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.8891, 25th=0.8682, Median=0.8906, 75th=0.9095
  Precision           : Mean=0.8355, 25th=0.7983, Median=0.8466, 75th=0.8854
  Recall              : Mean=0.6578, 25th=0.5650, Median=0.6996, 75th=0.7908
  F1                  : Mean=0.7220, 25th=0.6715, Median=0.7601, 75th=0.8166
  Genome_Size_Diff    : Mean=0.7882, 25th=0.6667, Median=0.8303, 75th=0.9337
  Fp_Removed          : Mean=0.9279, 25th=0.9017, Median=0.9347, 75th=0.9622
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.85, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7334, 25th=0.6818, Median=0.7327, 75th=0.7814
  Noisy_Precision     : Mean=0.4204, 25th=0.3715, Median=0.4325, 75th=0.4883
  Noisy_Recall        : Mean=0.1500, 25th=0.1430, Median=0.1500, 75th=0.1570
  Noisy_F1            : Mean=0.2185, 25th=0.2076, Median=0.2219, 75th=0.2332
  Noisy_Genome_Diff   : Mean=0.3816, 25th=0.3080, Median=0.3454, 75th=0.4015

Summary for FN rate 0.85, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.8179, 25th=0.7856, Median=0.8178, 75th=0.8490
  Precision           : Mean=0.7450, 25th=0.6813, Median=0.7603, 75th=0.8261
  Recall              : Mean=0.3856, 25th=0.2349, Median=0.3905, 75th=0.5301
  F1                  : Mean=0.4835, 25th=0.3544, Median=0.5096, 75th=0.6175
  Genome_Size_Diff    : Mean=0.5197, 25th=0.3255, Median=0.5076, 75th=0.6940
  Fp_Removed          : Mean=0.9201, 25th=0.8915, Median=0.9310, 75th=0.9583
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.95, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7079, 25th=0.6505, Median=0.7075, 75th=0.7617
  Noisy_Precision     : Mean=0.1996, 25th=0.1618, Median=0.2026, 75th=0.2439
  Noisy_Recall        : Mean=0.0501, 25th=0.0455, Median=0.0500, 75th=0.0544
  Noisy_F1            : Mean=0.0787, 25th=0.0716, Median=0.0795, 75th=0.0867
  Noisy_Genome_Diff   : Mean=0.2816, 25th=0.2077, Median=0.2466, 75th=0.2998

Summary for FN rate 0.95, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.7744, 25th=0.7347, Median=0.7806, 75th=0.8194
  Precision           : Mean=0.6691, 25th=0.5704, Median=0.6880, 75th=0.7919
  Recall              : Mean=0.2555, 25th=0.1326, Median=0.2293, 75th=0.3578
  F1                  : Mean=0.3337, 25th=0.2204, Median=0.3339, 75th=0.4523
  Genome_Size_Diff    : Mean=0.4308, 25th=0.1876, Median=0.3385, 75th=0.5643
  Fp_Removed          : Mean=0.9327, 25th=0.9042, Median=0.9438, 75th=0.9706
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 1.0, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.6952, 25th=0.6341, Median=0.6945, 75th=0.7515
  Noisy_Precision     : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_Recall        : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_F1            : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_Genome_Diff   : Mean=0.2316, 25th=0.1575, Median=0.1957, 75th=0.2503

Summary for FN rate 1.0, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.7792, 25th=0.7438, Median=0.7895, 75th=0.8215
  Precision           : Mean=0.6929, 25th=0.5827, Median=0.7243, 75th=0.8393
  Recall              : Mean=0.3015, 25th=0.1871, Median=0.2815, 75th=0.4012
  F1                  : Mean=0.3817, 25th=0.2913, Median=0.3942, 75th=0.4764
  Genome_Size_Diff    : Mean=0.5180, 25th=0.2437, Median=0.4023, 75th=0.6603
  Fp_Removed          : Mean=0.9351, 25th=0.9061, Median=0.9478, 75th=0.9764
  Fn_Recovered        : Mean=0.30

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load DataFrames and vocabulary
train_df, val_df = fast_load_dfs('COG_train1.feather', 'COG_val1.feather', frac=0.1)
global_vocab, cog2idx = get_global_vocab_and_cog2idx_from_df(train_df)

# Load pretrained SetTransformer
set_transformer = GenomeSetTransformer(vocab_size=len(global_vocab), d_model=1024,
                                         num_heads=8, num_sab=4, dropout=0.1)
state_dict = torch.load('COG_high_1024_8_4_BCE_60.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
set_transformer.load_state_dict(new_state_dict)
#set_transformer.load_state_dict(torch.load('model_checkpoint_full.pth'))
set_transformer = set_transformer.to(device)

# Load pretrained BinaryMLM
binary_mlm = BinaryMLMModel(vocab_size=3, embed_dim=512, num_layers=6,
                            num_heads=8, dropout=0.1, max_seq_len=len(global_vocab))
state_dict = torch.load('binMLM_512_6_8_01_e50.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith("module.") else k
    new_state_dict[name] = v
binary_mlm.load_state_dict(new_state_dict)
binary_mlm = binary_mlm.to(device)

# Define FN rates to test and fixed FP rate
fn_rates = [0.0, 0.1, 0.25, 0.5, 0.75, 0.85,0.95,1.0]
fp_rate = 0.1

all_sample_records = []  # Will collect records from both stages
# We'll print summary statistics for Pre-MLM separately for "noisy" (do nothing) metrics and processed ones.

for fn in fn_rates:
    print(f"\nEvaluating for FN rate = {fn} and FP rate = {fp_rate}...")
    val_dataset = GenomeDataset(val_df, global_vocab=global_vocab, cog2idx=cog2idx,
                                false_negative_rate=fn, false_positive_rate=fp_rate,
                                count_noise_std=0.0, random_state=42)
    val_dataloader = DataLoader(val_dataset, batch_size=16,
                                collate_fn=lambda batch: collate_genomes(batch, pad_idx=len(global_vocab)))

    # Evaluate Pre-MLM metrics (using SetTransformer output)
    #pre_metrics = validate_per_sample_extended(set_transformer, val_dataloader, device,
    #                                             threshold=0.5, lower_thresh=-0.2, upper_thresh=0.8)
    pre_metrics = validate_per_sample_extended(set_transformer, binary_mlm, val_dataloader, device,global_vocab,label_string="FN"+repr(fn)+"_FP"+repr(fp_rate))

    for rec in pre_metrics:
        rec["FN_rate"] = fn
        rec["Stage"] = "Pre-MLM"
    # Print summary statistics for Pre-MLM stage (both "noisy" and processed metrics)
    print(f"\nSummary for FN rate {fn}, Stage Pre-MLM (Noisy Input):")
    for metric in ["noisy_accuracy", "noisy_precision", "noisy_recall", "noisy_f1", "noisy_genome_diff"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")

    print(f"\nSummary for FN rate {fn}, Stage Post-SetTransformer (Processed):")
    for metric in ["accuracy", "precision", "recall", "f1", "genome_size_diff", "fp_removed", "fn_recovered"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")

    print(f"\nSummary for FN rate {fn}, Stage Post-MLM (Processed):")
    for metric in ["MLM_accuracy", "MLM_precision", "MLM_recall", "MLM_f1", "MLM_genome_diff", "MLM_fp_removed", "MLM_fn_recovered"]:
        vals = [r[metric] for r in pre_metrics if r.get(metric) is not None and not np.isnan(r[metric])]
        if len(vals) > 0:
            mean_val = np.mean(vals)
            q25 = np.nanpercentile(vals, 25)
            median_val = np.nanpercentile(vals, 50)
            q75 = np.nanpercentile(vals, 75)
            print(f"  {metric.title():<20}: Mean={mean_val:.4f}, 25th={q25:.4f}, Median={median_val:.4f}, 75th={q75:.4f}")
        else:
            print(f"  {metric.title():<20}: No valid values")


    print("-" * 60)
    all_sample_records.extend(pre_metrics)

# Combine all records into a DataFrame and save to CSV.
df_metrics = pd.DataFrame(all_sample_records)
df_metrics.to_csv("per_sample_combined_metrics_FP"+repr(fp_rate)+".csv", index=False)
print("\nSaved per-sample combined metrics to per_sample_combined_metrics_FP"+repr(fp_rate)+".csv")



Loaded train DataFrame from COG_train1.feather and validation DataFrame from COG_val1.feather.

Evaluating for FN rate = 0.0 and FP rate = 0.1...


                                                                                                                                                                                                       


Summary for FN rate 0.0, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.9000, 25th=0.8971, Median=0.9000, 75th=0.9029
  Noisy_Precision     : Mean=0.6974, 25th=0.6646, Median=0.7189, 75th=0.7601
  Noisy_Recall        : Mean=1.0000, 25th=1.0000, Median=1.0000, 75th=1.0000
  Noisy_F1            : Mean=0.8181, 25th=0.7985, Median=0.8364, 75th=0.8637
  Noisy_Genome_Diff   : Mean=1.4647, 25th=1.3156, Median=1.3911, 75th=1.5047

Summary for FN rate 0.0, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9305, 25th=0.9177, Median=0.9305, 75th=0.9432
  Precision           : Mean=0.8613, 25th=0.8352, Median=0.8739, 75th=0.9006
  Recall              : Mean=0.8405, 25th=0.8138, Median=0.8531, 75th=0.8859
  F1                  : Mean=0.8498, 25th=0.8223, Median=0.8597, 75th=0.8890
  Genome_Size_Diff    : Mean=0.9774, 25th=0.9483, Median=0.9802, 75th=1.0094
  Fp_Removed          : Mean=0.9508, 25th=0.9367, Median=0.9523, 75th=0.9671
  Fn_Recovered        : No valid 

                                                                                                                                                                                                       


Summary for FN rate 0.1, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.8745, 25th=0.8674, Median=0.8743, 75th=0.8814
  Noisy_Precision     : Mean=0.6756, 25th=0.6410, Median=0.6971, 75th=0.7410
  Noisy_Recall        : Mean=0.9000, 25th=0.8940, Median=0.9000, 75th=0.9058
  Noisy_F1            : Mean=0.7679, 25th=0.7484, Median=0.7852, 75th=0.8119
  Noisy_Genome_Diff   : Mean=1.3646, 25th=1.2158, Median=1.2903, 75th=1.4068

Summary for FN rate 0.1, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9295, 25th=0.9167, Median=0.9292, 75th=0.9428
  Precision           : Mean=0.8601, 25th=0.8347, Median=0.8714, 75th=0.8978
  Recall              : Mean=0.8373, 25th=0.8087, Median=0.8530, 75th=0.8881
  F1                  : Mean=0.8471, 25th=0.8202, Median=0.8579, 75th=0.8877
  Genome_Size_Diff    : Mean=0.9747, 25th=0.9444, Median=0.9823, 75th=1.0154
  Fp_Removed          : Mean=0.9489, 25th=0.9336, Median=0.9507, 75th=0.9664
  Fn_Recovered        : Mean=0.82

                                                                                                                                                                                                       


Summary for FN rate 0.25, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.8362, 25th=0.8198, Median=0.8359, 75th=0.8507
  Noisy_Precision     : Mean=0.6359, 25th=0.5983, Median=0.6571, 75th=0.7034
  Noisy_Recall        : Mean=0.7496, 25th=0.7413, Median=0.7492, 75th=0.7582
  Noisy_F1            : Mean=0.6838, 25th=0.6649, Median=0.6996, 75th=0.7249
  Noisy_Genome_Diff   : Mean=1.2143, 25th=1.0637, Median=1.1408, 75th=1.2546

Summary for FN rate 0.25, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9265, 25th=0.9131, Median=0.9265, 75th=0.9403
  Precision           : Mean=0.8577, 25th=0.8314, Median=0.8675, 75th=0.8951
  Recall              : Mean=0.8247, 25th=0.7924, Median=0.8490, 75th=0.8873
  F1                  : Mean=0.8382, 25th=0.8124, Median=0.8528, 75th=0.8839
  Genome_Size_Diff    : Mean=0.9627, 25th=0.9267, Median=0.9817, 75th=1.0246
  Fp_Removed          : Mean=0.9452, 25th=0.9264, Median=0.9475, 75th=0.9662
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.5, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7723, 25th=0.7407, Median=0.7726, 75th=0.8008
  Noisy_Precision     : Mean=0.5416, 25th=0.4975, Median=0.5619, 75th=0.6115
  Noisy_Recall        : Mean=0.4991, 25th=0.4896, Median=0.4992, 75th=0.5085
  Noisy_F1            : Mean=0.5145, 25th=0.4970, Median=0.5281, 75th=0.5501
  Noisy_Genome_Diff   : Mean=0.9637, 25th=0.8134, Median=0.8921, 75th=1.0027

Summary for FN rate 0.5, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.9032, 25th=0.8864, Median=0.9052, 75th=0.9213
  Precision           : Mean=0.8420, 25th=0.8076, Median=0.8497, 75th=0.8848
  Recall              : Mean=0.7199, 25th=0.6399, Median=0.7766, 75th=0.8570
  F1                  : Mean=0.7604, 25th=0.7248, Median=0.8040, 75th=0.8494
  Genome_Size_Diff    : Mean=0.8580, 25th=0.7477, Median=0.9164, 75th=1.0123
  Fp_Removed          : Mean=0.9343, 25th=0.9085, Median=0.9404, 75th=0.9691
  Fn_Recovered        : Mean=0.69

                                                                                                                                                                                                       


Summary for FN rate 0.75, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.7087, 25th=0.6624, Median=0.7087, 75th=0.7513
  Noisy_Precision     : Mean=0.3772, 25th=0.3305, Median=0.3882, 75th=0.4413
  Noisy_Recall        : Mean=0.2495, 25th=0.2406, Median=0.2493, 75th=0.2579
  Noisy_F1            : Mean=0.2957, 25th=0.2816, Median=0.3035, 75th=0.3200
  Noisy_Genome_Diff   : Mean=0.7142, 25th=0.5655, Median=0.6410, 75th=0.7530

Summary for FN rate 0.75, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.7885, 25th=0.7517, Median=0.7841, 75th=0.8213
  Precision           : Mean=0.6824, 25th=0.6148, Median=0.6939, 75th=0.7700
  Recall              : Mean=0.2557, 25th=0.1313, Median=0.2277, 75th=0.3553
  F1                  : Mean=0.3496, 25th=0.2171, Median=0.3414, 75th=0.4711
  Genome_Size_Diff    : Mean=0.3710, 25th=0.2055, Median=0.3274, 75th=0.5023
  Fp_Removed          : Mean=0.9401, 25th=0.9199, Median=0.9501, 75th=0.9717
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.85, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.6833, 25th=0.6312, Median=0.6831, 75th=0.7314
  Noisy_Precision     : Mean=0.2695, 25th=0.2270, Median=0.2756, 75th=0.3226
  Noisy_Recall        : Mean=0.1498, 25th=0.1429, Median=0.1495, 75th=0.1568
  Noisy_F1            : Mean=0.1887, 25th=0.1770, Median=0.1930, 75th=0.2063
  Noisy_Genome_Diff   : Mean=0.6144, 25th=0.4668, Median=0.5401, 75th=0.6544

Summary for FN rate 0.85, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.7572, 25th=0.7093, Median=0.7564, 75th=0.8025
  Precision           : Mean=0.5729, 25th=0.4699, Median=0.5838, 75th=0.6839
  Recall              : Mean=0.1357, 25th=0.0616, Median=0.1093, 75th=0.1868
  F1                  : Mean=0.2037, 25th=0.1100, Median=0.1821, 75th=0.2827
  Genome_Size_Diff    : Mean=0.2460, 25th=0.1137, Median=0.1925, 75th=0.3235
  Fp_Removed          : Mean=0.9537, 25th=0.9377, Median=0.9611, 75th=0.9764
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 0.95, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.6579, 25th=0.5993, Median=0.6577, 75th=0.7108
  Noisy_Precision     : Mean=0.1118, 25th=0.0878, Median=0.1123, 75th=0.1379
  Noisy_Recall        : Mean=0.0501, 25th=0.0456, Median=0.0498, 75th=0.0543
  Noisy_F1            : Mean=0.0674, 25th=0.0608, Median=0.0683, 75th=0.0752
  Noisy_Genome_Diff   : Mean=0.5147, 25th=0.3654, Median=0.4409, 75th=0.5549

Summary for FN rate 0.95, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.7589, 25th=0.7167, Median=0.7634, 75th=0.8024
  Precision           : Mean=0.6042, 25th=0.4719, Median=0.6207, 75th=0.7606
  Recall              : Mean=0.1695, 25th=0.0846, Median=0.1445, 75th=0.2324
  F1                  : Mean=0.2404, 25th=0.1435, Median=0.2309, 75th=0.3263
  Genome_Size_Diff    : Mean=0.3226, 25th=0.1371, Median=0.2390, 75th=0.4103
  Fp_Removed          : Mean=0.9535, 25th=0.9362, Median=0.9633, 75th=0.9809
  Fn_Recovered        : Mean=0.

                                                                                                                                                                                                       


Summary for FN rate 1.0, Stage Pre-MLM (Noisy Input):
  Noisy_Accuracy      : Mean=0.6451, 25th=0.5840, Median=0.6449, 75th=0.7008
  Noisy_Precision     : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_Recall        : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_F1            : Mean=0.0000, 25th=0.0000, Median=0.0000, 75th=0.0000
  Noisy_Genome_Diff   : Mean=0.4647, 25th=0.3156, Median=0.3911, 75th=0.5047

Summary for FN rate 1.0, Stage Post-SetTransformer (Processed):
  Accuracy            : Mean=0.7655, 25th=0.7271, Median=0.7731, 75th=0.8083
  Precision           : Mean=0.6484, 25th=0.5174, Median=0.6770, 75th=0.8103
  Recall              : Mean=0.2106, 25th=0.1261, Median=0.1898, 75th=0.2823
  F1                  : Mean=0.2910, 25th=0.2073, Median=0.2881, 75th=0.3752
  Genome_Size_Diff    : Mean=0.3860, 25th=0.1799, Median=0.2912, 75th=0.4835
  Fp_Removed          : Mean=0.9526, 25th=0.9349, Median=0.9628, 75th=0.9816
  Fn_Recovered        : Mean=0.21