In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
torch.manual_seed(42)

import numpy as np
import pandas as pd

from utilities import load_coordinates, load_dtm, load_mini_batches
from utilities import report_macro, report_class, compute_roc, compute_prc, plot_curves
%matplotlib inline

# Specify hyperparameters

In [2]:
n_epochs = 2
batch_size = 64
lr = 0.01
weight_decay = 1e-12

In [3]:
prefix = "lstm" # Prefix for plot file names

# Load the data

## Brain activation coordinates

In [4]:
act_bin = load_coordinates()
n_structs = act_bin.shape[1]
print("{:12s}{}".format("Documents", act_bin.shape[0]))
print("{:12s}{}".format("Structures", n_structs))

Documents   18155
Structures  114


## Term embeddings

In [5]:
vsm = pd.read_csv("data/text/glove_gen_n100_win15_min5_iter500_190428.txt", 
                  sep = " ", index_col=0, header=0)
n_emb = vsm.shape[1]
print("{:21s}{}".format("Embedding Dimension", n_emb))
print("{:21s}{}".format("Terms", vsm.shape[0]))

Embedding Dimension  100
Terms                350543


## Document-term matrix

In [6]:
dtm_bin = load_dtm()
dtm_bin = dtm_bin[dtm_bin.columns.intersection(vsm.index)]
n_terms = dtm_bin.shape[1]
print("{:12s}{}".format("Documents", dtm_bin.shape[0]))
print("{:12s}{}".format("Terms", n_terms))

Documents   18155
Terms       1542


## Text features 

#### 1. Concatenate embeddings for terms in the lexicon

In [7]:
emb = vsm.loc[dtm_bin.columns].values.reshape(1, n_terms*n_emb)
emb.shape

(1, 154200)

#### 2. Create a mask for embeddings of terms that occurred in documents

The mask is "1" for terms that occurred and "0" for terms that did not occur, with n_emb entries per term

In [8]:
dtm_mask = np.repeat(dtm_bin.values, n_emb, axis=1)
dtm_mask.shape

(18155, 154200)

#### 4. Apply the mask to term embeddings

In [9]:
dtm_emb = dtm_mask * emb
dtm_emb.shape

(18155, 154200)

In [None]:
dtm_emb = pd.DataFrame(dtm_emb, index=dtm_bin.index)

# Split the data

## Training and dev sets

In [10]:
splits = {}
for split in ["train", "dev"]:
    splits[split] = [int(pmid.strip()) for pmid in open("data/splits/{}.txt".format(split), "r").readlines()]

## Randomized mini-batches

In [11]:
X = dtm_emb.loc[splits["train"]].transpose().values
Y = act_bin.loc[splits["train"]].transpose().values
mini_batches = load_mini_batches(X, Y, mini_batch_size=batch_size, seed=42)

# Specify the LSTM

In [None]:
context_length = 3
lstm = nn.LSTM(n_emb*n_terms, n_emb*n_terms)  # Input is embedding dimension x N terms, output is N brain structures
inputs = [torch.randn(1, n_emb*n_terms) for _ in range(context_length)]  # Make a sequence of length context_length
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, n_emb*n_terms), torch.randn(1, 1, n_emb*n_terms))  # Clean out hidden state
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)
print(out)
print(hidden)

# Specify the classifier

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_terms, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 100)
        self.fc4 = nn.Linear(100, 100)
        self.fc5 = nn.Linear(100, 100)
        self.fc6 = nn.Linear(100, 100)
        self.fc7 = nn.Linear(100, 100)
        self.fc8 = nn.Linear(100, 100)
        self.fc9 = nn.Linear(100, 100)
        self.fc10 = nn.Linear(100, 100)
        self.fc11 = nn.Linear(100, n_structs)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))
        x = F.relu(self.fc8(x))
        x = F.relu(self.fc9(x))
        x = F.relu(self.fc10(x))
        x = torch.sigmoid(self.fc11(x))
        return x

In [None]:
net = Net()
criterion = F.binary_cross_entropy
optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

# Train the model

In [None]:
for epoch in range(n_epochs):  # Loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(mini_batches):
        
        # Get the inputs
        inputs, labels = data
        inputs = Variable(torch.from_numpy(inputs.T).float())
        labels = Variable(torch.from_numpy(labels.T).float())
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Print running loss
    running_loss += loss.item()
    print("Epoch {:3d} \t Loss {:6.6f}".format(epoch + 1, running_loss / 100))
    running_loss = 0.0

# Evaluate the model

In [None]:
def report_metrics(data_set):
    with torch.no_grad():
        inputs, labels = data_set[0]
        inputs = Variable(torch.from_numpy(inputs.T).float())
        labels = Variable(torch.from_numpy(labels.T).float())
        outputs = net(inputs)
        predictions = (outputs > 0.5).float() * 1
        print("-" * 50 + "\nMACRO-AVERAGED TOTAL\n" + "-" * 50)
        report_macro(labels, predictions)
        print("\n" + "-" * 50 + "\n\n")
        for i in range(n_structs):
            print("-" * 50 + "\n" + act_bin.columns[i].title().replace("_", " ") + "\n" + "-" * 50)
            report_class(labels[:,i], predictions[:,i])
            print("")

In [None]:
def report_curves(data_set, name): 
    with torch.no_grad():
        inputs, labels = data_set[0]
        inputs = Variable(torch.from_numpy(inputs.T).float())
        labels = Variable(torch.from_numpy(labels.T).float())
        pred_probs = net(inputs).float()
        fpr, tpr = compute_roc(labels, pred_probs)
        prec, rec = compute_prc(labels, pred_probs)
        plot_curves("{}_roc".format(name), fpr, tpr, diag=True, alpha=0.25,
                    xlab="1 - Specificity", ylab="Sensitivity")
        plot_curves("{}_prc".format(name), rec, prec, diag=False, alpha=0.5,
                    xlab="Recall", ylab="Precision")

## Training set performance

In [None]:
X_train = dtm_bin.loc[splits["train"]].transpose().values
Y_train = act_bin.loc[splits["train"]].transpose().values
train_set = load_mini_batches(X_train, Y_train, mini_batch_size=len(splits["train"]), seed=42)
report_curves(train_set, "{}_train".format(prefix))

In [None]:
report_metrics(train_set)

## Dev set performance

In [None]:
X_dev = dtm_bin.loc[splits["dev"]].transpose().values
Y_dev = act_bin.loc[splits["dev"]].transpose().values
dev_set = load_mini_batches(X_dev, Y_dev, mini_batch_size=len(splits["dev"]), seed=42)
report_curves(dev_set, "{}_dev".format(prefix))

In [None]:
report_metrics(dev_set)

# Export the trained model

In [None]:
torch.save(Net.state_dict(), "models/{}.pt".format(prefix))