In [None]:
%run Prep_data_All_Patients.ipynb

In [None]:
supDir = '/Users/elikond/Downloads/surprisal_analysis/'
clusterDir = '/Users/elikond/Downloads/clusters/'

In [None]:
import torch
import torchmetrics
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import tensorflow as tf

from sklearn import metrics
from sklearn import decomposition
from sklearn import manifold
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

import scipy

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import matplotlib.pyplot as plt

from operator import itemgetter

import copy
import random
import time

In [None]:
def create_iterator(data_list, BATCH_SIZE = 64):
    iterator = data.DataLoader(data_list,
                            shuffle=True,
                            batch_size = BATCH_SIZE)
    return iterator

In [None]:
def make_iterators(merged_df, barcode_len):
    data_list = list()
    X_data = merged_df.iloc[:,4:-barcode_len-4]
    X_arr = np.array(X_data)
    vstack_scrna = np.vstack(X_arr).astype(float)
    torch_tensor = torch.from_numpy(vstack_scrna)
    for i, x in enumerate(torch_tensor):
        data_list.append((x, merged_df['seurat_clusters'][i]))
    iterator = create_iterator(data_list)
    return iterator

In [None]:
def transform_data(merged_df, barcode_len):
    train_inter, test_df = train_test_split(merged_df, stratify = merged_df['seurat_clusters'], test_size = 0.15)
    train_df, valid_df = train_test_split(train_inter, stratify = train_inter['seurat_clusters'], test_size = 0.1)
    
    for df in [train_df, valid_df, test_df]:
        df.reset_index(inplace = True, drop = True)
        
    train_iterator = make_iterators(train_df, barcode_len)
    test_iterator = make_iterators(test_df, barcode_len)
    valid_iterator = make_iterators(valid_df, barcode_len)
    max_cluster = merged_df.seurat_clusters.max()
    num_genes = len(merged_df.columns)
    
    return train_iterator, test_iterator, valid_iterator, max_cluster, num_genes

In [None]:
merged_df, X_data, y_data, barcode_len, sigSubpopsDF = final(supDir, clusterDir, mesenDir, 'gb9')
transform_data(merged_df, barcode_len)

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        self.input_fc = nn.Linear(input_dim, 1000)
        self.hidden_1 = nn.Linear(1000, 300)
        self.drop = nn.Dropout(p = 0.2)
       # self.batch_norm = nn.BatchNorm1d(300, affine=False)
        self.hidden_2 = nn.Linear(300, 50)
        self.output_fc = nn.Linear(50, output_dim)

    def forward(self, x):

        # x = [batch size, height, width]

        batch_size = x.shape[0]

        x = x.view(batch_size, -1)

        # x = [batch size, height * width]

        h = torch.tanh(self.input_fc(x))

        h_1 = torch.tanh(self.hidden_1(h))

        h_2 = self.drop(h_1)

        #h_3 = self.batch_norm(h_2)

        h_3 = torch.tanh(self.hidden_2(h_2))

        y_pred = self.output_fc(h_3)

        #y_pred = [batch size, output dim]

        return y_pred, h_3

INPUT_DIM = num_genes
OUTPUT_DIM = max_cluster + 1

model = MLP(INPUT_DIM, OUTPUT_DIM)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
optimizer = optim.Adam(model.parameters())

In [None]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [None]:
def train(model, iterator, optimizer, criterion):

    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for (x, y) in tqdm(iterator, desc="Training", leave=False):

        optimizer.zero_grad()
        y_pred, _ = model(x.float())
        loss = criterion(y_pred, y)
        acc = calculate_accuracy(y_pred, y)

        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
def evaluate(model, iterator, criterion):

    epoch_loss = 0
    epoch_acc = 0

    model.eval()

    with torch.no_grad():

        for (x, y) in tqdm(iterator, desc="Evaluating", leave=False):

            y_pred, h = model(x.float())

            y_prob = F.softmax(y_pred, dim=-1)

            loss = criterion(y_pred, y)

            acc = calculate_accuracy(y_pred, y)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
def run(train_iterator, valid_iterator, test_iterator):
    EPOCHS = 10

    best_valid_loss = float('inf')

    history = {'Train': {'Accuracy': [], 'Loss': []}, 'Test': {'Accuracy': [], 'Loss': []}, 'Validation': {'Accuracy': [], 'Loss': []}}

    for epoch in trange(EPOCHS):

        start_time = time.monotonic()

        train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
        valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
        history['Train']['Loss'].append(train_loss)
        history['Train']['Accuracy'].append(train_acc)
        history['Validation']['Loss'].append(valid_loss)
        history['Validation']['Accuracy'].append(valid_acc)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss

        end_time = time.monotonic()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

        test_loss, test_acc = evaluate(model, test_iterator, criterion)

    return test_acc

final_acc = run(train_iterator, valid_iterator, test_iterator)
print(final_acc)