In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from scipy.stats import norm

"""
Gaussian Mixture Model (GMM) using PyTorch
This code implements a simple Gaussian Mixture Model (GMM) using PyTorch.
It includes methods for fitting the model to data and calculating the likelihood of new data points.
"""

class GaussianMixture(nn.Module):
    def __init__(self, n_components):
        super(GaussianMixture, self).__init__()
        self.n_components = n_components
        
        # Initialize parameters
        self.means = nn.Parameter(torch.linspace(0.1, 0.9, n_components))
        self.stds = nn.Parameter(torch.ones(n_components) * 0.1)  # Smaller initial std
        self.weights = nn.Parameter(torch.ones(n_components) / n_components)
        self.best_likelihood = -float('inf')
        self.best_params = None
        
    def _compute_responsibilities(self, x):
        # Expand dimensions for broadcasting
        x = x.unsqueeze(-1)  # Shape: (n_samples, 1)
        
        # Calculate Gaussian probability for each component
        gaussian_probs = torch.exp(-0.5 * ((x - self.means) / self.stds)**2) / (self.stds * torch.sqrt(torch.tensor(2 * np.pi)))
        
        # Weight the probabilities
        weighted_probs = self.weights * gaussian_probs
        
        # Sum probabilities for all components and normalize to get responsibilities
        total_probs = torch.sum(weighted_probs, dim=1, keepdim=True)
        responsibilities = weighted_probs / (total_probs + 1e-10)
        return responsibilities, torch.sum(torch.log(total_probs + 1e-10))

    def _m_step(self, x, responsibilities):
        # Update parameters based on responsibilities
        total_resp = responsibilities.sum(0)
        
        # Update means
        self.means.data = (responsibilities * x.unsqueeze(-1)).sum(0) / (total_resp + 1e-10)
        
        # Update standard deviations
        variance = (responsibilities * (x.unsqueeze(-1) - self.means)**2).sum(0) / (total_resp + 1e-10)
        self.stds.data = torch.sqrt(variance + 1e-10)
        
        # Update weights
        self.weights.data = total_resp / x.shape[0]

    
    def fit(self, data, n_iterations=100, n_restarts=10, tol=1e-6):
        """
        Fit GMM using EM algorithm with multiple random restarts
        """
        x = torch.from_numpy(data).float()
        best_likelihood = -float('inf')
        
        for restart in range(n_restarts):
            # Random initialization
            with torch.no_grad():
                self.means.data = torch.FloatTensor(self.n_components).uniform_(0.1, 0.9)
                self.stds.data = torch.ones(self.n_components) * 0.1
                self.weights.data = torch.ones(self.n_components) / self.n_components
            
            prev_likelihood = -float('inf')
            
            for iteration in range(n_iterations):
                # E-step
                responsibilities, log_likelihood = self._compute_responsibilities(x)
                
                # M-step
                self._m_step(x, responsibilities)
                
                ## Ensure parameters stay in valid ranges
                #with torch.no_grad():
                #    self.means.data = torch.clamp(self.means.data, 0.0, 1.0)
                #    self.stds.data = torch.clamp(self.stds.data, 0.001, 0.5)
                #    self.weights.data = torch.softmax(self.weights.data, dim=0)
                
                # Check convergence
                if abs(log_likelihood - prev_likelihood) < tol:
                    break
                    
                prev_likelihood = log_likelihood
                
                # Track best parameters
                if log_likelihood > best_likelihood:
                    best_likelihood = log_likelihood
                    self.best_likelihood = log_likelihood.item()
                    self.best_params = {
                        'means': self.means.data.clone(),
                        'stds': self.stds.data.clone(),
                        'weights': self.weights.data.clone()
                    }
                    print(f'Restart {restart}, Iteration {iteration}: Log Likelihood {log_likelihood.item():.4f}')
        
        # Restore best parameters
        with torch.no_grad():
            self.means.data = self.best_params['means']
            self.stds.data = self.best_params['stds']
            self.weights.data = self.best_params['weights']
            
        #return self.best_likelihood
           




def fit_gmm_to_ab(ind_dat, n_components):
    """
    Fit Gaussian Mixture Model (GMM) to allele balance data.
    
    Parameters:
        ab_dat (np.array): Allele balance data.
        n_components (int): Number of components in the GMM.
    """
    dat = ind_dat
    #print(ind_dat.shape)
    #print(len(ind_dat.shape))
    # Reshape the array to 2D if necessary
    # Example: allele_balance_array = np.random.rand(100, 10)  # Replace with actual data
    if len(ind_dat.shape) == 1:
        dat = ind_dat.reshape(-1, 1)
        #print(dat.shape)
    # Fit GMM to allele balance data   
    gmm = GaussianMixture(n_components = n_components)
    gmm.fit(dat)
    # Print best likelihood
    print(f'Best likelihood: {gmm.best_likelihood}')
    # Print fitted parameters
    print("Fitted GMM parameters:")
    print(f"Means: {gmm.best_params['means'].numpy()}")
    print(f"Std devs: {gmm.best_params['stds'].numpy()}")
    print(f"Weights: {gmm.best_params['weights'].numpy()}")
    return(gmm)

def plot_gmm_fit(data, gmm, n_points=1000, title="GMM Fit to Data"):
    """
    Plot histogram of observed data with fitted GMM components.
    
    Parameters:
        data (np.array): Original data used to fit the GMM
        gmm (GaussianMixture): Fitted GMM model
        n_points (int): Number of points for plotting the GMM curves
        title (str): Plot title
    """
    # Create figure
    plt.figure(figsize=(10, 6))
    
    # Plot histogram of observed data
    plt.hist(data, bins=50, density=True, alpha=0.5, label='Observed Data')
    
    # Generate points for plotting the GMM components
    x = np.linspace(min(data), max(data), n_points)
    x_tensor = torch.from_numpy(x).float()
    
    # Plot individual components
    for i in range(gmm.n_components):
        # Calculate Gaussian distribution for this component
        mu = gmm.means.data[i].item()
        sigma = gmm.stds.data[i].item()
        weight = gmm.weights.data[i].item()
        
        component = weight * norm.pdf(x, mu, sigma)
        plt.plot(x, component, '--', label=f'Component {i+1}')
    
    # Plot total mixture
    total = torch.zeros_like(x_tensor)
    for i in range(gmm.n_components):
        mu = gmm.means.data[i]
        sigma = gmm.stds.data[i]
        weight = gmm.weights.data[i]
        component = weight * torch.exp(-0.5 * ((x_tensor - mu) / sigma)**2) / (sigma * torch.sqrt(torch.tensor(2 * np.pi)))
        total += component
    
    plt.plot(x, total.numpy(), 'r-', label='Total Mixture', linewidth=2)
    
    plt.xlabel('Value')
    plt.ylabel('Density')
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()



#allele_balance_array = np.random.rand(1000, 3)
ab_left = np.random.normal(0.25, 0.05, (500, 4))
ab_middle = np.random.normal(0.5, 0.05, (1000, 4))
ab_right = np.random.normal(0.75, 0.05, (500,4))
allele_balance_array = np.concatenate([ab_left, ab_middle, ab_right], axis=0)
#print(allele_balance_array)
allele_mask_array = np.random.randint(0, 2, size=(2000, 4))
ab_dat = np.array([allele_balance_array,allele_mask_array])
#print(ab_dat)
for i in range(len(ab_dat[0,0,:])):
    print(i)
    ind_dat = ab_dat[0,:,i]
    #print(ind_dat)
    ind_mask = (ab_dat[1,:,i] == 1)
    #print(ind_mask)
    ind_dat_filtered = ind_dat[ind_mask]
    print(ind_dat_filtered)
    gmm = fit_gmm_to_ab(ind_dat_filtered, 3)
    
    #plot_gmm_fit(ind_dat_filtered, gmm, title="Allele Balance Distribution")

#print(allele_balance_array)
#print(allele_balance_array[:,2])
#allele_balance_df = pd.DataFrame(allele_balance_array, columns=[f"Sample_{i}" for i in range(1, 11)])
#allele_balance_df['Sample_ID'] = [f"Sample_{i}" for i in range(1, 101)]
#allele_balance_df.set_index('Sample_ID', inplace=True)
#allele_balance_df = allele_balance_df.transpose()
#print(allele_balance_df)
#allele_balance_df = allele_balance_df.reset_index()
#allele_balance_df.rename(columns={'index': 'Sample_ID'}, inplace=True)

0
[0.18263036 0.25813817 0.20526038 0.16640501 0.28496844 0.23090735
 0.27235816 0.25569745 0.28030261 0.24374318 0.2353714  0.23131575
 0.2294658  0.32187165 0.28220606 0.2862367  0.31185006 0.26646425
 0.23607619 0.16835915 0.32641967 0.25359488 0.28062352 0.31307439
 0.28039707 0.26958186 0.26095936 0.23410695 0.29336549 0.16959949
 0.27387514 0.21280568 0.2745401  0.32670125 0.27215285 0.21648054
 0.24690275 0.30591486 0.20552612 0.20843385 0.15490108 0.23183242
 0.27679358 0.23566815 0.31029912 0.29406207 0.22631564 0.18612553
 0.26180641 0.33078365 0.25963182 0.18064332 0.25896542 0.18038975
 0.18818787 0.20631614 0.31742014 0.20682184 0.23920743 0.23116783
 0.29400299 0.26181964 0.30192497 0.2609838  0.21375674 0.26953073
 0.24546906 0.29384058 0.28723809 0.25781153 0.28119496 0.22336953
 0.18457302 0.20697908 0.18569348 0.14854803 0.33786253 0.19476952
 0.22465693 0.36541527 0.31319896 0.34188456 0.23112034 0.19298583
 0.26667152 0.27340631 0.28642749 0.28806055 0.26894791 0.23