- This notebook demonstrates how to convert kmer genotype matrix to motif genotype matrix
    - In the previous preprocessing step, I computed the correspondence from kmer to motif.
      So once I have a kmer matrix, where each row is a kmer and each col is a sample,
      I can simply do row operations to get the motif dosage.
    - It takes about 10 min to compute each batch of data in `compute_gt_cgt_batch`.
      You could speed it up by moving the code to a Python script and using SLURM's array job.
      Just optional. 
    - Watch out for memory usage and remmeber to allocate enough mem for your compute node,
      otherwise ipython kernel would die
- With motif genotype matrix, we can then perform LD pruning over motifs
    - Add another function to save in tsv.gz format instead of pikcle for portability
    - Use PLINK or write your own code given certain LD threshold
- Data structure in this notebook that can be useful for LD pruning
    - `ccki_tr`: cumulative number of canonical compressed kmer (cck) per TR locus
        - cck is equivalent to motif in our definition
        - vector of size NTR (Number of TR loci)
        - retrieve # of cck in each locus `i` by
            - `ccki_tr[i] -  ccki_tr[i-1]` if i > 0
            - `ccki_tr[i]` if i == 0
        - we will need this to perform LD pruning for each locus, or a block in the matrix

In [None]:
#!/usr/bin/env python3
import sys
srcdir = "/project/mchaisso_100/cmb-16/tsungyul/work/vntr/danbing-tk/script/"
sys.path.insert(0, srcdir)

In [2]:
import numpy as np
import pandas as pd
import vntrutils as vu
import utils
import matplotlib
import matplotlib.pyplot as plt
from collections import defaultdict
from collections import Counter
import pickle
import itertools
import gc
import glob
import os
import statsmodels.api as sm
import gzip
from sklearn.metrics import r2_score
import seaborn as sns
import time

matplotlib.rc('font', size=7)
matplotlib.rc('axes', titlesize=7)
matplotlib.rc('xtick', labelsize=5)
matplotlib.rc('ytick', labelsize=5)
%load_ext autoreload
%autoreload 2

In [6]:
def get_1(file_path):
    with open(file_path, 'rb') as f:
        ki_tr, ccki_tr = pickle.load(f)
    return ki_tr, ccki_tr

def get_2(file_path):
    with open(file_path, 'rb') as f:
        ks, ccks, tr_cck_ns, ki_map = pickle.load(f)
    tr_cck_ns = np.array(tr_cck_ns)
    return ks, ccks, tr_cck_ns, ki_map

def gather_motifs(gt_HPRC, NCCK, NB, out_dir):
    genomes = np.loadtxt(gt_HPRC, dtype=object)
    ng = genomes.size
    BS = ng//NB
    
    print(f"Loading batches...")
    print(BS, ng, NB)
    sys.stdout.flush()
    cgt = np.zeros([NCCK,ng], dtype=np.float32)
    for i in range(NB):
        print(f"Loading batch {i+1}")
        sys.stdout.flush()
        BS_ = BS if i != NB-1 else ng - BS*i
        si = i*BS
        ei = i*BS + BS_
        with open(f"{out_dir}/cgt.{i}.pickle", 'rb') as f:
            cgt[:,si:ei] = pickle.load(f)
    print("Dumping cgt...")
    sys.stdout.flush()
    with open(f"{out_dir}/cgt.pickle", 'wb') as f:
        pickle.dump(cgt, f, protocol=pickle.HIGHEST_PROTOCOL)
    return cgt

def adjust_coverage(cgt, gt_HPRC, HPRC_chr1_cov, out_dir):
    print("Loading coverage...")
    sys.stdout.flush()
    genomes = np.loadtxt(gt_HPRC, dtype=object)
    cov = np.array([float(c) for g, c in np.loadtxt(HPRC_chr1_cov, dtype=object) if g in genomes])
    print("Computing acgt...")
    sys.stdout.flush()
    cgt /= cov
    print("Dumping acgt...")
    sys.stdout.flush()
    with open(f"{out_dir}/acgt.pickle", 'wb') as f:
        pickle.dump(cgt, f, protocol=pickle.HIGHEST_PROTOCOL)
    return cgt

def compute_partial_ld_r2(acgt, ccki_tr, ccks, r2_threshold, out_dir, start_idx, end_idx):
    # keep track of which variants have been pruned
    init_locus_start = ccki_tr[start_idx] if start_idx != 0 else 0
    pruned_size = ccki_tr[end_idx] - init_locus_start
    pruned = np.zeros(pruned_size, dtype=bool)
    print(f"Pruning loci from {start_idx} to {end_idx}...")
    print(f"r^2 threshold = {r2_threshold}")
    print(f"loci {start_idx}: {ccki_tr[start_idx]} motifs \nloci {end_idx}: {ccki_tr[end_idx]} motifs \npartial motif count: {pruned_size} / {len(ccki_tr)}")
    sys.stdout.flush()
    start_time = time.time()

    # loop through all loci
    for i in range(start_idx, end_idx + 1):
        curr_m = ccki_tr[i-1] if i != 0 else 0
        locus_e = ccki_tr[i] - 1
        # loop through all motifs in each loci
        while curr_m < locus_e:
            # skipped pruned motifs
            if not pruned[curr_m - init_locus_start]:
                iter_m = curr_m + 1
                # comparte current motif with all other motifs in loci
                while iter_m <= locus_e:
                    # skipped pruned motifs
                    if not pruned[iter_m - init_locus_start]:
                        r2 = r2_score(acgt[curr_m], acgt[iter_m])
                        if r2 > r2_threshold:
                            pruned[iter_m - init_locus_start] = True
                    iter_m += 1
            curr_m += 1
        if (i + 1) % 100 == 0:
            compute_time = time.time() - start_time
            print(f"Pruned {i + 1} loci in {compute_time:.2f} seconds")
            sys.stdout.flush()
        
    # pickle vector of pruned status
    print(f"Dumping pruned...")
    sys.stdout.flush()
    with open(f"{out_dir}/cck_pruned_{r2_threshold}_{start_idx}_{end_idx}.pickle", 'wb') as f:
        pickle.dump(pruned, f, protocol=pickle.HIGHEST_PROTOCOL)
    return pruned

def combine_pruned_files(out_dir, r2_threshold, num_jobs, total_loci):
    # Initialize a full pruned array for all loci
    pruned_combined = np.zeros(total_loci, dtype=bool)

    # Load each pruned pickle file and combine them
    for i in range(num_jobs):
        start_idx = i * (total_loci // num_jobs)
        end_idx = (i + 1) * (total_loci // num_jobs) if i != num_jobs - 1 else total_loci
        
        with open(f"{out_dir}/cck_pruned_{r2_threshold}_{start_idx}_{end_idx}.pickle", 'rb') as f:
            pruned_partial = pickle.load(f)
        
        pruned_combined[start_idx:end_idx] = pruned_partial

    # Dump the combined pruned array
    with open(f"{out_dir}/cck_pruned_combined_{r2_threshold}.pickle", 'wb') as f:
        pickle.dump(pruned_combined, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f"Combined pruned file saved at {out_dir}/cck_pruned_combined_{r2_threshold}.pickle")

In [7]:
get_1_file="/project/mchaisso_100/cmb-17/vntr_genotyping/rpgg2_k21_84k/hprc/full.v1/output8/cdbg/ki_tr.ccki_tr.pickle"
get_2_file="/project/mchaisso_100/cmb-17/vntr_genotyping/rpgg2_k21_84k/hprc/full.v1/output8/cdbg/ks.ccks.tr_cck_ns.ki_map.pickle"
gt_HPRC="/project/mchaisso_100/cmb-17/vntr_genotyping/aydin/LD_prune/input/genomes.txt"
HPRC_chr1_cov="/project/mchaisso_100/cmb-17/vntr_genotyping/aydin/LD_prune/input/1kg_all.cov.tsv"
out="/scratch1/tsungyul/aydin/k2m_output"

ki_tr, ccki_tr = get_1(get_1_file)
ks, ccks, tr_cck_ns, ki_map = get_2(get_2_file)

NK = len(ks)
NCCK = len(ccks)
NB = 40

# IL dosage

In [None]:
# acgt = None
# if os.path.exists(f"{out}/acgt.pickle"):
#     print("acgt file found")
#     with open(f"{out}/acgt.pickle", 'rb') as f:
#         acgt =  pickle.load(f)
# else:
#     print("acgt file not found")
#     cgt = gather_motifs(gt_HPRC, NCCK, NB, out)
#     acgt =  adjust_coverage(cgt, gt_HPRC, HPRC_chr1_cov, out)

# cgt = gather_motifs(gt_HPRC, NCCK, NB, out)
# acgt =  adjust_coverage(cgt, gt_HPRC, HPRC_chr1_cov, out)

acgt = None
if os.path.exists(f"{out}/acgt.pickle"):
    print("acgt file found")
    with open(f"{out}/acgt.pickle", 'rb') as f:
        acgt =  pickle.load(f)
else:
    cgt = None
    if os.path.exists(f"{out}/cgt.pickle"):
        print("cgt file found")
        with open(f"{out}/cgt.pickle", 'rb') as f:
            cgt = gather_motifs(gt_HPRC, NCCK, NB, out)
    else:
        print("creating cgt file")
        cgt = gather_motifs(gt_HPRC, NCCK, NB, out)
    print("creating acgt file")
    acgt =  adjust_coverage(cgt, gt_HPRC, HPRC_chr1_cov, out)

In [None]:
r2_threshold = 0.4
start_idx = 1
end_idx = 3
cck_pruned = compute_partial_ld_r2(acgt, ccki_tr, ccks, r2_threshold, out, start_idx, end_idx)