# What is this Notebook?

Hallo Leute,

das ist der Versuch das Conv. Neuronale Netz (CNN) vom *Filter Network* des *Waggle Dance Detectors* zum Laufen zu bringen.

**Source code**: [GitHub: BioroboticsLab/bb_wdd_filter](https://github.com/BioroboticsLab/bb_wdd_filter)

# Implementation 02 - Clone from GitHub

### Train Model

In [1]:
#%pip install git+https://github.com/linusb20/bb_wdd_filter.git
import bb_wdd_filter
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%pip list |  grep -E 'bb-wdd-filter|wandb'

bb-wdd-filter                 0.1        /srv/data/joeh97/github/bb_wdd_filter
wandb                         0.15.2
Note: you may need to restart the kernel to use updated packages.


In [3]:
import argparse

import pickle
import numpy as np
import os
import torch.nn

import bb_wdd_filter.dataset
import bb_wdd_filter.models_supervised
import bb_wdd_filter.trainer_supervised
import bb_wdd_filter.visualization


def run_wdd(
    gt_data_path,
    num_workers = 0,
    checkpoint_path=None,
    continue_training=True,
    epochs=1000,
    remap_wdd_dir=None,
    image_size=32,
    images_in_archives=True,
    multi_gpu=False,
    image_scale=0.5,
    batch_size="auto",
    max_lr=0.002 * 8,
    wandb_entity=None,
    wandb_project="wdd-image-classification",
):
    """
    Arguments:
        gt_data_path (string)
            Path to the .pickle file containing the ground-truth labels and paths.
        remap_wdd_dir (string, optional)
            Prefix of the path where the image data is saved. The paths in gt_data_path
            will be changed to point to this directory instead.
        images_in_archives (bool)
            Whether the images of the single waggle frames are saved withing an images.zip
            file in each WDD subdirectory.
        checkpoint_path (string, optional)
            Filename to which the model will be saved regularly during training.
            The model will be saved on every epoch AND every X batches.
        continue_training (bool)
            Whether to try to continue training from last checkpoint. Will use the same
            wandb run ID. Auto set to "false" in case no checkpoint is found.
        epochs (int)
            Number of epochs to train for.
            As the model is saved after every epoch in 'checkpoint_path' and as the logs are
            streamed live to wandb.ai, it's save to interrupt the training after any epoch.
        image_size (int)
            Width and height of images that are passed to the model.
        image_scale (float)
            Scale factor for the data. E.g. 0.5 will scale the images to half resolution.
            That allows for a wider FoV for the model by sacrificing some resolution.
        max_lr (float)
            The training uses a learning rate scheduler (OneCycleLR) for each epoch
            where max_lr constitutes the peak learning rate.
        wandb_entity (string, optional)
            User name for wandb.ai that the training will log data to.
        wandb_project (string)
            Project name for wandb.ai.

    """

    with open(gt_data_path, "rb") as f:
        wdd_gt_data = pickle.load(f)
        gt_data_df = [(key,) + v for key, v in wdd_gt_data.items()]

    all_indices = np.arange(len(gt_data_df))
    test_indices = all_indices[::10]
    train_indices = [idx for idx in all_indices if not (idx in test_indices)]

    print("Train set:")
    dataset = bb_wdd_filter.dataset.SupervisedDataset(
        [gt_data_df[idx] for idx in train_indices],
        images_in_archives=images_in_archives,
        image_size=image_size,
        load_wdd_vectors=True,
        load_wdd_durations=True,
        remap_paths_to=remap_wdd_dir,
    )

    print("Test set:")
    # The evaluator's job is to regularly evaluate the training progress on the test dataset.
    # It will calculate additional statistics that are logged over the wandb connection.
    evaluator = bb_wdd_filter.dataset.SupervisedValidationDatasetEvaluator(
        [gt_data_df[idx] for idx in test_indices],
        images_in_archives=images_in_archives,
        image_size=image_size,
        remap_paths_to=remap_wdd_dir,
        default_image_scale=image_scale,
    )

    model = bb_wdd_filter.models_supervised.WDDClassificationModel(
        image_size=image_size
    )

    if multi_gpu:
        model = torch.nn.DataParallel(model)

    model = model.cuda()

    if batch_size == "auto":
        # The batch size here is calculated so that it fits on two RTX 2080 Ti in multi-GPU mode.
        # Note that a smaller batch size might also need a smaller learning rate.
        factor = 1
        if multi_gpu:
            factor = 2
        batch_size = int((64 * 7 * factor) / ((image_size * image_size) / (32 * 32)))
    else:
        batch_size = int(batch_size)

    print(
        "N pars: ",
        str(sum(p.numel() for p in model.parameters() if p.requires_grad)),
        "batch size: ",
        batch_size,
    )

    wandb_config = None
    if False:
        # Project name is fixed so far.
        # This provides a logging interface to wandb.ai.
        wandb_config = (dict(project=wandb_project, entity=wandb_entity),)

    trainer = bb_wdd_filter.trainer_supervised.SupervisedTrainer(
        dataset,
        model,
        wandb_config=dict(),
        save_path=checkpoint_path,
        batch_size=batch_size,
        num_workers=num_workers,
        continue_training=continue_training,
        image_size=image_size,
        batch_sampler_kwargs=dict(
            image_scale_factor=image_scale,
            inflate_dataset_factor=1000,
            augmentation_per_image=False,
        ),
        test_set_evaluator=evaluator,
        eval_test_set_every_n_samples=2000,
        save_every_n_samples=200000,
        max_lr=max_lr,
        batches_to_reach_maximum_augmentation=1000,
    )

    trainer.run_epochs(epochs)


In [4]:
run_wdd(
    epochs=1,
    num_workers=4,
    continue_training=False,
    gt_data_path=    "../../../data/wdd_ground_truth/ground_truth_wdd_angles.pickle",
    remap_wdd_dir=      "../../../data/wdd_ground_truth/" ,
    checkpoint_path= "./wdd_filtering_supervised_model.pt",
    images_in_archives=True,
)

Train set:
Found 908 waggle folders.
Test set:
Found 101 waggle folders.
N pars:  121431 batch size:  448
SupervisedTrainer:init 1
SupervisedTrainer:init 2
Hello 1
Hello 2
1
3
4


  0%|          | 0/2026 [00:00<?, ?it/s]

5
6


  0%|          | 1/2026 [01:57<66:20:09, 117.93s/it]

7
8
17
18
5
6


  0%|          | 2/2026 [02:01<28:35:43, 50.86s/it] 

7
8
17
18
5
6


  0%|          | 3/2026 [02:05<16:32:46, 29.44s/it]

7
8
17
18
5
6


  0%|          | 4/2026 [02:09<10:53:32, 19.39s/it]

7
8
11
12
15
17
18
5
6


  0%|          | 5/2026 [03:46<26:34:40, 47.34s/it]

7
8
17
18
5
6


  0%|          | 6/2026 [03:50<18:17:08, 32.59s/it]

7
8
17
18
5
6


  0%|          | 7/2026 [03:54<13:01:43, 23.23s/it]

7
8
17
18
5
6


  0%|          | 8/2026 [03:58<9:35:15, 17.10s/it] 

7
8
17
18
5
6


  0%|          | 9/2026 [04:52<16:01:57, 28.62s/it]

7
8
11
12
15
17
18
5
6


  0%|          | 10/2026 [04:56<11:45:42, 21.00s/it]

7
8
17
18
5
6


  1%|          | 11/2026 [05:00<8:50:17, 15.79s/it] 

7
8
17
18
5
6


  1%|          | 12/2026 [05:04<6:49:31, 12.20s/it]

7
8
17
18
5
6


  1%|          | 13/2026 [05:57<13:47:45, 24.67s/it]

7
8
17
18
5
6


  1%|          | 14/2026 [06:01<10:17:37, 18.42s/it]

7
8
11
12
15
17
18
5
6


  1%|          | 15/2026 [06:05<7:51:29, 14.07s/it] 

7
8
17
18
5
6


  1%|          | 16/2026 [06:09<6:09:51, 11.04s/it]

7
8
17
18
5
6


  1%|          | 17/2026 [07:03<13:19:36, 23.88s/it]

7
8
17
18
5
6


  1%|          | 18/2026 [07:07<9:59:05, 17.90s/it] 

7
8
17
18
5
6


  1%|          | 19/2026 [07:11<7:38:56, 13.72s/it]

7
8
11
12
15
17
18
5
6


  1%|          | 20/2026 [07:15<6:01:09, 10.80s/it]

7
8
17
18
5
6


  1%|          | 21/2026 [08:10<13:23:31, 24.05s/it]

7
8
17
18
5
6


  1%|          | 22/2026 [08:14<10:01:53, 18.02s/it]

7
8
17
18
5
6


  1%|          | 23/2026 [08:18<7:41:03, 13.81s/it] 

7
8
17
18
5
6


  1%|          | 24/2026 [08:22<6:02:35, 10.87s/it]

7
8
11
12
15
17
18
5
6


  1%|          | 25/2026 [09:15<13:10:16, 23.70s/it]

7
8
17
18
5
6


  1%|▏         | 26/2026 [09:19<9:52:40, 17.78s/it] 

7
8
17
18
5
6


  1%|▏         | 27/2026 [09:23<7:34:33, 13.64s/it]

7
8
17
18
5
6


  1%|▏         | 28/2026 [09:27<5:57:47, 10.74s/it]

7
8
17
18
5
6


  1%|▏         | 29/2026 [10:22<13:11:16, 23.77s/it]

7
8
11
12
15
17
18
5
6


  1%|▏         | 30/2026 [10:26<9:53:17, 17.83s/it] 

7
8
17
18
5
6


  2%|▏         | 31/2026 [10:30<7:34:49, 13.68s/it]

7
8
17
18
5
6


  2%|▏         | 32/2026 [10:34<5:58:12, 10.78s/it]

7
8
17
18
5
6


  2%|▏         | 33/2026 [11:27<13:06:56, 23.69s/it]

7
8
17
18
5
6


  2%|▏         | 34/2026 [11:31<9:50:08, 17.78s/it] 

7
8
11
12
15
17
18
5
6


  2%|▏         | 35/2026 [11:35<7:32:39, 13.64s/it]

7
8
17
18
5
6


  2%|▏         | 36/2026 [11:39<5:56:32, 10.75s/it]

7
8
17
18
5
6


  2%|▏         | 37/2026 [12:33<13:03:55, 23.65s/it]

7
8
17
18
5
6


  2%|▏         | 38/2026 [12:37<9:47:57, 17.75s/it] 

7
8
17
18
5
6


  2%|▏         | 39/2026 [12:41<7:30:59, 13.62s/it]

7
8
11
12
15
17
18
5
6


  2%|▏         | 40/2026 [12:45<5:55:17, 10.73s/it]

7
8
17
18
5
6


  2%|▏         | 41/2026 [13:42<13:33:15, 24.58s/it]

7
8
17
18
5
6


  2%|▏         | 42/2026 [13:46<10:08:20, 18.40s/it]

7
8
17
18
5
6


  2%|▏         | 43/2026 [13:50<7:45:22, 14.08s/it] 

7
8
17
18
5
6


  2%|▏         | 44/2026 [13:54<6:05:11, 11.06s/it]

7
8
11
12
15
17
18
5
6


  2%|▏         | 45/2026 [14:52<13:50:18, 25.15s/it]

7
8
17
18
5
6


  2%|▏         | 46/2026 [14:56<10:20:13, 18.79s/it]

7
8
17
18
5
6


  2%|▏         | 47/2026 [15:00<7:53:21, 14.35s/it] 

7
8
17
18
5
6


  2%|▏         | 48/2026 [15:04<6:10:47, 11.25s/it]

7
8
17
18
5
6


  2%|▏         | 49/2026 [15:59<13:21:37, 24.33s/it]

7
8
11
12
15
17
18
5
6


  2%|▏         | 50/2026 [16:03<10:00:15, 18.23s/it]

7
8
17
18
5
6


  3%|▎         | 51/2026 [16:07<7:39:27, 13.96s/it] 

7
8
17
18
5
6


  3%|▎         | 52/2026 [16:11<6:00:53, 10.97s/it]

7
8
17
18
5
6


  3%|▎         | 53/2026 [17:07<13:25:02, 24.48s/it]

7
8
17
18
5
6


  3%|▎         | 54/2026 [17:11<10:02:29, 18.33s/it]

7
8
11
12
15
17
18
5
6


  3%|▎         | 55/2026 [17:15<7:40:56, 14.03s/it] 

7
8
17
18
5
6


  3%|▎         | 56/2026 [17:19<6:01:55, 11.02s/it]

7
8
17
18
5
6


  3%|▎         | 57/2026 [18:13<13:03:05, 23.86s/it]

7
8
17
18
5
6


  3%|▎         | 58/2026 [18:17<9:47:04, 17.90s/it] 

7
8
17
18
5
6


  3%|▎         | 59/2026 [18:21<7:30:02, 13.73s/it]

7
8
11
12
15
17
18
5
6


  3%|▎         | 60/2026 [18:25<5:54:16, 10.81s/it]

7
8
17
18
5
6


  3%|▎         | 61/2026 [19:18<12:57:20, 23.74s/it]

7
8
17
18
5
6


  3%|▎         | 62/2026 [19:22<9:42:49, 17.81s/it] 

7
8
17
18
5
6


  3%|▎         | 63/2026 [19:29<7:47:35, 14.29s/it]

7
8
17
18
5
6


  3%|▎         | 64/2026 [19:33<6:06:28, 11.21s/it]

7
8
11
12
15
17
18
5
6


  3%|▎         | 65/2026 [20:26<12:56:45, 23.77s/it]

7
8
17
18
5
6


  3%|▎         | 66/2026 [20:30<9:42:40, 17.84s/it] 

7
8
17
18
5
6


  3%|▎         | 67/2026 [20:35<7:41:08, 14.12s/it]

7
8
17
18
5
6


  3%|▎         | 68/2026 [20:39<6:01:48, 11.09s/it]

7
8
17
18
5
6


  3%|▎         | 69/2026 [21:31<12:45:10, 23.46s/it]

7
8
11
12
15
17
18
5
6


  3%|▎         | 70/2026 [21:35<9:34:27, 17.62s/it] 

7
8
17
18
5
6


  4%|▎         | 71/2026 [21:41<7:39:26, 14.10s/it]

7
8
17
18
5
6


  4%|▎         | 72/2026 [21:45<6:00:35, 11.07s/it]

7
8
17
18
5
6


  4%|▎         | 73/2026 [22:37<12:39:12, 23.32s/it]

7
8
17
18
5
6


  4%|▎         | 74/2026 [22:45<10:03:31, 18.55s/it]

7
8
11
12
15
17
18
5
6


  4%|▎         | 75/2026 [22:49<7:41:11, 14.18s/it] 

7
8
17
18
5
6


  4%|▍         | 76/2026 [22:53<6:01:49, 11.13s/it]

7
8
17
18
5
6


  4%|▍         | 77/2026 [23:48<13:11:44, 24.37s/it]

7
8
17
18
5
6


  4%|▍         | 78/2026 [23:52<9:52:38, 18.25s/it] 

7
8
17
18
5
6


  4%|▍         | 79/2026 [23:58<7:55:06, 14.64s/it]

7
8
11
12
15
17
18
5
6


  4%|▍         | 80/2026 [24:02<6:11:18, 11.45s/it]

7
8
17
18
5
6


  4%|▍         | 81/2026 [24:53<12:37:33, 23.37s/it]

7
8
17
18
5
6


  4%|▍         | 82/2026 [24:57<9:28:46, 17.55s/it] 

7
8
17
18
5
6


  4%|▍         | 83/2026 [25:04<7:46:36, 14.41s/it]

7
8
17
18
5
6


  4%|▍         | 84/2026 [25:08<6:05:23, 11.29s/it]

7
8
11
12
15
17
18
5
6


  4%|▍         | 85/2026 [26:04<13:19:47, 24.72s/it]

7
8
17
18
5
6


  4%|▍         | 86/2026 [26:08<9:58:10, 18.50s/it] 

7
8
17
18
5
6


  4%|▍         | 87/2026 [26:12<7:37:12, 14.15s/it]

7
8
17
18
5
6


  4%|▍         | 88/2026 [26:16<5:58:44, 11.11s/it]

7
8
17
18
5
6


  4%|▍         | 89/2026 [27:16<13:46:18, 25.60s/it]

7
8
11
12
15
17
18
5
6


  4%|▍         | 90/2026 [27:20<10:16:33, 19.11s/it]

7
8
17
18
5
6


  4%|▍         | 91/2026 [27:24<7:49:55, 14.57s/it] 

7
8
17
18
5
6


  5%|▍         | 92/2026 [27:28<6:07:30, 11.40s/it]

7
8
17
18
5
6


  5%|▍         | 93/2026 [28:27<13:49:56, 25.76s/it]

7
8
17
18
5
6


  5%|▍         | 94/2026 [28:31<10:18:58, 19.22s/it]

7
8
11
12
15
17
18
5
6


  5%|▍         | 95/2026 [28:35<7:51:31, 14.65s/it] 

7
8
17
18
5
6


  5%|▍         | 96/2026 [28:39<6:08:29, 11.46s/it]

7
8
17
18


                                                   

KeyboardInterrupt: 