# Explore threads metadata

Threads metadata has been added to the threads TreeSequence with the `annotate_tree`
script. Let's explore data and check if everything is fine.

In [None]:
import itertools

import cyvcf2
import tszip
import pandas as pd
import numpy as np

from tskitetude import get_data_dir, get_project_dir

Let's load information on sample names:

In [None]:
sample_info = pd.read_csv(
    get_data_dir() / "toInfer/tsm100M300I.sample_names.txt",
    sep="\t",
    header=None,
    names=["population", "individual"]
)
sample_info.head()

Now open the threads object:

:::{warning}
where are opening the fitted version of the trees, since they contain variant infomation
:::

In [None]:
ts_threads = tszip.load(get_project_dir() / "results-threads/toInfer-fit/threads/ts300I25k.1.trees.tsz")
ts_threads

Explore the metadata of the first 5 nodes:

In [None]:
for node in itertools.islice(ts_threads.nodes(), 5):
    print(f"Node {node.id}: population={ts_threads.population(node.population)}, individual={ts_threads.individual(node.individual)}")

Ok, seems that metadata follow the sample order in vcf file. Let's prove it by checking
the sample names in `ts_threads.individuals()`:

In [None]:
print(f"Num. samples: {ts_threads.num_samples}")
print(f"Num. individuals: {ts_threads.num_individuals}")

sample2nodes = {}

for i, individual in enumerate(ts_threads.individuals()):
    if i < 5:
        print(f"Individual {individual.id}: sample='{individual.metadata["sample_id"]}', nodes={individual.nodes.tolist()}")

    sample2nodes[individual.metadata["sample_id"]] = individual.nodes.tolist()

Let's discover individual breeds using their node information:

In [None]:
sample2breed = {}

for sample_id, nodes in sample2nodes.items():
    for node in nodes:
        sample2breed.setdefault(
            sample_id, set()
        ).add(ts_threads.population(ts_threads.node(node).population).metadata["breed"])


Test that sample and breed mapping is correct:

In [None]:
# test that sample and breed mapping is correct:
assert all(
    sample_info.loc[sample_info["individual"] == sample_id, "population"].values[0] in breed_set
    for sample_id, breed_set in sample2breed.items()
), "Sample to breed mapping is incorrect!"


Now test that the genotypes are consistent with sample names: using VCF as a reference,
read the `focal` sample generated through the pipeline:

In [None]:
vcf_file = get_project_dir() / "results-threads/toInfer/focal/ts300I25k.1.vcf.gz"
with cyvcf2.VCF(vcf_file) as vcf_reader:
    print(vcf_reader.samples[:5])  # show first 5 sample names in VCF

Check that the number of variant is the same:

In [None]:
# count number of variants in VCF file
num_variants = sum(1 for _ in cyvcf2.VCF(vcf_file))
print(f"N. of variants in VCF: {num_variants}")

# count number of variants in TS file
num_ts_variants = ts_threads.num_sites
print(f"N. of variants TreeSequence: {num_ts_variants}")

# check variant size
if num_variants == num_ts_variants:
    print(f"✓ N. of variants matches (between TS and VCF)!")
else:
    print(f"✗ Warning: VCF has {num_variants} variants, TreeSequence has {num_ts_variants}")

Ok, positions are not the same: for what I saw, first position in `ts_threads` is 0,
so let's read first position on vcf file and determine the offset of *TS* object:

In [None]:
with cyvcf2.VCF(vcf_file) as vcf_reader:
    variant = next(vcf_reader)
    offset = variant.POS

# translate the TS positions
ts_positions = np.array([int(site.position) + offset for site in ts_threads.sites()], dtype=np.int64)
print(ts_positions[:10])  # show first 10 positions

# read positions from VCF
with cyvcf2.VCF(vcf_file) as vcf_reader:
    vcf_positions = np.array([variant.POS for variant in vcf_reader])
print(vcf_positions[:10])  # show first 10 positions

In [None]:
# Find which variants are missing from the TreeSequence
# Get positions from both VCF and TreeSequence
print(f"TreeSequence positions: {len(ts_positions)} variants")
print(f"VCF positions: {len(vcf_positions)} variants")

print("\nExtracting variant positions from VCF...")

# collect information on variants
vcf_variants_info = []

with cyvcf2.VCF(vcf_file) as vcf_reader:
    for variant in vcf_reader:
        vcf_variants_info.append({
            'chrom': variant.CHROM,
            'pos': variant.POS,
            'ref': variant.REF,
            'alt': ','.join(variant.ALT) if variant.ALT else '',
            'qual': variant.QUAL,
            'filter': variant.FILTER if variant.FILTER else 'PASS',
        })

# create a dataframe from variant information
vcf_variants_df = pd.DataFrame(vcf_variants_info)

# Find missing positions (in VCF but not in TreeSequence)
missing_in_ts = np.setdiff1d(vcf_positions, ts_positions)
print(f"\nVariants in VCF but missing in TreeSequence: {len(missing_in_ts)}")

if len(missing_in_ts) > 0:
    # filter out missing variants
    missing_variants_df = vcf_variants_df[vcf_variants_df['pos'].isin(missing_in_ts)].copy()

    genotype_stats = []

    with cyvcf2.VCF(vcf_file) as vcf_reader:
        for variant in vcf_reader:
            if variant.POS in missing_in_ts:
                gts = np.array([gt[:2] for gt in variant.genotypes])
                unique_alleles = np.unique(gts[gts >= 0])  # Exclude missing (-1)

                genotype_stats.append({
                    'pos': variant.POS,
                    'unique_alleles': len(unique_alleles),
                    'is_monomorphic': len(unique_alleles) == 1,
                    'missing_gts': np.sum(gts == -1),
                    'total_gts': gts.size,
                    'allele_counts': dict(zip(*np.unique(gts.flatten(), return_counts=True)))
                })

    genotype_stats_df = pd.DataFrame(genotype_stats)

    # join information
    missing_variants_full = missing_variants_df.merge(genotype_stats_df, on='pos')

    # uncomment to save data
    # missing_variants_full.to_csv('missing_variants.csv', index=False)

else:
    print("\n✓ All VCF variants are present in TreeSequence")

# Also check if TreeSequence has any variants not in VCF (unlikely but possible)
extra_in_ts = np.setdiff1d(ts_positions, vcf_positions)
if len(extra_in_ts) > 0:
    extra_df = pd.DataFrame({'pos': extra_in_ts})

Let's explore those variants missing in the threads TreeSequence:

In [None]:
missing_variants_full

In [None]:
# Count monomorphic variants
monomorphic_count = missing_variants_full['is_monomorphic'].sum()
print(f"Number of monomorphic variants: {monomorphic_count}")

# Group by unique_alleles and count
allele_counts = missing_variants_full.groupby('unique_alleles').size()
print(f"\nCounts by number of unique alleles:")
print(allele_counts)

# More detailed breakdown
print(f"\nDetailed breakdown:")
for n_alleles, count in allele_counts.items():
    percentage = 100 * count / len(missing_variants_full)
    print(f"  {n_alleles} unique allele(s): {count} variants ({percentage:.1f}%)")

So the missing variants are monomorphic SNPs (no alternates) and SNPs with more
than 2 alleles. What about the SNP with two alleles? let's check it:

In [None]:
missing_variants_full[missing_variants_full["unique_alleles"] == 2]

Ok, is a SNP with tree alleles, but no allele reference is found in samples, so 
this is why has 2 unique alleles instead of 3. Let's determine the position to skip
in order to check genotypes:

In [None]:
skip_positions = missing_variants_full["pos"].tolist()

## Extract TreeSequences Genotypes

Collect the full genotype matrix from TS object:

In [None]:
# collect a matrix of (num_sites, num_samples)
ts_genotype_matrix = ts_threads.genotype_matrix()
print(f"Genotype matrix from TS: {ts_genotype_matrix.shape}")
print(f"Number of sites: {ts_genotype_matrix.shape[0]}")
print(f"Number of samples: {ts_genotype_matrix.shape[1]}")
print(f"\nSubsetting first 5 objects in both dimensions")
print(ts_genotype_matrix[:5, :5])

## Extract VCF genotypes

Collect the full genotype matrix from VCF (remember to skip filtered variants):

In [None]:
with cyvcf2.VCF(vcf_file) as vcf_reader:
    print("\nExtracting genotypes from VCF...")

    genotype_list = []

    for i, variant in enumerate(vcf_reader):
        if variant.POS in skip_positions:
            continue  # those variants are missing in TS

        # Extract alleles as a numpy array directly
        # variant.genotypes is a list, we take only the first 2 elements (alleles)
        # Unnest genotypes: create a (num_samples * 2,) flat array for this variant
        genotypes_array = np.array([gt[:2] for gt in variant.genotypes], dtype=np.int8)
        genotype_list.append(genotypes_array.flatten())

        if (i + 1) % 10000 == 0:
            print(f"  Processed {i + 1} variants...")

    # Convert to a 2D numpy array: dimensions (num_variants, num_samples * 2)
    vcf_genotype_matrix = np.array(genotype_list, dtype=np.int8)

    print(f"\nVCF genotype matrix created:")
    print(f"  Shape: {vcf_genotype_matrix.shape}")
    print(f"  Dtype: {vcf_genotype_matrix.dtype}")
    print(f"  Number of variants: {vcf_genotype_matrix.shape[0]}")
    print(f"  Number of samples: {vcf_genotype_matrix.shape[1]}")

    print(f"\nExample - first 5 genotypes of the first sample:")
    print(vcf_genotype_matrix[:5, :5])

## Compare genotype matrices

Compare the two genotype matrices:

In [None]:
# Check matching between VCF and TreeSequence genotype matrices: test for shape and values
if vcf_genotype_matrix.shape == ts_genotype_matrix.shape:
    print(f"✓ Matrices have the same shape: {vcf_genotype_matrix.shape}")

    # Both matrices have shape (num_sites, num_samples * 2 intended as haploid individuals)
    # test for site-wise matches
    matches = np.all(vcf_genotype_matrix == ts_genotype_matrix, axis=1)
    num_matches = np.sum(matches)
    num_total = len(matches)

    print(f"\nGlobal result:")
    print(f"  Sites with matching genotypes: {num_matches}/{num_total} ({100*num_matches/num_total:.2f}%)")

    if num_matches != num_total:
        # Find sites with mismatches
        mismatches = np.where(~matches)[0]
        print(f"\n  Number of sites with mismatches: {len(mismatches)}")
        print(f"  First 5 sites with mismatches: {mismatches[:5]}")

        # Analyze an example of mismatch
        if len(mismatches) > 0:
            site_idx = mismatches[0]
            print(f"\n  Example - Site {site_idx}:")

            # Find which samples have mismatches
            site_matches = vcf_genotype_matrix[site_idx] == ts_genotype_matrix[site_idx]
            discordant_samples = np.where(~site_matches)[0]

            print(f"    Samples with mismatches: {len(discordant_samples)}")
            if len(discordant_samples) > 0:
                sample_idx = discordant_samples[0]
                print(f"    Sample {sample_idx}:")
                print(f"      VCF:          {vcf_genotype_matrix[site_idx, sample_idx]}")
                print(f"      TreeSequence: {ts_genotype_matrix[site_idx, sample_idx]}")
    else:
        print("\n✓ All matrices match perfectly!")

    # Statistics per sample
    # Per-sample match stats (all sites must match for a sample to be "perfect")
    sample_matches = np.all(vcf_genotype_matrix == ts_genotype_matrix, axis=0)
    num_perfect_samples = int(np.sum(sample_matches))

    print(f"\nSamples with all matching genotypes: {num_perfect_samples}/{ts_threads.num_samples}")

    if num_perfect_samples != ts_threads.num_samples:
        mismatched_samples = np.where(~sample_matches)[0]
        print(f"  Number of samples with mismatches: {len(mismatched_samples)}")
        print(f"  First 10 samples with mismatches: {mismatched_samples[:10].tolist()}")
else:
    print(f"✗ Matrices have different shapes!")
    print(f"  VCF:          {vcf_genotype_matrix.shape}")
    print(f"  TreeSequence: {ts_genotype_matrix.shape}")