In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"


# Approach 1: binary classification using an RNN on raw, variable-length data
We use the annual data of each subject to predict whether his/her scores on `seven`, `bwcount` and `recall` goes up or down. 

The number of waves varies from subject to subject. So we will feed each subject's data into an RNN, which is a neural network that can handle variable-length sequences. The expected output is the final year's cognitive score minus the first year's cognitive score. Since there are three cognitive scores, we will train models separately on the three scores. 


## Prepare the Y's 

In [2]:
import pickle
import numpy

In [3]:
[X, Y] = pickle.load(open('XY.pickle', 'br'))
 

In [4]:
Y[0]
Y[0][-1, 1]
Y[0][0, 1]

array([[0.3 , 0.8 , 1.  ],
       [0.6 , 0.8 , 1.  ],
       [0.35, 0.4 , 1.  ],
       [0.2 , 0.2 , 1.  ]])

0.2

0.8

In [5]:
def prepare_one_y(Y, target_index: int):
    """Separate the target columns and generate the target for an approach 

    Y: 1D list of 2D numpy arrays 
    """
    y = [subjectY[-1, target_index] - subjectY[0, target_index] for subjectY in Y]
    y = numpy.array(y)
    y = numpy.heaviside(y, 0)
    return y 

y = prepare_one_y(Y, 0)

In [6]:
total_subjects = len(Y)
metric_names = ["recall", "seven", "bwcount"]
for target_index in [0,1,2]:
    y = prepare_one_y(Y, target_index)
    number_of_increasing_subjects = numpy.count_nonzero(y)
    percentage_of_increasing_subjects = number_of_increasing_subjects/total_subjects*100
    metric_name = metric_names[target_index]
    print (f'{number_of_increasing_subjects} out of {total_subjects}, \
    or {percentage_of_increasing_subjects:.2f}% show end-to-end improvement \
    on metric {metric_name}')

9216 out of 26178,     or 35.21% show end-to-end improvement     on metric recall
6167 out of 26178,     or 23.56% show end-to-end improvement     on metric seven
1206 out of 26178,     or 4.61% show end-to-end improvement     on metric bwcount


## build the network

In [7]:
X[0].shape

(4, 23)

In [8]:
import collections
collections.Counter(map(len, Y))

Counter({4: 3600, 5: 5115, 6: 3624, 7: 5562, 3: 2348, 2: 5929})

In [21]:
import torch.nn as nn
import torch

class Net(nn.Module):
  # cannot be batched as input sequence lengths are different 
  def __init__(self, input_size, hidden_size, num_layers):
      super().__init__()
      self.rnn = torch.nn.RNN(input_size, hidden_size, num_layers)
    #   self.fc1 = torch.nn.Linear(hidden_size, 3)
      self.fc1 = torch.nn.Linear(hidden_size, 1)
      self.fc2 = torch.nn.Linear(3, 1)

  def forward(self, x):
      output, hn = self.rnn(x)
      x = self.fc1(hn[0])
    #   x = self.fc2(x)
      x = torch.sigmoid(x)
      return x

net = Net(23, 5, 2)

In [31]:
# test 1 
net = Net(23, 5, 2)
x = torch.randn(7, 23)
net(x)
# x

tensor([0.4182], grad_fn=<SigmoidBackward0>)

In [23]:
# test 2 
# we run into a weird error see the cell below. So we have to do this workaround. 
# CREATE a new cell
# net = Net(23, 5, 2)
# net = net.float()
net(torch.from_numpy(X[2000]).float())

tensor([0.5056], grad_fn=<SigmoidBackward0>)

In [18]:
# why this didn't work? 
net = Net(23, 5, 2) # init a fresh one
net(torch.from_numpy(X[0]).float())

tensor([0.4664], grad_fn=<SigmoidBackward0>)

In [13]:
# Convert X to cope with the error 
# again, we cannot do batch due to a weird issue 
new_X = [torch.from_numpy(x).float() for x in X]

## The trainer

In [14]:
def train(net, X, Y, target_index, learning_rate, momentum):
    """

    net: a torch.nn instance
    target_index: integer 0, 1, 2 (recall, seven, bwcount)
    """
    import torch.optim as optim

    criterion = nn.BCELoss()
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)

    optimizer.zero_grad()    # zero the parameter gradients

    y = prepare_one_y(Y, target_index)
    y = torch.tensor(y).float() # workaround for the weird problem 

    batch_size = 2000 
    print_batch = 2000

    loss_log = [] 

    for epoch in range(5):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, inputs in enumerate(X, 1):
            labels = y[i-1]

            optimizer.zero_grad()  # should not be here.

            # forward + backward + optimize
            prediction = net(inputs)
            prediction = prediction[0]

            loss = criterion(prediction, labels)
            # loss /= batch_size # normalize the loss, may not needed 

            loss.backward()        

            # print statistics
            running_loss += loss.item()
            if i % batch_size == 0:    #every mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / batch_size:.3f}')
                loss_log.append(running_loss / batch_size)
                optimizer.step()
                running_loss = 0.0
                
            # running_loss += loss.item()
            # if i % print_batch == 0: 
            #     print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / print_batch:.3f}')
            #     running_loss = 0.0

    return loss_log

### Training log for `bw_count` (target_index =2)
Slow loss convergence with high initial loss (0.9): `optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)`
Slow loss convergence with low inital loss (0.7): `optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)`

In [15]:
net = Net(23, 5, 2)
loss_log_005_09 = train(net, new_X, Y, 2, 0.05, 0.9)

[1,  2001] loss: 0.855
[1,  4001] loss: 0.787
[1,  6001] loss: 0.684
[1,  8001] loss: 0.550
[1, 10001] loss: 0.423
[1, 12001] loss: 0.335


KeyboardInterrupt: 

### Training log for `seven` (target_index=1)


In [16]:
net = Net(23, 5, 2)
loss_log_005_09 = train(net, new_X, Y, 1, 0.05, 0.9)

[1,  2001] loss: 0.673
[1,  4001] loss: 0.653
[1,  6001] loss: 0.612
[1,  8001] loss: 0.575
[1, 10001] loss: 0.563
[1, 12001] loss: 0.497
[1, 14001] loss: 0.472
[1, 16001] loss: 0.608
[1, 18001] loss: 0.647
[1, 20001] loss: 0.745
[1, 22001] loss: 0.714
[1, 24001] loss: 0.736
[1, 26001] loss: 0.786
[2,  2001] loss: 0.548
[2,  4001] loss: 0.529
[2,  6001] loss: 0.512
[2,  8001] loss: 0.519
[2, 10001] loss: 0.549
[2, 12001] loss: 0.492
[2, 14001] loss: 0.470
[2, 16001] loss: 0.591
[2, 18001] loss: 0.599
[2, 20001] loss: 0.649
[2, 22001] loss: 0.603
[2, 24001] loss: 0.603
[2, 26001] loss: 0.629
[3,  2001] loss: 0.519
[3,  4001] loss: 0.541
[3,  6001] loss: 0.561
[3,  8001] loss: 0.577
[3, 10001] loss: 0.586
[3, 12001] loss: 0.529
[3, 14001] loss: 0.494
[3, 16001] loss: 0.570
[3, 18001] loss: 0.579
[3, 20001] loss: 0.638
[3, 22001] loss: 0.611
[3, 24001] loss: 0.626
[3, 26001] loss: 0.667
[4,  2001] loss: 0.513
[4,  4001] loss: 0.508
[4,  6001] loss: 0.507
[4,  8001] loss: 0.516
[4, 10001] 

### Training log for `recall` (target_index=0)


In [17]:
net = Net(23, 5, 2)
loss_log_005_09 = train(net, new_X, Y, 0, 0.05, 0.9)

[1,  2001] loss: 0.683
[1,  4001] loss: 0.672
[1,  6001] loss: 0.640
[1,  8001] loss: 0.609
[1, 10001] loss: 0.634
[1, 12001] loss: 0.563
[1, 14001] loss: 0.547
[1, 16001] loss: 0.694
[1, 18001] loss: 0.811
[1, 20001] loss: 0.916
[1, 22001] loss: 1.041
[1, 24001] loss: 1.151
[1, 26001] loss: 1.072
[2,  2001] loss: 0.714
[2,  4001] loss: 0.764
[2,  6001] loss: 0.759
[2,  8001] loss: 0.757
[2, 10001] loss: 0.863
[2, 12001] loss: 0.635
[2, 14001] loss: 0.609
[2, 16001] loss: 0.891
[2, 18001] loss: 1.015
[2, 20001] loss: 1.067
[2, 22001] loss: 1.173
[2, 24001] loss: 1.213
[2, 26001] loss: 1.040
[3,  2001] loss: 0.695
[3,  4001] loss: 0.728
[3,  6001] loss: 0.713
[3,  8001] loss: 0.687
[3, 10001] loss: 0.780
[3, 12001] loss: 0.592
[3, 14001] loss: 0.579
[3, 16001] loss: 0.801
[3, 18001] loss: 0.891
[3, 20001] loss: 0.933
[3, 22001] loss: 1.084
[3, 24001] loss: 1.170
[3, 26001] loss: 1.040
[4,  2001] loss: 0.741
[4,  4001] loss: 0.809
[4,  6001] loss: 0.815
[4,  8001] loss: 0.795
[4, 10001] 

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=38ce621b-4696-4047-8b25-0501b493ce55' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>