In [1]:
# libs
import logging
import argparse
import math
import sys

import numpy as np
import pandas as pd
import scanpy as sc

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
torch.__version__

'1.8.1'

In [3]:
sys.path.insert(0, '../../bin')
from helpers import Load_Dataset
from RNA_functions import celltype_function
from trainer import train

In [4]:
# Log GPU status
is_cuda = torch.cuda.is_available()
logging.info("Cuda available: " + str(is_cuda))
if is_cuda:
    current_device = torch.cuda.current_device()
    #torch.cuda.device(current_device)
    device_count = torch.cuda.device_count()
    logging.info("Cuda device count: " + str(device_count))
    device_name = torch.cuda.get_device_name(current_device)
    logging.info("Cuda device name: " + str(device_name))

2021-10-23 23:43:09 INFO     Cuda available: True
2021-10-23 23:43:09 INFO     Cuda device count: 4
2021-10-23 23:43:09 INFO     Cuda device name: Tesla V100-SXM2-32GB


In [5]:
def celltype_function(RNA_adata):
    """
    this code is taken from Chris, just to demonstrate wrapping into a function
    """
    # Find cell types and get index labels
    ct_grouped = RNA_adata.obs.groupby("cell_type").size()
    df_ct_grouped = pd.DataFrame(ct_grouped, columns=["count"])
    df_ct_grouped = df_ct_grouped.reset_index()
    df_ct_grouped['label_id'] = df_ct_grouped.index

    # Merge label ids with obs
    RNA_adata.obs = RNA_adata.obs.reset_index().merge(df_ct_grouped, on='cell_type', how='inner').set_index('index')
    return np.array(RNA_adata.obs.label_id)

In [6]:
atac_path = '/camp/lab/briscoej/working/Rory/transcriptomics/NeurIPS_2021/nips_2021/multiome/multiome_atac_processed_training.h5ad'
rna_path = '/camp/lab/briscoej/working/Rory/transcriptomics/NeurIPS_2021/nips_2021/multiome/multiome_gex_processed_training.h5ad'

In [8]:
# class Net(nn.Module):
#     def __init__(self, dataset):
#         super().__init__()
#         self.embed1 = nn.Linear(dataset.n_peaks, 100)
#         self.fc1 = nn.Linear(100, 200)
#         self.fc2 = nn.Linear(200, 100)
#         self.fc3 = nn.Linear(100, dataset.n_labels)

#     def forward(self, x):
#         x = F.relu(self.embed1(x))
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x

In [21]:


class ConvNet(nn.Module):
    """
    basic feed forward model to test things out...
    """
    def __init__(self,dataset):
        super().__init__()
        # could just do two 1d convs to a feed forward network, keep simple...
        
        self.embed = nn.Linear(dataset.n_peaks,5000)
        self.conv1a = nn.Conv1d(in_channels=1, out_channels=1, 
                                kernel_size=128, stride=2, padding=0, dilation=1, bias=True)
        self.conv1b = nn.Conv1d(in_channels=1, out_channels=1, 
                                kernel_size=128, stride=2, padding=0, dilation=1, bias=True)
        self.pool1d = nn.MaxPool1d(kernel_size=4)
        self.batch_embed = nn.BatchNorm1d(1)
        self.batch1a = nn.BatchNorm1d(1)
        self.batch1b = nn.BatchNorm1d(1)
        
#         # I would probably get rid of these:
#         self.conv2a = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=8, stride=1, padding=1, dilation=1)
#         self.conv2b = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=8, stride=1, padding=1, dilation=1)
#         self.pool2d = nn.MaxPool2d(2,2)
#         self.batch2a = nn.BatchNorm2d(1)
#         self.batch2b = nn.BatchNorm2d(1)
        
        self.fc1 = nn.Linear(241, 64)
        self.fc2 = nn.Linear(64, dataset.n_labels)
        self.drop = nn.Dropout(p=0.3)
        
        for layer in [self.conv1a,self.conv1b,self.fc1,self.fc2]:
            nn.init.xavier_normal_(layer.weight)
            nn.init.constant_(layer.bias, 0.1)

    def forward(self, x):
        x = F.relu(self.embed(x))
#         x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.batch_embed(x)
        x = self.drop(x)
        x = F.relu(self.conv1a(x))
        x = self.batch1a(x)
        x = self.pool1d(x)
        x = self.drop(x)
        x = F.relu(self.conv1b(x))
        x = self.batch1b(x)
        
#         x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
#         x = self.drop(x)
#         x = F.relu(self.conv2a(x))
#         x = self.batch2a(x)
#         x = self.pool2d(x)
#         x = self.drop(x)
#         x = F.relu(self.conv2b(x))
#         x = self.batch2b(x)
        
        x = x.reshape(x.shape[0],-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
def train(dataset, model, batch_size, optimizer, learning_rate, criterion, epochs, 
          test_pct=0.05, loss_print_freq=50, eval_freq=100):
    logging.info('Initialising')
    logging.info('Setup')

    test_num = math.floor(dataset.n_cells * test_pct)
    train_num = dataset.n_cells - test_num

    logging.info('Train examples: ' + str(train_num))
    logging.info('Test examples: ' + str(test_num))

    # Load and split dataset
    logging.info('Loading and splitting dataset')
    train_set, test_set = torch.utils.data.random_split(dataset, [train_num, test_num])

    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=True)

    # ## DEBUG ##
    # train_features, train_labels = next(iter(train_dataloader))
    # print(f"Feature batch shape: {train_features.size()}")
    # print(f"Labels batch shape: {train_labels.size()}")

    # Train model
    logging.info('Training model')

    for epoch in range(epochs):
        running_loss = 0.0
        train_losses = []
        test_losses = []
        test_accuracy = []
        for i, data in enumerate(train_dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % loss_print_freq == loss_print_freq - 1:
                single_loss = running_loss / loss_print_freq
                train_losses.append(single_loss)
                logging.info('[epoch-%d, %5d] loss: %.3f' % (epoch + 1, i + 1, single_loss))
                running_loss = 0.0

            # eval
            if i % eval_freq == eval_freq - 1:
                model.eval()
                test_loss = 0.0
                accuracy = 0.0

                with torch.no_grad():
                    for inputs, labels in test_dataloader:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        ps = torch.exp(outputs)
                        top_p, top_class = ps.topk(1, dim=1)
                        equals = top_class == labels.view(*top_class.shape)
                        test_loss += loss.item()
                        accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

                test_count = len(test_dataloader)

                print(accuracy)
                print(test_count)

                single_test_loss = test_loss / test_count
                single_test_accuracy = accuracy / test_count
                test_losses.append(single_test_loss)
                test_accuracy.append(single_test_accuracy)
                model.train()

                logging.info('EVAL - [epoch-%d, %5d] test_loss: %.3f test_accuracy: %.3f' % (epoch + 1, i + 1, single_test_loss, single_test_accuracy * 100))

    logging.info('Finished Training')

In [26]:
class Load_Dataset(Dataset):
    """
    This takes your ATAC data (as a directory location) and RNA data and a user defined function 
    to transform the RNA in a meaningful way.
    
    It stores your data, performs your transformation and sets them up in an iterable way for pytorch.
    """
    def __init__(self, ATAC_path, RNA_path, r_func, use_cuda=True, float_size=32):
        if float_size == 16:
            self.dtype = torch.float16
        elif float_size == 32:
            self.dtype = torch.float32
        elif float_size == 64:
            self.dtype = torch.float64
        if torch.cuda.is_available() and use_cuda:
            self.device = torch.device('cuda:0')
        else:
            self.device = torch.device('cpu')
        
        
        self.dataset = sc.read(ATAC_path)
        self.X = self.dataset.X.todense()
        self.n_cells = self.X.shape[0]
        self.n_peaks = self.X.shape[1]
        
        # get ground truth
        self.RNA = sc.read(RNA_path)    
        self.Y = r_func(self.RNA)
        self.n_labels = int(len(np.unique(self.Y)))
        self.r_func = r_func

    def __len__(self):
        return len(self.dataset.obs.index)

    def __getitem__(self, idx):
        label = torch.tensor(self.Y[idx], device=self.device, dtype=torch.long)
        data = torch.tensor(self.X[idx], device=self.device, dtype=self.dtype)
        return data, label

In [107]:
class CNN(nn.Module):
    """
    basic feed forward model to test things out...
    """
    def __init__(self,dataset):
        super().__init__()
        # could just do two 1d convs to a feed forward network, keep simple...
        
        self.conv1a = nn.Conv1d(in_channels=1, out_channels=1, 
                                kernel_size=64, stride=32, padding=0, dilation=1, bias=True)
        self.conv1b = nn.Conv1d(in_channels=1, out_channels=1, 
                                kernel_size=64, stride=32, padding=0, dilation=1, bias=True)
        self.pool1d = nn.MaxPool1d(kernel_size=8)
        self.batch_embed = nn.BatchNorm1d(1)
        self.batch1a = nn.BatchNorm1d(1)
        self.batch1b = nn.BatchNorm1d(1)
        
        # I would probably get rid of these:
#         self.conv2a = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=8, stride=1, padding=1, dilation=1)
#         self.conv2b = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=8, stride=1, padding=1, dilation=1)
#         self.pool2d = nn.MaxPool2d(2,2)
#         self.batch2a = nn.BatchNorm2d(2)
#         self.batch2b = nn.BatchNorm2d(4)
        
        self.fc1 = nn.Linear(352, 32)
        self.fc2 = nn.Linear(32, dataset.n_labels)
        self.drop = nn.Dropout(p=0.2)
        
        for layer in [self.conv1a,self.conv1b,self.fc1,self.fc2]:
            nn.init.xavier_normal_(layer.weight)
            nn.init.constant_(layer.bias, 0.1)

    def forward(self, x):
        x = self.drop(x)
        x = F.relu(self.conv1a(x))
        x = self.batch1a(x)
        x = self.pool1d(x)
        x = self.drop(x)
        x = F.relu(self.conv1b(x))
        x = self.batch1b(x)
        x = x.reshape(x.shape[0],-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
torch.cuda.empty_cache()
dataset = Load_Dataset(atac_path, rna_path, celltype_function)
model = CNN(dataset)
model.to('cuda')
########### VARIABLES ##########

test_pct = 0.2
batch_size = 128
learning_rate = 0.001
# momentum = 0.9
epochs = 50
loss_print_freq = 10
eval_freq = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

########### LET'S GO ##########

train(dataset, model, batch_size, optimizer, learning_rate, criterion, epochs, 
      test_pct, loss_print_freq, eval_freq)