# 检查生成的 vCross Corpus 结构
本 Notebook 用于加载并验证生成的 `ResMicroDB_90338_vCross.pkl` 文件，确保所有数据结构（Input IDs, Value IDs, Condition IDs）均正确实现。

In [14]:
import pickle
import torch
import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt
from MiCoGPT.utils_vCross.corpus_vCross import MiCoGPTCorpusVCross

# 路径设置 (请根据实际情况修改)
corpus_path = "../data/vCorss/ResMicroDB_90338_vCross.pkl"
encoder_path = "../data/vCorss/meta_encoders.joblib"

print(f"Loading corpus from {corpus_path}...")
with open(corpus_path, "rb") as f:
    corpus = pickle.load(f)
    
print(f"Corpus loaded successfully. Total samples: {len(corpus)}")

Loading corpus from ../data/vCorss/ResMicroDB_90338_vCross.pkl...
Corpus loaded successfully. Total samples: 90338


## 1. 基础属性检查

In [15]:
print(f"Attributes check:")
print(f"- num_bins: {corpus.num_bins}")
print(f"- log1p: {corpus.log1p}")
print(f"- normalize_total: {corpus.normalize_total} (If None, means adaptive)")
print(f"- max_len: {corpus.max_len}")
print(f"- use_meta_cols: {corpus.use_meta_cols}")
print(f"- Metadata Encoders keys: {list(corpus.meta_encoders.keys())}")

Attributes check:
- num_bins: 51
- log1p: True
- normalize_total: None (If None, means adaptive)
- max_len: 512
- use_meta_cols: ['Sample_Site']
- Metadata Encoders keys: ['Sample_Site']


## 2. 样本数据结构深度检查
我们将检查以下关键点：
1. **BOS/EOS 完整性**: 每个样本必须以 `<bos>` 开头，以 `<eos>` 结尾（除非被截断）。
2. **多模态对齐**: Input IDs (Taxon), Value IDs (Bin), Condition IDs (Meta) 必须严格对应。
3. **典型样本展示**: 选取最短、最长、以及前三个样本进行详细 Token 级展示。

In [16]:
def inspect_sample(idx, corpus, label="Sample"):
    sample_id = corpus.sample_ids[idx]
    input_ids = corpus.input_ids[idx]
    value_ids = corpus.value_ids[idx]
    cond_ids = corpus.condition_ids[idx]
    
    # 获取实际长度（去除 padding）
    # padding token id 通常是 0 或者 tokenizer.pad_token_id
    # 这里我们通过 attention mask 或者直接找 pad token
    pad_token_id = corpus.tokenizer.pad_token_id
    if pad_token_id is None: pad_token_id = 0
    
    non_pad_mask = (input_ids != pad_token_id)
    actual_len = non_pad_mask.sum().item()
    
    print(f"\n{'='*20} {label} (Index: {idx}, ID: {sample_id}) {'='*20}")
    print(f"Total Length (with pad): {len(input_ids)}")
    print(f"Actual Length (tokens):  {actual_len}")
    
    # 1. BOS/EOS Check
    bos_id = corpus.tokenizer.bos_token_id
    eos_id = corpus.tokenizer.eos_token_id
    
    has_bos = (input_ids[0].item() == bos_id)
    # EOS 应该在 actual_len - 1 的位置 (0-based index)
    # 如果样本被截断 (actual_len == max_len)，可能没有 EOS？或者 corpus 处理时保留了？
    # vCross 代码逻辑：input_ids[:self.max_len-1] + [input_ids[-1]] 
    # 如果 input_ids[-1] 原本就是 eos，那么截断后最后一个也是 eos。
    # 如果原序列超长，最后一个 token 强制变为 eos 吗？
    # vCross: sent = ['<bos>'] + taxa + ['<eos>']
    # Truncate: input_ids[:max_len-1] + [input_ids[-1]] -> 这里 input_ids[-1] 是 eos_id。
    # 所以理论上最后一个有效 token 必须是 EOS。
    last_token = input_ids[actual_len-1].item()
    has_eos = (last_token == eos_id)
    
    print(f"Structure Check:")
    print(f"  - Starts with BOS? {has_bos} (ID: {bos_id})")
    print(f"  - Ends with EOS?   {has_eos} (ID: {eos_id}, Found: {last_token})")
    
    # 2. Token-Level Detail (Show first 5 and last 5 valid tokens)
    print(f"Token Details (Taxon ID | Bin ID):")
    
    # Helper to print token
    def print_tok(i):
        tid = input_ids[i].item()
        vid = value_ids[i].item()
        try:
            # 尝试多种解码方式
            if hasattr(corpus.tokenizer.vocab, 'lookup_token'):
                token_str = corpus.tokenizer.vocab.lookup_token(tid)
            elif hasattr(corpus.tokenizer.vocab, 'itos'):
                token_str = corpus.tokenizer.vocab.itos[tid]
            elif hasattr(corpus.tokenizer, 'decode'):
                token_str = corpus.tokenizer.decode([tid])
            else:
                token_str = str(tid)
        except:
            token_str = "???"
        print(f"  [{i:3d}] {token_str:<20} (ID: {tid:5d}) | Bin: {vid:3d}")

    print("  --- Start ---")
    for i in range(min(5, actual_len)):
        print_tok(i)
        
    if actual_len > 10:
        print("  ... ...")
        
    print("  --- End ---")
    for i in range(max(0, actual_len-5), actual_len):
        print_tok(i)
        
    # 3. Metadata Check
    print(f"Metadata Conditions:")
    if corpus.use_meta_cols:
        original_meta = corpus.metadata.loc[sample_id]
        for k, col in enumerate(corpus.use_meta_cols):
            code = cond_ids[k].item()
            decoded = corpus.meta_encoders[col].inverse_transform([code])[0]
            original = str(original_meta[col])
            match = "✅" if (decoded == original or (original=='nan' and 'Unknown' in decoded)) else "❌"
            print(f"  - {col}: Code={code} -> '{decoded}' (Org: '{original}') {match}")
    else:
        print("  (No metadata used)")

# 找出最长和最短样本
lengths = []
pad_id = corpus.tokenizer.pad_token_id if corpus.tokenizer.pad_token_id is not None else 0
for i in range(len(corpus)):
    l = (corpus.input_ids[i] != pad_id).sum().item()
    lengths.append(l)

max_idx = np.argmax(lengths)
min_idx = np.argmin(lengths)

# 展示
inspect_sample(0, corpus, "First Sample")
inspect_sample(1, corpus, "Second Sample")
inspect_sample(2, corpus, "Third Sample")
inspect_sample(max_idx, corpus, f"Longest Sample (Len={lengths[max_idx]})")
inspect_sample(min_idx, corpus, f"Shortest Sample (Len={lengths[min_idx]})")


Total Length (with pad): 512
Actual Length (tokens):  132
Structure Check:
  - Starts with BOS? True (ID: 2)
  - Ends with EOS?   True (ID: 3, Found: 3)
Token Details (Taxon ID | Bin ID):
  --- Start ---
  [  0] <bos>                (ID:     2) | Bin:   0
  [  1] g__Staphylococcus    (ID:   359) | Bin:  51
  [  2] g__Mycoplasma        (ID:   476) | Bin:  50
  [  3] g__Bacteroides       (ID:   370) | Bin:  50
  [  4] g__Corynebacterium   (ID:   374) | Bin:  49
  ... ...
  --- End ---
  [127] g__Flavobacterium    (ID:   507) | Bin:   5
  [128] g__TM7a              (ID:   472) | Bin:   9
  [129] g__Candidatus_Kaiserbacteria (ID:   450) | Bin:   6
  [130] g__Vibrio            (ID:   590) | Bin:   7
  [131] <eos>                (ID:     3) | Bin:   0
Metadata Conditions:
  - Sample_Site: Code=0 -> 'BALF' (Org: 'BALF') ✅

Total Length (with pad): 512
Actual Length (tokens):  144
Structure Check:
  - Starts with BOS? True (ID: 2)
  - Ends with EOS?   True (ID: 3, Found: 3)
Token Details (Tax

## 3. Value IDs (Binning) 分布统计
如果不画图，我们需要详细的统计数据来确认分布的合理性。

In [18]:
# 聚合所有样本的 Value IDs (排除 0 padding)
all_values = corpus.value_ids.flatten()
non_zero_values = all_values[all_values > 0].numpy()

# 计算各 Bin 的频次
bin_counts = pd.Series(non_zero_values).value_counts().sort_index()

print(f"Binning Statistics (Total {len(non_zero_values)} tokens):")
print(f"- Min Bin ID: {non_zero_values.min()} (Expected: 1)")
print(f"- Max Bin ID: {non_zero_values.max()} (Expected: {corpus.num_bins})")
print(f"- Coverage: {len(bin_counts)}/{corpus.num_bins} bins used")

# 展示分布的百分比 (Quantiles)
print("\nBin Distribution Percentiles:")
percentiles = np.percentile(non_zero_values, [0, 25, 50, 75, 100])
print(f"  0% (Min):   {percentiles[0]}")
print(f"  25% (Q1):   {percentiles[1]}")
print(f"  50% (Med):  {percentiles[2]}")
print(f"  75% (Q3):   {percentiles[3]}")
print(f"  100% (Max): {percentiles[4]}")

print("\nDetailed Counts per Bin (First 10 & Last 10):")
print(bin_counts.head(51))


# 检查是否存在极度不均衡（例如某个 Bin 只有几个 Token）
min_count = bin_counts.min()
max_count = bin_counts.max()
print(f"\nImbalance Check:")
print(f"- Least frequent bin count: {min_count}")
print(f"- Most frequent bin count:  {max_count}")
print(f"- Ratio (Max/Min): {max_count/min_count:.2f}")

Binning Statistics (Total 5012551 tokens):
- Min Bin ID: 2 (Expected: 1)
- Max Bin ID: 51 (Expected: 51)
- Coverage: 50/51 bins used

Bin Distribution Percentiles:
  0% (Min):   2.0
  25% (Q1):   14.0
  50% (Med):  26.0
  75% (Q3):   39.0
  100% (Max): 51.0

Detailed Counts per Bin (First 10 & Last 10):
2     110359
3     102316
4     101660
5     100216
6     100136
7      98807
8      97830
9     102088
10     99199
11     99652
12     99545
13     99451
14    101073
15     97837
16    104884
17    100188
18    101591
19    100864
20    101780
21    100951
22     98297
23    106217
24    101555
25    101719
26    103048
27    100361
28    102618
29     97110
30    107197
31    101956
32    101632
33    102120
34    102141
35    101481
36     96919
37    107265
38    101742
39    101621
40    101505
41    101651
42    101502
43     95434
44    106778
45    100814
46    100954
47    100939
48    100832
49    100312
50     56069
51     90335
Name: count, dtype: int64

Imbalance Check:
-

## 4. Metadata Condition IDs 检查 (Skipped)
已集成在样本深度检查中。

## 5. Ranking 逻辑验证
验证 Input IDs 对应的 Token 序列，是否真的是按照 Value IDs (Bin) 的大小顺序排列的（大致单调递减）。

In [19]:
input_ids = corpus.input_ids[idx]
value_ids = corpus.value_ids[idx]

# 过滤掉 padding 和 special tokens (0)
valid_mask = (value_ids > 0)
valid_bins = value_ids[valid_mask].numpy()
valid_tokens = input_ids[valid_mask].numpy()

print(f"Valid Token Sequence (first 20): {valid_tokens[:20]}")
print(f"Corresponding Bins (first 20):   {valid_bins[:20]}")

# 检查 Bin 是否大致递减
# 注意：Binning 是基于数值的，且我们之前对样本进行了 sort_values(ascending=False)
# 所以 Bin ID 应该也是总体呈现从大到小（允许局部相等）。
is_sorted = np.all(valid_bins[:-1] >= valid_bins[1:])
print(f"\nAre Bins monotonically non-increasing? {is_sorted}")

if not is_sorted:
    print("Note: Strict monotonicity might be broken slightly due to random noise in binning edge cases, but overall trend should be decreasing.")
    print("Explanation: This is expected behavior with Seeded Random Noise Binning. Tied values (e.g. low counts) are randomly distributed across adjacent bins to preserve magnitude information.")
    # plt.plot(valid_bins)
    # plt.title("Bin Values along Sequence Position")
    # plt.ylabel("Bin ID")
    # plt.xlabel("Token Position")
    # plt.show()

Valid Token Sequence (first 20): [359 476 370 374 712 748 564 345 703   4 301 405 540 634  65 526 721 484
 677 399]
Corresponding Bins (first 20):   [51 50 50 49 49 49 48 48 47 47 47 46 46 46 45 45 44 44 44 43]

Are Bins monotonically non-increasing? False
Note: Strict monotonicity might be broken slightly due to random noise in binning edge cases, but overall trend should be decreasing.
Explanation: This is expected behavior with Seeded Random Noise Binning. Tied values (e.g. low counts) are randomly distributed across adjacent bins to preserve magnitude information.
