In [1]:
import os, sys
import numpy as np
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

## Preliminary

In [2]:
torch.manual_seed(12345)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

na_list = ['A', 'C', 'G', 'T'] #nucleic acids
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))

## Dataset

In [3]:
def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        if aptamer == "CTTTGTAATTGGTTCTGAGTTCCGTTGTGGGAGGAACATG": #took out aptamer control
            continue
        for peptide, _ in peptides:
            peptide = peptide.replace("_", "") #removed stop codons
            if "RRRRRR" in peptide: #took out peptide control
                continue
            if len(aptamer) == 40 and len(peptide) == 8: #making sure right length
                full_dataset.append((aptamer, peptide))
    full_dataset = list(set(full_dataset)) #removed duplicates
    return full_dataset

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, training_set):
        super(TrainDataset, self).__init__() 
        self.training_set = training_set
        
    def __len__(self):
        return len(self.training_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.training_set[idx]
        return aptamer, peptide
    
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_set):
        super(TestDataset, self).__init__() 
        self.test_set = test_set
        
    def __len__(self):
        return len(self.test_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.test_set[idx]
        return aptamer, peptide

In [4]:
aptamer_dataset_file = "../data/aptamer_dataset.json"
full_dataset = construct_dataset()
n = len(full_dataset)
training_set = full_dataset[:int(0.8*n)]
test_set = full_dataset[int(0.8*n):]
train_dataset = TrainDataset(training_set)
test_dataset = TestDataset(test_set)
train_loader = torch.utils.data.DataLoader(train_dataset)
test_loader = torch.utils.data.DataLoader(test_dataset)

## NN Models

In [5]:
class SimpleConvNet(nn.Module):
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 5, (4,4)) #similar to 3-gram
        self.cnn_pep_1 = nn.Conv2d(1, 5, (4,20))
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(210, 1)
        
    def forward(self, apt, pep):
        apt = self.cnn_apt_1(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep_1(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

## Sampling

In [6]:
# Sample x from P_X (assume apatamers follow uniform)
def get_x():
    x_idx = np.random.randint(0, 4, 40)
    x = ""
    for i in x_idx:
        x += na_list[i]
    return x

# Sample y from P_y (assume peptides follow NNK)
def get_y():
    y_idx = np.random.choice(20, 7, p=pvals)
    y = "M"
    for i in y_idx:
        y += aa_list[i]
    return y

# Generate uniformly from S without replacement
def get_xy(k):
    samples = [full_dataset[i] for i in np.random.choice(len(full_dataset), k, replace=False)]
    return samples

# S' contains S with double the size of S (domain for Importance Sampling)
def get_S_prime(k):
    S_prime_dict = dict.fromkeys(full_dataset, 0) #indicator 0 means in the original dataset
    S_new = []
    for _ in range(k):
        pair = (get_x(), get_y())
        S_prime_dict[pair] = 1 #indicator 1 means not in the original dataset
        S_new.append(pair)
    S_prime = [[k,int(v)] for k,v in S_prime_dict.items()]
    random.shuffle(S_prime)
    return S_prime, S_new

# Returns pmf of an aptamer
def get_x_pmf():
    return 0.25**40

# Returns pmf of a peptide
def get_y_pmf(y):
    pmf = 1
    for char in y[1:]: #skips first char "M"
        pmf *= aa_dict[char]
    return pmf

S_prime, S_new = get_S_prime(n) #use for sgd and eval

## SGD

In [27]:
## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    one_hot = np.zeros((len(sequence), len(letters)))
    for i in range(len(sequence)):
        char = sequence[i]
        for _ in range(len(letters)):
            idx = letters.index(char)
            one_hot[i][idx] = 1
    return one_hot

# Convert a pair to one-hot tensor
def convert(apt, pep): 
    apt = one_hot(apt, seq_type='aptamer') #(40, 4)
    pep = one_hot(pep, seq_type='peptide') #(8, 20)
    apt = torch.FloatTensor(np.reshape(apt, (1, 1, apt.shape[0], apt.shape[1]))).cuda() #(1, 1, 40, 4)
    pep = torch.FloatTensor(np.reshape(pep, (1, 1, pep.shape[0], pep.shape[1]))).cuda() #(1, 1, 8, 20)
    return apt, pep

def update(x, y):
    pmf = get_y_pmf(y)
    x.requires_grad=True
    y.requires_grad=True
    x = x.cuda()
    y = y.cuda()
    out = model(x, y)
    return pmf, out

def get_log_out(k):
    outs = []
    for (apt, pep) in full_dataset[:k]:
        x, y = convert(apt, pep)
        _, out = update(x, y)
        outs.append(torch.log(out).cpu().detach().numpy().flatten()[0])
    return np.average(outs)

def get_out_prime(k):
    outs = []
    for (apt, pep) in S_new[:k]:
        x, y = convert(apt, pep)
        _, out = update(x, y)
        outs.append(out.cpu().detach().numpy().flatten()[0])
    return np.average(outs)
    
def apply_param_grad(grads1, grads2, fn):
    gs = []
    for grad1, grad2 in zip(grads1, grads2):
        gs.append(fn(grad1, grad2))
    return gs

def sgd(t=1, #num of iter over the training set
        lamb=1e-1, #hyperparam
        gamma=1e-2): #step size
    optim = SGD(model.parameters(), lr=gamma)
    model.train()
    for _ in range(t):
        for i, (apt, pep) in enumerate(tqdm.tqdm(train_loader)):
            optim.zero_grad()
            x, y = convert(apt[0], pep[0])
            _, out = update(x, y)
            log_out = torch.log(out)
#             log_out.backward(retain_graph=True)
#             g1 = []
#             for param in model.parameters():
#                 g1.append(param.grad)

            optim.zero_grad()
            x_prime, y_prime = convert(S_prime[i][0][0], S_prime[i][0][1])
            y_pmf, out_prime = update(x_prime, y_prime)
            out_prime = out_prime*y_pmf*get_x_pmf()*2*n
#             out_prime.backward(retain_graph=True)
#             g2 = []
#             for param in model.parameters():
#                 g2.append(param.grad)

            const = S_prime[i][1] #indicator
#             gs = apply_param_grad(g1, g2, lambda g1, g2: lamb*const*g2 - g1)
#             for param, g in zip(model.parameters(), gs): #update params
#                 param.grad = g

            (lamb*const*out_prime - log_out).backward(retain_graph=True)
            optim.step()
            if i % 500 == 0:
                loss = lamb*get_out_prime(1000) - get_log_out(1000)
                print("Iteration: ", i, " Loss: %.3f" % loss)

In [8]:
# Recall on train set of size k to test for overfitting
def recall_train(k):
    model.eval()
    correct = 0
    train_recall_outputs = []
    for (apt, pep) in recall_train_samples:
        apt, pep = convert(apt, pep)
        out = model(apt, pep).cpu().detach().numpy().flatten()[0]
        train_recall_outputs.append(out)
        if out > 0.75:
            correct += 1
    train_recall = 100*correct/k #recall rate of k samples from training set
    return train_recall, train_recall_outputs #list of k outputs


# Recall on test set of size k
def recall_test(k):
    model.eval()
    correct = 0
    test_recall_outputs = []
    for i, (aptamer, peptide) in enumerate(tqdm.tqdm(test_loader)):
        if i > k:
            break
        apt, pep = convert(aptamer[0], peptide[0])
        output = model(apt, pep).cpu().detach().numpy().flatten()[0]
        test_recall_outputs.append(output)
        if output > 0.75:
            correct += 1
    test_recall = 100*correct/k #recall rate of k samples from test set
    return test_recall, test_recall_outputs #list of k outputs


# Eval on m new unseen pairs in S_new (not in our dataset)
def eval_unknown(m):
    model.eval()
    eval_unknown_outputs = []
    for i, (x, y) in enumerate(S_new[:m]):
        apt, pep = convert(x, y)
        output = model(apt, pep).cpu().detach().numpy().flatten()[0]
        eval_unknown_outputs.append(output)
    return eval_unknown_outputs #list of m outputs


# AUC Plot
def cdf(scores, i): # i is the index
    plt.hist(scores, 100, histtype='step', density=True, cumulative=True)
    g = gammas[i//len(lambdas)]
    l = lambdas[i%len(lambdas)]
    label = 'lambda =%.5f' % l  + ' gamma =%.5f' % g
    plt.legend([label])
    plt.show()

In [25]:
gammas = [1e-2]
#lambdas = [1e-1, 1e-3, 10, 100, 1e-5]
lambdas = [1e-1]
train_recalls = []
train_scores = []
train_cdfs = []
test_recalls = []
test_scores = []
test_cdfs = []

m = int(1e6) # number of unknown samples
k = m//10 # number of binding samples (test set size is 118262, k is just some limit we set)
recall_train_samples = get_xy(k) #use for eval

In [None]:
for g in range(len(gammas)):
    for l in range(len(lambdas)):
        model = SimpleConvNet()
        model.apply(weights_init)
        model.cuda()
    
        print("=============Training=======================")
        sgd(t=1, gamma=gammas[g], lamb=lambdas[l])
        
        print("=============Evaluating train===============")
        train_recall, train_recall_outputs = recall_train(k)
        print("Gamma: ", "%.5f" % gammas[g], "Lambda: ", "%.5f" % lambdas[l], \
              "Train recall: ", "%.2f" % train_recall)
        
        print("=============Evaluating test================")
        test_recall, test_recall_outputs = recall_test(k)
        print("Gamma: ", "%.5f" % gammas[g], "Lambda: ", "%.5f" % lambdas[l], \
              "Test recall: ", "%.2f" % test_recall)
        
        print("=============Evaluating unknown=============")
        eval_unknown_outputs = eval_unknown(m)
        
        train_score = np.asarray(eval_unknown_outputs + train_recall_outputs)
        train_scores.append(train_score)
        
        test_score = np.asarray(eval_unknown_outputs + test_recall_outputs)
        test_scores.append(test_score)
        
        train_cdf = np.sum(np.cumsum(train_score), dtype=float)/(np.sum(train_score)*len(train_score))
        test_cdf = np.sum(np.cumsum(test_score), dtype=float)/(np.sum(test_score)*len(test_score))
        print("G: ", "%.5f" % gammas[g], "L: ", "%.5f" % lambdas[l], \
              "Train CDF: ", "%.3f" % train_cdf, "Test CDF: ", "%.3f" % test_cdf)
        
        train_recalls.append(train_recall)
        test_recalls.append(test_recall)
        train_cdfs.append((gammas[g], lambdas[l], train_cdf))
        test_cdfs.append((gammas[g], lambdas[l],  test_cdf))

print("Train CDFs: ", train_cdfs)
print("Test CDFs: ", test_cdfs)








  0%|          | 0/473047 [00:00<?, ?it/s][A[A[A[A[A[A[A










  0%|          | 1/473047 [01:00<7930:29:58, 60.35s/it][A[A[A[A[A[A[A






  0%|          | 4/473047 [01:00<5552:42:20, 42.26s/it][A[A[A[A[A[A[A

Iteration:  0  Loss: 0.787









  0%|          | 7/473047 [01:00<3888:50:54, 29.60s/it][A[A[A[A[A[A[A






  0%|          | 10/473047 [01:00<2724:04:29, 20.73s/it][A[A[A[A[A[A[A






  0%|          | 12/473047 [01:00<1909:00:38, 14.53s/it][A[A[A[A[A[A[A






  0%|          | 14/473047 [01:00<1338:52:48, 10.19s/it][A[A[A[A[A[A[A






  0%|          | 17/473047 [01:01<939:05:20,  7.15s/it] [A[A[A[A[A[A[A






  0%|          | 19/473047 [01:01<659:25:00,  5.02s/it][A[A[A[A[A[A[A






  0%|          | 21/473047 [01:01<463:40:16,  3.53s/it][A[A[A[A[A[A[A






  0%|          | 24/473047 [01:01<326:33:51,  2.49s/it][A[A[A[A[A[A[A






  0%|          | 26/473047 [01:01<231:02:26,  1.76s/it][A[A[A[A[A[A[A






  0%|          | 28/473047 [01:01<164:06:02,  1.25s/it][A[A[A[A[A[A[A






  0%|          | 30/473047 [01:01<116:58:54,  1.12it/s][A[A[A[A[A[A[A






  0%|          | 32/473047 [01:01<83:54:47,  1.57it/s] [A[A[A[A[A[A[A

  0%|          | 277/473047 [01:13<6:16:37, 20.92it/s][A[A[A[A[A[A[A






  0%|          | 280/473047 [01:14<6:10:35, 21.26it/s][A[A[A[A[A[A[A






  0%|          | 283/473047 [01:14<5:52:19, 22.36it/s][A[A[A[A[A[A[A






  0%|          | 286/473047 [01:14<6:19:32, 20.76it/s][A[A[A[A[A[A[A






  0%|          | 289/473047 [01:14<6:28:19, 20.29it/s][A[A[A[A[A[A[A






  0%|          | 292/473047 [01:14<6:20:35, 20.70it/s][A[A[A[A[A[A[A






  0%|          | 295/473047 [01:14<6:16:53, 20.91it/s][A[A[A[A[A[A[A






  0%|          | 298/473047 [01:14<6:03:09, 21.70it/s][A[A[A[A[A[A[A






  0%|          | 301/473047 [01:15<5:45:59, 22.77it/s][A[A[A[A[A[A[A






  0%|          | 304/473047 [01:15<6:02:23, 21.74it/s][A[A[A[A[A[A[A






  0%|          | 307/473047 [01:15<6:35:01, 19.95it/s][A[A[A[A[A[A[A






  0%|          | 310/473047 [01:15<6:20:48, 20.69it/s][A[A[A[A[A[A[A






  0%|          |

Iteration:  500  Loss: 0.103









  0%|          | 506/473047 [02:26<603:17:17,  4.60s/it][A[A[A[A[A[A[A






  0%|          | 508/473047 [02:26<424:20:43,  3.23s/it][A[A[A[A[A[A[A






  0%|          | 510/473047 [02:26<299:25:02,  2.28s/it][A[A[A[A[A[A[A






  0%|          | 512/473047 [02:27<211:33:56,  1.61s/it][A[A[A[A[A[A[A






  0%|          | 514/473047 [02:27<150:06:57,  1.14s/it][A[A[A[A[A[A[A






  0%|          | 516/473047 [02:27<107:52:19,  1.22it/s][A[A[A[A[A[A[A






  0%|          | 519/473047 [02:27<77:19:13,  1.70it/s] [A[A[A[A[A[A[A






  0%|          | 521/473047 [02:27<56:38:26,  2.32it/s][A[A[A[A[A[A[A






  0%|          | 523/473047 [02:27<41:39:45,  3.15it/s][A[A[A[A[A[A[A






  0%|          | 525/473047 [02:27<31:13:26,  4.20it/s][A[A[A[A[A[A[A






  0%|          | 527/473047 [02:27<24:23:59,  5.38it/s][A[A[A[A[A[A[A






  0%|          | 529/473047 [02:27<19:17:37,  6.80it/s][A[A[A[A[A[A

  0%|          | 795/473047 [02:40<5:50:55, 22.43it/s][A[A[A[A[A[A[A






  0%|          | 798/473047 [02:40<5:56:22, 22.09it/s][A[A[A[A[A[A[A






  0%|          | 801/473047 [02:40<5:59:50, 21.87it/s][A[A[A[A[A[A[A






  0%|          | 804/473047 [02:40<6:08:05, 21.38it/s][A[A[A[A[A[A[A






  0%|          | 807/473047 [02:40<6:23:10, 20.54it/s][A[A[A[A[A[A[A






  0%|          | 810/473047 [02:40<6:19:19, 20.75it/s][A[A[A[A[A[A[A






  0%|          | 813/473047 [02:40<6:10:08, 21.26it/s][A[A[A[A[A[A[A






  0%|          | 816/473047 [02:41<6:02:03, 21.74it/s][A[A[A[A[A[A[A






  0%|          | 819/473047 [02:41<6:02:12, 21.73it/s][A[A[A[A[A[A[A






  0%|          | 822/473047 [02:41<5:40:56, 23.08it/s][A[A[A[A[A[A[A






  0%|          | 825/473047 [02:41<6:04:37, 21.58it/s][A[A[A[A[A[A[A






  0%|          | 828/473047 [02:41<5:55:08, 22.16it/s][A[A[A[A[A[A[A






  0%|          |

Iteration:  1000  Loss: 0.101









  0%|          | 1006/473047 [03:49<581:10:25,  4.43s/it][A[A[A[A[A[A[A






  0%|          | 1009/473047 [03:49<408:11:13,  3.11s/it][A[A[A[A[A[A[A






  0%|          | 1013/473047 [03:49<286:48:53,  2.19s/it][A[A[A[A[A[A[A






  0%|          | 1016/473047 [03:49<202:21:51,  1.54s/it][A[A[A[A[A[A[A






  0%|          | 1019/473047 [03:49<143:16:42,  1.09s/it][A[A[A[A[A[A[A






  0%|          | 1022/473047 [03:49<102:21:54,  1.28it/s][A[A[A[A[A[A[A






  0%|          | 1025/473047 [03:50<73:49:45,  1.78it/s] [A[A[A[A[A[A[A






  0%|          | 1028/473047 [03:50<54:02:47,  2.43it/s][A[A[A[A[A[A[A






  0%|          | 1030/473047 [03:50<40:38:37,  3.23it/s][A[A[A[A[A[A[A






  0%|          | 1032/473047 [03:50<30:42:13,  4.27it/s][A[A[A[A[A[A[A






  0%|          | 1034/473047 [03:50<23:53:21,  5.49it/s][A[A[A[A[A[A[A






  0%|          | 1036/473047 [03:50<19:12:20,  6.83it/s][A[A

  0%|          | 1284/473047 [04:02<6:07:11, 21.41it/s][A[A[A[A[A[A[A






  0%|          | 1287/473047 [04:02<6:03:34, 21.63it/s][A[A[A[A[A[A[A






  0%|          | 1290/473047 [04:02<6:18:00, 20.80it/s][A[A[A[A[A[A[A






  0%|          | 1293/473047 [04:02<6:06:12, 21.47it/s][A[A[A[A[A[A[A






  0%|          | 1296/473047 [04:03<6:08:56, 21.31it/s][A[A[A[A[A[A[A






  0%|          | 1299/473047 [04:03<6:26:01, 20.37it/s][A[A[A[A[A[A[A






  0%|          | 1302/473047 [04:03<6:35:21, 19.89it/s][A[A[A[A[A[A[A






  0%|          | 1305/473047 [04:03<6:11:11, 21.18it/s][A[A[A[A[A[A[A






  0%|          | 1308/473047 [04:03<6:18:57, 20.75it/s][A[A[A[A[A[A[A






  0%|          | 1311/473047 [04:03<6:14:29, 20.99it/s][A[A[A[A[A[A[A






  0%|          | 1315/473047 [04:03<5:48:43, 22.55it/s][A[A[A[A[A[A[A






  0%|          | 1318/473047 [04:04<5:35:39, 23.42it/s][A[A[A[A[A[A[A






  0%

Iteration:  1500  Loss: 0.101









  0%|          | 1505/473047 [05:15<409:45:57,  3.13s/it][A[A[A[A[A[A[A






  0%|          | 1508/473047 [05:16<288:36:25,  2.20s/it][A[A[A[A[A[A[A






  0%|          | 1510/473047 [05:16<204:00:05,  1.56s/it][A[A[A[A[A[A[A






  0%|          | 1512/473047 [05:16<145:01:53,  1.11s/it][A[A[A[A[A[A[A






  0%|          | 1515/473047 [05:16<103:25:49,  1.27it/s][A[A[A[A[A[A[A






  0%|          | 1518/473047 [05:16<74:16:57,  1.76it/s] [A[A[A[A[A[A[A






  0%|          | 1521/473047 [05:16<53:48:23,  2.43it/s][A[A[A[A[A[A[A






  0%|          | 1524/473047 [05:16<39:23:04,  3.33it/s][A[A[A[A[A[A[A






  0%|          | 1527/473047 [05:16<29:24:43,  4.45it/s][A[A[A[A[A[A[A






  0%|          | 1530/473047 [05:17<22:42:58,  5.77it/s][A[A[A[A[A[A[A






  0%|          | 1533/473047 [05:17<17:52:26,  7.33it/s][A[A[A[A[A[A[A






  0%|          | 1535/473047 [05:17<14:45:25,  8.88it/s][A[A

  0%|          | 1773/473047 [05:28<5:51:18, 22.36it/s][A[A[A[A[A[A[A






  0%|          | 1776/473047 [05:29<5:55:08, 22.12it/s][A[A[A[A[A[A[A






  0%|          | 1779/473047 [05:29<5:56:43, 22.02it/s][A[A[A[A[A[A[A






  0%|          | 1782/473047 [05:29<6:17:40, 20.80it/s][A[A[A[A[A[A[A






  0%|          | 1785/473047 [05:29<6:09:39, 21.25it/s][A[A[A[A[A[A[A






  0%|          | 1788/473047 [05:29<6:03:54, 21.58it/s][A[A[A[A[A[A[A






  0%|          | 1791/473047 [05:29<6:06:46, 21.41it/s][A[A[A[A[A[A[A






  0%|          | 1794/473047 [05:29<6:00:19, 21.80it/s][A[A[A[A[A[A[A






  0%|          | 1797/473047 [05:30<6:11:34, 21.14it/s][A[A[A[A[A[A[A






  0%|          | 1800/473047 [05:30<6:25:55, 20.35it/s][A[A[A[A[A[A[A






  0%|          | 1803/473047 [05:30<6:01:00, 21.76it/s][A[A[A[A[A[A[A






  0%|          | 1806/473047 [05:30<6:23:37, 20.47it/s][A[A[A[A[A[A[A






  0%

Iteration:  2000  Loss: 0.100









  0%|          | 2005/473047 [06:41<405:40:52,  3.10s/it][A[A[A[A[A[A[A






  0%|          | 2007/473047 [06:42<286:03:44,  2.19s/it][A[A[A[A[A[A[A






  0%|          | 2010/473047 [06:42<202:08:13,  1.54s/it][A[A[A[A[A[A[A






  0%|          | 2012/473047 [06:42<144:03:26,  1.10s/it][A[A[A[A[A[A[A






  0%|          | 2014/473047 [06:42<102:54:09,  1.27it/s][A[A[A[A[A[A[A






  0%|          | 2017/473047 [06:42<73:41:40,  1.78it/s] [A[A[A[A[A[A[A






  0%|          | 2019/473047 [06:42<53:36:59,  2.44it/s][A[A[A[A[A[A[A






  0%|          | 2021/473047 [06:42<39:29:47,  3.31it/s][A[A[A[A[A[A[A






  0%|          | 2024/473047 [06:42<29:35:47,  4.42it/s][A[A[A[A[A[A[A






  0%|          | 2026/473047 [06:42<22:57:50,  5.70it/s][A[A[A[A[A[A[A






  0%|          | 2028/473047 [06:43<18:21:53,  7.12it/s][A[A[A[A[A[A[A






  0%|          | 2030/473047 [06:43<15:03:51,  8.69it/s][A[A

  0%|          | 2297/473047 [06:55<6:49:51, 19.14it/s][A[A[A[A[A[A[A






  0%|          | 2299/473047 [06:55<7:04:23, 18.49it/s][A[A[A[A[A[A[A






  0%|          | 2301/473047 [06:55<6:58:38, 18.74it/s][A[A[A[A[A[A[A






  0%|          | 2304/473047 [06:55<6:50:17, 19.12it/s][A[A[A[A[A[A[A






  0%|          | 2307/473047 [06:55<6:49:53, 19.14it/s][A[A[A[A[A[A[A






  0%|          | 2309/473047 [06:56<7:20:04, 17.83it/s][A[A[A[A[A[A[A






  0%|          | 2311/473047 [06:56<7:49:02, 16.73it/s][A[A[A[A[A[A[A






  0%|          | 2314/473047 [06:56<7:39:39, 17.07it/s][A[A[A[A[A[A[A






  0%|          | 2317/473047 [06:56<6:54:22, 18.93it/s][A[A[A[A[A[A[A






  0%|          | 2319/473047 [06:56<7:15:03, 18.03it/s][A[A[A[A[A[A[A






  0%|          | 2322/473047 [06:56<6:59:23, 18.71it/s][A[A[A[A[A[A[A






  0%|          | 2324/473047 [06:56<6:56:23, 18.84it/s][A[A[A[A[A[A[A






  0%

Iteration:  2500  Loss: 0.100









  1%|          | 2506/473047 [08:07<596:53:47,  4.57s/it][A[A[A[A[A[A[A






  1%|          | 2508/473047 [08:07<419:54:21,  3.21s/it][A[A[A[A[A[A[A






  1%|          | 2511/473047 [08:07<295:42:28,  2.26s/it][A[A[A[A[A[A[A






  1%|          | 2513/473047 [08:07<209:00:27,  1.60s/it][A[A[A[A[A[A[A






  1%|          | 2516/473047 [08:07<148:09:25,  1.13s/it][A[A[A[A[A[A[A






  1%|          | 2519/473047 [08:07<105:22:37,  1.24it/s][A[A[A[A[A[A[A






  1%|          | 2521/473047 [08:07<75:46:04,  1.73it/s] [A[A[A[A[A[A[A






  1%|          | 2524/473047 [08:07<54:59:46,  2.38it/s][A[A[A[A[A[A[A






  1%|          | 2527/473047 [08:08<40:16:35,  3.25it/s][A[A[A[A[A[A[A






  1%|          | 2530/473047 [08:08<30:05:27,  4.34it/s][A[A[A[A[A[A[A






  1%|          | 2533/473047 [08:08<22:48:35,  5.73it/s][A[A[A[A[A[A[A






  1%|          | 2539/473047 [08:08<16:39:03,  7.85it/s][A[A

  1%|          | 2792/473047 [08:20<5:41:08, 22.97it/s][A[A[A[A[A[A[A






  1%|          | 2795/473047 [08:20<5:53:56, 22.14it/s][A[A[A[A[A[A[A






  1%|          | 2798/473047 [08:20<6:18:30, 20.71it/s][A[A[A[A[A[A[A






  1%|          | 2801/473047 [08:21<5:53:49, 22.15it/s][A[A[A[A[A[A[A






  1%|          | 2804/473047 [08:21<6:31:03, 20.04it/s][A[A[A[A[A[A[A






  1%|          | 2813/473047 [08:21<5:11:30, 25.16it/s][A[A[A[A[A[A[A






  1%|          | 2817/473047 [08:21<5:17:10, 24.71it/s][A[A[A[A[A[A[A






  1%|          | 2821/473047 [08:21<5:27:58, 23.90it/s][A[A[A[A[A[A[A






  1%|          | 2824/473047 [08:21<5:42:16, 22.90it/s][A[A[A[A[A[A[A






  1%|          | 2827/473047 [08:21<5:38:19, 23.16it/s][A[A[A[A[A[A[A






  1%|          | 2831/473047 [08:22<5:04:22, 25.75it/s][A[A[A[A[A[A[A






  1%|          | 2835/473047 [08:22<4:41:49, 27.81it/s][A[A[A[A[A[A[A






  1%

Iteration:  3000  Loss: 0.100









  1%|          | 3007/473047 [09:32<1196:01:48,  9.16s/it][A[A[A[A[A[A[A






  1%|          | 3010/473047 [09:33<838:43:42,  6.42s/it] [A[A[A[A[A[A[A






  1%|          | 3012/473047 [09:33<589:13:02,  4.51s/it][A[A[A[A[A[A[A






  1%|          | 3015/473047 [09:33<414:26:21,  3.17s/it][A[A[A[A[A[A[A






  1%|          | 3018/473047 [09:33<291:26:06,  2.23s/it][A[A[A[A[A[A[A






  1%|          | 3021/473047 [09:33<205:36:38,  1.57s/it][A[A[A[A[A[A[A






  1%|          | 3024/473047 [09:33<145:56:32,  1.12s/it][A[A[A[A[A[A[A






  1%|          | 3027/473047 [09:33<104:21:00,  1.25it/s][A[A[A[A[A[A[A






  1%|          | 3030/473047 [09:34<75:58:55,  1.72it/s] [A[A[A[A[A[A[A






  1%|          | 3032/473047 [09:34<55:25:11,  2.36it/s][A[A[A[A[A[A[A






  1%|          | 3034/473047 [09:34<41:21:27,  3.16it/s][A[A[A[A[A[A[A






  1%|          | 3036/473047 [09:34<31:29:43,  4.15it/s][

  1%|          | 3288/473047 [09:46<7:02:34, 18.53it/s][A[A[A[A[A[A[A






  1%|          | 3290/473047 [09:46<8:30:08, 15.35it/s][A[A[A[A[A[A[A






  1%|          | 3292/473047 [09:46<8:15:20, 15.81it/s][A[A[A[A[A[A[A






  1%|          | 3294/473047 [09:46<7:52:24, 16.57it/s][A[A[A[A[A[A[A






  1%|          | 3296/473047 [09:46<7:48:18, 16.72it/s][A[A[A[A[A[A[A






  1%|          | 3299/473047 [09:46<7:07:38, 18.31it/s][A[A[A[A[A[A[A






  1%|          | 3302/473047 [09:47<6:49:08, 19.14it/s][A[A[A[A[A[A[A






  1%|          | 3304/473047 [09:47<6:56:00, 18.82it/s][A[A[A[A[A[A[A






  1%|          | 3306/473047 [09:47<6:55:20, 18.85it/s][A[A[A[A[A[A[A






  1%|          | 3308/473047 [09:47<6:56:39, 18.79it/s][A[A[A[A[A[A[A






  1%|          | 3311/473047 [09:47<6:44:35, 19.35it/s][A[A[A[A[A[A[A






  1%|          | 3314/473047 [09:47<6:42:06, 19.47it/s][A[A[A[A[A[A[A






  1%

Iteration:  3500  Loss: 0.100









  1%|          | 3506/473047 [10:59<601:23:54,  4.61s/it][A[A[A[A[A[A[A






  1%|          | 3508/473047 [10:59<422:57:21,  3.24s/it][A[A[A[A[A[A[A






  1%|          | 3511/473047 [10:59<297:46:48,  2.28s/it][A[A[A[A[A[A[A






  1%|          | 3514/473047 [10:59<210:13:04,  1.61s/it][A[A[A[A[A[A[A






  1%|          | 3517/473047 [10:59<148:59:14,  1.14s/it][A[A[A[A[A[A[A






  1%|          | 3520/473047 [10:59<106:06:24,  1.23it/s][A[A[A[A[A[A[A






  1%|          | 3523/473047 [11:00<76:01:19,  1.72it/s] [A[A[A[A[A[A[A






  1%|          | 3526/473047 [11:00<54:58:12,  2.37it/s][A[A[A[A[A[A[A






  1%|          | 3529/473047 [11:00<40:31:32,  3.22it/s][A[A[A[A[A[A[A






  1%|          | 3532/473047 [11:00<29:52:50,  4.36it/s][A[A[A[A[A[A[A






  1%|          | 3535/473047 [11:00<22:43:56,  5.74it/s][A[A[A[A[A[A[A






  1%|          | 3538/473047 [11:00<17:35:07,  7.42it/s][A[A

  1%|          | 3795/473047 [11:12<6:49:23, 19.10it/s][A[A[A[A[A[A[A






  1%|          | 3798/473047 [11:12<6:25:14, 20.30it/s][A[A[A[A[A[A[A






  1%|          | 3801/473047 [11:13<6:23:53, 20.37it/s][A[A[A[A[A[A[A






  1%|          | 3804/473047 [11:13<6:38:42, 19.62it/s][A[A[A[A[A[A[A






  1%|          | 3806/473047 [11:13<7:37:45, 17.08it/s][A[A[A[A[A[A[A






  1%|          | 3808/473047 [11:13<7:29:29, 17.40it/s][A[A[A[A[A[A[A






  1%|          | 3810/473047 [11:13<7:15:20, 17.96it/s][A[A[A[A[A[A[A






  1%|          | 3812/473047 [11:13<7:09:09, 18.22it/s][A[A[A[A[A[A[A






  1%|          | 3815/473047 [11:13<6:46:04, 19.26it/s][A[A[A[A[A[A[A






  1%|          | 3817/473047 [11:13<6:41:56, 19.46it/s][A[A[A[A[A[A[A






  1%|          | 3820/473047 [11:14<6:55:30, 18.82it/s][A[A[A[A[A[A[A






  1%|          | 3823/473047 [11:14<6:40:43, 19.52it/s][A[A[A[A[A[A[A






  1%

In [None]:
cdf(train_scores, 0)