In [9]:
# Starts the autoreload extension, which allows editing the .py files with the notebook running and automatically imports the latest changes

%load_ext autoreload
%autoreload 2

# Imports all the libraries used

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sn
from sklearn.metrics import accuracy_score, confusion_matrix
from umap import UMAP
from tqdm import tqdm

from functools import partial

import jax
from jax import numpy as jnp
import haiku as hk
import optax

import resnet
import data
import train
import pickle

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
assert jax.local_device_count() >= 8, "TPUs not detected"

In [11]:
NUM_CLASSES = 4
SEED = 12
BATCH_SIZE = 128

jax.config.update("jax_debug_nans", False)
classes = ['Normal', 'Pneumonia-Bacterial', 'COVID-19', 'Pneumonia-Viral']

rng = jax.random.PRNGKey(SEED)

In [12]:
(x_train, y_train), (x_test, y_test) = data.load_data('.', rng, test_size = 0.1)

x_all = np.concatenate([x_test, x_train])
y_all = np.concatenate([y_test, y_train])

tcmalloc: large alloc 7241465856 bytes == 0x89c78000 @  0x7f4a49985680 0x7f4a499a6824 0x7f4a3ef174ce 0x7f4a3ef6d00e 0x7f4a3ef6dc4f 0x7f4a3f00f924 0x5f2cc9 0x5f30ff 0x5705f6 0x568d9a 0x5f5b33 0x56bc9b 0x568d9a 0x5f5b33 0x56fb87 0x568d9a 0x5f5b33 0x56bc9b 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956
tcmalloc: large alloc 7241465856 bytes == 0x2438a0000 @  0x7f4a49985680 0x7f4a499a6824 0x7f4a499a6b8a 0x7f4733d3fc37 0x7f4733cf25eb 0x7f4733d037a6 0x7f4733d046dd 0x7f472fb6db09 0x7f472fb711a0 0x7f4733a1f302 0x7f473111a220 0x7f473111a910 0x7f47310f5e95 0x7f47310fc686 0x7f47310fe434 0x7f472ecfe3bf 0x7f472ea779a8 0x7f472ea67540 0x5f2cc9 0x5f3010 0x50bf55 0x56fb87 0x568d9a 0x5f5b33 0x5f5308 0x6655bd 0x5f28fe 0x56c332 0x568d9a 0x5f5b33 0x5f5369
tcmalloc: large alloc 7241465856 bytes == 0x9065a6000 @  0x7f4a49985680 0x7f4a499a6824 0x7f4a499a6b8a 0x7f4733d3fc37 0x7f4733cf25eb 0x7f4733d037a6 0x7f4733d046dd 0x7f472fb6db09 0x7f472f

In [13]:
def forward(batch, is_training, return_representation = False, return_gradcam = False):
    net = resnet.ResNet18(num_classes = NUM_CLASSES, resnet_v2 = True)
    if return_representation:
        return net.embedding(batch, is_training, embedding_depth=0)
    elif return_gradcam:
        return net.gradcam(batch, is_training, gradcam_depth=0)
    else:
        return net(batch, is_training)

net = hk.transform_with_state(forward)
schedule = optax.cosine_decay_schedule(1e-1, 30 * (len(x_train) // BATCH_SIZE))
optim = optax.adamw(schedule, weight_decay = 1e-3)

In [14]:
# Gets functions for the model
init_fn, loss_fn, grad_fn, update, predict, evaluate, train_epoch = train.get_network_fns(net, optim, BATCH_SIZE)

# Initializes parameters and state
params, state, optim_state = init_fn(rng)

In [15]:
# Train the model for 30 epochs
#for i in range(30):
#    params, state, optim_state = train_epoch(params, state, optim_state, x_train, y_train, x_test, y_test)


def load_model(filename):
    with open(filename, "rb") as f:
        return pickle.load(f)

loaded_model = load_model("checkpoints/checkpoint.npy")
params = loaded_model[0]
state = loaded_model[1]
optim_state = loaded_model[2]

In [16]:
import pickle

def save_model(filename):
    with open(filename, "wb") as f:
        pickle.dump((params, state, optim_state), f)

# save_weights("checkpoint.npy")

In [22]:
cam = predict(params, state, x_train, return_gradcam=True, training=False, verbose=True)

  0%|                                                                                            | 0/65 [00:00<?, ?it/s]

(16, 8, 8, 512) (16, 4)


tcmalloc: large alloc 1099511627776 bytes == (nil) @  0x7f4a49985680 0x7f4a499a52ec 0x7f4a3ef175af 0x7f4a3ef68b58 0x7f4a3ef6ccb7 0x7f4a3f00b7b3 0x5f2cc9 0x5f30ff 0x5705f6 0x568d9a 0x5f5b33 0x56bc9b 0x5f5956 0x56aadf 0x568d9a 0x5f5b33 0x56aadf 0x568d9a 0x5f5b33 0x56fb87 0x568d9a 0x5f5b33 0x5f5369 0x6655bd 0x5f28fe 0x56c332 0x568d9a 0x5f5b33 0x5f5369 0x6655bd 0x5f28fe
  0%|                                                                                            | 0/65 [00:00<?, ?it/s]


MemoryError: Unable to allocate 1.00 TiB for an array with shape (524288, 524288) and data type float32

In [None]:
cam.shape