<a href="https://colab.research.google.com/github/firekind/project-fox/blob/master/yolo_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Downloading dependencies and code

In [None]:
%load_ext tensorboard

In [None]:
%%shell

if [[ ! -d project-fox ]]; then
    git clone https://github.com/firekind/project-fox --recurse-submodules &> /dev/null
    echo "Cloned repo."
else
    cd project-fox && git pull
    echo "Pulled repo."
fi

pip install --upgrade \
    git+http://github.com/firekind/athena \
    git+https://github.com/longcw/RoIAlign.pytorch \
    pytorch-lightning~=1.0.8 \
    &> /dev/null
echo "Downloaded dependencies."

Cloned repo.
Downloaded dependencies.




## Mounting drive, extracting dataset and weights

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%shell
DATA_DIR=/content/project-fox/data
if [[ ! -d $DATA_DIR ]]; then
    mkdir $DATA_DIR
    echo "extracting dataset..."
    unzip "/content/drive/My Drive/project-fox/data.zip" -d $DATA_DIR &> /dev/null
    echo "done."
else
    echo "dataset already extracted."
fi

extracting dataset...
done.




In [None]:
%%shell
WEIGHTS_DIR=/content/project-fox/weights
if [[ ! -d $WEIGHTS_DIR ]]; then
    mkdir $WEIGHTS_DIR
    echo "extracting weights..."
    tar -xf "/content/drive/My Drive/project-fox/weights.tar.gz" -C $WEIGHTS_DIR &> /dev/null
    echo "done."
else
    echo "weights already extracted."
fi

extracting weights...
done.




In [None]:
import os
os.chdir("/content/project-fox")

## Training Yolo

In [None]:
from fox.config import Config
from fox.dataset import ComboDataset
from fox.model import Model
import torch
import pytorch_lightning as pl
from athena.utils.progbar import ProgbarCallback
from fox.utils import parse_data_cfg
from fox.yolov3.utils.datasets import LoadImagesAndLabels

In [None]:
config = Config(
    USE_PLANERCNN=False,
    DATA_DIR="data",
    IMG_SIZE=640,
    MIN_IMG_SIZE=320,
    BATCH_SIZE=10,
    MIDAS_LOSS_WEIGHT=0,
    YOLO_LOSS_WEIGHT=1,
    EPOCHS=10
)
dataset = ComboDataset(config)
val_dataset = ComboDataset(config, train=False)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    collate_fn=dataset.collate_fn
)
loader_val = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    collate_fn=val_dataset.collate_fn
)

Caching labels data/yolo/labels.npy (3161 found, 0 missing, 42 empty, 0 duplicate, for 3203 images): 100%|██████████| 3203/3203 [00:00<00:00, 12814.76it/s]
Caching labels data/yolo/labels.npy (311 found, 0 missing, 7 empty, 0 duplicate, for 318 images): 100%|██████████| 318/318 [00:00<00:00, 1208.17it/s]


In [None]:
model = Model(config, len(loader), 4, dataset.yolo_dataset.yolo_labels)

Loading weights:  weights/midas.pt


Using cache found in /root/.cache/torch/hub/facebookresearch_WSL-Images_master


Model Summary: 147 layers, 5.93877e+07 parameters, 5.93877e+07 gradients


In [None]:
log_dir = "/content/drive/My Drive/project-fox/logs"
name = "Yolo-only-img-320-640"
tensorboard_logger = pl.loggers.TensorBoardLogger(
    log_dir, name="", version=name, default_hp_metric=False
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath=os.path.join(
        log_dir, 
        name,
        "checkpoints",
        "last",
    ),
)

In [None]:
%tensorboard --logdir "/content/drive/My Drive/project-fox/logs"

In [None]:
pl.seed_everything(0)
trainer = pl.Trainer(
    max_epochs=config.EPOCHS,
    gpus=1,
    logger=tensorboard_logger,
    checkpoint_callback=checkpoint_callback,
    callbacks=[ProgbarCallback()],
    progress_bar_refresh_rate=20,
    automatic_optimization=False
)

trainer.fit(
    model,
    train_dataloader=loader,
    val_dataloaders=loader_val
)

Epoch: 1 / 10
Validation set: avg yolo val loss: 5.5626, yolo mAP: 0.4911, avg yolo loss: 0.8946, avg midas loss: 3683.1284, avg total loss: 0.8946

Epoch: 2 / 10
Validation set: avg yolo val loss: 5.5213, yolo mAP: 0.5016, avg yolo loss: 0.7407, avg midas loss: 3687.0597, avg total loss: 0.7407

Epoch: 3 / 10
Validation set: avg yolo val loss: 5.9675, yolo mAP: 0.4022, avg yolo loss: 0.6831, avg midas loss: 3683.4561, avg total loss: 0.6831

Epoch: 4 / 10
Validation set: avg yolo val loss: 5.5015, yolo mAP: 0.5285, avg yolo loss: 0.6350, avg midas loss: 3688.5775, avg total loss: 0.6350

Epoch: 5 / 10
Validation set: avg yolo val loss: 5.9646, yolo mAP: 0.4501, avg yolo loss: 0.6031, avg midas loss: 3692.1846, avg total loss: 0.6031

Epoch: 6 / 10
Validation set: avg yolo val loss: 5.4038, yolo mAP: 0.5204, avg yolo loss: 0.5660, avg midas loss: 3690.4208, avg total loss: 0.5660

Epoch: 7 / 10
Validation set: avg yolo val loss: 5.5190, yolo mAP: 0.5406, avg yolo loss: 0.5388, avg mida

1