In [1]:
import torch 
import anndata as ad
import pandas as pd
import numpy as np
import tqdm
from sklearn.metrics import r2_score

In [2]:
seed = 16
use_gpu = True

np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fb960c30950>

In [5]:
adata = ad.read_h5ad('..//output/preprocess/bulk_adata_f.h5ad')
df = pd.DataFrame(adata.X, index=pd.MultiIndex.from_frame(adata.obs[['cell_type', 'sm_name', 'plate_name']]), columns=adata.var_names)
data = df.values

In [23]:
class NN(torch.nn.Module):
    def __init__(self, n_genes:int, n_nodes_latent:int=16):
        torch.nn.Module.__init__(self)
        self.n_genes = n_genes
        dropout_p = 0.2

        bias = False
        nonlinear = torch.nn.PReLU()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(n_genes, 64),
            nonlinear,
            torch.nn.Dropout(dropout_p),

            torch.nn.Linear(64, n_nodes_latent*2, bias=bias),
            nonlinear
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(n_nodes_latent, n_genes, bias=bias)
        )
        for module in self.modules():
            if isinstance(module, torch.nn.Linear):
                torch.nn.init.kaiming_uniform_(module.weight, nonlinearity='relu', mode='fan_in')
                if module.bias is not None:
                    torch.nn.init.constant_(module.bias, 0.001)

    def reparametrize(self, mu, log_var):
        # max_log_var = 100
        # log_var = torch.clamp(log_var, max=-max_log_var)
        # print(log_var)
        # aaa
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        # print(std)
        # return mu + std*eps
        return mu + log_var*eps

    def forward(self, x: torch.Tensor):
        x = self.encoder(x)
        mu, log_var = torch.chunk(x,2, dim=1)
        x = self.reparametrize(mu, log_var)
        x = self.decoder(x)
        return x
#TODO: dataloader = DataLoader(train_data, batch_size=self.opt.batch_size, shuffle=True, num_workers=1)
#TODO: how to account for batch effects        
def train(train_data: np.array, batch_size = 400, n_epoch = 40):
    # train_data = (train_data - train_data.mean(axis=0))/(train_data.std(axis=0) +1e-8) #TODO: take care of this norm
    # train_data = np.log10(train_data)
    # print(train_data)
    # aa

    # torch data
    train_data = torch.FloatTensor(train_data)
    if use_gpu:
        train_data = train_data.cuda()
    
    model = NN(n_genes=train_data.shape[1])
    if use_gpu:
        model = model.cuda()
    #TODO: loss function. cross entropy loss? other types  
    criterion = lambda Y_true, Y_pred: torch.mean(torch.abs(Y_true-Y_pred))
    # optimizer
    # optimizer = torch.optim.SGD(model.parameters(), lr=.00001, momentum=.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-7)
    
    # scheduler = torch.lr_scheduler.StepLR(optimizer, step_size=10, gamma=.1) 
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, 
                patience=5, threshold_mode='rel', threshold=0.0001, cooldown=5, min_lr=1e-10, eps=1e-8)

    pbar = tqdm.tqdm(range(n_epoch))
    for i_epoch in pbar:
        # shuffle the data
        train_data = train_data[torch.randperm(train_data.size(dim=0))]
        # train for a epoch
        model.train()
        rel_loss_store = []
        Y_pred_stack = []
        Y_true_stack = []
        for X in np.array_split(train_data, batch_size, axis=0):
            optimizer.zero_grad()
            
            Y_true = X #TODO: add noise to the data
            Y_pred = model.forward(X)
            
            Y_true_stack.append(Y_true.cpu().detach().numpy())
            Y_pred_stack.append(Y_pred.cpu().detach().numpy())
            # calculate loss and backpropagate 
            loss = criterion(Y_true, Y_pred)
            loss.backward()
            optimizer.step()

            # baseline pred
            Y_pred_mean = torch.mean(Y_true, axis=0)
            loss_baseline = criterion(Y_true, Y_pred_mean)
            rel_loss = loss/loss_baseline
            rel_loss_store.append(rel_loss.item())

        mean_rel_loss = np.mean(rel_loss_store)
        scheduler.step(mean_rel_loss)

        y_pred = np.concatenate(Y_pred_stack, axis=0)
        y_true = np.concatenate(Y_true_stack, axis=0)

        r2 = r2_score(y_true, y_pred, multioutput='variance_weighted')

        pbar.set_description(f'Rel loss: {mean_rel_loss:.3f}, R2:{r2:.3f}')
    return model
            
model = train(data, batch_size = 400, n_epoch = 80)

Rel loss: 1.524, R2:-0.978:   4%|▍         | 3/80 [00:05<02:27,  1.92s/it]


KeyboardInterrupt: 

1.546

In [10]:
std = 10
np.log10(10**2) -> 2*log10 -> 

2.0