In [21]:
import os
import sys
import pickle
import psutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split

# Define data path
DATA_PATH = "data/"
CHECKPOINT_PATH = "models/"

In [22]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'vids.pkl'), 'rb'))
targs = pickle.load(open(os.path.join(DATA_PATH,'targets.pkl'), 'rb'))
prob_targs = pickle.load(open(os.path.join(DATA_PATH, 'prob_targets_allvisits.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'seqs.pkl'), 'rb'))
diags = pickle.load(open(os.path.join(DATA_PATH,'diags.pkl'), 'rb'))
codes = pickle.load(open(os.path.join(DATA_PATH,'icd9.pkl'), 'rb'))
categories = pickle.load(open(os.path.join(DATA_PATH,'categories.pkl'), 'rb'))
sub_categories = pickle.load(open(os.path.join(DATA_PATH,'subcategories.pkl'), 'rb'))
assert len(pids) == len(vids) == len(targs) == len(seqs)

In [23]:
embedding_matrix = torch.load(os.path.join(DATA_PATH, 'embedding_matrix.pt'))

In [24]:
class CustomDataset(Dataset):
    
    def __init__(self, seqs, targets):
        self.x = seqs
        self.y = targets
    
    def __len__(self): 
        return(len(self.x))
    
    def __getitem__(self, index):
        return (self.x[index], self.y[index])

In [25]:
# dataset = CustomDataset(seqs, targs)
dataset = CustomDataset(seqs, prob_targs)

In [26]:
def collate_fn(data):
    """
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        x_masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time.
        rev_masks: same as x_masks but in reversed time.
        y: a tensor of shape (# patients, max # diagnosis categories) of type torch.float
        y_masks: a tensor of shape (# patients, max # diagnosis categories) of type torch.bool
    """
#     sequences, targets = zip(*data)

#     num_patients = len(sequences)
#     num_visits = [len(patient) for patient in sequences]
#     num_codes = [len(visit) for patient in sequences for visit in patient]
#     batch_num_categories = [len(visit) for patient in targets for visit in patient]

#     max_num_visits = max(num_visits)
#     max_num_codes = max(num_codes)
#     max_num_categories = max(batch_num_categories)
    
#     x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
#     y = torch.zeros((num_patients, max_num_categories), dtype=torch.long)
#     rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
#     x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
#     rev_x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
#     y_masks = torch.zeros((num_patients, max_num_categories), dtype=torch.bool)
#     for i_patient, patient in enumerate(sequences):   
#         for j_visit, visit in enumerate(patient[:-1]):
#             for k_code, code in enumerate(visit):
#                 x[i_patient, j_visit, k_code] = code
#                 x_masks[i_patient, j_visit, k_code] = 1
#                 if j_visit == len(patient) - 2 and k_code == len(visit) - 1:
#                     rev_visit = x_masks[i_patient].any(dim=1)
#                     rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
#                     rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
                
#     for i_patient, patient in enumerate(targets):   
#         for visit in patient[-1:]:
#             for k_code, code in enumerate(visit):
#                 y[i_patient, k_code] = code
#                 y_masks[i_patient, k_code] = 1
    
#     return x, x_masks, rev_x, rev_x_masks, y, y_masks

    sequences, targets = zip(*data)

    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
    batch_num_categories = [len(visit) for patient in targets for visit in patient]
    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
#     max_num_categories = max(batch_num_categories)
    max_num_categories = len(targets[0][0])
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
#     y = torch.zeros((num_patients, max_num_categories), dtype=torch.long)
    y = torch.zeros((num_patients, max_num_categories), dtype=torch.float)
    y_masks = torch.zeros((num_patients, max_num_categories), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):   
        for j_visit, visit in enumerate(patient[:-1]):
            for k_code, code in enumerate(visit):
                x[i_patient, j_visit, k_code] = code
                x_masks[i_patient, j_visit, k_code] = 1
                if j_visit == len(patient) - 2 and k_code == len(visit) - 1:
                    rev_visit = x_masks[i_patient].any(dim=1)
                    rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
                    rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)

#     for i_patient, patient in enumerate(targets):
# #         import pdb; pdb.set_trace()
#         last_visit = patient[-1]
#         y[i_patient,:len(last_visit)] = torch.LongTensor(last_visit)
#         y_masks[i_patient,:len(last_visit)] = torch.BoolTensor(np.ones(len(last_visit)))
    for i_patient, patient in enumerate(targets):
#         import pdb; pdb.set_trace()
        last_visit = patient[-1]
        y[i_patient] = torch.FloatTensor(last_visit)
        y_masks[i_patient] = torch.BoolTensor(np.ones(max_num_categories))
    
    return x, x_masks, rev_x, rev_x_masks, y, y_masks

In [27]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)
val_split = int(len(dataset)*0.10)

In [28]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)

lengths = [train_split, test_split, len(dataset) - (train_split + test_split)]
train_dataset, test_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of test dataset:", len(test_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6561
Length of test dataset: 1312
Length of val dataset: 875


In [29]:
def load_data(train_dataset, test_dataset, val_dataset, collate_fn):
    '''
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        test dataset: test dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, test_loader, val_loader: train, test and validation dataloaders
    '''
    batch_size = 100
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               collate_fn=collate_fn,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                           batch_size=batch_size,
                                           collate_fn=collate_fn,
                                           shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             collate_fn=collate_fn,
                                             shuffle=False)
    
    return train_loader, test_loader, val_loader


train_loader, test_loader, val_loader = load_data(train_dataset, test_dataset, val_dataset, collate_fn)

In [30]:
def sum_embeddings_with_mask(x, masks):
    """
    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
    """
    x[~masks] = 0
    return x.sum(2)

In [31]:
def attention_sum(alpha, beta, rev_v, rev_masks):
    """
    Arguments:
        alpha: the alpha attention weights of shape (batch_size, # visits, 1)
        beta: the beta attention weights of shape (batch_size, # visits, hidden_dim)
        rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
        rev_masks: the padding masks in reversed time of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        c: the context vector of shape (batch_size, hidden_dim)
    """
    rev_masks = rev_masks[:,:].any(dim=2)
    rev_v_masked = rev_v.clone()
    rev_v_masked[~rev_masks] = 0
    weights = alpha * beta
    a_sum = torch.sum(weights * rev_v_masked, dim=1)
    return a_sum

In [32]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """ 
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, # visits, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, # visits, 1)
        """
        m = nn.Softmax(dim=1)
        return m(self.a_att(g))

In [33]:
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        """
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, # visits, hidden_dim) 
        
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, # visits, hidden_dim)
        """
        return torch.tanh(self.b_att(h))

In [34]:
class EnhancedRETAIN(nn.Module):
    
    def __init__(self, num_codes, num_categories):
        super().__init__()
        """
        Arguments:
            num_codes: total number of diagnosis codes
            num_categories: total number of diagnosis categories to predict
        """
        self.embedding = nn.Linear(num_codes, 300)
        self.embedding.weight.data = embedding_matrix
        self.rnn_a = nn.GRU(300, hidden_size=128, batch_first=True)
        self.rnn_b = nn.GRU(300, hidden_size=128, batch_first=True)
        self.att_a = AlphaAttention(128)
        self.att_b = AlphaAttention(128)
        self.fc = nn.Linear(300, num_categories)
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)
            rev_x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            rev_masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            logits: logits of shape (batch_size, # categories)
        """
        rev_x = indices_to_multihot(rev_x, rev_masks, 4903)
        rev_x = self.embedding(rev_x)
        rev_x[~rev_masks.any(dim=2)] = 0  # Set masked visits to 0
        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        alpha = self.att_a(g)
        beta = self.att_b(h)
        c = attention_sum(alpha, beta, rev_x, rev_masks)
        logits = self.fc(c)
        return logits

# load the model here
enhanced_retain = EnhancedRETAIN(num_codes = len(codes), num_categories=len(sub_categories))
enhanced_retain

EnhancedRETAIN(
  (embedding): Linear(in_features=4903, out_features=300, bias=True)
  (rnn_a): GRU(300, 128, batch_first=True)
  (rnn_b): GRU(300, 128, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (fc): Linear(in_features=300, out_features=184, bias=True)
)

In [35]:
criterion = nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)
optimizer = torch.optim.Adadelta(enhanced_retain.parameters(), weight_decay=0.001)

In [36]:
def train(model, train_loader, test_loader, n_epochs):
    """
    Arguments:
        model: the EnhancedRETAIN model
        train_loader: training dataloder
        test_loader: validation dataloader
        n_epochs: total number of epochs
    """
#     max_cpu, max_ram = print_cpu_usage()
#     for epoch in range(n_epochs):
#         model.train()
#         train_loss = 0
#         for x, x_masks, rev_x, rev_x_masks, y, y_masks in train_loader:
#             y_hat = model(x, x_masks, rev_x, rev_x_masks) 
#             y_mh = indices_to_multihot(y, y_masks, y_hat.shape[-1])
#             loss = criterion(y_hat, y_mh)
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
#             train_loss += loss.item()
#         train_loss = train_loss / len(train_loader)
#         cpu, ram = print_cpu_usage()
#         max_cpu = cpu if cpu > max_cpu else max_cpu
#         max_ram = ram if ram > max_ram else max_ram
#         print(f'Epoch: {epoch+1} \t Training Loss: {train_loss:.6f}')
#         for k in range(5, 31, 5):
#             precision_k, accuracy_k = eval_model(model, test_loader, k=k)
#             print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.4f}, accuracy@k{k}: {accuracy_k:.4f}')
#     final_cpu, final_ram = print_cpu_usage()
#     print(f"Max CPU usage: {max_cpu:.3f}\tMax RAM % usage: {max_ram}")

    max_cpu, max_ram = print_cpu_usage()
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, x_masks, rev_x, rev_x_masks, y, y_masks in train_loader:
            y_hat = model(x, x_masks, rev_x, rev_x_masks)                  
#             y_mh = indices_to_multihot(y, y_masks, y_hat)
            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        cpu, ram = print_cpu_usage()
        max_cpu = cpu if cpu > max_cpu else max_cpu
        max_ram = ram if ram > max_ram else max_ram
        print(f'Epoch: {epoch+1} \t Training Loss: {train_loss:.6f}')
        for k in range(5, 31, 5):
            precision_k, accuracy_k = eval_model(model, val_loader, k=k)
            print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.4f}, accuracy@k{k}: {accuracy_k:.4f}')
    final_cpu, final_ram = print_cpu_usage()
    print(f"Max CPU usage: {max_cpu:.3f}\tMax RAM % usage: {max_ram}")

In [37]:
def eval_model(model, test_loader, k=15, n=-1):
    """
    Arguments:
        model: the EnhancedRETAIN model
        test_loader: validation dataloader
        k: value for top k predictions
        n: num of records to evaluate in the batch, value -1 evaulates all records
        
    Outputs:
        precision_k: visit-level precison@k
        accuracy_k: code-level accuracy@k
    """
#     y_pred = torch.LongTensor()
#     y_true = torch.LongTensor()
#     all_precision = []
#     all_accuracy = []
    
#     model.eval()
#     with torch.no_grad():
#         for x, x_masks, rev_x, rev_x_masks, y, y_masks in test_loader:
#             n_eval = y.shape[0] - 1 if n == -1 else n
#             y_hat = model(x, x_masks, rev_x, rev_x_masks)
#             y_hat = F.softmax(y_hat, dim=-1)
#             y_multihot = indices_to_multihot(y, y_masks, y_hat.shape[-1])
#             k_correct = 0
#             total_precision = 0
#             total_accuracy = 0
#             for i in range(n_eval):
#                 visit_correct = 0
#                 y_true = y[i, y_masks[i]]
#                 _, y_pred = torch.topk(y_hat[i], k)

#                 for v in y_true:
#                     if v in y_pred:
#                         visit_correct += 1
                        
#                 visit_precision = visit_correct / min(k, len(y_true))
#                 visit_accuracy = visit_correct / len(y_true)
#                 k_correct += visit_correct
#                 total_precision += visit_precision
#                 total_accuracy += visit_accuracy

#             precision_k = total_precision / n_eval
#             accuracy_k = total_accuracy / n_eval
#             all_precision.append(precision_k)
#             all_accuracy.append(accuracy_k)

#     total_precision_k = np.mean(all_precision)
#     total_accuracy_k = np.mean(all_accuracy)
#     return total_precision_k, total_accuracy_k

    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    all_precision = []
    all_accuracy = []
    
    model.eval()
    with torch.no_grad():
        for x, x_masks, rev_x, rev_x_masks, y, y_masks in test_loader:
            n_eval = y.shape[0] - 1 if n == -1 else n
            y_hat = model(x, x_masks, rev_x, rev_x_masks)
            y_hat = F.softmax(y_hat, dim=-1)
            #y_multihot = indices_to_multihot(y, y_masks, y_hat)
            nz_rows, nz_cols = torch.nonzero(y, as_tuple=True)
            k_correct = 0
            total_precision = 0
            total_accuracy = 0
            for i in range(n_eval):
                visit_correct = 0
                #y_true = y[i, y_masks[i]]
                y_true = nz_cols[nz_rows == i]
                _, y_pred = torch.topk(y_hat[i], k)
                for v in y_true:
                    if v in y_pred:
                        visit_correct += 1
                visit_precision = visit_correct / min(k, len(y_true))
                visit_accuracy = visit_correct / len(y_true)
                k_correct += visit_correct
                total_precision += visit_precision
                total_accuracy += visit_accuracy
            precision_k = total_precision / n_eval
            accuracy_k = total_accuracy / n_eval
            all_precision.append(precision_k)
            all_accuracy.append(accuracy_k)
            
    total_precision_k = np.mean(all_precision)
    total_accuracy_k = np.mean(all_accuracy)
    return total_precision_k, total_accuracy_k

In [38]:
def indices_to_multihot(indices, masks, dim):
    mh_size = tuple(indices.shape[:-1]) + (dim,)
    multihot = torch.zeros(mh_size, dtype=torch.float)
    if len(mh_size) == 3:
        for i, patient in enumerate(indices):
            for j, visit in enumerate(patient):
                if masks[i,j].sum() == 0:
                    break
                y_idx = visit[masks[i,j]].unique()
                multihot[i,j] = F.one_hot(y_idx.to(torch.int64), multihot.shape[-1]).sum(0)
    else:
        for idx, row in enumerate(indices):
            y_idx = row[masks[idx]].unique()
            multihot[idx] = F.one_hot(y_idx, dim).sum(0).float()
    return multihot

In [39]:
def print_cpu_usage():
    load = psutil.getloadavg()[2]
    cpu_usage = (load/os.cpu_count()) * 100
    ram = psutil.virtual_memory()[2]
    print(f"CPU: {cpu_usage:0.2f}")
    print(f"RAM %: {ram}")
    return cpu_usage, ram

In [40]:
n_epochs = 100
%time train(enhanced_retain, train_loader, val_loader, n_epochs)

CPU: 15.82
RAM %: 61.4
CPU: 15.90
RAM %: 61.9
Epoch: 1 	 Training Loss: 4.313518
Epoch: 1 	 Validation precision@k5: 0.4301, accuracy@k5: 0.2212
Epoch: 1 	 Validation precision@k10: 0.4155, accuracy@k10: 0.3537
Epoch: 1 	 Validation precision@k15: 0.4788, accuracy@k15: 0.4651
Epoch: 1 	 Validation precision@k20: 0.5684, accuracy@k20: 0.5670
Epoch: 1 	 Validation precision@k25: 0.6451, accuracy@k25: 0.6450
Epoch: 1 	 Validation precision@k30: 0.7171, accuracy@k30: 0.7171
CPU: 15.92
RAM %: 61.9
Epoch: 2 	 Training Loss: 4.065195
Epoch: 2 	 Validation precision@k5: 0.5084, accuracy@k5: 0.2672
Epoch: 2 	 Validation precision@k10: 0.4983, accuracy@k10: 0.4269
Epoch: 2 	 Validation precision@k15: 0.5555, accuracy@k15: 0.5401
Epoch: 2 	 Validation precision@k20: 0.6267, accuracy@k20: 0.6251
Epoch: 2 	 Validation precision@k25: 0.7006, accuracy@k25: 0.7005
Epoch: 2 	 Validation precision@k30: 0.7666, accuracy@k30: 0.7666
CPU: 15.91
RAM %: 61.8
Epoch: 3 	 Training Loss: 3.974746
Epoch: 3 	 Vali

CPU: 16.81
RAM %: 61.8
Epoch: 19 	 Training Loss: 3.559059
Epoch: 19 	 Validation precision@k5: 0.6546, accuracy@k5: 0.3539
Epoch: 19 	 Validation precision@k10: 0.6095, accuracy@k10: 0.5286
Epoch: 19 	 Validation precision@k15: 0.6518, accuracy@k15: 0.6352
Epoch: 19 	 Validation precision@k20: 0.7144, accuracy@k20: 0.7128
Epoch: 19 	 Validation precision@k25: 0.7734, accuracy@k25: 0.7733
Epoch: 19 	 Validation precision@k30: 0.8183, accuracy@k30: 0.8183
CPU: 16.74
RAM %: 61.8
Epoch: 20 	 Training Loss: 3.552010
Epoch: 20 	 Validation precision@k5: 0.6297, accuracy@k5: 0.3427
Epoch: 20 	 Validation precision@k10: 0.5813, accuracy@k10: 0.5049
Epoch: 20 	 Validation precision@k15: 0.6372, accuracy@k15: 0.6208
Epoch: 20 	 Validation precision@k20: 0.7081, accuracy@k20: 0.7065
Epoch: 20 	 Validation precision@k25: 0.7743, accuracy@k25: 0.7742
Epoch: 20 	 Validation precision@k30: 0.8226, accuracy@k30: 0.8226
CPU: 16.89
RAM %: 61.8
Epoch: 21 	 Training Loss: 3.535679
Epoch: 21 	 Validation 

Epoch: 36 	 Validation precision@k30: 0.8364, accuracy@k30: 0.8364
CPU: 18.19
RAM %: 61.8
Epoch: 37 	 Training Loss: 3.437498
Epoch: 37 	 Validation precision@k5: 0.6911, accuracy@k5: 0.3730
Epoch: 37 	 Validation precision@k10: 0.6447, accuracy@k10: 0.5575
Epoch: 37 	 Validation precision@k15: 0.6822, accuracy@k15: 0.6642
Epoch: 37 	 Validation precision@k20: 0.7340, accuracy@k20: 0.7323
Epoch: 37 	 Validation precision@k25: 0.7886, accuracy@k25: 0.7885
Epoch: 37 	 Validation precision@k30: 0.8363, accuracy@k30: 0.8363
CPU: 18.09
RAM %: 61.8
Epoch: 38 	 Training Loss: 3.432912
Epoch: 38 	 Validation precision@k5: 0.6923, accuracy@k5: 0.3767
Epoch: 38 	 Validation precision@k10: 0.6351, accuracy@k10: 0.5491
Epoch: 38 	 Validation precision@k15: 0.6758, accuracy@k15: 0.6579
Epoch: 38 	 Validation precision@k20: 0.7365, accuracy@k20: 0.7347
Epoch: 38 	 Validation precision@k25: 0.7901, accuracy@k25: 0.7900
Epoch: 38 	 Validation precision@k30: 0.8340, accuracy@k30: 0.8340
CPU: 18.32
RAM 

Epoch: 54 	 Validation precision@k25: 0.7980, accuracy@k25: 0.7979
Epoch: 54 	 Validation precision@k30: 0.8415, accuracy@k30: 0.8415
CPU: 19.33
RAM %: 61.9
Epoch: 55 	 Training Loss: 3.394301
Epoch: 55 	 Validation precision@k5: 0.6971, accuracy@k5: 0.3749
Epoch: 55 	 Validation precision@k10: 0.6556, accuracy@k10: 0.5646
Epoch: 55 	 Validation precision@k15: 0.6843, accuracy@k15: 0.6653
Epoch: 55 	 Validation precision@k20: 0.7416, accuracy@k20: 0.7398
Epoch: 55 	 Validation precision@k25: 0.7987, accuracy@k25: 0.7986
Epoch: 55 	 Validation precision@k30: 0.8405, accuracy@k30: 0.8405
CPU: 19.27
RAM %: 62.0
Epoch: 56 	 Training Loss: 3.393245
Epoch: 56 	 Validation precision@k5: 0.6903, accuracy@k5: 0.3713
Epoch: 56 	 Validation precision@k10: 0.6547, accuracy@k10: 0.5633
Epoch: 56 	 Validation precision@k15: 0.6890, accuracy@k15: 0.6698
Epoch: 56 	 Validation precision@k20: 0.7459, accuracy@k20: 0.7441
Epoch: 56 	 Validation precision@k25: 0.7992, accuracy@k25: 0.7991
Epoch: 56 	 Val

Epoch: 72 	 Validation precision@k20: 0.7433, accuracy@k20: 0.7415
Epoch: 72 	 Validation precision@k25: 0.7972, accuracy@k25: 0.7971
Epoch: 72 	 Validation precision@k30: 0.8376, accuracy@k30: 0.8376
CPU: 21.29
RAM %: 62.7
Epoch: 73 	 Training Loss: 3.388022
Epoch: 73 	 Validation precision@k5: 0.6870, accuracy@k5: 0.3685
Epoch: 73 	 Validation precision@k10: 0.6503, accuracy@k10: 0.5596
Epoch: 73 	 Validation precision@k15: 0.6856, accuracy@k15: 0.6666
Epoch: 73 	 Validation precision@k20: 0.7430, accuracy@k20: 0.7412
Epoch: 73 	 Validation precision@k25: 0.7980, accuracy@k25: 0.7979
Epoch: 73 	 Validation precision@k30: 0.8400, accuracy@k30: 0.8400
CPU: 21.14
RAM %: 62.3
Epoch: 74 	 Training Loss: 3.387979
Epoch: 74 	 Validation precision@k5: 0.6975, accuracy@k5: 0.3747
Epoch: 74 	 Validation precision@k10: 0.6584, accuracy@k10: 0.5669
Epoch: 74 	 Validation precision@k15: 0.6852, accuracy@k15: 0.6660
Epoch: 74 	 Validation precision@k20: 0.7451, accuracy@k20: 0.7432
Epoch: 74 	 Val

Epoch: 90 	 Validation precision@k15: 0.6895, accuracy@k15: 0.6702
Epoch: 90 	 Validation precision@k20: 0.7422, accuracy@k20: 0.7403
Epoch: 90 	 Validation precision@k25: 0.7950, accuracy@k25: 0.7949
Epoch: 90 	 Validation precision@k30: 0.8350, accuracy@k30: 0.8350
CPU: 20.87
RAM %: 62.3
Epoch: 91 	 Training Loss: 3.385465
Epoch: 91 	 Validation precision@k5: 0.6969, accuracy@k5: 0.3748
Epoch: 91 	 Validation precision@k10: 0.6544, accuracy@k10: 0.5634
Epoch: 91 	 Validation precision@k15: 0.6876, accuracy@k15: 0.6682
Epoch: 91 	 Validation precision@k20: 0.7431, accuracy@k20: 0.7413
Epoch: 91 	 Validation precision@k25: 0.7964, accuracy@k25: 0.7963
Epoch: 91 	 Validation precision@k30: 0.8350, accuracy@k30: 0.8350
CPU: 21.34
RAM %: 62.3
Epoch: 92 	 Training Loss: 3.385214
Epoch: 92 	 Validation precision@k5: 0.6926, accuracy@k5: 0.3726
Epoch: 92 	 Validation precision@k10: 0.6537, accuracy@k10: 0.5620
Epoch: 92 	 Validation precision@k15: 0.6843, accuracy@k15: 0.6651
Epoch: 92 	 Val

In [41]:
for k in range(5, 31, 5):
    precision_k, accuracy_k = eval_model(enhanced_retain, test_loader, k=k)
    print(f'Validation precision@k{k}: {precision_k:.4f}, accuracy@k{k}: {accuracy_k:.4f}')

Validation precision@k5: 0.6896, accuracy@k5: 0.3794
Validation precision@k10: 0.6472, accuracy@k10: 0.5581
Validation precision@k15: 0.6867, accuracy@k15: 0.6666
Validation precision@k20: 0.7440, accuracy@k20: 0.7409
Validation precision@k25: 0.7936, accuracy@k25: 0.7934
Validation precision@k30: 0.8325, accuracy@k30: 0.8325


In [42]:
torch.save(enhanced_retain, os.path.join(CHECKPOINT_PATH, "EnhancedRETAIN_100.pth"))