In [None]:
# Starts the autoreload extension, which allows editing the .py files with the notebook running and automatically imports the latest changes

%load_ext autoreload
%autoreload 2

import trim_duplicates, model, network, utils, gradcam, plots
from dataset import Dataset
import haiku as hk
import jax.numpy as jnp
import numpy as np
import jax
import wandb

In [None]:
assert jax.local_device_count() >= 8

In [None]:
NUM_CLASSES = 4
SEED = 12
BATCH_SIZE = 128

In [None]:
rng = jax.random.PRNGKey(SEED)
dataset_tawsifur = Dataset.load("tawsifur_kaggle", rng=rng)

In [None]:
net, optim = model.init_net_and_optim(dataset_tawsifur.x_train, NUM_CLASSES, BATCH_SIZE)

# Gets functions for the model
net_container = network.create(net, optim, BATCH_SIZE, shape = (10, 256, 256, 3))

In [None]:
config = {'dataset' : 'tawsifur',
         'image_resolution' : 256}

wandb.init(project='xrays', entity='usp-covid-xrays', reinit=True)


def basemodel_process(x): return x
basemodel_tawsifur = model.train_model("basemodel_tawsifur", net_container, basemodel_process, dataset_tawsifur)
y_test_pred_tawsifur = net_container.predict(basemodel_tawsifur.params, basemodel_tawsifur.state, dataset_tawsifur.x_test)
plots.confusion_matrix(dataset_tawsifur, y_test_pred_tawsifur, "Tawsifur - Not curated")


wandb.finish()

In [None]:
sims = trim_duplicates.compute_similarities(dataset_tawsifur, net_container, basemodel_tawsifur)

In [None]:
thresh = 0.998
max_sims = sims.max(axis=1) - thresh
y_classes = dataset_tawsifur.y_all[:sims.shape[0]].argmax(1)
max_sims_index = sims.argmax(axis=1)
mask = (max_sims >= 0) & (max_sims <= 0.0005)
indices = np.where(mask)[0]
plots.compare_images(dataset_tawsifur.x_all[indices], dataset_tawsifur.x_all[max_sims_index[indices]], rows=10)
#trim_duplicates.plot_similarities(dataset_tawsifur, sims, threshold=0.99)

In [None]:
trim_duplicates.plot_similarities(dataset_tawsifur, sims, threshold=thresh)
dataset_tawsifur_curated = trim_duplicates.remove_duplicates(dataset_tawsifur, sims, threshold=thresh)

In [None]:
print(dataset_tawsifur_curated.x_train.shape, dataset_tawsifur.x_train.shape)

In [None]:
basemodel_tawsifur_curated = model.train_model("basemodel_tawsifur_curated", net_container, basemodel_process, dataset_tawsifur_curated)
y_pred_tawsifur_curated = net_container.predict(basemodel_tawsifur_curated.params, basemodel_tawsifur_curated.state, dataset_tawsifur_curated.x_test)
plots.confusion_matrix(dataset_tawsifur_curated, y_pred_tawsifur_curated, "Tawsifur - Curated")