In [None]:
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); device

In [None]:
# device = "cpu"

In [None]:
ontology_and_seq__fp = "../data/intermediary/drosophila_protein_ontology_and_seqs.csv"
df = pd.read_csv(ontology_and_seq__fp)
relevant_subset = df[df.qualifier.isin(["enables", "involved_in"])].dropna()
interesting_go_names = [name for (name, freq) in relevant_subset.go_name.value_counts().to_dict().items() if 1 < freq]  # <- probably need to change the filter step !!
df = df[df.go_name.isin(interesting_go_names)]
one_row_per_gene = pd.DataFrame(index=df.seq.unique(), columns=interesting_go_names).fillna(0)
for _, row in df.iterrows():
    one_row_per_gene.loc[row.seq, row.go_name] = 1
one_row_per_gene = one_row_per_gene.reset_index().rename(columns={"index": "seq"})
one_row_per_gene.head()

In [None]:
df_original = pd.read_csv(ontology_and_seq__fp)
relevant_subset = df_original[df_original.qualifier.isin(["enables", "involved_in"])].dropna()

interesting_go_names = [
    name for (name, freq)
    in relevant_subset.go_name.value_counts().to_dict().items()
    if 7 < freq  # !! probably need to change the filter step !!
]  
relevant_subset = relevant_subset[relevant_subset.go_name.isin(interesting_go_names)]

df = pd.DataFrame(index=relevant_subset.seq.unique(), columns=interesting_go_names).fillna(0)
for _, row in relevant_subset.iterrows():
    df.loc[row.seq, row.go_name] = 1
df["training"] = df.assign(training=0).training.apply(lambda _: random.random() < 0.75)
df = df.reset_index().rename(columns={"index": "seqs"})

vocab = set()
for seq in df.seqs:
    vocab.update(seq)
vocab.add("<pad>")
to_ix = {char: i for i, char in enumerate(vocab)}

In [None]:
N_EXAMPLES, _ = df.shape
N_ONTOLOGICAL_CATEGORIES = len(interesting_go_names)
VOCAB_SIZE = len(vocab)
EMBEDDING_DIM = 6
HIDDEN_DIM = 6

In [None]:
N_EXAMPLES/6

In [None]:
class SeqOntologyDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.X = list(df.seqs)
        self.y = df.loc[:, interesting_go_names]
    def __len__(self):
        return len(self.X)
    def __getitem__(self, i):
        seq = self.X[i]
        seq_tensor = torch.tensor([to_ix[residue] for residue in seq])
        label = self.y.iloc[i,:].values
        return seq_tensor, label

ds_train, ds_test = torch.utils.data.random_split(
    SeqOntologyDataset(),
    lengths=[N_EXAMPLES-(N_EXAMPLES//4),N_EXAMPLES//4]
)
dl_args = {"batch_size": 1, "shuffle": True}
dl = {"train": torch.utils.data.DataLoader(ds_train, **dl_args),
      "test": torch.utils.data.DataLoader(ds_test, **dl_args)}

In [None]:
"""https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/sequence_models_tutorial.ipynb#scrollTo=CLqVNguZ1gOX"""

class OntologyLSTM(nn.Module):
    def __init__(self):
        super(OntologyLSTM, self).__init__()
        self.seq_embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM)
        self.linear = nn.Linear(HIDDEN_DIM, N_ONTOLOGICAL_CATEGORIES)
        
    def reset(self):
        """clear gradients from earlier examples"""
        self.zero_grad()
        self.hidden = torch.zeros(1, 1, HIDDEN_DIM), \
                      torch.zeros(1, 1, HIDDEN_DIM)
        
    def forward(self, X):
        seq_len = len(X) 
        X = self.seq_embedding(X)
        X, self.hidden = self.lstm(X.view(seq_len, 1, -1), self.hidden)
        logits = self.linear(X)[-1,:,:]  # we only care about the final set of activations
        return logits

In [None]:
a, b = next(iter(dl["train"]))
a.shape

In [None]:
b.shape

In [None]:
clf = OntologyLSTM()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(clf.parameters())

for seq, ontology in dl["train"]:
    clf.reset()
    ontology_pred = clf(seq.T)
    loss = criterion(ontology_pred, ontology)
    loss.backwards()
    optimizer.step()