## Getting A Simple Model Running

Goals:
1. Re-implement Chris's code, changing things so that the data loader accepts ATAC, RNA and a function to transform RNA (here, still just selecting pre-defined cell types).

2. Make a stepped up model based on convolutions.

In [6]:
import logging
import math
import sys

import numpy as np
import pandas as pd
import scanpy as sc

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [7]:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

In [8]:
example_count = 22463
peak_count = 116490
label_count = 21
test_perc = 0.05

batch_size = 20
learning_rate = 0.001
momentum = 0.9
epochs = 10
loss_print_freq = 100
eval_freq = 1000

In [35]:
class Load_Dataset(Dataset):
    """
    This takes your ATAC data (as a directory location) and RNA data and a user defined function 
    to transform the RNA in a meaningful way.
    
    It stores your data, performs your transformation and sets them up in an iterable way for pytorch.
    """
    def __init__(self, ATAC_path, RNA_path, r_func):
        # main peak vs cell matrix to be fed to transformer
        self.dataset = sc.read(ATAC_path)
        self.X = self.dataset.X.todense()
        
        # get ground truth
        self.RNA = sc.read(RNA_path)    
        self.Y = r_func(self.RNA)
        self.r_func = r_func

    def __len__(self):
        return len(self.dataset.obs.index)

    def __getitem__(self, idx):
        label = np.squeeze(np.asarray(self.Y[idx]))
        data = np.asarray(self.X[idx])
        data = np.squeeze(data)
        return data, label

In [36]:
class Net(nn.Module):
    """
    basic feed forward model to test things out...
    """
    def __init__(self):
        super().__init__()
        self.embed1 = nn.Linear(peak_count, 20)
        self.fc1 = nn.Linear(20, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, label_count)

    def forward(self, x):
        x = F.relu(self.embed1(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [37]:
def train(model, ATAC, RNA, r_function):
    """
    wrapping chris's learning code into a function
    """
    # Log GPU status
    is_cuda = torch.cuda.is_available()
    logging.info("Cuda available: " + str(is_cuda))
    if is_cuda:
        current_device = torch.cuda.current_device()
        #torch.cuda.device(current_device)
        device_count = torch.cuda.device_count()
        logging.info("Cuda device count: " + device_count)
        device_name = torch.cuda.get_device_name(current_device)
        logging.info("Cuda device name: " + device_name)

    logging.info('Initialising')

    logging.info('Setup')

    # Init parameters
    test_num = math.floor(example_count * test_perc)
    train_num = example_count - test_num

    logging.info('Train examples: ' + str(train_num))
    logging.info('Test examples: ' + str(test_num))

    # Load and split dataset
    logging.info('Loading and splitting dataset')
    dataset = Load_Dataset(ATAC, RNA, r_function)
    
    train_set, test_set = torch.utils.data.random_split(dataset, [train_num, test_num])

    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

    ## DEBUG ##
    train_features, train_labels = next(iter(train_dataloader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

    # Train model
    logging.info('Training model')

    for epoch in range(epochs):
        running_loss = 0.0
        train_losses = []
        test_losses = []
        test_accuracy = []
        for i, data in enumerate(train_dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % loss_print_freq == loss_print_freq - 1:
                single_loss = running_loss / loss_print_freq
                train_losses.append(single_loss)
                logging.info('[epoch-%d, %5d] loss: %.3f' % (epoch + 1, i + 1, single_loss))
                running_loss = 0.0

            # eval
            if i % eval_freq == eval_freq - 1:
                model.eval()
                test_loss = 0.0
                accuracy = 0.0

                with torch.no_grad():
                    for inputs, labels in test_dataloader:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        ps = torch.exp(outputs)
                        top_p, top_class = ps.topk(1, dim=1)
                        equals = top_class == labels.view(*top_class.shape)
                        test_loss += loss.item()
                        accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

                test_count = len(test_dataloader)
                single_test_loss = test_loss / test_count
                single_test_accuracy = accuracy / test_count
                test_losses.append(single_test_loss)
                test_accuracy.append(single_test_accuracy)
                model.train()

                logging.info('EVAL - [epoch-%d, %5d] test_loss: %.3f test_accuracy: %.3f' % (epoch + 1, i + 1, single_test_loss, single_test_accuracy * 100))


    logging.info('Finished Training')

In [38]:
atac_path = '/camp/lab/briscoej/working/Rory/transcriptomics/NeurIPS_2021/nips_2021/multiome/multiome_atac_processed_training.h5ad'
rna_path = '/camp/lab/briscoej/working/Rory/transcriptomics/NeurIPS_2021/nips_2021/multiome/multiome_gex_processed_training.h5ad'

In [39]:
def celltype_function(RNA):
    """
    this code is taken from Chris, just to demonstrate wrapping into a function
    """
    sc_raw_training = sc.read(rna_path) # load data
    # Find cell types and get index labels
    ct_grouped = sc_raw_training.obs.groupby("cell_type").size()
    df_ct_grouped = pd.DataFrame(ct_grouped, columns=["count"])
    df_ct_grouped = df_ct_grouped.reset_index()
    df_ct_grouped['label_id'] = df_ct_grouped.index

    # Merge label ids with obs
    sc_raw_training.obs = sc_raw_training.obs.reset_index().merge(df_ct_grouped, on='cell_type', how='inner').set_index('index')
    return np.array(sc_raw_training.obs.label_id)

so now to learn, we instantiate a model, then train this model on ATAC data, given RNA data and a function to transform the RNA...

In [220]:
model1 = Net()

train(model1, atac_path, rna_path, celltype_function)

# CNN Model

It learns! But can we build a better model?

The end goal is to have an attention-based, Transformer-inspired model - but for now, let's do a simple CNN.

The structure:

    through 1d convolutions, with increased stride and max pools, the channel width is substantially decreased from 100k to ~5000 to ~100. Meanwhile, the single channel is expanded to multiple parallel kernels, ultimately numbering around 100. The result is a 100 x 100 square with width capturing genome space and height capturing different correlative structures; this is fed through 2d convolution layers (a slightly experimental move) to try to capture higher order interactions between the correlative structures. 


In [290]:
example_count = 22463
peak_count = 116490
label_count = 21
test_perc = 0.05

batch_size = 128
learning_rate = 0.001
momentum = 0.9
epochs = 10
loss_print_freq = 1000
eval_freq = 50

In [300]:
class ConvNet(nn.Module):
    """
    basic feed forward model to test things out...
    """
    def __init__(self):
        super().__init__()
        # could just do two 1d convs to a feed forward network, keep simple...
        
        self.conv1a = nn.Conv1d(1, 1, 
                                kernel_size=128, stride=64, padding=0, dilation=1, bias=True)
        self.conv1b = nn.Conv1d(4, 2, 
                                kernel_size=128, stride=4, padding=0, dilation=1, bias=True)
        self.pool1d = nn.MaxPool1d(kernel_size=4)
        self.batch1a = nn.BatchNorm1d(2)
        self.batch1b = nn.BatchNorm1d(4)
        
        # I would probably get rid of these:
        self.conv2a = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=8, stride=1, padding=1, dilation=1)
        self.conv2b = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=8, stride=1, padding=1, dilation=1)
        self.pool2d = nn.MaxPool2d(2,2)
        self.batch2a = nn.BatchNorm2d(1)
        self.batch2b = nn.BatchNorm2d(1)
        
        self.fc1 = nn.Linear(80, 32)
        self.fc2 = nn.Linear(32, label_count)
        self.drop = nn.Dropout(p=0.3)


    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.drop(x)
        x = F.relu(self.conv1a(x))
        x = self.batch1a(x)
        x = self.pool1d(x)
        x = self.drop(x)
        x = F.relu(self.conv1b(x))
        x = self.batch1b(x)
        
        x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
        x = self.drop(x)
        x = F.relu(self.conv2a(x))
        x = self.batch2a(x)
        x = self.pool2d(x)
        x = self.drop(x)
        x = F.relu(self.conv2b(x))
        x = self.batch2b(x)
        
        x = x.reshape(x.shape[0],-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [285]:
m = ConvNet()

In [274]:
xi = torch.randn(batch_size, peak_count)

In [269]:
xi = xi.reshape(xi.shape[0], 1, xi.shape[1])

In [270]:
x2 = m.conv1a(xi)

In [271]:
x3 = m.conv1b(x2)

In [272]:
x3.shape

torch.Size([128, 64, 427])

In [301]:
model2 = ConvNet()

train(model1, atac_path, rna_path, celltype_function)

INFO:root:Cuda available: False
INFO:root:Initialising
INFO:root:Setup
INFO:root:Train examples: 21340
INFO:root:Test examples: 1123
INFO:root:Loading and splitting dataset
Feature batch shape: torch.Size([128, 116490])
Labels batch shape: torch.Size([128])
INFO:root:Training model
INFO:root:EVAL - [epoch-1,    50] test_loss: 0.236 test_accuracy: 95.369
INFO:root:EVAL - [epoch-1,   100] test_loss: 0.238 test_accuracy: 95.283
INFO:root:EVAL - [epoch-1,   150] test_loss: 0.242 test_accuracy: 94.588
INFO:root:EVAL - [epoch-2,    50] test_loss: 0.232 test_accuracy: 94.823
INFO:root:EVAL - [epoch-2,   100] test_loss: 0.229 test_accuracy: 94.950
INFO:root:EVAL - [epoch-2,   150] test_loss: 0.230 test_accuracy: 95.048
INFO:root:EVAL - [epoch-3,    50] test_loss: 0.223 test_accuracy: 95.471
INFO:root:EVAL - [epoch-3,   100] test_loss: 0.230 test_accuracy: 95.109
INFO:root:EVAL - [epoch-3,   150] test_loss: 0.227 test_accuracy: 95.236
INFO:root:EVAL - [epoch-4,    50] test_loss: 0.231 test_accu

KeyboardInterrupt: 