# Core analysis

### Setup run

In [None]:
import tensorflow as tf
import neptune

from core_analysis.dataset import Dataset
from core_analysis.architecture import Model
from core_analysis.utils.visualize import (
    Figure,
    Image,
    Mask,
    Loss,
)
from core_analysis.utils.constants import (
    MODEL_FILENAME,
    LABELS_PATH,
    TODAY,
    NEPTUNE_PROJECT,
    NEPTUNE_API_TOKEN,
)

##### Check the quantity of available GPUs

In [None]:
# Check the number of available GPUs.
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))
physical_devices = tf.config.list_physical_devices("GPU")
if len(physical_devices) > 0:
    for i in range(len(physical_devices)):
        tf.config.experimental.set_memory_growth(physical_devices[i], True)

##### Setup arguments

In [None]:
WEIGHTS_FILENAME = MODEL_FILENAME
DO_AUGMENT = False
RUN_EAGERLY = False

##### Setup notebook

In [None]:
run = neptune.init_run(
    project=NEPTUNE_PROJECT,
    api_token=NEPTUNE_API_TOKEN,
)

model = Model(WEIGHTS_FILENAME, RUN_EAGERLY)
dataset = Dataset(LABELS_PATH)

### Train

In [None]:
train_subset = dataset.subset("train")
val_subset = dataset.subset("val")

image = next(iter(train_subset.imgs.values()))
Figure(
    filename="image_masks",
    subplots=[
        Image(image, draw_boxes=True),
        *(Mask(image.masks[..., i]) for i in range(3)),
    ],
)
Figure(subplots=[Image(image=image, mask=image.masks[..., 1], draw_boxes=True)])
patches, masks = next(iter(train_subset))
Figure(
    filename="tiles",
    subplots=[Image(patches[0]), *(Mask(masks[0, ..., i]) for i in range(3))],
)

history = model.train(train_subset, val_subset)

Figure(filename=f"graph_losses_{TODAY}", subplots=[Loss(history)])

### Test

In [None]:
results = model.test(dataset.subset("test"))

image = next(iter(dataset.subset("test").imgs.values()))
pred = model.predict([image])
Figure(
    filename="predictions",
    subplots=[
        Image(image),
        Mask(image.masks[..., 1]),
        *(Mask(pred[..., i]) for i in range(3)),
    ],
)
Figure(
    filename="predictions_with_images",
    subplots=[Image(image.without_background(), mask=pred[..., i]) for i in range(3)],
)

### Stop Neptune logging

In [None]:
run.stop()