In [None]:
from sklearn.mixture import GaussianMixture
import torch

class NoiseModeling:
    def __init__(self, model, num_classes, num_mc_samples=10, uncertainty_weight=0.1):
        self.model = model
        self.num_classes = num_classes
        self.num_mc_samples = num_mc_samples
        self.uncertainty_weight = uncertainty_weight

    def compute_epistemic_uncertainty(self, x):
        # Monte Carlo Dropout
        predictions = []
        for _ in range(self.num_mc_samples):
            self.model.train()  # Enable dropout
            logits = self.model(x)
            predictions.append(F.softmax(logits, dim=-1))
        
        mean_prediction = torch.mean(torch.stack(predictions), dim=0)
        epistemic_uncertainty = -torch.sum(mean_prediction * torch.log(mean_prediction + 1e-10), dim=-1)
        return mean_prediction, epistemic_uncertainty

    def fit_gmm(self, loss_values):
        # Convert loss values to a format suitable for sklearn
        loss_values_np = loss_values.cpu().detach().numpy().reshape(-1, 1)
        
        # Fit a Gaussian Mixture Model
        gmm = GaussianMixture(n_components=2, max_iter=100, random_state=0)
        gmm.fit(loss_values_np)
        
        # Get the probabilities for each component
        gmm_probs = gmm.predict_proba(loss_values_np)
        
        # Select the component with the lower mean as the "clean" label component
        clean_component = gmm.means_.argmin()
        clean_prob = gmm_probs[:, clean_component]
        
        # Convert probabilities back to tensor
        clean_prob_tensor = torch.from_numpy(clean_prob).float().to(loss_values.device)
        
        return clean_prob_tensor

    def compute_clean_probability(self, loss_values, epistemic_uncertainty, class_labels):
        clean_probs = []
        for c in range(self.num_classes):
            class_mask = (class_labels == c)
            class_loss_values = loss_values[class_mask]

            if len(class_loss_values) == 0:
                continue
            
            clean_prob = self.fit_gmm(class_loss_values)

            clean_prob_combined = (1 - epistemic_uncertainty[class_mask]) ** self.uncertainty_weight * clean_prob ** (1 - self.uncertainty_weight)
            clean_probs.append(clean_prob_combined)
        
        return torch.cat(clean_probs, dim=0)

    def update_labels(self, x, labels):
        # Compute loss values
        self.model.eval()
        logits = self.model(x)
        loss_values = F.cross_entropy(logits, labels, reduction='none')
        
        # Compute epistemic uncertainty
        mean_prediction, epistemic_uncertainty = self.compute_epistemic_uncertainty(x)
        
        # Compute clean probabilities
        clean_probs = self.compute_clean_probability(loss_values, epistemic_uncertainty, labels)
        
        # Update labels
        updated_labels = (clean_probs.unsqueeze(-1) * F.one_hot(labels, num_classes=self.num_classes) + 
                          (1 - clean_probs.unsqueeze(-1)) * mean_prediction).argmax(dim=-1)
        return updated_labels
