In [1]:
# libs
import logging
import argparse
import math
import sys

import numpy as np
import pandas as pd
import scanpy as sc

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

  data = yaml.load(f.read()) or {}


In [2]:
sys.path.insert(0, '../../bin')
from helpers import Load_Dataset
from RNA_functions import celltype_function
from trainer import train

In [3]:
# Log GPU status
is_cuda = torch.cuda.is_available()
logging.info("Cuda available: " + str(is_cuda))
if is_cuda:
    current_device = torch.cuda.current_device()
    #torch.cuda.device(current_device)
    device_count = torch.cuda.device_count()
    logging.info("Cuda device count: " + str(device_count))
    device_name = torch.cuda.get_device_name(current_device)
    logging.info("Cuda device name: " + str(device_name))

2021-10-21 13:49:20 INFO     Cuda available: False


In [4]:
def celltype_function(RNA_adata):
    """
    this code is taken from Chris, just to demonstrate wrapping into a function
    """
    # Find cell types and get index labels
    ct_grouped = RNA_adata.obs.groupby("cell_type").size()
    df_ct_grouped = pd.DataFrame(ct_grouped, columns=["count"])
    df_ct_grouped = df_ct_grouped.reset_index()
    df_ct_grouped['label_id'] = df_ct_grouped.index

    # Merge label ids with obs
    RNA_adata.obs = RNA_adata.obs.reset_index().merge(df_ct_grouped, on='cell_type', how='inner').set_index('index')
    return np.array(RNA_adata.obs.label_id)

In [5]:
atac_path = '/camp/lab/briscoej/working/Rory/transcriptomics/NeurIPS_2021/nips_2021/multiome/multiome_atac_processed_training.h5ad'
rna_path = '/camp/lab/briscoej/working/Rory/transcriptomics/NeurIPS_2021/nips_2021/multiome/multiome_gex_processed_training.h5ad'

In [6]:
dataset = Load_Dataset(atac_path, rna_path, celltype_function)

In [7]:
dataset.n_peaks

116490

In [8]:
class Net(nn.Module):
    def __init__(self, dataset):
        super().__init__()
        self.embed1 = nn.Linear(dataset.n_peaks, 100)
        self.fc1 = nn.Linear(100, 200)
        self.fc2 = nn.Linear(200, 100)
        self.fc3 = nn.Linear(100, dataset.n_labels)

    def forward(self, x):
        x = F.relu(self.embed1(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [49]:
class ConvNet(nn.Module):
    """
    basic feed forward model to test things out...
    """
    def __init__(self,dataset):
        super().__init__()
        # could just do two 1d convs to a feed forward network, keep simple...
        
        self.embed = nn.Linear(dataset.n_peaks,5000)
        self.conv1a = nn.Conv1d(in_channels=1, out_channels=1, 
                                kernel_size=128, stride=2, padding=0, dilation=1, bias=True)
        self.conv1b = nn.Conv1d(in_channels=1, out_channels=1, 
                                kernel_size=128, stride=2, padding=0, dilation=1, bias=True)
        self.pool1d = nn.MaxPool1d(kernel_size=4)
        self.batch_embed = nn.BatchNorm1d(1)
        self.batch1a = nn.BatchNorm1d(1)
        self.batch1b = nn.BatchNorm1d(1)
        
#         # I would probably get rid of these:
#         self.conv2a = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=8, stride=1, padding=1, dilation=1)
#         self.conv2b = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=8, stride=1, padding=1, dilation=1)
#         self.pool2d = nn.MaxPool2d(2,2)
#         self.batch2a = nn.BatchNorm2d(1)
#         self.batch2b = nn.BatchNorm2d(1)
        
        self.fc1 = nn.Linear(241, 64)
        self.fc2 = nn.Linear(64, dataset.n_labels)
        self.drop = nn.Dropout(p=0.3)
        
        for layer in [self.conv1a,self.conv1b,self.fc1,self.fc2]:
            nn.init.xavier_normal_(layer.weight)
            nn.init.constant_(layer.bias, 0.1)

    def forward(self, x):
        x = F.relu(self.embed(x))
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.batch_embed(x)
        x = self.drop(x)
        x = F.relu(self.conv1a(x))
        x = self.batch1a(x)
        x = self.pool1d(x)
        x = self.drop(x)
        x = F.relu(self.conv1b(x))
        x = self.batch1b(x)
        
#         x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
#         x = self.drop(x)
#         x = F.relu(self.conv2a(x))
#         x = self.batch2a(x)
#         x = self.pool2d(x)
#         x = self.drop(x)
#         x = F.relu(self.conv2b(x))
#         x = self.batch2b(x)
        
        x = x.reshape(x.shape[0],-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [50]:
model = ConvNet(dataset)

########### VARIABLES ##########

test_pct = 0.05
batch_size = 20
learning_rate = 0.001
momentum = 0.9
epochs = 10
loss_print_freq = 10
eval_freq = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

########### LET'S GO ##########

train(dataset, model, batch_size, optimizer, learning_rate, criterion, epochs, 
      test_pct, loss_print_freq, eval_freq)

2021-10-21 15:13:46 INFO     Initialising
2021-10-21 15:13:46 INFO     Setup
2021-10-21 15:13:46 INFO     Train examples: 21340
2021-10-21 15:13:46 INFO     Test examples: 1123
2021-10-21 15:13:46 INFO     Loading and splitting dataset
2021-10-21 15:13:46 INFO     Training model
2021-10-21 15:14:03 INFO     [epoch-1,    10] loss: 3.089
2021-10-21 15:14:38 INFO     EVAL - [epoch-1,    10] test_loss: 2.876 test_accuracy: 12.467


140.0
1123


2021-10-21 15:14:53 INFO     [epoch-1,    20] loss: 2.737
2021-10-21 15:15:27 INFO     EVAL - [epoch-1,    20] test_loss: 2.692 test_accuracy: 18.789


211.0
1123


2021-10-21 15:15:47 INFO     [epoch-1,    30] loss: 2.749
2021-10-21 15:16:21 INFO     EVAL - [epoch-1,    30] test_loss: 2.612 test_accuracy: 19.947


224.0
1123


2021-10-21 15:16:37 INFO     [epoch-1,    40] loss: 2.579
2021-10-21 15:17:11 INFO     EVAL - [epoch-1,    40] test_loss: 2.566 test_accuracy: 20.570


231.0
1123


2021-10-21 15:17:25 INFO     [epoch-1,    50] loss: 2.619
2021-10-21 15:17:59 INFO     EVAL - [epoch-1,    50] test_loss: 2.574 test_accuracy: 20.303


228.0
1123


2021-10-21 15:18:12 INFO     [epoch-1,    60] loss: 2.628
2021-10-21 15:18:47 INFO     EVAL - [epoch-1,    60] test_loss: 2.548 test_accuracy: 21.638


243.0
1123


2021-10-21 15:19:00 INFO     [epoch-1,    70] loss: 2.645
2021-10-21 15:19:34 INFO     EVAL - [epoch-1,    70] test_loss: 2.535 test_accuracy: 20.570


231.0
1123


2021-10-21 15:19:50 INFO     [epoch-1,    80] loss: 2.604
2021-10-21 15:20:24 INFO     EVAL - [epoch-1,    80] test_loss: 2.500 test_accuracy: 26.180


294.0
1123


2021-10-21 15:20:37 INFO     [epoch-1,    90] loss: 2.585
2021-10-21 15:21:11 INFO     EVAL - [epoch-1,    90] test_loss: 2.503 test_accuracy: 22.262


250.0
1123


2021-10-21 15:21:25 INFO     [epoch-1,   100] loss: 2.585
2021-10-21 15:21:59 INFO     EVAL - [epoch-1,   100] test_loss: 2.481 test_accuracy: 26.803


301.0
1123


2021-10-21 15:22:16 INFO     [epoch-1,   110] loss: 2.561
2021-10-21 15:22:50 INFO     EVAL - [epoch-1,   110] test_loss: 2.564 test_accuracy: 19.412


218.0
1123


2021-10-21 15:23:04 INFO     [epoch-1,   120] loss: 2.613
2021-10-21 15:23:38 INFO     EVAL - [epoch-1,   120] test_loss: 2.500 test_accuracy: 20.748


233.0
1123


2021-10-21 15:23:52 INFO     [epoch-1,   130] loss: 2.515
2021-10-21 15:24:26 INFO     EVAL - [epoch-1,   130] test_loss: 2.487 test_accuracy: 25.735


289.0
1123


2021-10-21 15:24:40 INFO     [epoch-1,   140] loss: 2.533
2021-10-21 15:25:15 INFO     EVAL - [epoch-1,   140] test_loss: 2.487 test_accuracy: 21.193


238.0
1123


2021-10-21 15:25:28 INFO     [epoch-1,   150] loss: 2.446
2021-10-21 15:26:02 INFO     EVAL - [epoch-1,   150] test_loss: 2.459 test_accuracy: 27.159


305.0
1123


2021-10-21 15:26:16 INFO     [epoch-1,   160] loss: 2.499
2021-10-21 15:26:50 INFO     EVAL - [epoch-1,   160] test_loss: 2.525 test_accuracy: 25.111


282.0
1123


2021-10-21 15:27:03 INFO     [epoch-1,   170] loss: 2.466
2021-10-21 15:27:37 INFO     EVAL - [epoch-1,   170] test_loss: 2.430 test_accuracy: 24.132


271.0
1123


2021-10-21 15:27:50 INFO     [epoch-1,   180] loss: 2.454
2021-10-21 15:28:24 INFO     EVAL - [epoch-1,   180] test_loss: 2.409 test_accuracy: 24.755


278.0
1123


2021-10-21 15:28:38 INFO     [epoch-1,   190] loss: 2.520
2021-10-21 15:29:12 INFO     EVAL - [epoch-1,   190] test_loss: 2.434 test_accuracy: 23.954


269.0
1123


2021-10-21 15:29:30 INFO     [epoch-1,   200] loss: 2.443
2021-10-21 15:30:04 INFO     EVAL - [epoch-1,   200] test_loss: 2.431 test_accuracy: 26.091


293.0
1123


2021-10-21 15:30:18 INFO     [epoch-1,   210] loss: 2.450
2021-10-21 15:30:52 INFO     EVAL - [epoch-1,   210] test_loss: 2.389 test_accuracy: 26.536


298.0
1123


2021-10-21 15:31:06 INFO     [epoch-1,   220] loss: 2.488
2021-10-21 15:31:40 INFO     EVAL - [epoch-1,   220] test_loss: 2.479 test_accuracy: 24.666


277.0
1123


2021-10-21 15:31:53 INFO     [epoch-1,   230] loss: 2.390
2021-10-21 15:32:27 INFO     EVAL - [epoch-1,   230] test_loss: 2.414 test_accuracy: 26.358


296.0
1123


2021-10-21 15:32:40 INFO     [epoch-1,   240] loss: 2.476
2021-10-21 15:33:15 INFO     EVAL - [epoch-1,   240] test_loss: 2.515 test_accuracy: 24.488


275.0
1123


2021-10-21 15:33:30 INFO     [epoch-1,   250] loss: 2.426
2021-10-21 15:34:04 INFO     EVAL - [epoch-1,   250] test_loss: 2.377 test_accuracy: 26.447


297.0
1123


2021-10-21 15:34:17 INFO     [epoch-1,   260] loss: 2.372


KeyboardInterrupt: 