In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
from pprint import pprint as pp
from time import time as tt
import inspect
import importlib
import yaml

# External imports
import matplotlib.pyplot as plt
import matplotlib.colors
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch_geometric.data import DataLoader
from mpl_toolkits.mplot3d import Axes3D

from itertools import chain
from random import shuffle, sample

from torch.nn import Linear
import torch.nn.functional as F
from torch_scatter import scatter, segment_csr, scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_cluster import knn_graph, radius_graph
import trackml.dataset
import torch_geometric

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import LightningModule, Trainer

sys.path.append('..')
device = "cuda" if torch.cuda.is_available() else "cpu"

# Triplet Construction Exploration

## Really Dumb Method

In [470]:
x = torch.rand((100000,3)).to(device)

In [478]:
k = 5
r = 0.01

In [484]:
%%time
knn_object = ops.knn_points(x.unsqueeze(0), x.unsqueeze(0), K=k, return_sorted=False)
I, D = knn_object.idx[0], knn_object.dists[0]

CPU times: user 325 µs, sys: 269 µs, total: 594 µs
Wall time: 411 µs


Want the indices satisfying < R and negative PID on each row

In [486]:
%%time
ind = torch.Tensor.repeat(torch.arange(I.shape[0], device=device), (I.shape[1], 1), 1).T
shuffled_index =  torch.randperm(I.shape[1])
shuffled_I = I[:,shuffled_index]
shuffled_D = D[:, shuffled_index]
shuffled_I[shuffled_D> r**2] = -1
# edge_list = torch.stack([ind[D <= r**2], I[D <= r**2]])
# edge_list = edge_list[:, edge_list[0] != edge_list[1]]

CPU times: user 671 µs, sys: 556 µs, total: 1.23 ms
Wall time: 801 µs


In [487]:
%%time
shuffled_I

CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 5.25 µs


tensor([[   -1,     0,    -1, 54537,    -1],
        [   -1,    -1,    -1,    -1,     1],
        [    2,    -1,    -1,    -1,    -1],
        ...,
        [   -1, 99997,    -1,    -1,    -1],
        [   -1,    -1,    -1,    -1, 99998],
        [99999,    -1,    -1,    -1,    -1]], device='cuda:0')

In [None]:
pid

In [446]:
%%time
e = edge_list.cpu().numpy()
e_sp = sp.sparse.coo_matrix((np.ones(e.shape[1]), e))
e_sp = e_sp.tolil()
row_indices = e_sp.rows

CPU times: user 362 ms, sys: 3.89 ms, total: 366 ms
Wall time: 364 ms


In [448]:
import itertools

In [449]:
%%time
rect_indices = np.column_stack((itertools.zip_longest(*row_indices, fillvalue=0)))

CPU times: user 301 ms, sys: 11.8 ms, total: 313 ms
Wall time: 312 ms


  """Entry point for launching an IPython kernel.


In [450]:
%%time
shuffled_rect = rect_indices[:, np.random.permutation(rect_indices.shape[1])]

CPU times: user 897 µs, sys: 715 µs, total: 1.61 ms
Wall time: 1.09 ms


In [451]:
%%time
shuffled_rect = push_all_zeros_back(shuffled_rect)

CPU times: user 5.25 ms, sys: 983 µs, total: 6.24 ms
Wall time: 5.75 ms


In [455]:
num_true = np.random.randint(0, k, row_indices.shape[0])

In [456]:
%%time
selected_negatives = shuffled_rect[num_true[:,None] > np.arange(shuffled_rect.shape[1])]

CPU times: user 4.22 ms, sys: 0 ns, total: 4.22 ms
Wall time: 3.75 ms


In [356]:
%%time
neg_available = [len(row) > num for row, num in zip(row_indices, num_true)]
pos_available = num_true > 0
pass_available = neg_available & pos_available

CPU times: user 111 µs, sys: 81 µs, total: 192 µs
Wall time: 196 µs


In [357]:
%%time
zipped_array = zip(row_indices[pass_available], num_true[pass_available])

CPU times: user 10.6 ms, sys: 0 ns, total: 10.6 ms
Wall time: 10.6 ms


In [358]:
%%time
zipped_array = np.vstack([row_indices[pass_available], num_true[pass_available]]).T

CPU times: user 60 µs, sys: 0 ns, total: 60 µs
Wall time: 62.5 µs


In [359]:
def random_func(row, max_size = None):
#     print("step", row[0], row[1])
    empty = np.zeros(max_size, dtype=int)
    random_choices = np.random.choice(row[0], row[1], replace=False)
    empty[:random_choices.shape[0]] = random_choices
    return empty

In [360]:
%%time
np.apply_along_axis(random_func, 1, zipped_array, max_size = num_true[pass_available].max())

CPU times: user 221 µs, sys: 160 µs, total: 381 µs
Wall time: 359 µs


array([[3],
       [3]])

In [363]:
rand_matrix = np.random.randint(0, 10, (10, 5))

In [369]:
rand_matrix[:]

array([[6, 5, 7, 3, 2],
       [1, 0, 1, 4, 6],
       [8, 6, 6, 7, 4],
       [6, 3, 8, 7, 4],
       [6, 2, 2, 4, 6],
       [8, 4, 4, 1, 6],
       [1, 3, 7, 4, 9],
       [6, 7, 9, 3, 0],
       [9, 9, 4, 8, 3],
       [3, 3, 3, 4, 0]])

In [370]:
a = np.arange(25).reshape([5, 5])
numbers = np.array([3, 2, 0, 1, 2])

In [371]:
a

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])

In [372]:
numbers

array([3, 2, 0, 1, 2])

In [373]:
numbers[:, None]

array([[3],
       [2],
       [0],
       [1],
       [2]])

In [374]:
numbers[:,None] > np.arange(a.shape[1])

array([[ True,  True,  True, False, False],
       [ True,  True, False, False, False],
       [False, False, False, False, False],
       [ True, False, False, False, False],
       [ True,  True, False, False, False]])

In [362]:
num_true

array([4, 1, 0, 2, 4, 1, 3, 0, 1, 2])

In [270]:
row_indices[pass_available]

array([list([7594, 36900, 40127, 41911]),
       list([11814, 22120, 57996, 82412]),
       list([7572, 28477, 44805, 50811]), ...,
       list([5757, 42958, 54359, 71728]),
       list([21794, 67340, 68281, 98877]),
       list([26659, 33479, 57606, 90733])], dtype=object)

In [269]:
np.array(zipped_array)

array(<zip object at 0x2aab7365a0f0>, dtype=object)

In [264]:
%%time
random_negatives = [np.random.choice(row, num, replace=False) for row, num in zip(row_indices[pass_available], num_true[pass_available])]

CPU times: user 3.15 s, sys: 88.2 ms, total: 3.24 s
Wall time: 3.14 s


In [None]:
e_sp

In [133]:
radius_indices = D < 0.3**2

In [134]:
radius_indices

tensor([[ True, False, False, False, False],
        [False,  True, False, False,  True],
        [False, False,  True, False, False],
        [False, False, False,  True, False],
        [False, False, False, False,  True],
        [False, False, False,  True, False],
        [False, False, False, False,  True],
        [False, False, False, False,  True],
        [False,  True, False, False,  True],
        [False, False, False, False,  True]], device='cuda:0')

## Some Smarter Methods

In [8]:
def push_all_negs_back(a):
    # Based on http://stackoverflow.com/a/42859463/3293881
    valid_mask = a!=-1
    flipped_mask = valid_mask.sum(1,keepdims=1) > np.arange(a.shape[1]-1,-1,-1)
    flipped_mask = flipped_mask[:,::-1]
    a[flipped_mask] = a[valid_mask]
    a[~flipped_mask] = -1
    return a

In [9]:
from pytorch3d import ops

In [10]:
model.setup(stage="fit")

In [76]:
k = 20
r = 0.1

In [77]:
data = model.trainset[0].to(device)

In [141]:
x = data.x
pid = data.pid
true_edges = data.pid_true_edges
# true_edges = torch.cat([data.pid_true_edges, data.pid_true_edges.flip(0)], axis=-1)

In [142]:
%%time
torch_e = torch.sparse.FloatTensor(true_edges, torch.ones(true_edges.shape[1]).to(device), size=(len(x), len(x)))
sparse_sum = torch.sparse.sum(torch_e, dim=0)

num_true_torch = torch.zeros(len(x)).to(device).int()
num_true_torch[sparse_sum.indices()] = sparse_sum.values().int()

CPU times: user 643 µs, sys: 1.12 ms, total: 1.77 ms
Wall time: 1.35 ms


In [143]:
%%time
knn_object = ops.knn_points(x.unsqueeze(0), x.unsqueeze(0), K=k, return_sorted=False)
I, D = knn_object.idx[0], knn_object.dists[0]

CPU times: user 509 µs, sys: 0 ns, total: 509 µs
Wall time: 367 µs


In [144]:
%%time
shuffled_index =  torch.randperm(I.shape[1])
shuffled_I = I[:,shuffled_index]
shuffled_D = D[:, shuffled_index]
ind = torch.Tensor.repeat(torch.arange(shuffled_I.shape[0], device=device), (shuffled_I.shape[1], 1), 1).T

CPU times: user 533 µs, sys: 0 ns, total: 533 µs
Wall time: 387 µs


In [145]:
%%time
shuffled_I[shuffled_D> r**2] = -1
shuffled_I[pid[ind] == pid[shuffled_I]] = -1
shuffled_I[ind == shuffled_I] = -1

CPU times: user 786 µs, sys: 0 ns, total: 786 µs
Wall time: 644 µs


In [146]:
%%time
squished_I = push_all_negs_back(shuffled_I.cpu().numpy())
squished_I = torch.from_numpy(squished_I).to(device)

CPU times: user 4.41 ms, sys: 721 µs, total: 5.13 ms
Wall time: 4.59 ms


In [147]:
num_false = (squished_I > -1).sum(axis=1)

In [148]:
squished_I

tensor([[ 1254,   660,   283,  ...,   259,   992,    -1],
        [  531,   365,  1039,  ...,   182,    -1,    -1],
        [ 4927,  5537,  5470,  ...,    -1,    -1,    -1],
        ...,
        [14327,    -1,    -1,  ...,    -1,    -1,    -1],
        [14032, 14307, 14378,  ...,    -1,    -1,    -1],
        [   -1,    -1,    -1,  ...,    -1,    -1,    -1]], device='cuda:0')

In [149]:
num_true_torch.max()

tensor(18, device='cuda:0', dtype=torch.int32)

In [150]:
num_false

tensor([19, 18, 17,  ...,  1,  4,  0], device='cuda:0')

In [151]:
neg_available = num_false > num_true_torch
pos_available = num_true_torch > 0
pass_available = neg_available & pos_available

In [152]:
squished_I = torch.cat([squished_I, -1*torch.ones(squished_I.shape[0], max(0, num_true_torch.max() - k), dtype=int, device=device)], axis=-1)

In [153]:
squished_I

tensor([[ 1254,   660,   283,  ...,   259,   992,    -1],
        [  531,   365,  1039,  ...,   182,    -1,    -1],
        [ 4927,  5537,  5470,  ...,    -1,    -1,    -1],
        ...,
        [14327,    -1,    -1,  ...,    -1,    -1,    -1],
        [14032, 14307, 14378,  ...,    -1,    -1,    -1],
        [   -1,    -1,    -1,  ...,    -1,    -1,    -1]], device='cuda:0')

In [154]:
num_true_torch[pass_available,None] > torch.arange(squished_I.shape[1], device=device)

tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')

In [155]:
%%time
selected_negatives = squished_I[pos_available][num_true_torch[pos_available,None] > torch.arange(squished_I.shape[1], device=device)]

CPU times: user 261 µs, sys: 462 µs, total: 723 µs
Wall time: 490 µs


In [156]:
selected_negatives.unsqueeze(0)

tensor([[1254,  660,  283,  ...,   -1,   -1,   -1]], device='cuda:0')

In [157]:
sorted_true_indices = torch.argsort(true_edges[0])

In [158]:
num_true_torch.sum()

tensor(156710, device='cuda:0')

In [159]:
sorted_true_edges = true_edges[:, sorted_true_indices]

In [160]:
sorted_true_edges

tensor([[    0,     0,     0,  ..., 14392, 14392, 14392],
        [ 8950, 11941,  6739,  ..., 12768, 13202,  8423]], device='cuda:0')

In [161]:
triplets = torch.cat([sorted_true_edges, selected_negatives.unsqueeze(0)], axis=0)

In [162]:
triplets = triplets[:, triplets[2]!=-1]

In [163]:
triplets

tensor([[    0,     0,     0,  ..., 14391, 14391, 14391],
        [ 8950, 11941,  6739,  ...,  8940, 11707, 13548],
        [ 1254,   660,   283,  ..., 14307, 14378, 14028]], device='cuda:0')

## Validation

In [128]:
reference = x.index_select(0, triplets[2])
neighbors = x.index_select(0, triplets[0])
d = torch.sum((reference - neighbors)**2, dim=-1)

In [131]:
(d < r**2).all()

tensor(True, device='cuda:0')

Good

In [133]:
(pid[triplets[0]] == pid[triplets[1]]).all()

tensor(True, device='cuda:0')

In [136]:
(pid[triplets[0]] != pid[triplets[2]]).all()

tensor(True, device='cuda:0')

## Full Method

In [140]:
data.pid_true_edges

tensor([[11481, 11481, 11481,  ..., 11345, 11345, 11345],
        [ 5723,  5675,  5664,  ...,  9272,  7997,  9215]], device='cuda:0')

In [138]:
%%time
def mine_triplets(true_edges, spatial, r_max, k_max):

    # -------- TRUTH
    torch_e = torch.sparse.FloatTensor(true_edges, torch.ones(true_edges.shape[1]).to(device), size=(len(spatial), len(spatial)))
    sparse_sum = torch.sparse.sum(torch_e, dim=0)
    num_true_torch = torch.zeros(len(spatial)).to(device).int()
    num_true_torch[sparse_sum.indices()] = sparse_sum.values().int()
    sorted_true_indices = torch.argsort(true_edges[0])
    sorted_true_edges = true_edges[:, sorted_true_indices]

    # --------- HNM
    knn_object = ops.knn_points(spatial.unsqueeze(0), spatial.unsqueeze(0), K=k_max, return_sorted=False)
    I, D = knn_object.idx[0], knn_object.dists[0]

    # ---------- Shuffle
    shuffled_index =  torch.randperm(I.shape[1])
    shuffled_I = I[:, shuffled_index]
    shuffled_D = D[:, shuffled_index]
    ind = torch.Tensor.repeat(torch.arange(shuffled_I.shape[0], device=device), (shuffled_I.shape[1], 1), 1).T

    # ---------- Constraints
    shuffled_I[shuffled_D > r_max**2] = -1
    shuffled_I[pid[ind] == pid[shuffled_I]] = -1
    shuffled_I[ind == shuffled_I] = -1

    # ----------- Reshape with -1's
    squished_I = push_all_negs_back(shuffled_I.cpu().numpy())
    squished_I = torch.from_numpy(squished_I).to(device)

    # ---------- Handle # pos > # neg
    pos_available = num_true_torch > 0
    squished_I = torch.cat([squished_I, -1*torch.ones(squished_I.shape[0], max(0, num_true_torch.max() - k), dtype=int, device=device)], axis=-1)

    # ----------- Build Triplets
    selected_negatives = squished_I[pos_available][num_true_torch[pos_available,None] > torch.arange(squished_I.shape[1], device=device)]
    triplets = torch.cat([sorted_true_edges, selected_negatives.unsqueeze(0)], axis=0)
    triplets = triplets[:, triplets[2]!=-1]
    
    return triplets

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.48 µs


# Doublet Loss Benchmark

In [3]:
from lightning_modules.Embedding.Models.layerless_embedding import LayerlessEmbedding

In [4]:
with open("../lightning_modules/Embedding/train_coda_small_embedding.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
model = LayerlessEmbedding(hparams)
wandb_logger = WandbLogger(project='End2End-TripletEmbedding')
wandb_logger.watch(model)
wandb_logger.log_hyperparams({"model": type(model)})
trainer = Trainer(gpus=1, max_epochs=50, logger=wandb_logger)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.21 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [6]:
trainer.fit(model)

Set SLURM handle signals.

  | Name        | Type       | Params
-------------------------------------------
0 | emb_network | Sequential | 203 K 
-------------------------------------------
203 K     Trainable params
0         Non-trainable params
203 K     Total params
0.814     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  eff = torch.tensor(cluster_true_positive / cluster_true)
  pur = torch.tensor(cluster_true_positive / cluster_positive)






HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

# Triplet Loss Embedding

In [3]:
from lightning_modules.Embedding.Models.layerless_embedding import TripletEmbedding

In [4]:
with open("../lightning_modules/Embedding/train_coda_small_embedding.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
model = TripletEmbedding(hparams)
wandb_logger = WandbLogger(project='End2End-TripletEmbedding')
wandb_logger.watch(model)
wandb_logger.log_hyperparams({"model": type(model)})
trainer = Trainer(gpus=1, max_epochs=50, logger=wandb_logger, num_sanity_val_steps=0)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.21 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [6]:
trainer.fit(model)

Set SLURM handle signals.

  | Name        | Type       | Params
-------------------------------------------
0 | emb_network | Sequential | 203 K 
-------------------------------------------
203 K     Trainable params
0         Non-trainable params
203 K     Total params
0.814     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  eff = torch.tensor(cluster_true_positive / cluster_true)
  pur = torch.tensor(cluster_true_positive / cluster_positive)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1