In [1]:
hapmap_path = '../data/filtered_hapmap3'
kg_path = '../data/1000g/populations'
output_path = '../data/pca'

In [39]:
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from pathlib import Path

class JointPopulationAnalyzer:
def __init__(self, snp_mapping_file):
    """
    Initialize with shared populations, data paths and SNP mapping file
    """
    self.shared_populations = [
        'ASW', 'CEU', 'CHB', 'JPT', 
        'LWK', 'MXL', 'TSI', 'YRI'
    ]
    
    self.data_paths = {
        'hapmap3': Path('../data/filtered_hapmap3'),
        '1000g': Path('../data/1000g/populations')
    }
    
    # Load SNP mapping and print structure
    self.snp_mapping = pd.read_csv(snp_mapping_file, sep='\t')
    print("SNP mapping columns:", self.snp_mapping.columns.tolist())
    print("\nFirst few rows of mapping file:")
    print(self.snp_mapping.head())
    
    # Store both HapMap3 and 1000G alleles
    self.hapmap_to_1000g = {}
    for _, row in self.snp_mapping.iterrows():
        # We'll adjust these column names based on the debug output
        self.hapmap_to_1000g[row['position']] = (
            row['ref'],  # Using generic column names for now
            row['alt']
        )
    def _read_geno_file(self, filepath):
        """Read genotype file with consistent dimensions"""
        ind_file = str(filepath).replace('.geno', '.ind')
        with open(ind_file, 'r') as f:
            n_samples = sum(1 for line in f)
        
        print(f"Reading genotype file: {filepath}")
        print(f"Expected samples from .ind file: {n_samples}")
        
        geno_lines = []
        with open(filepath, 'r') as f:
            for i, line in enumerate(f, 1):
                line = line.strip()
                numeric_line = []
                for char in line[:n_samples]:
                    if char in '012':
                        numeric_line.append(int(char))
                    else:
                        numeric_line.append(-1)
                        
                while len(numeric_line) < n_samples:
                    numeric_line.append(-1)
                    
                geno_lines.append(numeric_line)
        
        geno_array = np.array(geno_lines, dtype=np.int8)
        print(f"Final genotype matrix shape: {geno_array.shape}")
        
        return geno_array

    def load_genetic_data(self, population, dataset):
        """Load genetic data with position-based filtering"""
        base_path = self.data_paths[dataset]
        
        try:
            geno_path = base_path / f"{population}.geno"
            ind_path = base_path / f"{population}.ind"
            snp_path = base_path / f"{population}.snp"
            
            # Check files exist
            for path in [geno_path, ind_path, snp_path]:
                if not path.exists():
                    print(f"Missing file: {path}")
                    return None
            
            # Load data
            geno = self._read_geno_file(geno_path)
            
            ind = pd.read_csv(ind_path, sep='\s+', header=None,
                            names=['sample_id', 'sex', 'population'])
            
            snp = pd.read_csv(snp_path, sep='\s+', header=None,
                            names=['id', 'chr', 'genetic_dist', 'position', 'ref', 'alt'])
            
            # Filter SNPs based on mapping
            if dataset == 'hapmap3':
                valid_positions = set(self.snp_mapping['position'])
            else:  # 1000g
                valid_positions = set(self.snp_mapping['position'])
            
            position_mask = snp['position'].isin(valid_positions)
            filtered_geno = geno[position_mask]
            filtered_snp = snp[position_mask]
            
            print(f"Filtered from {len(snp)} to {len(filtered_snp)} SNPs based on mapping")
            
            return {
                'genotype': filtered_geno,
                'individual': ind,
                'snp': filtered_snp
            }
            
        except Exception as e:
            print(f"Error loading {population} from {dataset}: {str(e)}")
            return None

    def identify_common_snps(self):
        """Identify SNPs common to all populations based on position mapping"""
        common_positions = set(self.snp_mapping['position'])
        print(f"Using {len(common_positions)} mapped positions")
        return common_positions

    def prepare_joint_analysis(self):
        """Prepare data for joint analysis"""
        print("\nLoading data for all populations...")
        
        for dataset in self.data_paths.keys():
            self.data[dataset] = {}
            for pop in self.shared_populations:
                print(f"\nProcessing {pop} from {dataset}...")
                data = self.load_genetic_data(pop, dataset)
                if data is not None:
                    self.data[dataset][pop] = data
        
        self.common_snps = self.identify_common_snps()
        self.joint_data = self._prepare_joint_matrix()
        print(f"\nFinal joint genotype matrix shape: {self.joint_data['genotype'].shape}")

    def _prepare_joint_matrix(self):
        """Prepare joint matrix using position-based mapping with detailed debugging"""
        genotypes = []
        pop_labels = []
        dataset_labels = []
        sample_ids = []

        for dataset in self.data.keys():
            for pop, pop_data in self.data[dataset].items():
                if pop_data is None:
                    print(f"Skipping {pop}-{dataset}: No data")
                    continue

                # Debug information
                print(f"\nProcessing {pop}-{dataset}")
                print(f"Original SNP count: {len(pop_data['snp'])}")
                print(f"Original positions: {pop_data['snp']['position'].head()}")
                print(f"Common SNP positions count: {len(self.common_snps)}")

                # Filter by mapped positions
                snp_indices = pop_data['snp']['position'].isin(self.common_snps)
                print(f"Matched positions: {sum(snp_indices)}")

                if sum(snp_indices) == 0:
                    print(f"Warning: No matching positions found for {pop}-{dataset}")
                    continue

                filtered_geno = pop_data['genotype'][snp_indices]
                filtered_positions = pop_data['snp'].loc[snp_indices, 'position']

                print(f"Filtered genotype shape: {filtered_geno.shape}")
                print(f"First few positions after filtering: {filtered_positions.head()}")

                # For 1000G data, check allele flipping
                if dataset == '1000g':
                    flip_count = 0
                    for i, pos in enumerate(filtered_positions):
                        if pos in self.hapmap_to_1000g:
                            ref_kg, alt_kg = self.hapmap_to_1000g[pos]
                            if ref_kg != pop_data['snp'].loc[snp_indices].iloc[i]['ref']:
                                # Flip genotypes (0->2, 2->0, 1->1)
                                mask = filtered_geno[i] != 1
                                filtered_geno[i][mask] = 2 - filtered_geno[i][mask]
                                flip_count += 1
                    print(f"Flipped genotypes for {flip_count} SNPs")

                # Only add if we have valid data
                if filtered_geno.size > 0:
                    genotypes.append(filtered_geno)
                    n_samples = filtered_geno.shape[1]
                    pop_labels.extend([pop] * n_samples)
                    dataset_labels.extend([dataset] * n_samples)
                    sample_ids.extend(pop_data['individual']['sample_id'])
                    print(f"Successfully added {pop}-{dataset}: {filtered_geno.shape}")
                else:
                    print(f"Warning: Empty genotype matrix for {pop}-{dataset}")

        if not genotypes:
            raise ValueError("No valid genotype data after filtering")

        # Check dimensions before concatenation
        print("\nFinal dimension check before concatenation:")
        for i, geno in enumerate(genotypes):
            print(f"Matrix {i}: shape {geno.shape}")

        try:
            joint_geno = np.hstack(genotypes)
            print(f"Successfully created joint matrix with shape: {joint_geno.shape}")
        except Exception as e:
            print("Error during concatenation:")
            print(str(e))
            raise

        return {
            'genotype': joint_geno,
            'population': np.array(pop_labels),
            'dataset': np.array(dataset_labels),
            'sample_id': np.array(sample_ids)
        }
    
    def perform_pca(self, n_components=2):
        """
        Perform PCA on the joint dataset

        Args:
            n_components: Number of principal components to compute
        """
        print("Performing PCA analysis...")

        # Replace missing values with mean
        genotype_data = self.joint_data['genotype'].copy()
        mask = genotype_data == -1

        for i in range(genotype_data.shape[0]):
            row = genotype_data[i]
            valid_values = row[row != -1]
            if len(valid_values) > 0:
                row_mean = np.mean(valid_values)
                genotype_data[i][mask[i]] = row_mean

        # Standardize the data
        scaler = StandardScaler()
        scaled_data = scaler.fit_transform(genotype_data.T)

        # Perform PCA
        pca = PCA(n_components=n_components)
        self.pca_results = pca.fit_transform(scaled_data)
        self.explained_variance = pca.explained_variance_ratio_

        print(f"Variance explained: PC1={self.explained_variance[0]:.2%}, PC2={self.explained_variance[1]:.2%}")

    def plot_pca(self, save_path=None):
        """
        Plot PCA results

        Args:
            save_path: Optional path to save the plot
        """
        plt.figure(figsize=(12, 8))

        # Plot each population-dataset combination
        for pop in self.shared_populations:
            for dataset in self.data_paths.keys():
                mask = (self.joint_data['population'] == pop) & \
                      (self.joint_data['dataset'] == dataset)

                if np.any(mask):  # Only plot if we have data for this combination
                    plt.scatter(
                        self.pca_results[mask, 0],
                        self.pca_results[mask, 1],
                        c=self.population_colors[pop],
                        marker='o' if dataset == 'hapmap3' else '^',
                        label=f'{pop}-{dataset}',
                        alpha=0.7,
                        s=50
                    )

        plt.xlabel(f'PC1 ({self.explained_variance[0]:.1%} variance)')
        plt.ylabel(f'PC2 ({self.explained_variance[1]:.1%} variance)')
        plt.title('PCA of HapMap3 and 1000G Populations')

        # Add legend
        legend_elements = []
        for pop in self.shared_populations:
            legend_elements.extend([
                plt.Line2D([0], [0], marker='o', color='w', 
                          markerfacecolor=self.population_colors[pop],
                          label=f'{pop}-HapMap3', markersize=8),
                plt.Line2D([0], [0], marker='^', color='w', 
                          markerfacecolor=self.population_colors[pop],
                          label=f'{pop}-1000G', markersize=8)
            ])

        plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), 
                  loc='upper left', borderaxespad=0.)
        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            print(f"Plot saved to {save_path}")

        plt.show()
def run_analysis(snp_mapping_file):
    """Run the complete analysis pipeline"""
    try:
        analyzer = JointPopulationAnalyzer(snp_mapping_file)
        analyzer.prepare_joint_analysis()
        analyzer.perform_pca()
        analyzer.plot_pca("joint_population_pca.png")
        return analyzer
    except Exception as e:
        print(f"Analysis failed: {str(e)}")
        return None

IndentationError: expected an indented block after class definition on line 8 (2679352954.py, line 9)

In [63]:
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from pathlib import Path
from scipy import stats

class JointPopulationAnalyzer:
    def __init__(self, snp_mapping_file):
        """
        Initialize with shared populations, data paths and SNP mapping file
        
        Args:
            snp_mapping_file: Path to the SNP mapping file between HapMap3 and 1000G
        """
        self.shared_populations = [
            'ASW', 'CEU', 'CHB', 'JPT', 
            'LWK', 'MXL', 'TSI', 'YRI'
        ]
        
        self.data_paths = {
            'hapmap3': Path('../data/filtered_hapmap3'),
            '1000g': Path('../data/1000g/populations')
        }
        
        # Load SNP mapping with additional QC information
        self.snp_mapping = pd.read_csv(snp_mapping_file, sep='\t')
        print(f"Loaded {len(self.snp_mapping)} SNP mappings")
        
        # Create enhanced lookup dictionaries
        self.hapmap_to_1000g = dict(zip(
            self.snp_mapping['position'], 
            zip(self.snp_mapping['ref_kg'], 
                self.snp_mapping['alt_kg'])
        ))
        
        self.population_colors = {
            'ASW': '#1f77b4', 'CEU': '#ff7f0e', 'CHB': '#2ca02c',
            'JPT': '#d62728', 'LWK': '#9467bd', 'MXL': '#8c564b',
            'TSI': '#e377c2', 'YRI': '#7f7f7f'
        }
        
        self.data = {}

    def _read_geno_file(self, filepath):
        """Read genotype file with enhanced QC"""
        ind_file = str(filepath).replace('.geno', '.ind')
        with open(ind_file, 'r') as f:
            n_samples = sum(1 for line in f)

        print(f"Reading genotype file: {filepath}")

        # Read genotype data
        geno_lines = []
        with open(filepath, 'r') as f:
            for line in f:
                line = line.strip()
                numeric_line = np.array([int(char) if char in '012' else -1 
                                       for char in line[:n_samples]])
                geno_lines.append(numeric_line)

        geno_array = np.array(geno_lines, dtype=np.int8)

        # Basic QC metrics
        valid_genotypes = geno_array != -1
        missing_rate = 1 - np.mean(valid_genotypes, axis=1)

        # Calculate MAF only for valid genotypes
        het_freq = np.zeros(len(geno_array))
        hom_alt_freq = np.zeros(len(geno_array))

        for i in range(len(geno_array)):
            valid_mask = valid_genotypes[i]
            if np.any(valid_mask):
                het_freq[i] = np.mean(geno_array[i][valid_mask] == 1)
                hom_alt_freq[i] = np.mean(geno_array[i][valid_mask] == 2)

        maf = hom_alt_freq + 0.5 * het_freq
        maf = np.minimum(maf, 1 - maf)

        # Filter SNPs with high missing rate or extreme MAF
        valid_snps = (missing_rate < 0.05) & (maf > 0.01)
        filtered_geno = geno_array[valid_snps]

        print(f"Removed {np.sum(~valid_snps)} SNPs failing QC")
        print(f"Final genotype matrix shape: {filtered_geno.shape}")

        return filtered_geno, valid_snps

    def load_genetic_data(self, population, dataset):
        """Load genetic data with enhanced QC"""
        base_path = self.data_paths[dataset]

        try:
            geno_path = base_path / f"{population}.geno"
            ind_path = base_path / f"{population}.ind"
            snp_path = base_path / f"{population}.snp"

            # Check files exist
            for path in [geno_path, ind_path, snp_path]:
                if not path.exists():
                    print(f"Missing file: {path}")
                    return None

            # Load data with QC
            geno, valid_snps = self._read_geno_file(geno_path)

            # Load and filter SNP data
            if dataset == 'hapmap3':
                snp = pd.read_csv(snp_path, sep='\s+', header=None,
                                names=['id', 'chr', 'genetic_dist', 'position', 'ref', 'alt'])
            else:
                snp = pd.read_csv(snp_path, sep='\s+', header=None,
                                names=['chr', 'position', 'ref', 'alt', 'genetic_dist'])

            snp = snp[valid_snps].reset_index(drop=True)

            # Load individual data
            ind = pd.read_csv(ind_path, sep='\s+', header=None,
                            names=['sample_id', 'sex', 'population'])

            # Perform HWE test
            hwe_pvals = self._perform_hardy_weinberg(geno)
            hwe_mask = hwe_pvals > 1e-6

            # Apply HWE filter
            filtered_geno = geno[hwe_mask]
            filtered_snp = snp[hwe_mask].reset_index(drop=True)

            print(f"Removed {np.sum(~hwe_mask)} SNPs failing HWE")
            print(f"Final data shapes - Genotype: {filtered_geno.shape}, SNPs: {len(filtered_snp)}")

            return {
                'genotype': filtered_geno,
                'individual': ind,
                'snp': filtered_snp
            }

        except Exception as e:
            print(f"Error loading {population} from {dataset}: {str(e)}")
            print("Full traceback:", e.__traceback__)
            return None

    def _perform_hardy_weinberg(self, genotypes):
        """Perform Hardy-Weinberg equilibrium test"""
        p_values = []

        for snp_genotypes in genotypes:
            valid_mask = snp_genotypes != -1
            counts = np.bincount(snp_genotypes[valid_mask], minlength=3)

            # Skip SNPs with insufficient data
            if len(counts) < 3 or sum(counts) < 20:  # minimum sample size threshold
                p_values.append(0)
                continue

            # Calculate expected frequencies
            n = sum(counts)
            p = (2 * counts[2] + counts[1]) / (2 * n)
            q = 1 - p

            exp = np.array([
                n * q**2,  # Expected AA
                n * 2 * p * q,  # Expected AB
                n * p**2   # Expected BB
            ])

            # Avoid division by zero in chi-square test
            if np.any(exp < 1):
                p_values.append(0)
                continue

            # Chi-square test
            chi2, p_value = stats.chisquare(counts, exp)
            p_values.append(p_value)

        return np.array(p_values)
        
    def prepare_joint_analysis(self):
        """Prepare data for joint analysis"""
        print("\nLoading data for all populations...")
        
        for dataset in self.data_paths.keys():
            self.data[dataset] = {}
            for pop in self.shared_populations:
                print(f"\nProcessing {pop} from {dataset}...")
                data = self.load_genetic_data(pop, dataset)
                if data is not None:
                    self.data[dataset][pop] = data
        
        self.common_snps = self.identify_common_snps()
        self.joint_data = self._prepare_joint_matrix()
        print(f"\nFinal joint genotype matrix shape: {self.joint_data['genotype'].shape}")

        
    def identify_common_snps(self):
        """Identify SNPs common to all populations based on position mapping"""
        common_positions = set(self.snp_mapping['position'])
        print(f"Using {len(common_positions)} mapped positions")
        return common_positions

 

    def _prepare_joint_matrix(self):
        """Prepare joint matrix with improved SNP alignment"""
        genotypes = []
        pop_labels = []
        dataset_labels = []
        sample_ids = []

        # Collect and sort positions
        common_positions = self._get_common_positions()
        print(f"Found {len(common_positions)} common positions")
        
        for dataset in self.data.keys():
            for pop, pop_data in self.data[dataset].items():
                if pop_data is None:
                    continue

                print(f"\nProcessing {pop}-{dataset}")
                
                # Get position indices
                pos_indices = [i for i, pos in enumerate(pop_data['snp']['position']) 
                             if pos in common_positions]
                
                geno = pop_data['genotype'][pos_indices]
                snp_info = pop_data['snp'].iloc[pos_indices]
                
                # Handle allele flipping for 1000G
                if dataset == '1000g':
                    for i, (pos, ref, alt) in enumerate(zip(snp_info['position'], 
                                                          snp_info['ref'], 
                                                          snp_info['alt'])):
                        if pos in self.hapmap_to_1000g:
                            ref_kg, alt_kg = self.hapmap_to_1000g[pos]
                            if ref_kg != ref:
                                mask = geno[i] != 1
                                geno[i][mask] = 2 - geno[i][mask]
                
                # Add to joint matrices
                if geno.size > 0:
                    genotypes.append(geno)
                    n_samples = geno.shape[1]
                    pop_labels.extend([pop] * n_samples)
                    dataset_labels.extend([dataset] * n_samples)
                    sample_ids.extend(pop_data['individual']['sample_id'])

        if not genotypes:
            raise ValueError("No valid genotype data after filtering")

        joint_geno = np.vstack([g.T for g in genotypes]).T
        
        return {
            'genotype': joint_geno,
            'population': np.array(pop_labels),
            'dataset': np.array(dataset_labels),
            'sample_id': np.array(sample_ids)
        }

    def _get_common_positions(self):
        """Get positions common to all datasets with QC"""
        positions_per_dataset = {}
        
        for dataset in self.data.keys():
            dataset_positions = set()
            for pop_data in self.data[dataset].values():
                if pop_data is not None:
                    dataset_positions.update(pop_data['snp']['position'])
            positions_per_dataset[dataset] = dataset_positions
        
        return set.intersection(*positions_per_dataset.values())

    def perform_pca(self, n_components=2):
        """Perform PCA with batch effect correction"""
        print("Performing PCA analysis...")

        # Center and scale the data
        genotype_data = self.joint_data['genotype'].copy()
        
        # Improved missing value imputation
        for i in range(genotype_data.shape[0]):
            row = genotype_data[i]
            mask = row == -1
            if np.any(mask):
                # Impute missing values using population-specific means
                for pop in self.shared_populations:
                    pop_mask = (self.joint_data['population'] == pop) & ~mask
                    if np.any(pop_mask):
                        pop_mean = np.mean(row[pop_mask])
                        pop_missing = (self.joint_data['population'] == pop) & mask
                        row[pop_missing] = pop_mean
        
        # Standardize
        scaler = StandardScaler()
        scaled_data = scaler.fit_transform(genotype_data.T)
        
        # Perform PCA
        pca = PCA(n_components=n_components)
        self.pca_results = pca.fit_transform(scaled_data)
        self.explained_variance = pca.explained_variance_ratio_
        
        print(f"Variance explained: {self.explained_variance * 100}")

    def plot_pca(self, save_path=None):
        """Enhanced PCA plot"""
        plt.figure(figsize=(12, 8))
        
        for pop in self.shared_populations:
            for dataset in self.data_paths.keys():
                mask = (self.joint_data['population'] == pop) & \
                      (self.joint_data['dataset'] == dataset)
                
                if np.any(mask):
                    plt.scatter(
                        self.pca_results[mask, 0],
                        self.pca_results[mask, 1],
                        c=self.population_colors[pop],
                        marker='o' if dataset == 'hapmap3' else '^',
                        label=f'{pop}-{dataset}',
                        alpha=0.7,
                        s=50
                    )

        plt.xlabel(f'PC1 ({self.explained_variance[0]:.1%} variance)')
        plt.ylabel(f'PC2 ({self.explained_variance[1]:.1%} variance)')
        plt.title('PCA of HapMap3 and 1000G Populations')
        
        legend_elements = []
        for pop in self.shared_populations:
            legend_elements.extend([
                plt.Line2D([0], [0], marker='o', color='w',
                          markerfacecolor=self.population_colors[pop],
                          label=f'{pop}-HapMap3', markersize=8),
                plt.Line2D([0], [0], marker='^', color='w',
                          markerfacecolor=self.population_colors[pop],
                          label=f'{pop}-1000G', markersize=8)
            ])
        
        plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1),
                  loc='upper left', borderaxespad=0.)
        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
        
        plt.show()

def run_analysis(snp_mapping_file):
    """Run the complete analysis pipeline"""
    try:
        analyzer = JointPopulationAnalyzer(snp_mapping_file)
        analyzer.prepare_joint_analysis()
        analyzer.perform_pca()
        analyzer.plot_pca("joint_population_pca.png")
        return analyzer
    except Exception as e:
        print(f"Analysis failed: {str(e)}")
        return None

In [64]:
analyzer = run_analysis("../data/common_snps/common_snps.txt")

Loaded 21715 SNP mappings

Loading data for all populations...

Processing ASW from hapmap3...
Reading genotype file: ../data/filtered_hapmap3/ASW.geno
Removed 235 SNPs failing QC
Final genotype matrix shape: (21459, 49)
Removed 1984 SNPs failing HWE
Final data shapes - Genotype: (19475, 49), SNPs: 19475

Processing CEU from hapmap3...
Reading genotype file: ../data/filtered_hapmap3/CEU.geno
Removed 387 SNPs failing QC
Final genotype matrix shape: (21307, 112)
Removed 1027 SNPs failing HWE
Final data shapes - Genotype: (20280, 112), SNPs: 20280

Processing CHB from hapmap3...
Reading genotype file: ../data/filtered_hapmap3/CHB.geno
Removed 190 SNPs failing QC
Final genotype matrix shape: (21504, 84)
Removed 1753 SNPs failing HWE
Final data shapes - Genotype: (19751, 84), SNPs: 19751

Processing JPT from hapmap3...
Reading genotype file: ../data/filtered_hapmap3/JPT.geno
Removed 200 SNPs failing QC
Final genotype matrix shape: (21494, 86)
Removed 1812 SNPs failing HWE
Final data shapes 