In [1]:
# Starts the autoreload extension, which allows editing the .py files with the notebook running and automatically imports the latest changes

%load_ext autoreload
%autoreload 2

import trim_duplicates, model, network, utils, gradcam, plots
from dataset import Dataset
import haiku as hk
import jax.numpy as jnp
import numpy as np
import jax
import sklearn
import wandb

assert jax.local_device_count() >= 8

NUM_CLASSES = 4
SEED = 12
BATCH_SIZE = 128

def basemodel_process(x): return x

In [2]:
rng = jax.random.PRNGKey(SEED)
dataset_mendeley = Dataset.load("mendeley", rng=rng)
dataset_tawsifur = Dataset.load("tawsifur", rng=rng)

tcmalloc: large alloc 7241465856 bytes == 0x867b2000 @  0x7f98d1c7f680 0x7f98d1ca0824 0x7f98c70464ce 0x7f98c709c00e 0x7f98c709cc4f 0x7f98c713e924 0x5f5db9 0x5f698e 0x57195c 0x56a0ba 0x5f6343 0x56cf2a 0x56a0ba 0x5f6343 0x570e46 0x56a0ba 0x5f6343 0x56cf2a 0x56a0ba 0x68d5b7 0x600f54 0x5c5530 0x56bddd 0x5004f8 0x56d80c 0x5004f8 0x56d80c 0x5004f8 0x5042c6 0x56bf09 0x5f6166
tcmalloc: large alloc 7241465856 bytes == 0x23ffc8000 @  0x7f98d1c7f680 0x7f98d1ca0824 0x7f98d1ca0b8a 0x7f96a970b6b7 0x7f96a4402790 0x7f96a4411414 0x7f96a4414287 0x7f96a435ff0f 0x7f96a40d1be8 0x7f96a40be166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba
tcmalloc: large alloc 7241465856 bytes == 0x3f274e000 @  0x7f98d1c7f680 0x7f98d1ca0824 0x7f98d1ca0b8a 0x7f96a970b6b7 0x7f96a4402790 0x7f96a4411414 0x7f96a4414287 0x7f96a435ff0f 0x7f96a40d1be8 0x7f96a40be166 0x5f5db9 0x5f698e 0x50b4c7

In [3]:
net, optim = model.init_net_and_optim(dataset_mendeley.x_train, NUM_CLASSES, BATCH_SIZE)

# Gets functions for the model
net_container = network.create(net, optim, BATCH_SIZE, shape = (10, 256, 256, 3))

In [4]:
def five_fold_cross_validation(model_name, original_dataset, process_fn, class_names, seed=12, num_epochs=30):
    config = { 'dataset' : original_dataset.name,
               'random_seed' : seed,
               'batch_size' : BATCH_SIZE,
               'resolution' : 256 }

    group = model_name + '_CV'

    for i in range(5):
        job_type = 'train_and_eval'

        run = wandb.init(project='xrays', entity='usp-covid-xrays',
                         group=group,
                         job_type=job_type,
                         name=model_name + "_CV_" + str(i),
                         reinit=True, config = config)

        cv_dataset = original_dataset.five_fold(i)

        model.train_model(model_name + "_CV" + str(i),
                          net_container, process_fn,
                          cv_dataset, masks = None,
                          class_names = class_names,
                          num_epochs = num_epochs,
                          wandb_run=run)

        run.finish()


In [5]:
mendeley_class_names = ['Normal', 'Pneumonia Bacteriana', 'COVID-19', 'Pneumonia Viral']
five_fold_cross_validation("base_mendeley", dataset_mendeley, basemodel_process, mendeley_class_names)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: pedromartelleto (use `wandb login --relogin` to force relogin)


tcmalloc: large alloc 5793644544 bytes == 0xebea0e000 @  0x7f98d1c7f680 0x7f98d1ca0824 0x7f98d1ca0b8a 0x7f96a970b6b7 0x7f96a4402790 0x7f96a4411414 0x7f96a4414287 0x7f96a435ff0f 0x7f96a40d1be8 0x7f96a40be166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x56a0ba
tcmalloc: large alloc 5793644544 bytes == 0x1017f4e000 @  0x7f98d1c7f680 0x7f98d1ca0824 0x7f98d1ca0b8a 0x7f96a970b6b7 0x7f96a4402790 0x7f96a4411414 0x7f96a4414287 0x7f96a435ff0f 0x7f96a40d1be8 0x7f96a40be166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x50adf0 0x56cf2a 0x5f6166
tcmalloc: large alloc 7241465856 bytes == 0x129fbce000 @  0x7f98d1c7f680 0x7f98d1ca0824 0x7f98d1ca0b8a 0x7f96a970b6b7 0x7f96a4402790 0x7f96a4411414 0x7f96a4414287 0x7f96a435ff0f 0x7f96a40d1be8 0x7f96a40be166

Model saved to models/base_mendeley_CV0.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▁▂▃▄▆▅▄▆▆▆▆▆▆▆▅▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇██████
loss,▇█▇▇▅▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
val_acc,▁▃▅▆▆▇▇▇▇▇▇▇▇▇████████████████
val_loss,█▇▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.97656
loss,0.07913
val_acc,0.85658
val_loss,0.43877


tcmalloc: large alloc 7241465856 bytes == 0x15a9b0e000 @  0x7f98d1c7f680 0x7f98d1ca0824 0x7f98d1ca0b8a 0x7f96a970b6b7 0x7f96a4402790 0x7f96a4411414 0x7f96a4414287 0x7f96a435ff0f 0x7f96a40d1be8 0x7f96a40be166 0x5f5db9 0x5f698e 0x50b4c7 0x570e46 0x56a0ba 0x5f70bb 0x66600d 0x5f574e 0x56d5f6 0x56a0ba 0x5f6343 0x5f70f7 0x66600d 0x5f574e 0x56d5f6 0x5f6166 0x56bf09 0x56a0ba 0x5f6343 0x50ad7c 0x5f56c7
100%|█████████████████████████████████| 57/57 [00:07<00:00,  7.99it/s, loss=1.81, acc=0.35, val_loss=1.33, val_acc=0.34]
100%|█████████████████████████████████| 57/57 [00:07<00:00,  8.11it/s, loss=1.30, acc=0.36, val_loss=1.31, val_acc=0.34]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.11it/s, loss=1.20, acc=0.41, val_loss=1.12, val_acc=0.49]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.03it/s, loss=0.91, acc=0.63, val_loss=0.76, val_acc=0.69]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.94it/s, loss=0.69, acc=0.72, val_loss=0.65, val_acc=0.73

Model saved to models/base_mendeley_CV1.pickle


VBox(children=(Label(value=' 11.25MB of 1023.80MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.0109899…

0,1
acc,▁▂▁▁▄▅▅▅▅▅▅▆▆▆▆▅▆▆▇▇▆▆▇▇▆▇▇▇▆▇▇▇▇▇███▇██
loss,███▇▆▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▃▃▂▂▃▂▂▂▃▂▂▂▂▁▁▁▁▁▁▁
val_acc,▁▁▃▆▆▇▇▇▇▇▇▇██████████████████
val_loss,██▆▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁

0,1
acc,1.0
loss,0.04744
val_acc,0.8471
val_loss,0.40586


100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.32it/s, loss=1.81, acc=0.35, val_loss=1.31, val_acc=0.35]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.69it/s, loss=1.30, acc=0.36, val_loss=1.28, val_acc=0.35]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.81it/s, loss=1.16, acc=0.45, val_loss=1.09, val_acc=0.54]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.89it/s, loss=0.87, acc=0.66, val_loss=0.76, val_acc=0.70]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.91it/s, loss=0.65, acc=0.74, val_loss=0.65, val_acc=0.72]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.89it/s, loss=0.57, acc=0.76, val_loss=0.54, val_acc=0.77]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.95it/s, loss=0.52, acc=0.79, val_loss=0.52, val_acc=0.78]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.95it/s, loss=0.49, acc=0.80, val_loss=0.52, val_acc=0.79]
100%|███████████████████████████

Model saved to models/base_mendeley_CV2.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▂▁▂▄▆▅▆▅▅▅▆▆▆▆▇▇▆▇▆▇▇▇▇▆▇▇▇▇▇▆▇▇▇██████
loss,███▇▆▄▄▃▄▄▄▄▃▃▃▃▂▃▃▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁
val_acc,▁▁▄▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████
val_loss,██▆▄▃▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.97656
loss,0.07789
val_acc,0.85714
val_loss,0.4379


100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.04it/s, loss=1.80, acc=0.35, val_loss=1.32, val_acc=0.36]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.26it/s, loss=1.30, acc=0.35, val_loss=1.29, val_acc=0.36]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.22it/s, loss=1.17, acc=0.45, val_loss=1.05, val_acc=0.56]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.47it/s, loss=0.92, acc=0.63, val_loss=0.83, val_acc=0.64]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.22it/s, loss=0.74, acc=0.70, val_loss=0.68, val_acc=0.71]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.07it/s, loss=0.61, acc=0.75, val_loss=0.56, val_acc=0.76]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.70it/s, loss=0.54, acc=0.78, val_loss=0.58, val_acc=0.76]
100%|█████████████████████████████████| 57/57 [00:06<00:00,  8.81it/s, loss=0.50, acc=0.80, val_loss=0.51, val_acc=0.78]
100%|███████████████████████████

Model saved to models/base_mendeley_CV3.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▂▂▁▃▅▅▅▆▅▆▅▆▆▇▆▇▇▆▆▆▇▆▇▇▆▇▇▇▇▇█▆█▇██████
loss,███▇▆▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▃▂▂▃▂▂▂▂▂▁▃▁▂▁▁▁▁▁▁
val_acc,▁▁▄▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████
val_loss,██▆▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,1.0
loss,0.06714
val_acc,0.85435
val_loss,0.39558


100%|█████████████████████████████████| 57/57 [00:06<00:00,  9.42it/s, loss=1.80, acc=0.35, val_loss=1.31, val_acc=0.37]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.69it/s, loss=1.30, acc=0.35, val_loss=1.27, val_acc=0.38]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.58it/s, loss=1.17, acc=0.46, val_loss=1.02, val_acc=0.61]
100%|█████████████████████████████████| 57/57 [00:08<00:00,  6.48it/s, loss=0.93, acc=0.61, val_loss=0.76, val_acc=0.69]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.67it/s, loss=0.70, acc=0.71, val_loss=0.59, val_acc=0.76]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.71it/s, loss=0.56, acc=0.77, val_loss=0.49, val_acc=0.80]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.76it/s, loss=0.50, acc=0.79, val_loss=0.48, val_acc=0.81]
100%|█████████████████████████████████| 57/57 [00:05<00:00,  9.81it/s, loss=0.48, acc=0.80, val_loss=0.46, val_acc=0.81]
100%|███████████████████████████

Model saved to models/base_mendeley_CV4.pickle


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▂▂▁▄▄▅▅▆▅▆▆▆▇▇▆▇▇▆▆▆▇▇▇▇▆▇▇█▇▇▇▇▇███████
loss,███▇▆▅▄▃▄▃▄▃▃▃▃▃▃▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
val_acc,▁▁▄▆▇▇▇▇▇▇▇███████████████████
val_loss,██▆▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.98438
loss,0.06906
val_acc,0.85993
val_loss,0.45859


In [None]:
matrices_array = np.asarray(matrices)
plots.heatmatrix(matrices_array.mean(axis=0), "Mean heatmap")
plots.heatmatrix(matrices_array.std(axis=0), "Std heatmap")

In [None]:
basemodel_tawsifur = model.train_model(model_name + "CV" + str(i), net_container, process_fn, cv_dataset)
sims = trim_duplicates.compute_similarities(dataset_tawsifur, net_container, basemodel_tawsifur)

In [None]:
thresh = 0.998
max_sims = sims.max(axis=1) - thresh
y_classes = dataset_tawsifur.y_all[:sims.shape[0]].argmax(1)
max_sims_index = sims.argmax(axis=1)
mask = (max_sims >= 0) & (max_sims <= 0.0005)
indices = np.where(mask)[0]
plots.compare_images(dataset_tawsifur.x_all[indices], dataset_tawsifur.x_all[max_sims_index[indices]], rows=10)
#trim_duplicates.plot_similarities(dataset_tawsifur, sims, threshold=0.99)

In [None]:
trim_duplicates.plot_similarities(dataset_tawsifur, sims, threshold=thresh)
dataset_tawsifur_curated = trim_duplicates.remove_duplicates(dataset_tawsifur, sims, threshold=thresh)

In [None]:
print(dataset_tawsifur_curated.x_train.shape, dataset_tawsifur.x_train.shape)

In [None]:
basemodel_tawsifur_curated = model.train_model("basemodel_tawsifur_curated", net_container, basemodel_process, dataset_tawsifur_curated)
y_pred_tawsifur_curated = net_container.predict(basemodel_tawsifur_curated.params, basemodel_tawsifur_curated.state, dataset_tawsifur_curated.x_test)
plots.confusion_matrix(dataset_tawsifur_curated, y_pred_tawsifur_curated, "Tawsifur - Curated")

# Transfer learning test

In [None]:
rng = jax.random.PRNGKey(SEED)
dataset_mendeley = Dataset.load("mendeley", rng=rng)

net, optim = model.init_net_and_optim(dataset_mendeley.x_train, NUM_CLASSES, BATCH_SIZE)

# Gets functions for the model
net_container = network.create(net, optim, BATCH_SIZE, shape = (10, 256, 256, 3))

In [None]:
# Test of tawsifur on mendeley

basemodel_tawsifur = model.train_model("basemodel_tawsifurCV0", net_container, basemodel_process, dataset_mendeley)
y_test_pred = net_container.predict(basemodel_tawsifur.params, basemodel_tawsifur.state, dataset_mendeley.x_test)
matrix = sklearn.metrics.confusion_matrix(
        dataset_mendeley.y_test[0:y_test_pred.shape[0],].argmax(1),
        y_test_pred.argmax(1), normalize = 'true'
    )
plots.heatmatrix(matrix, "Transfer learning from tawsifur to mendeley")