In [None]:
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence

from tqdm import tqdm, trange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); device

In [None]:
# device = "cpu"

In [None]:
ontology_and_seq__fp = "../data/intermediary/drosophila_protein_ontology_and_seqs.csv"
df = pd.read_csv(ontology_and_seq__fp)
relevant_subset = df[df.qualifier.isin(["enables", "involved_in"])].dropna()
interesting_go_names = [name for (name, freq) in relevant_subset.go_name.value_counts().to_dict().items() if 1 < freq]  # <- probably need to change the filter step !!
df = df[df.go_name.isin(interesting_go_names)]
one_row_per_gene = pd.DataFrame(index=df.seq.unique(), columns=interesting_go_names).fillna(0)
for _, row in df.iterrows():
    one_row_per_gene.loc[row.seq, row.go_name] = 1
one_row_per_gene = one_row_per_gene.reset_index().rename(columns={"index": "seq"})
one_row_per_gene.head()

In [None]:
df_original = pd.read_csv(ontology_and_seq__fp)
relevant_subset = df_original[df_original.qualifier.isin(["enables", "involved_in"])].dropna()

interesting_go_names = [
    name for (name, freq)
    in relevant_subset.go_name.value_counts().to_dict().items()
    if 7 < freq  # !! probably need to change the filter step !!
]  
relevant_subset = relevant_subset[relevant_subset.go_name.isin(interesting_go_names)]

df = pd.DataFrame(index=relevant_subset.seq.unique(), columns=interesting_go_names).fillna(0)
for _, row in relevant_subset.iterrows():
    df.loc[row.seq, row.go_name] = 1
df["training"] = df.assign(training=0).training.apply(lambda _: random.random() < 0.75)
df = df.reset_index().rename(columns={"index": "seqs"})

vocab = set()
for seq in df.seqs:
    vocab.update(seq)
vocab.add("<pad>")
to_ix = {char: i for i, char in enumerate(vocab)}

In [None]:
N_EXAMPLES, _ = df.shape
N_ONTOLOGICAL_CATEGORIES = len(interesting_go_names)
VOCAB_SIZE = len(vocab)
EMBEDDING_DIM = 8
HIDDEN_DIM = 16
BATCH_SIZE = 1

In [None]:
class SeqOntologyDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.X = list(df.seqs)
        self.y = df.loc[:, interesting_go_names]
    def __len__(self):
        return len(self.X)
    def __getitem__(self, i):
        seq = self.X[i]
        seq_tensor = torch.tensor([to_ix[residue] for residue in seq])
        
        label = self.y.iloc[i,:].values.T
        return seq_tensor, torch.tensor(label[np.newaxis,:], dtype=torch.double)

ds_train, ds_test = torch.utils.data.random_split(
    SeqOntologyDataset(),
    lengths=[N_EXAMPLES-(N_EXAMPLES//4), N_EXAMPLES//4]
)
dl_args = {"batch_size": BATCH_SIZE, "shuffle": True}
dl = {"train": torch.utils.data.DataLoader(ds_train, **dl_args),
      "test": torch.utils.data.DataLoader(ds_test, **dl_args)}

In [None]:
VOCAB_SIZE

In [None]:
seq, _ = next(iter(dl["train"])); seq.shape
embed = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM, padding_idx=to_ix["<pad>"])

seq.shape, embed(seq).shape

In [None]:
embed(seq)

In [None]:
torch.randn(2, 3, 4)

In [None]:
rnn = nn.LSTM(10, 20, 2)
                    # ^ num_layers â€“ Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together
                # ^ hidden_size:  The number of features in the hidden state h
            # ^ input_size: The number of expected features in the input x
input = torch.randn(5, 3, 10)  # 5 batches x 3 examples x 10 features, 
h0 = torch.randn(2, 3, 20) # n_layers x n_examples x hidden_size
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
cn.shape

In [None]:
"""https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/sequence_models_tutorial.ipynb#scrollTo=CLqVNguZ1gOX"""

class OntologyLSTM(nn.Module):
    def __init__(self):
        super(OntologyLSTM, self).__init__()
        self.seq_embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM).to(device)
        self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM).to(device)
        self.fc = nn.Linear(HIDDEN_DIM, N_ONTOLOGICAL_CATEGORIES).to(device)
        self.sigmoid = nn.Sigmoid().to(device)
        
    def reset(self, seq_len):
        """clear gradients from earlier examples"""
        self.zero_grad()
        self.h0 = torch.zeros(1, seq_len, HIDDEN_DIM).to(device)
        self.c0 = torch.zeros(1, seq_len, HIDDEN_DIM).to(device)
        
    def forward(self, seq):
        seq_embedded = self.seq_embedding(seq).view(len(seq), -1, EMBEDDING_DIM)
        _, (self.h0, self.c0) = self.lstm(seq_embedded, (self.h0, self.c0))
        logits = self.fc(self.c0[:,-1,np.newaxis,:])
        likelihoods = self.sigmoid(logits)
        return logits, likelihoods

In [None]:
"abc"[-2:-1]

In [None]:
clf = OntologyLSTM()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(clf.parameters())
N_EPOCHS = 100
for epoch in trange(N_EPOCHS, unit="epoch"):
    for phase in ["train", "test"]:
        clf.train() if phase == "train" else clf.eval()
        running_loss = 0
        for seq, ontology in dl[phase]:
            seq, ontology = seq.to(device), ontology.to(device)
            _, seq_len = seq.shape
            clf.reset(seq_len)
            with torch.set_grad_enabled(phase == "train"):
                ontology_logits, ontology_likelihoods = clf(seq)
                loss = criterion(ontology_logits, ontology)
                if phase == "train":
                    loss.backward()
                    optimizer.step()
        running_loss += loss.item()  # * batch size
    print(f"{phase} loss: {running_loss/len(dl[phase]):.2f}")

In [None]:
__file__