In [1]:
# Starts the autoreload extension, which allows editing the .py files with the notebook running and automatically imports the latest changes

%load_ext autoreload
%autoreload 2

import trim_duplicates, model, network, gradcam, plots
from dataset import Dataset
import haiku as hk
import jax.numpy as jnp
import numpy as np
import jax
import sklearn
import wandb

assert jax.local_device_count() >= 8

NUM_CLASSES = 4
SEED = 12
BATCH_SIZE = 128

def basemodel_process(x): return x

In [2]:
rng = jax.random.PRNGKey(SEED)
dataset_mendeley = Dataset.load("mendeley", rng=rng)
dataset_tawsifur = Dataset.load("tawsifur", rng=rng)
print("Loaded mendeley", dataset_mendeley.classnames)
print("Loaded tawsifur", dataset_tawsifur.classnames)

tcmalloc: large alloc 7241465856 bytes == 0x864ce000 @  0x7f24f935d680 0x7f24f937e824 0x7f24ef1a44ce 0x7f24ef1fa00e 0x7f24ef1fac4f 0x7f24ef29c924 0x5f5db9 0x5f698e 0x57195c 0x56a0ba 0x5f6343 0x56cf2a 0x56a0ba 0x5f6343 0x570e46 0x56a0ba 0x5f6343 0x56cf2a 0x56a0ba 0x68d5b7 0x600f54 0x5c5530 0x56bddd 0x5004f8 0x56d80c 0x5004f8 0x56d80c 0x5004f8 0x5042c6 0x56bf09 0x5f6166
tcmalloc: large alloc 7241465856 bytes == 0x23fca0000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba
tcmalloc: large alloc 7241465856 bytes == 0x3f1d9a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7

Loaded mendeley ['Normal', 'Viral pneumonia', 'COVID-19', 'Pneumonia bacterial']
Loaded tawsifur ['Normal', 'Viral pneumonia', 'COVID-19', 'Lung opacity']


In [3]:
net, optim = model.init_net_and_optim(dataset_mendeley.x_train, NUM_CLASSES, BATCH_SIZE)

# Gets functions for the model
net_container = network.create(net, optim, BATCH_SIZE, shape = (10, 256, 256, 3))

In [4]:
def five_fold_cross_validation(model_name, original_dataset, process_fn, dup_thresh, seed=12, num_epochs=30):
    config = { 'dataset' : original_dataset.name,
               'random_seed' : seed,
               'batch_size' : BATCH_SIZE,
               'resolution' : 256 }

    group = model_name + '_CV'

    for i in range(5):
        job_type = 'train_and_eval'
        # RUN 1: Train and eval
        run = wandb.init(project='xrays', entity='usp-covid-xrays',
                         group=group,
                         job_type=job_type,
                         name=model_name + "_CV_" + str(i),
                         reinit=True, config = config)

        cv_dataset = original_dataset.five_fold(i)

        trained = model.train_model(model_name + "_CV" + str(i),
                          net_container, process_fn,
                          cv_dataset, masks = None,
                          num_epochs = num_epochs,
                          wandb_run=run)

        run.finish()

        # RUN 2: Remove duplicates
        run2 = wandb.init(project='xrays', entity='usp-covid-xrays',
                         group=group,
                         job_type='duplicate_removal',
                         name='duprem_' + model_name + "_CV_" + str(i),
                         reinit=True, config = config)
        
        sims = trim_duplicates.compute_similarities(cv_dataset, net_container, trained)
        cv_dataset_curated = trim_duplicates.remove_duplicates(cv_dataset, sims, threshold=dup_thresh)

        train_before = cv_dataset.x_train.shape[0]
        test_before = cv_dataset.x_test.shape[0]
        
        train_after = cv_dataset_curated.x_train.shape[0]
        test_after = cv_dataset_curated.x_test.shape[0]

        run2.log({
            "train-before": train_before,
            "test-before": test_before,
            "train-duplicates": train_before - train_after,
            "test-duplicates": test_before - test_after,
            "train-after": train_after,
            "test-after": test_after
        })

        run2.finish()

        # RUN 3: Re-train and re-eval
        run3 = wandb.init(project='xrays', entity='usp-covid-xrays',
                         group=group,
                         job_type=job_type + '_nodups',
                         name='nodups_' + model_name + "_CV_" + str(i),
                         reinit=True, config = config)

        model.train_model('nodups_' + model_name + "_CV" + str(i),
                          net_container, process_fn,
                          cv_dataset_curated, masks = None,
                          num_epochs = num_epochs,
                          wandb_run=run3)

        run3.finish()

In [5]:
dup_thresholds = {
    'mendeley': 0.99,
    'tawsifur': 0.998
}

five_fold_cross_validation("base_mendeley", dataset_mendeley, basemodel_process,
                            dup_thresh=dup_thresholds['mendeley'], num_epochs=30)
#five_fold_cross_validation("base_tawsifur", dataset_tawsifur, basemodel_process,
#                            dup_thresh=dup_thresholds['tawsifur'], num_epochs=30)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: pedromartelleto (use `wandb login --relogin` to force relogin)


tcmalloc: large alloc 5793644544 bytes == 0xebe05a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba
tcmalloc: large alloc 5793644544 bytes == 0x101759a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x5f6166
tcmalloc: large alloc 7241465856 bytes == 0x129f21a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166

Model saved to models/base_mendeley_CV0.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▂▂▃▄▅▅▆▆▆▆▆▇▇▇▇▇▆▆▇▇▆▇▇█▇▇██▇██████████
loss,█▇▇█▆▅▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▃▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
val_loss,██▆▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▂▁▂▁▂▂▁▂▂▂▂▂

0,1
acc,1.0
loss,0.01312
val_acc,0.88281
val_loss,0.47781


Calculating embeddings...


100%|███████████████████████████████████████████████████████████████████████████████████| 71/71 [00:28<00:00,  2.49it/s]
tcmalloc: large alloc 9529458688 bytes == 0x154fc1a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x5f6343 0x50b291 0x5f56c7
tcmalloc: large alloc 9529458688 bytes == 0x178841a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba


Computing cosine similarities...


100%|██████████| 9208/9208 [00:36<00:00, 251.61it/s]


trim_duplicates.remove_duplicates - Removed images: 3373 (36.6%)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test-after,▁
test-before,▁
test-duplicates,▁
train-after,▁
train-before,▁
train-duplicates,▁

0,1
test-after,1630
test-before,1841
test-duplicates,211
train-after,4205
train-before,7367
train-duplicates,3162


100%|█████████████████████████████████| 32/32 [00:04<00:00,  7.55it/s, loss=3.00, acc=0.35, val_loss=1.31, val_acc=0.37]
100%|█████████████████████████████████| 32/32 [00:03<00:00,  8.69it/s, loss=1.31, acc=0.35, val_loss=1.30, val_acc=0.37]
100%|█████████████████████████████████| 32/32 [00:04<00:00,  7.89it/s, loss=1.31, acc=0.36, val_loss=1.31, val_acc=0.37]
100%|█████████████████████████████████| 32/32 [00:05<00:00,  5.91it/s, loss=1.31, acc=0.36, val_loss=1.30, val_acc=0.37]
100%|█████████████████████████████████| 32/32 [00:03<00:00,  8.32it/s, loss=1.30, acc=0.36, val_loss=1.30, val_acc=0.37]
100%|█████████████████████████████████| 32/32 [00:04<00:00,  7.77it/s, loss=1.30, acc=0.36, val_loss=1.30, val_acc=0.37]
100%|█████████████████████████████████| 32/32 [00:04<00:00,  7.91it/s, loss=1.30, acc=0.36, val_loss=1.30, val_acc=0.37]
100%|█████████████████████████████████| 32/32 [00:04<00:00,  7.98it/s, loss=1.29, acc=0.39, val_loss=1.28, val_acc=0.37]
100%|███████████████████████████

Model saved to models/nodups_base_mendeley_CV0.pickle


VBox(children=(Label(value=' 0.03MB of 1023.79MB uploaded (0.00MB deduped)\r'), FloatProgress(value=2.93454877…

0,1
acc,▂▁▂▃▂▃▂▃▂▃▃▄▃▄▅▆▅▆▆▆▆▅▆▆▆▆▆▇▇▇▇█▇█▇█████
loss,█▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▁▁▁▁▁▁▂▃▄▅▅▅▅▅▅▅▆▆▇▇█▇██████
val_loss,█████████▇▇▆▅▅▅▄▄▄▄▃▂▂▁▂▁▁▁▁▁▁

0,1
acc,0.83594
loss,0.39412
val_acc,0.79883
val_loss,0.46025


100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.84it/s, loss=2.10, acc=0.35, val_loss=1.30, val_acc=0.35]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.20it/s, loss=1.30, acc=0.36, val_loss=1.30, val_acc=0.35]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.19it/s, loss=1.30, acc=0.37, val_loss=1.31, val_acc=0.34]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.20it/s, loss=1.26, acc=0.44, val_loss=1.23, val_acc=0.48]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.13it/s, loss=1.15, acc=0.54, val_loss=1.13, val_acc=0.57]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.11it/s, loss=0.93, acc=0.63, val_loss=0.82, val_acc=0.66]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.03it/s, loss=0.76, acc=0.69, val_loss=0.74, val_acc=0.70]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.08it/s, loss=0.64, acc=0.73, val_loss=0.65, val_acc=0.72]
100%|███████████████████████████

Model saved to models/base_mendeley_CV1.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▂▂▂▂▃▄▅▆▇▅▆▇▇▇▇▇▆▇▆▇▇▇▇▇▇▇█▇▇█▇████████
loss,█▇▇██▇▆▅▄▄▄▃▂▂▃▃▃▃▂▃▂▃▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▂▁▁
val_acc,▁▁▁▃▄▆▆▆▇▇▇███████████████████
val_loss,███▇▆▄▃▃▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.90625
loss,0.18189
val_acc,0.83036
val_loss,0.43905


Calculating embeddings...


100%|███████████████████████████████████████████████████████████████████████████████████| 71/71 [00:24<00:00,  2.95it/s]
tcmalloc: large alloc 9529458688 bytes == 0x178841a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x5f6343 0x50b291 0x5f56c7
tcmalloc: large alloc 9529458688 bytes == 0x1b70e1a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba


Computing cosine similarities...


100%|██████████| 9208/9208 [00:32<00:00, 285.56it/s]


trim_duplicates.remove_duplicates - Removed images: 3097 (33.6%)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test-after,▁
test-before,▁
test-duplicates,▁
train-after,▁
train-before,▁
train-duplicates,▁

0,1
test-after,1669
test-before,1841
test-duplicates,172
train-after,4442
train-before,7367
train-duplicates,2925


100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.96it/s, loss=2.67, acc=0.35, val_loss=1.31, val_acc=0.33]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.40it/s, loss=1.31, acc=0.36, val_loss=1.32, val_acc=0.36]
100%|█████████████████████████████████| 34/34 [00:03<00:00,  8.64it/s, loss=1.30, acc=0.37, val_loss=1.31, val_acc=0.34]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.45it/s, loss=1.29, acc=0.39, val_loss=1.30, val_acc=0.40]
100%|█████████████████████████████████| 34/34 [00:06<00:00,  5.44it/s, loss=1.24, acc=0.48, val_loss=1.17, val_acc=0.52]
100%|█████████████████████████████████| 34/34 [00:03<00:00,  8.81it/s, loss=1.05, acc=0.59, val_loss=0.95, val_acc=0.61]
100%|█████████████████████████████████| 34/34 [00:03<00:00,  8.75it/s, loss=0.93, acc=0.62, val_loss=0.91, val_acc=0.62]
100%|█████████████████████████████████| 34/34 [00:03<00:00,  8.84it/s, loss=0.89, acc=0.62, val_loss=0.89, val_acc=0.62]
100%|███████████████████████████

Model saved to models/nodups_base_mendeley_CV1.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▁▂▂▁▂▄▄▄▄▅▅▃▆▆▇▆▆▇▆▆▆▇▆▇▆▆▇▇██▆▇▆████▇█
loss,█▃▄▄▄▄▃▃▃▃▂▃▃▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁▂▂▁▁▁▁▁▁
val_acc,▁▁▁▂▄▅▅▅▅▆▇▇▇▇▇▇▇▇▇▇██████████
val_loss,████▇▅▅▅▄▄▂▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.82031
loss,0.41252
val_acc,0.8095
val_loss,0.50125


100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.48it/s, loss=2.10, acc=0.36, val_loss=1.31, val_acc=0.36]
100%|█████████████████████████████████| 57/57 [00:05<00:00, 10.20it/s, loss=1.30, acc=0.36, val_loss=1.31, val_acc=0.36]
100%|█████████████████████████████████| 57/57 [00:05<00:00, 10.39it/s, loss=1.29, acc=0.39, val_loss=1.26, val_acc=0.50]
100%|█████████████████████████████████| 57/57 [00:05<00:00, 10.16it/s, loss=1.21, acc=0.50, val_loss=1.17, val_acc=0.50]
100%|█████████████████████████████████| 57/57 [00:05<00:00, 10.19it/s, loss=1.12, acc=0.56, val_loss=1.21, val_acc=0.50]
100%|█████████████████████████████████| 57/57 [00:05<00:00, 10.45it/s, loss=1.04, acc=0.60, val_loss=0.94, val_acc=0.62]
100%|█████████████████████████████████| 57/57 [00:05<00:00, 10.29it/s, loss=0.85, acc=0.66, val_loss=0.64, val_acc=0.74]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.78it/s, loss=0.60, acc=0.74, val_loss=0.62, val_acc=0.74]
100%|███████████████████████████

Model saved to models/base_mendeley_CV2.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▂▁▂▂▃▃▃▄▆▅▆▆▇▇▆▇▆▆▆▇▆▇▇▆▇▇█▇▇▇▆▇▇▇██▇██
loss,█▇▇██▇▆▇▆▄▄▃▃▃▂▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁
val_acc,▁▁▃▃▃▅▆▇▇▇▇▇▇▇▇▇██████████████
val_loss,███▇▇▅▃▃▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.94531
loss,0.19463
val_acc,0.83929
val_loss,0.41881


Calculating embeddings...


100%|███████████████████████████████████████████████████████████████████████████████████| 71/71 [00:26<00:00,  2.69it/s]
tcmalloc: large alloc 9529458688 bytes == 0x1b70e1a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x5f6343 0x50b291 0x5f56c7
tcmalloc: large alloc 9529458688 bytes == 0x1da961a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba


Computing cosine similarities...


100%|██████████| 9208/9208 [00:34<00:00, 269.84it/s]


trim_duplicates.remove_duplicates - Removed images: 3102 (33.7%)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test-after,▁
test-before,▁
test-duplicates,▁
train-after,▁
train-before,▁
train-duplicates,▁

0,1
test-after,1672
test-before,1841
test-duplicates,169
train-after,4434
train-before,7367
train-duplicates,2933


100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.35it/s, loss=2.59, acc=0.36, val_loss=1.35, val_acc=0.34]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.82it/s, loss=1.30, acc=0.36, val_loss=1.33, val_acc=0.35]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.79it/s, loss=1.30, acc=0.37, val_loss=1.32, val_acc=0.35]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.89it/s, loss=1.29, acc=0.38, val_loss=1.32, val_acc=0.35]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.83it/s, loss=1.28, acc=0.40, val_loss=1.28, val_acc=0.46]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.57it/s, loss=1.23, acc=0.50, val_loss=1.23, val_acc=0.49]
100%|█████████████████████████████████| 34/34 [00:03<00:00,  8.66it/s, loss=1.15, acc=0.55, val_loss=1.07, val_acc=0.56]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.26it/s, loss=0.96, acc=0.62, val_loss=0.92, val_acc=0.62]
100%|███████████████████████████

Model saved to models/nodups_base_mendeley_CV2.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▃▁▁▁▁▂▃▃▂▅▅▅▅▅▆▅▇▅▆▇▇▆▆▆▇▇▆▇█▇▇▇▇▇▇████▇
loss,█▄▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁
val_acc,▁▁▁▁▃▃▄▅▅▅▆▆▇▇▇▇▇█████████████
val_loss,████▇▇▆▅▅▄▄▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▂

0,1
acc,0.82812
loss,0.33808
val_acc,0.80649
val_loss,0.49594


100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.86it/s, loss=2.11, acc=0.36, val_loss=1.30, val_acc=0.35]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.96it/s, loss=1.30, acc=0.36, val_loss=1.29, val_acc=0.36]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.19it/s, loss=1.30, acc=0.36, val_loss=1.30, val_acc=0.36]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.74it/s, loss=1.29, acc=0.41, val_loss=1.21, val_acc=0.51]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.18it/s, loss=1.21, acc=0.51, val_loss=1.17, val_acc=0.52]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.75it/s, loss=1.02, acc=0.60, val_loss=0.89, val_acc=0.65]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.59it/s, loss=0.79, acc=0.68, val_loss=0.76, val_acc=0.70]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.81it/s, loss=0.66, acc=0.73, val_loss=0.65, val_acc=0.72]
100%|███████████████████████████

Model saved to models/base_mendeley_CV3.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▂▂▁▂▃▃▃▅▆▆▆▇▇▆▇▆▆▇▇▇▇▇▇▇▇▇█▇▇▇███▇█▇███
loss,█▇███▇▇▇▅▄▄▃▃▃▂▂▃▃▃▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▁▃▃▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████
val_loss,███▇▇▅▄▃▃▂▂▂▂▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁

0,1
acc,0.9375
loss,0.16354
val_acc,0.85491
val_loss,0.40848


Calculating embeddings...


100%|███████████████████████████████████████████████████████████████████████████████████| 71/71 [00:30<00:00,  2.34it/s]
tcmalloc: large alloc 9529458688 bytes == 0x1b70e1a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x5f6343 0x50b291 0x5f56c7
tcmalloc: large alloc 9529458688 bytes == 0x1da961a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba


Computing cosine similarities...


100%|██████████| 9208/9208 [00:32<00:00, 284.96it/s]


trim_duplicates.remove_duplicates - Removed images: 3177 (34.5%)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test-after,▁
test-before,▁
test-duplicates,▁
train-after,▁
train-before,▁
train-duplicates,▁

0,1
test-after,1665
test-before,1841
test-duplicates,176
train-after,4366
train-before,7367
train-duplicates,3001


100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.04it/s, loss=2.55, acc=0.35, val_loss=1.32, val_acc=0.35]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.32it/s, loss=1.32, acc=0.35, val_loss=1.31, val_acc=0.35]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.34it/s, loss=1.31, acc=0.35, val_loss=1.30, val_acc=0.35]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.26it/s, loss=1.27, acc=0.40, val_loss=1.19, val_acc=0.50]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.25it/s, loss=1.12, acc=0.53, val_loss=1.04, val_acc=0.58]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.30it/s, loss=0.93, acc=0.64, val_loss=0.94, val_acc=0.62]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.35it/s, loss=0.85, acc=0.66, val_loss=0.83, val_acc=0.67]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.31it/s, loss=0.79, acc=0.68, val_loss=0.77, val_acc=0.68]
100%|███████████████████████████

Model saved to models/nodups_base_mendeley_CV3.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▃▂▁▁▁▂▃▄▄▅▆▄▆▇▆▆▇▇▇▆▇▆▇▇▇▇▆▆▆▇▇▇▇▇█▇▇▇██
loss,█▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▁▃▅▅▆▆▆▇▇▇██████████████████
val_loss,███▇▆▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁

0,1
acc,0.83594
loss,0.36891
val_acc,0.8155
val_loss,0.51056


100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.64it/s, loss=2.11, acc=0.35, val_loss=1.28, val_acc=0.36]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.75it/s, loss=1.31, acc=0.35, val_loss=1.28, val_acc=0.38]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.74it/s, loss=1.30, acc=0.36, val_loss=1.28, val_acc=0.38]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.77it/s, loss=1.29, acc=0.40, val_loss=1.22, val_acc=0.49]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.75it/s, loss=1.21, acc=0.50, val_loss=1.15, val_acc=0.55]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.77it/s, loss=1.12, acc=0.56, val_loss=1.07, val_acc=0.59]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.52it/s, loss=1.00, acc=0.61, val_loss=0.95, val_acc=0.61]
100%|█████████████████████████████████| 57/57 [00:09<00:00,  6.05it/s, loss=0.82, acc=0.67, val_loss=0.76, val_acc=0.69]
100%|███████████████████████████

Model saved to models/base_mendeley_CV4.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▁▁▁▂▃▄▂▄▄▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇████▇██▇██
loss,█████▇▇▇▆▆▅▄▃▃▃▃▂▂▃▃▂▂▂▂▂▂▂▁▂▂▁▂▂▂▂▁▁▂▁▁
val_acc,▁▁▁▃▄▄▅▆▇▇▇▇▇▇████████████████
val_loss,████▇▆▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.94531
loss,0.20292
val_acc,0.84152
val_loss,0.38745


Calculating embeddings...


100%|███████████████████████████████████████████████████████████████████████████████████| 71/71 [00:29<00:00,  2.41it/s]
tcmalloc: large alloc 9529458688 bytes == 0x1da961a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x5f6343 0x50b291 0x5f56c7
tcmalloc: large alloc 9529458688 bytes == 0x14cfc1a000 @  0x7f24f935d680 0x7f24f937e824 0x7f24f937eb8a 0x7f22d18696b7 0x7f22cc560790 0x7f22cc56f414 0x7f22cc572287 0x7f22cc4bdf0f 0x7f22cc22fbe8 0x7f22cc21c166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba


Computing cosine similarities...


100%|██████████| 9208/9208 [00:33<00:00, 277.86it/s]


trim_duplicates.remove_duplicates - Removed images: 3110 (33.8%)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
test-after,▁
test-before,▁
test-duplicates,▁
train-after,▁
train-before,▁
train-duplicates,▁

0,1
test-after,1671
test-before,1841
test-duplicates,170
train-after,4427
train-before,7367
train-duplicates,2940


100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.44it/s, loss=2.58, acc=0.35, val_loss=1.30, val_acc=0.37]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.05it/s, loss=1.31, acc=0.35, val_loss=1.29, val_acc=0.40]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.14it/s, loss=1.29, acc=0.39, val_loss=1.27, val_acc=0.42]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.00it/s, loss=1.27, acc=0.40, val_loss=1.25, val_acc=0.46]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  7.97it/s, loss=1.20, acc=0.48, val_loss=1.09, val_acc=0.57]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  6.96it/s, loss=1.09, acc=0.55, val_loss=1.04, val_acc=0.59]
100%|█████████████████████████████████| 34/34 [00:04<00:00,  8.21it/s, loss=0.96, acc=0.63, val_loss=0.87, val_acc=0.65]
100%|█████████████████████████████████| 34/34 [00:05<00:00,  5.89it/s, loss=0.81, acc=0.69, val_loss=0.73, val_acc=0.70]
100%|███████████████████████████

Model saved to models/nodups_base_mendeley_CV4.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▃▁▁▃▂▂▃▃▅▄▆▆▆▇▇▇▆▇▇▆▇▆▇▇▇▇▆█▆███▇▇▇███▇▇
loss,█▄▄▄▄▄▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▂▂▄▄▅▆▆▆▆▇▇▇▇▇▇█████████████
val_loss,████▆▆▅▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.83594
loss,0.39764
val_acc,0.82212
val_loss,0.44632


In [None]:
basemodel_tawsifur = model.train_model(model_name + "CV" + str(i), net_container, process_fn, cv_dataset)
sims = trim_duplicates.compute_similarities(dataset_tawsifur, net_container, basemodel_tawsifur)

In [None]:
thresh = 0.998
max_sims = sims.max(axis=1) - thresh
y_classes = dataset_tawsifur.y_all[:sims.shape[0]].argmax(1)
max_sims_index = sims.argmax(axis=1)
mask = (max_sims >= 0) & (max_sims <= 0.0005)
indices = np.where(mask)[0]
plots.compare_images(dataset_tawsifur.x_all[indices], dataset_tawsifur.x_all[max_sims_index[indices]], rows=10)
#trim_duplicates.plot_similarities(dataset_tawsifur, sims, threshold=0.99)

In [None]:
trim_duplicates.plot_similarities(dataset_tawsifur, sims, threshold=thresh)
dataset_tawsifur_curated = trim_duplicates.remove_duplicates(dataset_tawsifur, sims, threshold=thresh)

In [None]:
print(dataset_tawsifur_curated.x_train.shape, dataset_tawsifur.x_train.shape)

In [None]:
basemodel_tawsifur_curated = model.train_model("basemodel_tawsifur_curated", net_container, basemodel_process, dataset_tawsifur_curated)
y_pred_tawsifur_curated = net_container.predict(basemodel_tawsifur_curated.params, basemodel_tawsifur_curated.state, dataset_tawsifur_curated.x_test)
plots.confusion_matrix(dataset_tawsifur_curated, y_pred_tawsifur_curated, "Tawsifur - Curated")

# Transfer learning test

In [None]:
rng = jax.random.PRNGKey(SEED)
dataset_mendeley = Dataset.load("mendeley", rng=rng)

net, optim = model.init_net_and_optim(dataset_mendeley.x_train, NUM_CLASSES, BATCH_SIZE)

# Gets functions for the model
net_container = network.create(net, optim, BATCH_SIZE, shape = (10, 256, 256, 3))

In [None]:
# Test of tawsifur on mendeley

basemodel_tawsifur = model.train_model("basemodel_tawsifurCV0", net_container, basemodel_process, dataset_mendeley)
y_test_pred = net_container.predict(basemodel_tawsifur.params, basemodel_tawsifur.state, dataset_mendeley.x_test)
matrix = sklearn.metrics.confusion_matrix(
        dataset_mendeley.y_test[0:y_test_pred.shape[0],].argmax(1),
        y_test_pred.argmax(1), normalize = 'true'
    )
plots.heatmatrix(matrix, "Transfer learning from tawsifur to mendeley")