# Load packages

In [None]:
import session_info
import pandas as pd
import seaborn as sns
from pyhere import here
import matplotlib.pyplot as plt
from rfmix_reader import read_rfmix

# Load data

In [None]:
prefix_path = here("input/real_data/_m/")
binary_dir = here("real_data/gpu_version/_m/binary_files/")
loci, rf_q, admix = read_rfmix(prefix_path, binary_dir=binary_dir)

# Visualize chromosome global ancestry

## Organize chromosome

In [None]:
chrom_order = [f'chr{i}' for i in range(1, 23)]
rf_q_pandas = rf_q.to_pandas() # Convert cuDF to pandas DataFrame
rf_q_pandas['chrom'] = pd.Categorical(rf_q_pandas['chrom'], 
                                      categories=chrom_order, ordered=True)
rf_q_sorted = rf_q_pandas.sort_values('chrom')

## Create and save the plot

In [None]:
plt.figure(figsize=(15, 8))
sns.boxplot(x='chrom', y='AFR', data=rf_q_sorted, 
            color='lightgray', width=0.6)
sns.stripplot(x='chrom', y='AFR', data=rf_q_sorted, 
              color='black', alpha=0.1, jitter=True)
plt.axhline(y=0.5, color='black', linestyle='--', linewidth=1)
plt.title('Global Ancestry (AFR) by Chromosome', fontsize=18)
plt.xlabel('Chromosome', fontsize=14)
plt.ylabel('African Genetic Ancestry', fontsize=14)
plt.xticks(rotation=45, fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()

## Mean and median African ancestry proportion

In [None]:
rf_q_sorted.groupby("chrom").agg({"AFR": ["mean", "median"]})

In [None]:
print(rf_q_sorted.AFR.mean(), rf_q_sorted.AFR.median())

**Select chromosome with:**
  * High African Ancestry -- chromosome 13
  * Low African Ancestry -- chromosome 19
  * Average African Ancestry -- chromosome 16

# Prepare data

## Helper functions

In [None]:
import dask.dataframe as dd
from multiprocessing import cpu_count

In [None]:
try:
    from torch.cuda import is_available
except ModuleNotFoundError as e:
    print("Warning: PyTorch is not installed. Using CPU!")
    def is_available():
        return False


In [None]:
def _get_pops(rf_q):
    return rf_q.drop(["sample_id", "chrom"], axis=1).columns.values


def _get_sample_names(rf_q):
    if is_available():
        return rf_q.sample_id.unique().to_arrow()
    else:
        return rf_q.sample_id.unique()

## Define column names

In [None]:
pops = _get_pops(rf_q)
sample_ids = _get_sample_names(rf_q)

## Convert data to dask dataframe

In [None]:
parts = cpu_count()
ncols = admix.shape[1] // len(pops)
ddf = dd.from_pandas(loci.to_pandas(), npartitions=parts)
data_matrix = admix[:, :ncols] # select the first pop only (pop2 is just 1-pop1)
dask_df = dd.from_dask_array(data_matrix, columns=sample_ids)

## Combine loci with haplotype data

In [None]:
ddf = dd.concat([ddf, dask_df], axis=1)
del dask_df # remove for memory consumption

## Select chromosomes

In [None]:
chrom13 = ddf[ddf["chromosome"] == "chr13"]
chrom16 = ddf[ddf["chromosome"] == "chr16"]
chrom19 = ddf[ddf["chromosome"] == "chr19"]
del ddf

# Plot a section of each chromosome

## Helper function

In [None]:
from numpy import random

In [None]:
def select_random_section(df, section_size=10000):
    # Get the minimum and maximum positions
    min_pos = df['physical_position'].min().compute()
    max_pos = df['physical_position'].max().compute()
    # Randomly select a start position
    start_pos = random.randint(min_pos, max_pos - section_size)
    end_pos = start_pos + section_size
    # Filter the DataFrame for the selected section
    section = df[(df['physical_position'] >= start_pos) & (df['physical_position'] < end_pos)]
    ##section = section.melt(id_vars=["chromosome", "physical_position", "i"], 
    ##                       var_name="BrNum", value_name="Haplotypes")
    return section.compute()  # Compute to bring data into memory


def plot_section(df, section_size, fname, chrom):
    selected_section = select_random_section(df, section_size)
    selected_section = selected_section.sort_values("physical_position")
    plt.figure(figsize=(9, 2))
    sns.scatterplot(data=selected_section, x='physical_position', y='Br2585', 
                    color='black', legend=False, s=20)
    plt.title(f'Chromosome {chrom} Section Plot (Positions {selected_section["physical_position"].min()} to {selected_section["physical_position"].max()})', fontsize=16)
    plt.xlabel('Chromosome Position', fontsize=12)
    plt.ylabel('Haplotypes', fontsize=12)
    plt.yticks([0, 1, 2])
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(f'local_ancestry.{fname}.pdf', dpi=300, bbox_inches='tight')

## Seed for reproducibility

In [None]:
seed_value = 13
random.seed(seed_value)
section_size = 1000000

## Plotting

In [None]:
plot_section(chrom13, section_size, "chr13", "13")
plot_section(chrom16, section_size, "chr16", "16")
plot_section(chrom19, section_size, "chr19", "19")

# Session information

In [None]:
session_info.show()