Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equations & Citation #4

Open
valentinwust opened this issue Jul 8, 2024 · 7 comments
Open

Equations & Citation #4

valentinwust opened this issue Jul 8, 2024 · 7 comments
Labels
question Further information is requested

Comments

@valentinwust
Copy link

Hello,

I'd like to try out the different unmixings you implemented (and specifically removing some spectrum from the overall fluorescence, I haven't seen that done anywhere else before), but R or working in a specific GUI annoys me. It probably isn't that complicated to reimplement this in pytorch, do you have any equation what exactly you are doing/optimizing? I would guess the MAPE stuff is probably from the Generalized Unmixing Model paper, is there anything other than the git repo I can/should cite if I end up using your method somewhere?

Greetings
Valentin

@exaexa
Copy link
Owner

exaexa commented Jul 8, 2024

Hi,

I had a paper draft about this somewhere but it never really got out or so. I will try to find it asap.

For your effort I guess a very simple description might be sufficient, basically:

  • the unmixing can be implemented using any gradient descent that predicts the value of x (unmixed signal) based on y (whatever was measured for the cytometer event), given spectra S, with three sources of error gradient:
    • sum(x^2) for all x values that are negative (this penalizes guessing "negative light")
    • sum of squared residuals, split into positive and negative residual errors with different weights (equal weights would give the usual OLS, but splitting them allows you to penalize "overshooting" the channel contributions while relaxing the possibility to leave a bit of positive residual noise behind (positive poisson-like noise is the realistic one here))
  • as the actual gradient deacent algorithm, it is very good to use something that can accelerate very quickly in good directions because typically there is a very long linear valley. We used something like RMSprop but others should work just as well -- the more momentum there is the better. Normally people would start with Adam but it's a bit too conservative for a relatively simple problem like this.
  • for doing the "spectrum subtraction" we just run unmixing where negative residuals ("overshooting it") are penalized a lot, while there is just a minimal error cost on undershooting it. The residual is then the "subtracted" result. It turned out to be quite useful for removing e.g. the very predictable UV noise peaks, to be able to get "clean" spectra estimation, without the UV peak. That helps a lot later in unmixing, because your spectra do not have to "compete" for that omnipresent UV peak.

Hope this helps, please ask away in case of any questions :)

@valentinwust
Copy link
Author

Yes thank you, that was very helpful. RMSprop was pretty hard to get to converge in a reasonable time, but simple stochastic gradient descent with a bit of momentum works just fine.

You implemented a bunch of other non-OLS unmixing methods for panelbuilder, how commonly used are they? I think I might propose your method as an addition to FlowKit, if you think someone might be interested in the other ones as well I would also include them.

@exaexa
Copy link
Owner

exaexa commented Jul 8, 2024

RMSprop

just checked the code at https://github.com/exaexa/nougad/blame/master/src/nougad.cpp#L57 , the 2 things of notice are:

  • momentum feeds back to itself, thus crawling a valley of length X can be done "very optimistically in log(X) steps"
  • if the sign of the residual flips, momentum resets in that dimension

Citation: For code please follow GPL (but i assume a rewrite so that's hopefully not an issue). There is no other real citation yet, so if you want feel free to just acknowledge names.

Re FlowKit, that's a nice package, I wish I had more time for this stuff. :) I wanted to eventually also port embedsom to python (https://github.com/exaexa/EmbedSOM), it's super nice for quickchecking various intermediate results so I guess it could help here

@exaexa exaexa closed this as completed Jul 8, 2024
@exaexa exaexa reopened this Jul 8, 2024
@exaexa
Copy link
Owner

exaexa commented Jul 8, 2024

(misclick)

@exaexa
Copy link
Owner

exaexa commented Jul 8, 2024

I updated the misclicked comment. In the meantime please feel free to leave this issue open for tracking and questions.

@exaexa exaexa added the question Further information is requested label Jul 8, 2024
@valentinwust
Copy link
Author

valentinwust commented Jul 8, 2024

I'll be honest, I never looked at optimizers in enough detail to understand your implementation without putting more time into that. But I think my implementation works well enough that I'm happy for now, I'll see how well it works in practice.

Just in case I end up forgetting about this and someone finds it in the future, this code should basically reproduce the C/R implementation using pytorch:

Collapsed Code
import numpy as np
import torch

def nougad_unmix(x, spectra, device="cpu", verbose=False, **kwargs):
    """ Unmix data using nougad.
        Initializes with OLS solution for faster convergence.
        x: (events, channels)
        spectra: (markers, channels)
    """
    ngd = Nougad(x, spectra, rnw=1, rpw=2, nw=4, device=device, verbose=verbose, lr=2e-2, init="ols", **kwargs)
    if verbose: ngd.plot_loss()
    return ngd.unmixed

def nougad_remove(x, spectra, device="cpu", verbose=False, **kwargs):
    """ Remove spectra from data using nougad.
        Seems to tolerate a higher learning rate than unmixing.
        Initialize as 0., OLS doesn't necessarily make sense here.
        x: (events, channels)
        spectra: (markers, channels)
    """
    ngd = Nougad(x, spectra, rnw=10., rpw=.1, nw=1., device=device, verbose=verbose, lr=1e-1, init="null", **kwargs)
    if verbose: ngd.plot_loss()
    return ngd.residuals


def _extend(v, N):
    v = np.asarray(v).reshape(-1)
    if   v.shape[0]==1: return v.repeat(N)
    elif v.shape[1]==N: return v
    else: raise ValueError(f"Cannot cast {v} as array of length {N}!")

class Nougad:
    """ Non-linear unmixing by gradient descent
        Reimplementation of https://github.com/exaexa/nougad, up to maybe a spillover orientation.
        
        Optimizes unmixed values with three losses:
        - negative residuals after unmixing, weighted by rnw
        - positive residuals after unmixing, weighted by rpw
        - negative unmixed values to punish negative intensity, weighted by nw
        
        1., 1., 0. reproduces standard OLS compensation; for numerical stability the implementation defaults to simple OLS in that case.
        1., 2., 4. produces standard nougad unmixing.
        10., .1, 1. to remove spectra.
        
        Weights can be set globally or separately by marker/channel.
        
        Can be either initialised with zeros, or with the OLS solution.
        
        Spectral not tested yet!!!
    """
    def __init__(self,
                 x, # (events, channels), raw mixed data
                 spill, # (markers, channels), spillover matrix, oriented such that unmixed @ spill = raw
                 
                 rnw=1., # weight of negative residuals, float or iterable
                 rpw=1., # weight of positive residuals, float or iterable
                 nw =0., # weight of negative unmixed values, float or iterable
                 
                 init="ols", # initialisation, choose between 'null' for 0., and 'ols' for OLS solution
                 lr=1e-2, # learning rate for SGD
                 momentum=.5, # SGD momentum
                 maxiters=1000, # max iterations
                 opt_tolerance=1e-6, # if loss doesn't improve stop iterations, set to zero to disable
                 device="cpu", # device on which to run the computations, either 'cpu' or CUDA device
                 verbose=False,
                ):
        self.x = np.asarray(x)
        self.spill = np.asarray(spill)
        self.Nevents, self.Nmarker, self.Nchannel = self.x.shape[0], self.spill.shape[0], self.spill.shape[1]
        
        self.rnw, self.rpw, self.nw = _extend(rnw, self.Nchannel), _extend(rpw, self.Nchannel), _extend(nw, self.Nmarker)
        self.is_OLS = np.all([np.isclose(self.rnw, 1.).all(),
                              np.isclose(self.rpw, 1.).all(),
                              np.isclose(self.nw,  0.).all()]) # if basically OLS
        
        self.init = init
        self.sgd_lr, self.sgd_momentum, self.sgd_maxiters, self.opt_tolerance = lr, momentum, maxiters, opt_tolerance
        self.device = device
        self.verbose = verbose
        
        self.add_inverted_spill()
        self.do_unmixing()
    
    def add_inverted_spill(self):
        """ Add inverted spillover, if spectral do (M^T * M)^(-1) * M^T instead of direct inverse
        """
        if self.spill.shape[0] == self.spill.shape[1]: # regular flow
            self.spill_inv = np.linalg.inv(self.spill)
        elif self.spill.shape[1] > self.spill.shape[0]: # spectral flow
            self.spill_inv = (np.linalg.inv(self.spill @ self.spill.T) @ self.spill).T
        else: # not invertible!
            raise ValueError(f"Passed spillover matrix has {self.spill.shape[0]} markers but only {self.spill.shape[1]} channels, not invertible!")
        
    def do_unmixing(self):
        """ Do unmixing with gradient descent.
        """
        if self.is_OLS:
            self.unmixed = self.x @ self.spill_inv
            self.residuals = np.zeros_like(self.x)
        else:
            if self.init=="null":  unmx_torch = torch.zeros( (self.Nevents, self.Nmarker), dtype=torch.float32, device=self.device, requires_grad=True)
            elif self.init=="ols": unmx_torch = torch.tensor( self.x @ self.spill_inv,     dtype=torch.float32, device=self.device, requires_grad=True)
            else: raise ValueError(f"Initialisation method '{init}' for nougad is not known! Choose between 'null' and 'ols'.")
            
            if self.verbose:
                from tqdm import tqdm
                fct_iter = lambda x: tqdm(x)
            else:
                fct_iter = lambda x: x
            
            self.losshistory = []
            optimizer = torch.optim.SGD([unmx_torch], lr=self.sgd_lr, momentum=self.sgd_momentum)
            to_tensor = lambda x: torch.tensor(x, dtype=torch.float32, device=self.device)
            x_torch = to_tensor(self.x)
            spill_torch = to_tensor(self.spill)
            rnw, rpw, nw = to_tensor(self.rnw)[None].sqrt(), to_tensor(self.rpw)[None].sqrt(), to_tensor(self.nw)[None].sqrt()
            
            for it in fct_iter(range(self.sgd_maxiters)):
                # A few more multiplications than necessary for the weights, probably possible to improve this
                optimizer.zero_grad()
                loss = ( nw * unmx_torch)[unmx_torch<0].square().sum()
                residuals_torch = x_torch - unmx_torch @ spill_torch
                loss += ( rnw * residuals_torch )[residuals_torch<0].square().sum()
                loss += ( rpw * residuals_torch )[residuals_torch>0].square().sum()
                loss.backward()
                optimizer.step()
                self.losshistory.append(loss.item())
                
                if (it>5):
                    crit = np.abs((self.losshistory[-2]-self.losshistory[-1] ) / self.losshistory[-2])
                    if crit<self.opt_tolerance:
                        break
    
            residuals_torch = x_torch - unmx_torch @ spill_torch
            self.unmixed = unmx_torch.detach().cpu().numpy()
            self.residuals = residuals_torch.detach().cpu().numpy()
    
    def plot_loss(self):
        """ Plot loss over iterations
        """
        import matplotlib.pyplot as plt
        
        plt.plot(self.losshistory, zorder=10)
        plt.yscale("log")
        plt.axhline(self.losshistory[-1], color="black")
        plt.xlabel("Iteration")
        plt.ylabel("Total Loss")
        plt.show()

@exaexa
Copy link
Owner

exaexa commented Jul 8, 2024

that looks all right to me, the main component is in the loss function which is precisely what you have there. The C implementation I have is a lil bit overoptimized (it SSEs automatically on most compilers) which really doesn't help with readability 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants