# HW4 CAML

In this question, we will implement Convolutional Attention for Multi-Label classification (CAML) proposed by Mullenbach et al. in the paper "[Explainable Prediction of Medical Codes from Clinical Text](https://www.aclweb.org/anthology/N18-1100/)".

Clinical notes are text documents that are created by clinicians for each patient encounter. They are typically accompanied by medical codes, which describe the diagnosis and treatment. Annotating these codes is labor intensive and error prone; furthermore, the connection between the codes and the text is not annotated, obscuring the reasons and details behind specific diagnoses and treatments. Thus, let us implement CAML, an attentional convolutional network to predict medical codes from clinical text.

<img src='img/clinical notes.png'>

Image courtsey: [link](https://www.aclweb.org/anthology/2020.acl-demos.33/)

---

In [339]:
import os
import csv
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [340]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

In [341]:
DATA_PATH = "../HW4-CAML-lib/data/"

assert os.path.isdir(DATA_PATH)

---

## Dataset

Navigate to the data folder `DATA_PATH`, there are several files:

- `train_50.csv`, `test_50.csv`: these two files contains the data used for training and testing.
    - `SUBJECT_ID` refers to a unique patient.
    - `HADM_ID` refers to a unique hospital admission.
    - `TEXT` refers to the clinical text.
    - `LABELS` refers to the medical codes. We will predict the top 50 most frequent codes.
    - `length` refers to the length of the clinical text.
- `vocab.csv`: this file contains the vocabularies used in the clinical text.
- `TOP_50_CODES.csv`: this file contains the top 50 medical codes.

In [342]:
!ls {DATA_PATH}

test_50.csv  TOP_50_CODES.csv  train_50.csv  vocab.csv


For example, the first admission in `train_50.csv` has:
- `SUBJECT_ID`: 43909
- `HADM_ID`: 167612
- `TEXT`: admission date discharge date date of birth sex f service medicine allergies patient recorded as having no known allergies to drugs attending last name namepattern4 chief complaint pea arrest major surgical or invasive procedure intubation history of present illness yof w a h o metastatic cancer to lung presents with pea arrest resuscitated and intubated found to have large pneumonia patient with no potential cure per hospital oncology providers family had decided to make patient comfort care measures only she was transferred to icu for extubation past medical history metastatic cancer social history nc family history nc physical exam afebrile on ventilator sating well in general nad follows commands understands situation and wants to be extubated lungs ctab while vented rrr no m r g pertinent results 43pm lactate 43pm comments green top 45pm urine hyaline 45pm urine rbc wbc bacteria few yeast many epi 45pm urine blood tr nitrite neg protein glucose tr ketone neg bilirubin neg urobilngn neg ph leuk tr 45pm urine color yellow appear hazy sp last name un 45pm pt ptt inr pt 45pm plt count 45pm neuts lymphs monos eos basos 45pm wbc rbc hgb hct mcv mch mchc rdw 45pm calcium phosphate magnesium 45pm ck mb notdone 45pm ctropnt 45pm ck cpk 45pm estgfr using this 45pm glucose urea n creat sodium potassium chloride total co2 anion gap 16pm o2 sat 16pm lactate 16pm type central ve brief hospital course the patient was admitted to the icu for terminal extubation she was made comfort care measures only and given her widely metastatic cancer and pneumonia as well as pea arrest her family had decided to make her comfort care only she was extubated in the icu and died within minutes medications on admission unknown discharge medications expired discharge disposition expired discharge diagnosis patient expired discharge condition patient expired discharge instructions patient expired followup instructions patient expired initials namepattern4 last name namepattern4 name8 md md md number
- `LABELS`: 311;496;486;96.71;427.31
- `length`: 323

## 1 Prepare the Dataset [40 points]

### 1.1 Helper Functions [20 points]

To begin, weith, let us first implement some helper functions we will use later.

In [343]:
def to_index(sequence, token2idx):
    """
    TODO: convert the sequnce of tokens to indices. 
    If the word in unknown, then map it to `len(token2idx)'.
    
    INPUT:
        sequence (type: list of str): a sequence of tokens
        stoken2idx (type: dict): a dictionary mapping token to the corresponding index
    
    OUTPUT:
        indices (type: list of int): a sequence of indicies
        
    EXAMPLE:
        >>> sequence = ['hello', 'world', 'unknown_word']
        >>> token2idx = {'hello': 0, 'world': 1}
        >>> to_index(sequence, token2idx)
        [0, 1, 2]
    """
    # your code here
    indices = []
    for word in sequence:
        if word in token2idx:
            indices.append(token2idx[word])
        else:
#             indices.append(sequence.index(word))
            indices.append(len(token2idx))
            
    return indices
            
#     raise NotImplementedError

In [344]:
'''
AUTOGRADER CELL. DO NOT MODIFY THIS.
'''

sequence = ['hello', 'world', 'unknown_word']
token2idx = {'hello': 0, 'world': 1}
assert to_index(sequence, token2idx) == [0, 1, 2], "to_index() is wrong!"



In [345]:
def to_multi_hot(label, size):
    """
    TODO: convert the label to multi-hot.
    
    INPUT:
        label (type: list of int): class indices
        size (type: int): total number of distinct classes
    
    OUTPUT:
        multi_hot_label (type: list of int): multi-hot encoding for the input label
        
    EXAMPLE:
        >>> label = [1, 2, 3]
        >>> size = 4
        >>> to_multi_hot(label, size)
        [0, 1, 1, 1]
    """
    # your code here
    multi_hot_label = [0] * (size)
    i = 0
    while i in range(size):
            if i in label:
                multi_hot_label[i] = 1
            i+=1
    print(multi_hot_label)
    return multi_hot_label
#     raise NotImplementedError

In [346]:
'''
AUTOGRADER CELL. DO NOT MODIFY THIS.
'''

assert to_multi_hot([1, 2, 3], 4) == [0, 1, 1, 1], "to_multi_hot is wrong!"



[0, 1, 1, 1]


### 1.2 CustomDataset [10 points]

Now, let us implement a custom dataset using PyTorch class `Dataset`, which will characterize the key features of the dataset we want to generate.

We will use the clinical text as input and medical codes as output.

In [366]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, filename):        
        # read in the data files
        data = []
        with open(filename, "r") as file:
            csv_reader = csv.DictReader(file, delimiter=',')
            for row in csv_reader:
                data.append(row)
        self.data = data
        # load word lookup
        self.idx2word, self.word2idx = self.load_lookup(f'{DATA_PATH}/vocab.csv', padding=True)
        # load code lookup
        self.idx2code, self.code2idx = self.load_lookup(f'{DATA_PATH}/TOP_50_CODES.csv')
        
    def load_lookup(self, filename, padding=False):
        """ load lookup for word or code """
        tokens = set()
        with open(filename, 'r') as vocabfile:
            for i, line in enumerate(vocabfile):
                line = line.rstrip()
                if line != '':
                    tokens.add(line.strip())
        idx2token = {}
        if padding:  # padding with index 0
            idx2token[0] = '**PAD**'
        for w in sorted(tokens):
            idx2token[len(idx2token)] = w
        token2idx = {w:i for i,w in idx2token.items()}
        return idx2token, token2idx
    
    def to_multi_hot(self, label):
        multi_hot_label = [0] * len(self.idx2code)
        for idx in label:
            multi_hot_label[idx] = 1
        return multi_hot_label
        
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. admissions).
        """
        
        # your code here
        return len(self.data)
#         raise NotImplementedError
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.

        STEP: 1. convert text to indices using to_index();
              2. convert labels to indices using to_index();
              3. convert labels to multi-hot using to_multi_hot();
        """
        data = self.data[index]
        text = data['TEXT'].split(' ')
        labels = data['LABELS'].split(';')
        
        # your code here
        text = to_index(text, self.word2idx)
        labels = to_index(labels, self.code2idx)
        labels = to_multi_hot(labels, len(self.code2idx))
        print(text)
    
#         raise NotImplementedError
        # return text as long tensor, labels as float tensor;
        return torch.tensor(text, dtype=torch.long), torch.tensor(labels, dtype=torch.float)

In [367]:
'''
AUTOGRADER CELL. DO NOT MODIFY THIS.
'''

dataset = CustomDataset(f'{DATA_PATH}/train_50.csv')
assert len(dataset) == 84, "__len__() is wrong!"

text, labels = dataset[1]

assert type(text) is torch.Tensor, "__getitem__(): text is not tensor!"
assert type(labels) is torch.Tensor, "__getitem__(): labels is not tensor!"
assert text.dtype is torch.int64, "__getitem__(): text is not of type long!"
assert labels.dtype is torch.float32, "__getitem__(): labels is not of type float!"



[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 2891, 3303, 3112, 174, 920, 3032, 2476, 3590, 3590, 962, 920, 3032, 2476, 3590, 3590, 3590, 3303, 3536, 1931, 1850, 1758, 2520, 666, 1458, 1658, 1758, 2233, 2963, 482, 3590, 1690, 855, 1510, 2233, 2500, 1592, 3197, 2356, 1710, 106, 3572, 2240, 1914, 3530, 2720, 3062, 106, 3590, 3536, 1508, 3590, 1457, 3492, 3590, 351, 3197, 2841, 253, 509, 521, 1112, 3251, 3197, 1104, 894, 3525, 1211, 2760, 3492, 2385, 3197, 2356, 3492, 1572, 3536, 3590, 495, 253, 3492, 1106, 1698, 447, 626, 3333, 3519, 2425, 1607, 3197, 1104, 894, 1457, 930, 351, 2981, 2453, 1445, 106, 562, 319, 253, 3590, 2549, 3492, 1654, 3197, 2356, 3492, 509, 1106, 3251, 3197, 2262, 2790, 189, 626, 3558, 2636, 2768, 3196, 3197, 2777, 626, 3332, 295, 3251, 414, 433, 3197, 3590, 2410, 1188, 3197, 2356, 1437, 106, 1371, 2233, 1457, 3492, 743, 

### 1.3 Collate Function [10 points]

The collate function `collate_fn()` will be called by `DataLoader` after fetching a list of samples using the indices from `CustomDataset` to collate the list of samples into batches.

For example, assume the `DataLoader` gets a list of two samples.

```
[ [3,  1,  2, 8, 5], 
  [12, 13, 6, 7, 12, 23, 11] ]
```

where the first sample has text `[3, 1, 2, 8, 5]` the second sample has text `[12, 13, 6, 7, 12, 23, 11]`.

The collate function `collate_fn()` is supposed to pad them into the same shape (7), where 7 is the maximum number of tokens.

``` 
[ [3,  1,  2, 8, 5, *0*, *0*], 
  [12, 13, 6, 7, 12, 23,  11 ]
```

where `*0*` indicates the padding token.

We need to pad the sequences into the same length so that we can do batch training on GPU. And we also need this mask so that when training, we can ignored the padded value as they actually do not contain any information.

In [368]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(data):
    """
    TODO: implement the collate function.
    
    STEP: 1. pad the text using pad_sequence(). Set `batch_first=True`.
          2. stack the labels using torch.stack().
          
    OUTPUT:
        text: the padded text, shape: (batch size, max length)
        labels: the stacked labels, shape: (batch size, num classes)
    """
    text, labels = zip(*data)
    
    # your code here
    
    text = pad_sequence(text, batch_first=True, padding_value=0)
    labels = torch.stack(labels, dim=0)
    
    print(labels.shape)

    
    
    
#     raise NotImplementedError
    
    return text, labels

In [369]:
'''
AUTOGRADER CELL. DO NOT MODIFY THIS.
'''

from torch.utils.data import DataLoader

dataset = CustomDataset(f'{DATA_PATH}/train_50.csv')
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
text, labels = next(loader_iter)

assert text.shape == (10, 612), "collate_fn(): text has incorrect shape!"
assert labels.shape == (10, 50), "collate_fn(): labels has incorrect shape!"



[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1228, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1773, 2081, 2093, 629, 719, 2367, 319, 1911, 3113, 2266, 1700, 2525, 1699, 1510, 2233, 2500, 1592, 3580, 3480, 106, 1434, 2206, 1983, 549, 3251, 1886, 2504, 3536, 2367, 319, 3590, 253, 1698, 1332, 3251, 1445, 1770, 2446, 2356, 3536, 2151, 2482, 3590, 2380, 1518, 2249, 3590, 1240, 1437, 857, 3251, 1912, 2356, 700, 572, 1949, 2253, 2908, 3492, 3290, 3251, 1579, 1329, 1224, 2351, 1956, 1510, 1983, 549, 2974, 1510, 2098, 1240, 1510, 2098, 2410, 1187, 185, 2244, 3438, 3590, 3517, 1607, 1373, 2080, 1324, 704, 3590, 3590, 253, 3590, 3251, 414, 1223, 1887, 823, 3528, 3590, 2797, 2151, 1900, 2611, 1357, 2399, 2758, 69, 1760, 69, 705, 1416, 3267, 71, 3397, 1539, 71, 3397, 2637, 3500, 396, 1270, 3574, 1921, 1143, 71, 3397, 469,

All done, now let us load the dataset and data loader.

In [370]:
train_set = CustomDataset(f'{DATA_PATH}/train_50.csv')
test_set = CustomDataset(f'{DATA_PATH}/test_50.csv')
train_loader = DataLoader(train_set, batch_size=32, collate_fn=collate_fn, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, collate_fn=collate_fn)

## 2 Model [50 points]

Next, we will implement the CAML model.

<img src='img/caml.png'>

CAML is a convolutional neural network (CNN)-based model. It employs a per-label attention mechanism, which allows the model to learn distinct document representations for each label.

In [371]:
def load_embeddings(embed_file):
    """ helper function used to load the word2vec word embeddings """
    W = []
    with open(embed_file) as ef:
        for line in ef:
            line = line.rstrip().split()
            vec = np.array(line[1:]).astype(np.float)
            # normalizes the embeddings
            vec = vec / float(np.linalg.norm(vec) + 1e-6)
            W.append(vec)
        # UNK embedding, gaussian randomly initialized 
        vec = np.random.randn(len(W[-1]))
        vec = vec / float(np.linalg.norm(vec) + 1e-6)
        W.append(vec)
    W = np.array(W)
    return W

In [372]:
from math import floor
from torch.nn.init import xavier_uniform_


class CAML(nn.Module):

    def __init__(self, Y=50, embed_file=f'../HW4-CAML-lib/weights/processed_full.embed', kernel_size=10, num_filter_maps=16, embed_size=100, dropout=0.5):
        super(CAML, self).__init__()
        
        self.Y = Y  # number of codes
        self.embed_size = embed_size  # size of each embedding
        self.embed_drop = nn.Dropout(p=dropout)

        # make embedding layer from pre-trained word2vec embeddings
        W = torch.Tensor(load_embeddings(embed_file))
        self.embed = nn.Embedding(W.size()[0], W.size()[1], padding_idx=0)
        self.embed.weight.data = W.clone()

        # initialize conv layer as in 2.1
        self.conv = nn.Conv1d(self.embed_size, num_filter_maps, kernel_size=kernel_size, padding=int(floor(kernel_size/2)))
        xavier_uniform_(self.conv.weight)

        # context vectors for computing attention as in 2.2
        self.U = nn.Linear(num_filter_maps, Y)
        xavier_uniform_(self.U.weight)

        # final layer: create a matrix to use for the L binary classifiers as in 2.3
        self.final = nn.Linear(num_filter_maps, Y)
        xavier_uniform_(self.final.weight)
        
    def forward_embed(self, text):
        """
        TODO: Feed text through the embedding (self.embed) and dropout layer (self.embed_drop).
        
        INPUT: 
            text: (batch size, seq_len)
            
        OURPUT:
            text: (batch size, seq_len, embed_size)
        """
        # your code here
        text = self.embed(text)
        text = self.embed_drop(text)
    
        return text
#         raise NotImplementedError
        
    def forward_conv(self, text):
        """
        TODO: Feed text through the convolution layer (self.conv) and tanh activation function (F.tanh) 
        in eq (1) in the paper.
        
        INTPUT:
            text: (batch size, embed_size, seq_len)
            
        OUTPUT:
            text: (batch size, num_filter_maps, seq_len)
        """
        # your code here
        text = self.conv(text)
        text = F.tanh(text)
        return text
#         raise NotImplementedError
        
    def forward_calc_atten(self, text):
        """
        TODO: calculate the attention weights in eq (2) in the paper.
        
        INPUT:
            text: (batch size, seq_len, num_filter_maps)

        OUTPUT:
            alpha: (batch size, num_class, seq_len), the attention weights
            
        STEP: 1. multiply `self.U.weight` with `text` using torch.matmul();
              2. apply softmax using `F.softmax()`.
        """
        # (batch size, seq_len, num_filter_maps) -> (batch size, num_filter_mapsseq_len)
        text = text.transpose(1,2)
        # your code here
        alpha = torch.matmul(self.U.weight, text)
        alpha = F.softmax(alpha, dim=2)
        
        return alpha
#         raise NotImplementedError
        
    def forward_aply_atten(self, alpha, text):
        """
        TODO: apply the attention in eq (3) in the paper.

        INPUT: 
            text: (batch size, seq_len, num_filter_maps)
            alpha: (batch size, num_class, seq_len), the attention weights
            
        OUTPUT:
            v: (batch size, num_class, num_filter_maps), vector representations for each label
            
        STEP: multiply `alpha` with `text` using torch.matmul().
        """
        # your code here
        v = torch.matmul(alpha, text)
        
        return v
    
    def forward_linear(self, v):
        """
        TODO: apply the final linear classification in eq (5) in the paper.
        
        INPUT: 
            v: (batch size, num_class, num_filter_maps), vector representations for each label
            
        OUTPUT:
            y_hat: (batch size, num_class), label probability
            
        STEP: 1. multiply `self.final.weight` v `text` element-wise using torch.mul();
              2. sum the result over dim 2 (i.e. num_filter_maps);
              3. add the result with `self.final.bias`;
              4. apply sigmoid with torch.sigmoid().
        """
        # your code here
        text = torch.mul(self.final.weight, v)
        num_filter_maps = torch.sum(text, dim = 2)
        
        return torch.sigmoid(num_filter_maps + self.final.bias)
        
#         raise NotImplementedError
        
    def forward(self, text):
        """ 1. get embeddings and apply dropout """
        text = self.forward_embed(text)
        # (batch size, seq_len, embed_size) -> (batch size, embed_size, seq_len);
        text = text.transpose(1, 2)

        """ 2. apply convolution and nonlinearity (tanh) """
        text = self.forward_conv(text)
        # (batch size, num_filter_maps, seq_len) -> (batch size, seq_len, num_filter_maps);
        text = text.transpose(1,2)
        
        """ 3. calculate attention """
        alpha = self.forward_calc_atten(text)
        
        """ 3. apply attention """
        v = self.forward_aply_atten(alpha, text)
           
        """ 4. final layer classification """
        y_hat = self.forward_linear(v)
        print(y_hat)
        return y_hat
    
    
model = CAML()

In [373]:
'''
AUTOGRADER CELL. DO NOT MODIFY THIS.
'''

model = CAML()
model.eval()  # disable dropout



CAML(
  (embed_drop): Dropout(p=0.5, inplace=False)
  (embed): Embedding(3591, 100, padding_idx=0)
  (conv): Conv1d(100, 16, kernel_size=(10,), stride=(1,), padding=(5,))
  (U): Linear(in_features=16, out_features=50, bias=True)
  (final): Linear(in_features=16, out_features=50, bias=True)
)

## 3 Training and Inferencing [10 points]

In [374]:
model = CAML()

In [375]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()

Now let us implement the `eval()` and `train()` function. Note that `train()` should call `eval()` at the end of each training epoch to see the results on the validaion dataset.

In [376]:
from sklearn.metrics import precision_recall_fscore_support


def eval(model, test_loader):
    
    """    
    INPUT:
        model: the CAML model
        test_loader: dataloader
        
    OUTPUT:
        precision: overall micro precision score
        recall: overall micro recall score
        f1: overall micro f1 score
        
    REFERENCE: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """

    model.eval()
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    model.eval()
    for sequences, labels in test_loader:
        """
        TODO: 1. preform forward pass
              2. obtain the predicted class (0, 1) by comparing forward pass output against 0.5, 
                 assign the predicted class to y_hat.
        """
        # your code here
        f_w = model.forward(sequences)
        y_hat = (f_w > 0.5)
#         raise NotImplementedError
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, labels.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='micro')
    return p, r, f

In [377]:
def train(model, train_loader, test_loader, n_epochs):
    """    
    INPUT:
        model: the CAML model
        train_loader: dataloder
        val_loader: dataloader
        n_epochs: total number of epochs
    """
    model.train()
    for epoch in range(n_epochs):
        train_loss = 0
        for sequences, labels in train_loader:
            optimizer.zero_grad()
            """ 
            TODO: 1. perform forward pass using `model`, save the output to y_hat;
                  2. calculate the loss using `criterion`, save the output to loss.
            """
            y_hat, loss = None, None
            # your code here
            y_hat = model(sequences)
            loss = criterion(y_hat, labels)
#             raise NotImplementedError
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f = eval(model, test_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'.format(epoch+1, p, r, f))

    
# number of epochs to train the model
n_epochs = 5

train(model, train_loader, test_loader, n_epochs)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
[174, 849, 962, 849, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1291, 2085, 1820, 629, 719, 1566, 3536, 2472, 469, 828, 1911, 3113, 2266, 1700, 2525, 3355, 1433, 1797, 1506, 1311, 1024, 1797, 2414, 2426, 1510, 2233, 2500, 1592, 3579, 1228, 3536, 2440, 2233, 358, 1273, 253, 1535, 2502, 3251, 3197, 1081, 2244, 3536, 1235, 3251, 3224, 134, 3251, 2356, 2808, 850, 2356, 3536, 1618, 1245, 253, 3590, 3558, 852, 3536, 347, 3552, 3590, 3558, 851, 2356, 2808, 850, 3200, 542, 3197, 1379, 2244, 541, 366, 1022, 1773, 2081, 3048, 3530, 2656, 3196, 2908, 699, 1685, 3197, 1081, 3389, 174, 3251, 3197, 1081, 2356, 1437, 106, 1268, 3251, 253, 3492, 1386, 3343, 2908, 3492, 1386, 32, 2188, 331, 2356, 1869, 3447, 1044, 2908, 3492, 176, 3251, 3197, 1961, 2891, 1652, 2908, 3492, 2179, 3251, 1445, 1878, 1405, 3184, 1797, 1506, 2321

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1900, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1291, 2085, 1820, 629, 719, 1005, 1911, 3113, 2266, 1700, 2525, 3218, 2387, 1023, 3590, 253, 2705, 562, 578, 1510, 2233, 2500, 1592, 2050, 1748, 1774, 1710, 106, 3579, 1900, 1434, 2206, 1154, 529, 1056, 2808, 2310, 838, 2233, 1345, 253, 648, 253, 3590, 3569, 3530, 1710, 3032, 2476, 1155, 2244, 3530, 3492, 963, 1515, 2476, 2255, 2244, 253, 912, 2537, 1005, 2299, 3197, 2135, 3511, 253, 1439, 3590, 1505, 3251, 2500, 3251, 3197, 1081, 2244, 2176, 2356, 3492, 2864, 1607, 3219, 3112, 660, 2244, 351, 3527, 3240, 1457, 3492, 2179, 3251, 1445, 106, 2023, 2951, 1797, 2436, 1088, 2356, 2180, 3196, 1457, 1199, 1061, 3590, 3524, 3590, 3018, 253, 236, 316, 1508, 1526, 1457, 930, 2174, 1445, 2913, 2233, 494, 2233, 2749, 518, 930, 1

tensor([[0.4975, 0.4401, 0.4799,  ..., 0.4542, 0.4525, 0.4910],
        [0.4966, 0.4401, 0.4803,  ..., 0.4533, 0.4535, 0.4908],
        [0.4885, 0.4354, 0.4819,  ..., 0.4529, 0.4595, 0.4940],
        ...,
        [0.5081, 0.4451, 0.4785,  ..., 0.4530, 0.4466, 0.4892],
        [0.4959, 0.4377, 0.4806,  ..., 0.4540, 0.4535, 0.4917],
        [0.4999, 0.4406, 0.4797,  ..., 0.4537, 0.4513, 0.4908]],
       grad_fn=<SigmoidBackward>)
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1900, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1291, 2085, 1820, 629, 719, 1268, 229, 1973, 3032, 1911, 3113, 2266, 1700, 2525, 922, 1510, 2233, 2500, 1592, 3217, 1710, 106, 3572, 2240, 1918, 3536, 106, 1510, 2233, 995, 1587, 532, 628, 1084, 253, 1455, 895, 1158, 3530, 2504, 3251, 3197, 1081, 1343, 922, 3536, 1268, 1361, 3590, 253, 229, 19

tensor([[0.4971, 0.4366, 0.4764,  ..., 0.4583, 0.4512, 0.4861],
        [0.4970, 0.4367, 0.4760,  ..., 0.4582, 0.4510, 0.4856],
        [0.4875, 0.4308, 0.4755,  ..., 0.4617, 0.4554, 0.4843],
        ...,
        [0.4443, 0.4017, 0.4733,  ..., 0.4754, 0.4839, 0.4767],
        [0.4400, 0.4015, 0.4725,  ..., 0.4751, 0.4862, 0.4734],
        [0.4311, 0.3935, 0.4736,  ..., 0.4802, 0.4951, 0.4769]],
       grad_fn=<SigmoidBackward>)
Epoch: 1 	 Validation p: 0.22, r:0.36, f: 0.27
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1900, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1291, 2085, 1820, 629, 719, 2913, 2233, 494, 1911, 3113, 2266, 1700, 2525, 2161, 1510, 2233, 2500, 1592, 2050, 1748, 1774, 1710, 106, 3572, 2240, 1914, 3536, 2351, 1956, 1510, 2233, 782, 2244, 62, 2095, 553, 351, 1515, 2284, 3234, 549, 642, 2708, 

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1900, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1291, 2085, 1820, 629, 719, 1005, 1911, 3113, 2266, 1700, 2525, 3218, 2387, 1023, 3590, 253, 2705, 562, 578, 1510, 2233, 2500, 1592, 2050, 1748, 1774, 1710, 106, 3579, 1900, 1434, 2206, 1154, 529, 1056, 2808, 2310, 838, 2233, 1345, 253, 648, 253, 3590, 3569, 3530, 1710, 3032, 2476, 1155, 2244, 3530, 3492, 963, 1515, 2476, 2255, 2244, 253, 912, 2537, 1005, 2299, 3197, 2135, 3511, 253, 1439, 3590, 1505, 3251, 2500, 3251, 3197, 1081, 2244, 2176, 2356, 3492, 2864, 1607, 3219, 3112, 660, 2244, 351, 3527, 3240, 1457, 3492, 2179, 3251, 1445, 106, 2023, 2951, 1797, 2436, 1088, 2356, 2180, 3196, 1457, 1199, 1061, 3590, 3524, 3590, 3018, 253, 236, 316, 1508, 1526, 1457, 930, 2174, 1445, 2913, 2233, 494, 2233, 2749, 518, 930, 1

[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 2891, 2129, 219, 398, 3590, 645, 366, 1291, 2085, 1820, 629, 719, 590, 590, 761, 1642, 1911, 3113, 2266, 1700, 2525, 3590, 1329, 446, 3590, 2640, 3590, 2233, 3590, 2244, 3284, 1510, 2233, 2500, 1592, 1529, 2356, 1710, 106, 3590, 3530, 1710, 2808, 2310, 2777, 3590, 3590, 3590, 2233, 3590, 2244, 2908, 3492, 963, 2244, 1343, 1519, 3251, 1518, 2691, 3217, 2036, 3528, 3462, 3536, 1491, 1537, 2908, 3492, 3590, 3251, 1445, 2897, 3590, 2380, 1537, 2808, 2718, 1146, 2233, 2868, 149, 2380, 3197, 1537, 2808, 2718, 3197, 2869, 2863, 3251, 414, 3590, 3251, 3197, 2777, 2926, 2233, 1491, 1458, 253, 2109, 253, 902, 331, 2068, 3590, 3197, 1537, 891, 3590, 281, 1353, 149, 1607, 230, 311, 2233, 1491, 473, 2351, 1956, 1510, 3590, 390, 1750, 1344, 1750, 1506, 1356, 2808, 2310, 2275, 1607, 2650, 604, 1853, 927, 628, 3329, 886, 2185, 2808,

Epoch: 2 	 Training Loss: 0.668202
[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 2891, 1961, 1510, 2233, 2500, 1592, 3197, 2356, 1710, 250, 3572, 2240, 1260, 3536, 106, 2351, 1956, 1510, 2233, 1554, 789, 326, 971, 1552, 253, 2158, 3590, 2808, 1895, 3530, 3492, 3290, 3251, 1291, 2086, 2091, 1773, 2081, 2091, 1291, 2081, 3247, 1773, 2081, 3247, 1343, 250, 2298, 1518, 3536, 2446, 727, 521, 106, 2076, 1632, 253, 2881, 3197, 2356, 3492, 1652, 176, 3251, 1520, 1520, 2244, 3536, 2913, 2233, 494, 796, 253, 2446, 2908, 3200, 1437, 250, 1145, 2233, 1566, 1607, 3197, 1104, 2790, 253, 3492, 3590, 3536, 1312, 2908, 3492, 3290, 3251, 3197, 1306, 3525, 2908, 426, 717, 122, 626, 2321, 2908, 1437, 1146, 2233, 2626, 358, 1273, 2908, 3492, 3200, 3290, 3251, 3197, 1956, 1673, 572, 3374, 351, 1520, 3525, 2908, 2648, 944, 253, 1873, 3536, 3196, 2908, 416, 1567, 192, 3536, 3143, 469, 

[0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1900, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1291, 2085, 1820, 629, 719, 1061, 1911, 3113, 2266, 1700, 2525, 2611, 1589, 1699, 1510, 2233, 2500, 1592, 3590, 3536, 1434, 2206, 2898, 782, 2244, 90, 1515, 2207, 2503, 3251, 2285, 3536, 3590, 2233, 1618, 2973, 3536, 796, 2531, 2233, 3575, 3529, 3008, 3220, 2380, 1508, 3533, 2151, 655, 1269, 626, 2321, 2077, 3409, 112, 2321, 926, 2630, 1729, 2323, 1508, 3533, 2180, 2461, 726, 3536, 1508, 1311, 253, 933, 2754, 253, 1619, 3516, 1343, 1781, 1044, 3251, 1781, 830, 1457, 1444, 423, 3161, 1772, 846, 2380, 3533, 1457, 2502, 3251, 2081, 2140, 3197, 851, 2518, 3251, 174, 518, 2932, 2295, 234, 3389, 321, 3251, 2285, 1508, 3477, 97, 1457, 3492, 2993, 1607, 3590, 2880, 3536, 3590, 3523, 253, 930, 2174, 3259, 1892, 1301, 1457, 34

tensor([[0.3597, 0.3369, 0.4465,  ..., 0.5259, 0.5224, 0.4235],
        [0.4511, 0.4033, 0.4639,  ..., 0.4805, 0.4609, 0.4647],
        [0.4820, 0.4246, 0.4701,  ..., 0.4653, 0.4506, 0.4759],
        ...,
        [0.4688, 0.4161, 0.4672,  ..., 0.4719, 0.4543, 0.4713],
        [0.4583, 0.4077, 0.4656,  ..., 0.4772, 0.4586, 0.4674],
        [0.4648, 0.4117, 0.4665,  ..., 0.4744, 0.4565, 0.4698]],
       grad_fn=<SigmoidBackward>)


IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
[174, 849, 962, 849, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1291, 2085, 1820, 629, 719, 1566, 3536, 2472, 469, 828, 1911, 3113, 2266, 1700, 2525, 3355, 1433, 1797, 1506, 1311, 1024, 1797, 2414, 2426, 1510, 2233, 2500, 1592, 3579, 1228, 3536, 2440, 2233, 358, 1273, 253, 1535, 2502, 3251, 3197, 1081, 2244, 3536, 1235, 3251, 3224, 134, 3251, 2356, 2808, 850, 2356, 3536, 1618, 1245, 253, 3590, 3558, 852, 3536, 347, 3552, 3590, 3558, 851, 2356, 2808, 850, 3200, 542, 3197, 1379, 2244, 541, 366, 1022, 1773, 2081, 3048, 3530, 2656, 3196, 2908, 699, 1685, 3197, 1081, 3389, 174, 3251, 3197, 1081, 2356, 1437, 106, 1268, 3251, 253, 3492, 1386, 3343, 2908, 3492, 1386, 32, 2188, 331, 2356, 1869, 3447, 1044, 2908, 3492, 176, 3251, 3197, 1961, 2891, 1652, 2908, 3492, 2179, 3251, 1445, 1878, 1405, 3184, 1797, 1506, 2321

tensor([[0.4502, 0.4004, 0.4572,  ..., 0.4766, 0.4495, 0.4575],
        [0.4487, 0.3995, 0.4564,  ..., 0.4773, 0.4497, 0.4564],
        [0.4172, 0.3768, 0.4468,  ..., 0.4926, 0.4553, 0.4398],
        ...,
        [0.3104, 0.2967, 0.4095,  ..., 0.5550, 0.5058, 0.3674],
        [0.3034, 0.2929, 0.4063,  ..., 0.5590, 0.5112, 0.3606],
        [0.2910, 0.2824, 0.4040,  ..., 0.5680, 0.5273, 0.3547]],
       grad_fn=<SigmoidBackward>)
Epoch: 4 	 Validation p: 0.22, r:0.25, f: 0.24
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1228, 2891, 2241, 219, 3590, 3590, 763, 366, 1291, 2085, 1820, 629, 719, 1570, 1911, 3113, 2266, 1700, 2525, 506, 3355, 1433, 3218, 1510, 2233, 2500, 1592, 3217, 2560, 1710, 106, 3584, 2240, 1260, 3536, 2441, 2933, 1329, 3016, 3590, 2158, 2963, 602, 1886, 529, 2808, 2310, 2251, 838, 2233, 3590, 3590, 253, 3569, 176, 1329, 353

[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1228, 2891, 1961, 219, 3095, 3097, 275, 2377, 3590, 245, 366, 1291, 2085, 1820, 629, 719, 3135, 1237, 3590, 1565, 1911, 3113, 2266, 1700, 2525, 2161, 1510, 2233, 2500, 1592, 3590, 3536, 1565, 3290, 1343, 2285, 1329, 3385, 1237, 1343, 3019, 3536, 3068, 1484, 2560, 1403, 3387, 3251, 1398, 3251, 3197, 412, 3590, 3387, 2244, 3197, 1306, 2244, 332, 2650, 3408, 2244, 644, 850, 509, 1491, 3251, 3197, 1518, 417, 2908, 1258, 3196, 3197, 2356, 3492, 2035, 1806, 2560, 891, 281, 3590, 626, 2321, 2330, 2518, 3251, 3197, 1237, 2560, 227, 2180, 1797, 688, 477, 2321, 1607, 3197, 1081, 1651, 3466, 2612, 3590, 1187, 2719, 331, 723, 3381, 2864, 521, 2129, 3530, 2656, 2711, 1596, 253, 3544, 3387, 2233, 3135, 1403, 32, 2188, 2104, 253, 354, 1458, 821, 2917, 390, 1750, 1344, 253, 1797, 3183, 3068, 3590, 1695, 3590, 3

[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 849, 2233, 456, 2901, 1900, 2891, 1961, 219, 2356, 2659, 331, 1446, 2151, 1748, 219, 3251, 1043, 366, 1291, 2085, 1820, 629, 719, 1459, 1268, 2109, 2321, 1911, 3113, 2266, 1700, 2525, 1882, 2573, 609, 259, 1510, 2233, 2500, 1592, 3571, 2206, 1918, 3536, 2518, 397, 3590, 3536, 2897, 851, 2961, 3590, 1435, 2109, 3590, 3590, 3536, 1268, 253, 631, 891, 1973, 3032, 615, 912, 2097, 253, 3476, 851, 2561, 3251, 2285, 351, 2285, 1437, 1458, 821, 2113, 1329, 463, 518, 1880, 2917, 1770, 244, 2637, 3536, 2114, 397, 1411, 3017, 253, 827, 3290, 1492, 1329, 3590, 3480, 3346, 2711, 1880, 741, 2637, 3536, 3590, 253, 3590, 3590, 609, 259, 2114, 1329, 3590, 253, 980, 2648, 153, 3590, 247, 3421, 3590, 3492, 842, 523, 842, 2721, 3590, 3590, 1607, 1323, 2118, 3590, 253, 1144, 1657, 1329, 642, 393, 2321, 3133, 426, 3538, 852, 2233, 1773, 1

tensor([[0.4330, 0.3876, 0.4481,  ..., 0.4791, 0.4470, 0.4455],
        [0.4309, 0.3862, 0.4470,  ..., 0.4800, 0.4472, 0.4440],
        [0.3932, 0.3589, 0.4335,  ..., 0.4974, 0.4518, 0.4218],
        ...,
        [0.2776, 0.2707, 0.3836,  ..., 0.5695, 0.4981, 0.3325],
        [0.2710, 0.2666, 0.3798,  ..., 0.5743, 0.5034, 0.3254],
        [0.2595, 0.2568, 0.3764,  ..., 0.5838, 0.5194, 0.3180]],
       grad_fn=<SigmoidBackward>)
Epoch: 5 	 Validation p: 0.22, r:0.24, f: 0.23


In [380]:
'''
AUTOGRADER CELL. DO NOT MODIFY THIS.
'''

p, r, f = eval(model, test_loader)
assert f > 0.20, "f1 below 0.25!"



[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[174, 849, 962, 849, 2891, 1961, 1510, 2233, 2500, 1592, 3197, 2356, 1710, 250, 3572, 2240, 1260, 3536, 106, 2351, 1956, 1510, 2233, 1554, 789, 326, 971, 1552, 253, 2158, 3590, 2808, 1895, 3530, 3492, 3290, 3251, 1291, 2086, 2091, 1773, 2081, 2091, 1291, 2081, 3247, 1773, 2081, 3247, 1343, 250, 2298, 1518, 3536, 2446, 727, 521, 106, 2076, 1632, 253, 2881, 3197, 2356, 3492, 1652, 176, 3251, 1520, 1520, 2244, 3536, 2913, 2233, 494, 796, 253, 2446, 2908, 3200, 1437, 250, 1145, 2233, 1566, 1607, 3197, 1104, 2790, 253, 3492, 3590, 3536, 1312, 2908, 3492, 3290, 3251, 3197, 1306, 3525, 2908, 426, 717, 122, 626, 2321, 2908, 1437, 1146, 2233, 2626, 358, 1273, 2908, 3492, 3200, 3290, 3251, 3197, 1956, 1673, 572, 3374, 351, 1520, 3525, 2908, 2648, 944, 253, 1873, 3536, 3196, 2908, 416, 1567, 192, 3536, 3143, 469, 2507, 1607, 3197, 1878, 98, 3197, 2