In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
NEPTUNE_API_TOKEN='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI2Yjg5NjBiZC02ZWJjLTQ2MWYtOWEzZi0wNDdiM2ZjMjdjNjMifQ=='

In [None]:
from great_barrier_reef import (
    StarfishDataset, StarfishDatasetAdapter, 
    StarfishDataModule, StarfishEfficientDetModel,
    get_train_transforms, get_valid_transforms, 
    compare_bboxes_for_image
)
import pandas as pd
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, NeptuneLogger
from great_barrier_reef.dataset.starfish_dataset import draw_pascal_voc_bboxes
from pytorch_lightning.utilities.seed import seed_everything
import matplotlib.pyplot as plt
import os

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
data_df = pd.read_csv('../data/train.csv')

In [None]:
non_empty_df = data_df.loc[data_df['annotations']!='[]', :]
train_df = non_empty_df.loc[non_empty_df['video_id']!=2, :]
val_df = non_empty_df.loc[non_empty_df['video_id']==2, :]

In [None]:
adapter_dataset_train = StarfishDatasetAdapter(train_df)
adapter_dataset_val = StarfishDatasetAdapter(val_df)

In [None]:
adapter_dataset_train.show_image(920)

In [None]:
adapter_dataset_val.show_image(0)

In [None]:
datamodule = StarfishDataModule(
    adapter_dataset_train,
    adapter_dataset_val,
    train_transforms=get_train_transforms_pad(target_img_size=1280),
    valid_transforms=get_valid_transforms_pad(target_img_size=1280),    
    num_workers=8,
    batch_size=4
)

In [None]:
model = StarfishEfficientDetModel(
    num_classes=1,
    img_size=1280,
    inference_transforms=get_valid_transforms_pad(1280),
    model_architecture='tf_efficientdet_d1_ap'
    )
callbacks = [EarlyStopping(monitor='valid_loss_epoch', patience=15),
             ModelCheckpoint(verbose=True, monitor='valid_loss_epoch')
            ]
loggers = [
    CSVLogger(
        save_dir='csv_logs', 
        name=f'd1_all_non_empty'), 
    NeptuneLogger(
        api_key=NEPTUNE_API_TOKEN,
        project_name="azkalot1/reef",
        experiment_name=f'd1_all_non_empty')
]
trainer = Trainer(
    callbacks=callbacks,
    logger=loggers,
    gpus=[0],
    max_epochs=100, 
    num_sanity_val_steps=1, 
    precision=16, 
    )
trainer.fit(model, datamodule)

In [None]:
model.eval()

In [None]:
image1, truth_bboxes1, _, _, _ = datamodule.val_dataset().ds.get_image_and_labels_by_idx(327)
image2, truth_bboxes2, _, _, _ = datamodule.val_dataset().ds.get_image_and_labels_by_idx(328)
images = [image1, image2]
predicted_bboxes, predicted_class_confidences, predicted_class_labels = model.predict(images)

In [None]:
predicted_bboxes

In [None]:
compare_bboxes_for_image(image2, predicted_bboxes=predicted_bboxes[1], actual_bboxes=truth_bboxes2)

In [None]:
compare_bboxes_for_image(image1, predicted_bboxes=predicted_bboxes[0], actual_bboxes=truth_bboxes1)

In [None]:
model.img_size