### Data Sourcing and Processing

In [None]:
import muon as mu
import scanpy as sc
import numpy as np
import pandas as pd
import torch
import random
#import itertools

In [None]:
PATHWAY = "/media/data/single_cell/brent10070/side_project/SHARE_seq/datasets"
Train_pathway = PATHWAY + "/Train.h5mu"

Train_Data_Load = mu.read(Train_pathway)

In [None]:
Train_Data_Load

In [None]:
Train_neighbor_pathway = PATHWAY + "/Train_neighbor.csv"
Train_Neighbor_df = pd.read_csv(Train_neighbor_pathway, index_col=0)

Validation_neighbor_pathway = PATHWAY + "/Validation_neighbor.csv"
Validation_Neighbor_df = pd.read_csv(Validation_neighbor_pathway, index_col=0)

In [None]:
def batch_index(sample_index, BATCH_SIZE):
    def grouper(iterable, n, *, fillvalue=None):
        args = [iter(iterable)] * n
        return zip(*args)
    batch_index = list(grouper(random.sample(list(sample_index), len(sample_index)), BATCH_SIZE))
    batch_index = [list(x) for x in batch_index]
    return batch_index

In [None]:
def sample_idx_transform(knn_df, Neighbor_df):
    def ind(x):
        return np.array(list(Neighbor_df.index))[x].tolist()

    knn_df[knn_df.columns] = knn_df[knn_df.columns].apply(ind)
    knn_df = knn_df.rename_axis('Neigh_0').reset_index()
    return knn_df

def get_expn_tensor(knn_df, Data, modal, BATCH_SIZE):
    T = torch.tensor(sc.get.var_df(Data.mod[modal], keys=knn_df.values.flatten().tolist()).values)
    T = T.reshape(T.size()[0], BATCH_SIZE, 32)
    return T

In [None]:
BATCH_SIZE = 32
Train_batch_index = batch_index(Train_Neighbor_df.index, BATCH_SIZE)
KNNdf = Train_Neighbor_df.loc[Train_batch_index[2]]
KNNdf = sample_idx_transform(KNNdf, Train_Neighbor_df)

In [None]:
tgt = get_expn_tensor(KNNdf, Train_Data_Load, 'ATAC', BATCH_SIZE)

In [None]:
sys.getsizeof(tgt.double().storage())

### Seq2Seq Network using Transformer

In [None]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)

    def forward(self,
                src: Tensor,
                tgt: Tensor): #Remove the masks
        outs = self.transformer(src, tgt) #Remove the masks
        return outs

    def encode(self, src: Tensor):
        return self.transformer.encoder(src) #Remove the masks

    def decode(self, tgt: Tensor, memory: Tensor):
        return self.transformer.decoder(tgt, memory) #Remove the masks

In [None]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = Train_Data_Load.mod['RNA'].n_vars
TGT_VOCAB_SIZE = Train_Data_Load.mod['ATAC'].n_vars
EMB_SIZE = 32 # 1 + KNN31 = 32
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 1
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.BCELoss()

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [None]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

get_n_params(transformer)

In [None]:
def train_epoch(model, optimizer, Train_batch_index):
    model.train()
    losses = 0

    for n in range(len(Train_batch_index)):
        KNNdf = Train_Neighbor_df.loc[Train_batch_index[n]]
        KNNdf = sample_idx_transform(KNNdf, Train_Neighbor_df)
        src = get_expn_tensor(KNNdf, Train_Data_Load, 'RNA', BATCH_SIZE)
        tgt = get_expn_tensor(KNNdf, Train_Data_Load, 'ATAC', BATCH_SIZE)

        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        output = model(src, tgt)

        optimizer.zero_grad()

        loss = loss_fn(output, tgt)
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(Train_batch_index) * BATCH_SIZE


def evaluate(model, Validation_batch_index):
    model.eval()
    losses = 0

    for n in range(len(Validation_batch_index)):
        KNNdf = Validation_Neighbor_df.loc[Validation_batch_index[n]]
        KNNdf = sample_idx_transform(KNNdf, Validation_Neighbor_df)
        src = get_expn_tensor(KNNdf, Validation_Data_Load, 'RNA')
        tgt = get_expn_tensor(KNNdf, Validation_Data_Load, 'ATAC')

        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        output = model(src, tgt)

        optimizer.zero_grad()

        loss = loss_fn(output, tgt)
        losses += loss.item()

    return losses / len(Validation_batch_index) * BATCH_SIZE

In [None]:
from timeit import default_timer as timer
NUM_EPOCHS = 1

for epoch in range(1, NUM_EPOCHS+1):
    
    start_time = timer()
    Train_batch_index = batch_index(Train_Neighbor_df.index, BATCH_SIZE)
    train_loss = train_epoch(transformer, optimizer, Train_batch_index)
    end_time = timer()

    #Validation_batch_index = batch_index(Validation_Neighbor_df.index, BATCH_SIZE)
    #val_loss = evaluate(transformer, Validation_batch_index)
    #print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

In [None]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")