## Downloading dependencies and code

In [1]:
%load_ext tensorboard

In [2]:
%%shell

if [[ ! -d project-fox ]]; then
    git clone https://github.com/firekind/project-fox --recurse-submodules &> /dev/null
    echo "Cloned repo."
else
    cd project-fox && git pull
    echo "Pulled repo."
fi

pip install --upgrade \
    git+http://github.com/firekind/athena \
    git+https://github.com/longcw/RoIAlign.pytorch \
    pytorch-lightning~=1.0.8 \
    &> /dev/null
echo "Downloaded dependencies."

Cloned repo.
Downloaded dependencies.




## Mounting drive, extracting dataset and weights

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
%%shell
DATA_DIR=/content/project-fox/data
if [[ ! -d $DATA_DIR ]]; then
    mkdir $DATA_DIR
    echo "extracting dataset..."
    unzip "/content/drive/My Drive/project-fox/data.zip" -d $DATA_DIR &> /dev/null
    echo "done."
else
    echo "dataset already extracted."
fi

extracting dataset...
done.




In [5]:
%%shell
WEIGHTS_DIR=/content/project-fox/weights
if [[ ! -d $WEIGHTS_DIR ]]; then
    mkdir $WEIGHTS_DIR
    echo "extracting weights..."
    tar -xf "/content/drive/My Drive/project-fox/weights.tar.gz" -C $WEIGHTS_DIR &> /dev/null
    echo "done."
else
    echo "weights already extracted."
fi

extracting weights...
done.




In [6]:
import os
os.chdir("/content/project-fox")

## Training PlaneRCNN

In [7]:
from fox.config import Config
from fox.dataset import ComboDataset
from fox.model import Model
import torch
import pytorch_lightning as pl
from athena.utils.progbar import ProgbarCallback
from fox.utils import parse_data_cfg
from fox.yolov3.utils.datasets import LoadImagesAndLabels

In [13]:
config = Config(
    USE_YOLO=False,
    DATA_DIR="data",
    IMG_SIZE=640,
    MIN_IMG_SIZE=480,
    BATCH_SIZE=10,
    MIDAS_LOSS_WEIGHT=0,
    PLANERCNN_LOSS_WEIGHT=1,
    PREDICT_DEPTH=False,
    EPOCHS=10
)
dataset = ComboDataset(config)
val_dataset = ComboDataset(config, train=False)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    collate_fn=dataset.collate_fn
)
loader_val = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    collate_fn=val_dataset.collate_fn
)

Caching labels data/yolo/labels.npy (3161 found, 0 missing, 42 empty, 0 duplicate, for 3203 images): 100%|██████████| 3203/3203 [00:00<00:00, 9947.54it/s]
Caching labels data/yolo/labels.npy (311 found, 0 missing, 7 empty, 0 duplicate, for 318 images): 100%|██████████| 318/318 [00:00<00:00, 7497.36it/s]


In [14]:
model = Model(config, len(loader), 4, dataset.yolo_dataset.yolo_labels)

Loading weights:  weights/midas.pt


Using cache found in /root/.cache/torch/hub/facebookresearch_WSL-Images_master


In [10]:
log_dir = "/content/drive/My Drive/project-fox/logs"
name = "Planercnn-only-img-480-640"
tensorboard_logger = pl.loggers.TensorBoardLogger(
    log_dir, name="", version=name, default_hp_metric=False
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath=os.path.join(
        log_dir, 
        name,
        "checkpoints",
        "last",
    ),
)

In [11]:
%tensorboard --logdir "/content/drive/My Drive/project-fox/logs"

In [15]:
pl.seed_everything(0)
trainer = pl.Trainer(
    max_epochs=config.EPOCHS,
    gpus=1,
    logger=tensorboard_logger,
    checkpoint_callback=checkpoint_callback,
    callbacks=[ProgbarCallback()],
    progress_bar_refresh_rate=20,
    automatic_optimization=False
)

trainer.fit(
    model,
    train_dataloader=loader,
    val_dataloaders=loader_val
)

  "See the documentation of nn.Upsample for details.".format(mode))


Epoch: 1 / 10




Validation set: avg planercnn loss: 3.6413, avg midas loss: 3684.1770, avg total loss: 3.6413

Epoch: 2 / 10
Validation set: avg planercnn loss: 3.5812, avg midas loss: 3683.0701, avg total loss: 3.5812

Epoch: 3 / 10
Validation set: avg planercnn loss: 3.5356, avg midas loss: 3688.6321, avg total loss: 3.5356

Epoch: 4 / 10
Validation set: avg planercnn loss: 3.5092, avg midas loss: 3685.3510, avg total loss: 3.5092

Epoch: 5 / 10
Validation set: avg planercnn loss: 3.4947, avg midas loss: 3686.5210, avg total loss: 3.4947

Epoch: 6 / 10
Validation set: avg planercnn loss: 3.4800, avg midas loss: 3680.2164, avg total loss: 3.4800

Epoch: 7 / 10
Validation set: avg planercnn loss: 3.4657, avg midas loss: 3684.3093, avg total loss: 3.4657

Epoch: 8 / 10
Validation set: avg planercnn loss: 3.4620, avg midas loss: 3685.8238, avg total loss: 3.4620

Epoch: 9 / 10
Validation set: avg planercnn loss: 3.4526, avg midas loss: 3697.8058, avg total loss: 3.4526

Epoch: 10 / 10
Validation set: av

1