# What is this Notebook?

Hallo Leute,

das ist der Versuch das Conv. Neuronale Netz (CNN) vom *Filter Network* des *Waggle Dance Detectors* zum Laufen zu bringen.

**Source code**: [GitHub: BioroboticsLab/bb_wdd_filter](https://github.com/BioroboticsLab/bb_wdd_filter)

# Implementation 02 - Clone from GitHub

### Train Model

In [3]:
#%pip install git+https://github.com/linusb20/bb_wdd_filter.git
import bb_wdd_filter
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%pip list |  grep -E 'bb-wdd-filter|wandb'

bb-wdd-filter                 0.1        /srv/data/joeh97/github/bb_wdd_filter
wandb                         0.15.2
Note: you may need to restart the kernel to use updated packages.


In [4]:
import argparse

import pickle
import numpy as np
import os
import torch.nn

import bb_wdd_filter.dataset
import bb_wdd_filter.models_supervised
import bb_wdd_filter.trainer_supervised
import bb_wdd_filter.visualization


def run_wdd(
    gt_data_path,
    checkpoint_path=None,
    continue_training=True,
    epochs=1000,
    remap_wdd_dir=None,
    image_size=32,
    images_in_archives=True,
    multi_gpu=False,
    image_scale=0.5,
    batch_size="auto",
    max_lr=0.002 * 8,
    wandb_entity=None,
    wandb_project="wdd-image-classification",
):
    """
    Arguments:
        gt_data_path (string)
            Path to the .pickle file containing the ground-truth labels and paths.
        remap_wdd_dir (string, optional)
            Prefix of the path where the image data is saved. The paths in gt_data_path
            will be changed to point to this directory instead.
        images_in_archives (bool)
            Whether the images of the single waggle frames are saved withing an images.zip
            file in each WDD subdirectory.
        checkpoint_path (string, optional)
            Filename to which the model will be saved regularly during training.
            The model will be saved on every epoch AND every X batches.
        continue_training (bool)
            Whether to try to continue training from last checkpoint. Will use the same
            wandb run ID. Auto set to "false" in case no checkpoint is found.
        epochs (int)
            Number of epochs to train for.
            As the model is saved after every epoch in 'checkpoint_path' and as the logs are
            streamed live to wandb.ai, it's save to interrupt the training after any epoch.
        image_size (int)
            Width and height of images that are passed to the model.
        image_scale (float)
            Scale factor for the data. E.g. 0.5 will scale the images to half resolution.
            That allows for a wider FoV for the model by sacrificing some resolution.
        max_lr (float)
            The training uses a learning rate scheduler (OneCycleLR) for each epoch
            where max_lr constitutes the peak learning rate.
        wandb_entity (string, optional)
            User name for wandb.ai that the training will log data to.
        wandb_project (string)
            Project name for wandb.ai.

    """

    with open(gt_data_path, "rb") as f:
        wdd_gt_data = pickle.load(f)
        gt_data_df = [(key,) + v for key, v in wdd_gt_data.items()]

    all_indices = np.arange(len(gt_data_df))
    test_indices = all_indices[::10]
    train_indices = [idx for idx in all_indices if not (idx in test_indices)]

    print("Train set:")
    dataset = bb_wdd_filter.dataset.SupervisedDataset(
        [gt_data_df[idx] for idx in train_indices],
        images_in_archives=images_in_archives,
        image_size=image_size,
        load_wdd_vectors=True,
        load_wdd_durations=True,
        remap_paths_to=remap_wdd_dir,
    )

    print("Test set:")
    # The evaluator's job is to regularly evaluate the training progress on the test dataset.
    # It will calculate additional statistics that are logged over the wandb connection.
    evaluator = bb_wdd_filter.dataset.SupervisedValidationDatasetEvaluator(
        [gt_data_df[idx] for idx in test_indices],
        images_in_archives=images_in_archives,
        image_size=image_size,
        remap_paths_to=remap_wdd_dir,
        default_image_scale=image_scale,
    )

    model = bb_wdd_filter.models_supervised.WDDClassificationModel(
        image_size=image_size
    )

    if multi_gpu:
        model = torch.nn.DataParallel(model)

    model = model.cuda()

    if batch_size == "auto":
        # The batch size here is calculated so that it fits on two RTX 2080 Ti in multi-GPU mode.
        # Note that a smaller batch size might also need a smaller learning rate.
        factor = 1
        if multi_gpu:
            factor = 2
        batch_size = int((64 * 7 * factor) / ((image_size * image_size) / (32 * 32)))
    else:
        batch_size = int(batch_size)

    print(
        "N pars: ",
        str(sum(p.numel() for p in model.parameters() if p.requires_grad)),
        "batch size: ",
        batch_size,
    )

    wandb_config = None
    if False:
        # Project name is fixed so far.
        # This provides a logging interface to wandb.ai.
        wandb_config = (dict(project=wandb_project, entity=wandb_entity),)

    trainer = bb_wdd_filter.trainer_supervised.SupervisedTrainer(
        dataset,
        model,
        wandb_config=dict(),
        save_path=checkpoint_path,
        batch_size=batch_size,
        num_workers=0,
        continue_training=continue_training,
        image_size=image_size,
        batch_sampler_kwargs=dict(
            image_scale_factor=image_scale,
            inflate_dataset_factor=1000,
            augmentation_per_image=False,
        ),
        test_set_evaluator=evaluator,
        eval_test_set_every_n_samples=2000,
        save_every_n_samples=200000,
        max_lr=max_lr,
        batches_to_reach_maximum_augmentation=1000,
    )

    trainer.run_epochs(epochs)


In [5]:
#import wandb
#wandb.init()

In [6]:
run_wdd(
    epochs=1,
    continue_training=False,
    gt_data_path=    "  ../../../data/wdd_ground_truth/ground_truth_wdd_angles.pickle",
    remap_wdd_dir=      "../../../data/wdd_ground_truth/" ,
    checkpoint_path= "./wdd_filtering_supervised_model.pt",
    images_in_archives=True,
)

FileNotFoundError: [Errno 2] No such file or directory: './wdd_ground_truth/wdd_ground_truth/ground_truth_wdd_angles.pickle'

# Implementation 01 - Copy & Paste Code into Notebook

### Additional Installations

In [None]:
!pip install madgrad
!pip install joblib
!pip install numba
!pip install imgaug

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting madgrad
  Downloading madgrad-1.3.tar.gz (7.9 kB)
  Installing build dependencies ... [?25l[?25hcanceled
[31mERROR: Operation cancelled by user[0m[31m
[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Visualization

In [None]:
# visualization.py
import matplotlib.pyplot as plt
import matplotlib.cm
import numpy as np
import sklearn.decomposition
import sklearn.manifold
import seaborn as sns
import torch
import tqdm.auto


def sample_embeddings(model, dataset, N=1000, test_batch_size=64, seed=42):
    n_batches = N // test_batch_size
    state = np.random.get_state()
    try:
        np.random.seed(seed)
        random_samples = list(
            np.random.choice(
                len(dataset), size=n_batches * test_batch_size, replace=False
            )
        )
    finally:
        np.random.set_state(state)
    all_embeddings = []
    all_indices = []

    model.eval()
    with torch.no_grad():
        for _ in tqdm.auto.tqdm(range(n_batches), total=n_batches, leave=False):
            batch_images = []
            for _ in range(test_batch_size):
                idx = random_samples.pop()
                all_indices.append(idx)
                images, _ = dataset.__getitem__(
                    idx, return_just_one=True, normalize_to_float=True
                )
                batch_images.append(images)
            batch_images = np.stack(batch_images, axis=0)
            batch_images = torch.from_numpy(batch_images).cuda()

            embeddings = model.embed(batch_images)
            all_embeddings.append(embeddings.detach().cpu().numpy())

    embeddings = np.concatenate(all_embeddings, axis=0)[:, :, 0, 0]
    return embeddings, all_indices


def plot_embeddings(
    embeddings,
    indices,
    dataset,
    images=None,
    labels=None,
    scatterplot=False,
    display=True,
):
    embeddings = sklearn.decomposition.PCA(16).fit_transform(embeddings)
    embeddings = sklearn.manifold.TSNE(2, init="pca", perplexity=50).fit_transform(
        embeddings
    )

    label_colormap = dict()
    if labels is not None:
        unique_labels = np.unique(labels)
        for idx, label in enumerate(unique_labels):
            color = 255.0 * np.array(
                matplotlib.cm.tab10(idx / (len(unique_labels) + 1))
            )
            # Convert to PIL color string..
            color = "rgb({:d}, {:d},{:d})".format(*list(map(int, color)))
            label_colormap[label] = color

    from PIL import Image, ImageOps

    min_x, max_x = embeddings[:, 0].min(), embeddings[:, 0].max()
    min_y, max_y = embeddings[:, 1].min(), embeddings[:, 1].max()

    W, H = 4000, 3000
    scale_x = W / (max_x - min_x)
    scale_y = H / (max_y - min_y)
    fig, ax = plt.subplots(figsize=(10, 10))
    if not scatterplot:
        embedding_image = Image.new("RGBA", (W, H))
        for idx, ((x, y), img_idx) in enumerate(zip(embeddings, indices)):
            if images is not None:
                small = (images[idx] + 1.0) * (255.0 / 2.0)
            else:
                small = dataset.__getitem__(img_idx, return_just_one=True)[0][0]
            small = np.clip(small, 0, 255)

            small = Image.fromarray(small.astype(np.uint8))
            small = small.resize((128, 128))
            small = small.convert("RGBA")

            if labels is not None:
                small = ImageOps.expand(
                    small, border=8, fill=label_colormap[labels[img_idx]]
                )
            embedding_image.paste(
                small, (int((x - min_x) * scale_x), int((y - min_y) * scale_y))
            )
        ax.imshow(embedding_image)
    else:
        sns.scatterplot(x=embeddings[:, 0], y=embeddings[:, 1], alpha=0.5)

    ax.set_axis_off()

    if display:
        plt.show()
    else:
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close()
        return image


### Dataset

In [None]:
# dataset.PY
from locale import normalize
import imgaug.augmenters as iaa
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
import pandas
import pickle
import PIL
import scipy.spatial.distance
import skimage.transform
import sklearn.metrics
import torchvision.transforms
import torch
import torch.utils.data
import tqdm.auto
import zipfile
import sklearn.preprocessing

class ImageNormalizer:
    def __init__(self, image_size, scale_factor):
        self.image_size = image_size
        self.scale_factor = scale_factor

        self.crop = iaa.Sequential(
            [
                iaa.Resize(scale_factor),
                iaa.CenterCropToFixedSize(image_size, image_size),
            ]
        )

        self.normalize_to_float = iaa.Sequential(
            [
                # Scale to range -1, +1
                iaa.Multiply(2.0 / 255.0),
                iaa.Add(-1.0),
            ]
        )

    def crop_images(self, images):
        images = self.crop.augment_images(images)
        return images

    def floatify_image(self, img):

        if not np.issubdtype(img.dtype, np.floating):
            assert img.max() > 1
            img = img.astype(np.float32)
        else:
            img = 255.0 * img

        img = self.normalize_to_float.augment_image(img)
        return img

    def floatify_images(self, images):
        images = [self.floatify_image(img) for img in images]
        return images

    def normalize_images(self, images):
        return self.floatify_images(self.crop_images(images))
        

class WDDDataset:
    def __init__(
        self,
        paths,
        temporal_dimension=15,
        n_targets=3,
        target_offset=2,
        images_in_archives=True,
        remap_wdd_dir=None,
        image_size=128,
        silently_skip_invalid=True,
        load_wdd_vectors=False,
        load_wdd_durations=False,
        wdd_angles_for_samples=None,
        default_image_scale=0.5,
    ):

        self.load_wdd_vectors = load_wdd_vectors
        self.load_wdd_durations = load_wdd_durations
        self.silently_skip_invalid = silently_skip_invalid
        self.images_in_archives = images_in_archives
        self.sample_gaps = False
        self.all_meta_files = []
        self.wdd_angles_for_samples = wdd_angles_for_samples

        # Count and index waggle information.
        if isinstance(paths, str):
            if paths.endswith(".pickle"):
                with open(paths, "rb") as f:
                    self.all_meta_files = pickle.load(f)["json_files"]
            else:
                paths = [paths]

        if isinstance(paths, list) and str(paths[0]).endswith(".json"):
            self.all_meta_files += paths
        else:
            if not self.all_meta_files:
                for path in paths:
                    self.all_meta_files += list(pathlib.Path(path).glob("**/*.json"))

        print("Found {} waggle folders.".format(len(self.all_meta_files)))

        if remap_wdd_dir:
            for i, path in enumerate(self.all_meta_files):
                path = str(path).replace("/mnt/thekla/", remap_wdd_dir)
                path = pathlib.Path(path)
                self.all_meta_files[i] = path

        self.temporal_dimension = temporal_dimension
        self.n_targets = n_targets
        self.target_offset = target_offset

        self.default_normalizer = ImageNormalizer(image_size=image_size,
                scale_factor=default_image_scale)
        

    def load_and_normalize_image(self, filename):
        img = WDDDataset.load_image(filename)

        img = self.default_normalizer.crop_images(img)
        img = self.default_normalizer.floatify_image(img)

        return img

    @staticmethod
    def load_image(filename):
        img = PIL.Image.open(filename)
        img = np.asarray(img)
        assert img.dtype is np.dtype(np.uint8)
        return img

    @staticmethod
    def load_images(filenames, parent=""):
        return [WDDDataset.load_image(os.path.join(parent, f)) for f in filenames]

    @staticmethod
    def load_images_from_archive(filenames, archive):
        images = []
        for fn in filenames:
            with archive.open(fn, "r") as f:
                images.append(WDDDataset.load_image(f))
        return images

    @staticmethod
    def load_metadata_for_waggle(
        waggle_metadata_path,
        temporal_dimension,
        load_images=True,
        images_in_archives=False,
        gap_factor=1,
        n_targets=0,
        target_offset=1,
        return_center_images=False,
    ):

        waggle_dir = waggle_metadata_path.parent

        with open(waggle_metadata_path, "r") as f:
            waggle_metadata = json.load(f)

        available_frames_length = len(waggle_metadata["frame_timestamps"])
        try:
            waggle_angle = waggle_metadata["waggle_angle"]
            assert np.abs(waggle_angle) < np.pi * 2.0
            waggle_duration = waggle_metadata["waggle_duration"]
        except:
            waggle_angle = np.nan
            waggle_duration = np.nan

        if temporal_dimension is not None:
            target_sequence_length = n_targets * target_offset
            sequence_length = int(
                gap_factor * temporal_dimension + target_sequence_length
            )

            if not return_center_images:
                sequence_start = np.random.randint(
                    0, available_frames_length - sequence_length
                )
            else:
                sequence_center = available_frames_length // 2
                sequence_start = sequence_center - sequence_length // 2

            assert available_frames_length >= target_sequence_length + sequence_length

        def select_images_from_list(images):

            if temporal_dimension is None:
                if return_center_images:
                    n_available_images = len(images)
                    if n_available_images > 32:
                        images = images[
                            (n_available_images // 4) : -(n_available_images // 4)
                        ]
                return images

            if len(images) != available_frames_length:
                print(
                    "N images: {}, available_frames_length: {}".format(
                        len(images), available_frames_length
                    )
                )

            assert len(images) == available_frames_length
            images = images[sequence_start : (sequence_start + sequence_length)]

            targets_start = sequence_length - target_sequence_length

            if n_targets != 0:
                targets = images[targets_start:][::target_offset]
            else:
                targets = []

            if temporal_dimension == sequence_length - target_sequence_length:
                images = images[:targets_start]
            else:
                if return_center_images:
                    mid = len(images) // 2
                    margin = temporal_dimension // 2
                    images = images[(mid - margin) : (mid + margin + 1)]
                else:
                    images = [
                        images[idx]
                        for idx in sorted(
                            np.random.choice(
                                sequence_length - target_sequence_length,
                                size=temporal_dimension,
                                replace=False,
                            )
                        )
                    ]
            return images + targets

        if images_in_archives:
            zip_file_path = os.path.join(waggle_dir, "images.zip")
            if not os.path.exists(zip_file_path):
                print("{} does not exist.".format(zip_file_path))
                return None, None

            try:
                with zipfile.ZipFile(zip_file_path, "r") as zf:
                    images = list(sorted(zf.namelist()))
                    images = select_images_from_list(images)

                    if load_images:
                        images = WDDDataset.load_images_from_archive(images, zf)
            except zipfile.BadZipFile:
                print("ZipFile corrupt: {}".format(zip_file_path))
                return None, None

        else:
            images = list(
                sorted([f for f in os.listdir(waggle_dir) if f.endswith("png")])
            )
            if len(images) == 0:
                print("No images found in folder {}.".format(waggle_dir))
            assert len(images) > 0

            images = select_images_from_list(images)
            if load_images:
                images = WDDDataset.load_images(images, waggle_dir)

        return images, waggle_angle, waggle_duration

    def __len__(self):
        return len(self.all_meta_files)

    def __getitem__(
        self,
        i,
        aug=None,
        return_just_one=False,
        normalize_to_float=False,
        return_center_images=False,
    ):
        waggle_metadata_path = self.all_meta_files[i]

        images, waggle_angle, waggle_duration = WDDDataset.load_metadata_for_waggle(
            waggle_metadata_path,
            self.temporal_dimension,
            images_in_archives=self.images_in_archives,
            n_targets=self.n_targets,
            target_offset=self.target_offset,
            return_center_images=return_center_images,
        )

        if self.wdd_angles_for_samples is not None:
            waggle_angle = self.wdd_angles_for_samples[i]

        if images is None:
            if self.silently_skip_invalid:
                return self[i + 1]
            else:
                return None, None, None
        if return_just_one:
            images = images[:1]
        # images = WDDDataset.load_images(image_filenames, parent=waggle_metadata_path.parent)
        if aug is not None:
            images, waggle_angle = aug(images, waggle_angle)
        else:
            images = self.default_normalizer.crop_images(images)

        if normalize_to_float:
            images = self.default_normalizer.floatify_images(images)

        images = np.stack(images, axis=0)  # Stack over channels.

        if self.load_wdd_vectors:
            waggle_vector = np.zeros(shape=(2,), dtype=np.float32)
            if np.isfinite(waggle_angle):
                waggle_vector[0] = np.cos(waggle_angle)
                waggle_vector[1] = np.sin(waggle_angle)
        else:
            waggle_vector = None

        if not self.load_wdd_durations:
            waggle_duration = None

        return images, waggle_vector, np.float32(waggle_duration)


class BatchSampler:
    def __init__(
        self,
        dataset,
        batch_size,
        image_size=32,
        inflate_dataset_factor=1,
        image_scale_factor=0.5,
        augmentation_per_image=True,
    ):
        self.batch_size = batch_size
        self.dataset = dataset
        self.total_length = len(dataset)
        self.inflate_dataset_factor = int(inflate_dataset_factor)
        self.image_scale_factor = image_scale_factor
        self.augmentation_per_image = augmentation_per_image

        self.augmenters = None
        self.image_size = image_size

    def init_augmenters(self, current_epoch=1, total_epochs=1):

        p = np.clip(
            0.1 + np.log1p(2 * current_epoch / (max(1, total_epochs - 1))), 0, 1
        )

        # p = 0.0

        # These are applied to each image individually and must not rotate e.g. the images.
        self.quality_augmenters = iaa.Sequential(
            [
                iaa.Sometimes(0.55 * p, iaa.GammaContrast((0.9, 1.1))),
                iaa.Sometimes(0.25 * p, iaa.SaltAndPepper(0.01)),
                iaa.Sometimes(0.5 * p, iaa.AdditiveGaussianNoise(scale=(0, 0.1))),
                iaa.Sometimes(0.25 * p, iaa.GaussianBlur(sigma=(0.0, 0.5))),
                iaa.Sometimes(0.25 * p, iaa.Add(value=(-5, 5))),
            ]
        )
        self.rescale = iaa.Sequential(
            [
                # Scale to range -1, +1
                iaa.Multiply(2.0 / 255.0),
                iaa.Add(-1.0),
            ]
        )

        # These are sampled for each batch and applied to all images.
        self.augmenters = iaa.Sequential(
            [
                iaa.Affine(
                    translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)},
                    rotate=0.0,
                    shear=(-5, 5),
                    scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
                ),
                iaa.CropToFixedSize(
                    self.image_size * int(1.0 / self.image_scale_factor),
                    self.image_size * int(1.0 / self.image_scale_factor),
                    position="center",
                ),
                iaa.Resize(self.image_scale_factor),
                iaa.Sometimes(
                    0.25 * p,
                    iaa.Sequential(
                        [
                            iaa.Crop(
                                percent=(0.1, 0.25),
                                sample_independently=False,
                                keep_size=False,
                            ),
                            iaa.PadToFixedSize(
                                self.image_size,
                                self.image_size,
                                position="center",
                            ),
                        ]
                    ),
                ),
            ]
        )

        # self.augmenters = iaa.Sequential([iaa.CropToFixedSize(128, 128, position="center")])

    def __len__(self):
        return (self.total_length * self.inflate_dataset_factor) // self.batch_size

    def __getitem__(self, _):

        if self.augmenters is None:
            self.init_augmenters()

        aug = self.augmenters.to_deterministic()

        def augment_fn(images, *args):
            nonlocal aug
            img_aug = self.quality_augmenters
            if not self.augmentation_per_image:
                # Apply the same augmentation to the whole sequence.
                img_aug = img_aug.to_deterministic()
            images = img_aug.augment_images(images)
            images = self.rescale.augment_images(
                [img.astype(np.float32) for img in images]
            )
            images, angles = BatchSampler.augment_sequence(aug, images, *args)
            return images, angles

        samples, angles, durations = [], [], []
        has_labels = False
        labels = []

        for _ in range(self.batch_size):
            idx = np.random.randint(self.total_length)
            sample_data = self.dataset.__getitem__(idx, aug=augment_fn)
            label = None
            if len(sample_data) == 2:
                images, angle, duration = sample_data
            else:
                images, angle, duration, label = sample_data
                has_labels = True

            samples.append(images)
            angles.append(angle)
            durations.append(duration)
            labels.append(label)

        samples = np.stack(samples, axis=0)
        angles = np.stack(angles, axis=0)
        durations = np.stack(durations, axis=0)

        if not has_labels:
            return samples, angles, durations

        labels = np.stack(labels, axis=0)
        return samples, angles, durations, labels

    @classmethod
    def augment_sequence(self, aug, images, angle, rotate=True):

        rotation = np.random.randint(0, 360)

        for idx, img in enumerate(images):
            if rotate:
                img = skimage.transform.rotate(img, rotation)
            images[idx] = aug.augment_image(img)

        return images, angle + rotation / 180.0 * np.pi


class ValidationDatasetEvaluator:
    def __init__(
        self,
        gt_data_path,
        remap_paths_to="/mnt/thekla/",
        images_in_archives=False,
        image_size=128,
        raw_paths=None,
        temporal_dimension=None,
        return_indices=False,
    ):

        if raw_paths is None:
            self.gt_data_df, paths = ValidationDatasetEvaluator.load_ground_truth_data(
                gt_data_path, remap_paths_to=remap_paths_to
            )
        else:
            paths = raw_paths

        self.dataset = WDDDataset(
            paths,
            images_in_archives=images_in_archives,
            temporal_dimension=temporal_dimension,
            image_size=image_size,
            n_targets=0,
            silently_skip_invalid=False,
        )

        self.return_indices = return_indices

    @staticmethod
    def load_ground_truth_data(gt_data_path, remap_paths_to=""):
        if isinstance(gt_data_path, str):
            with open(gt_data_path, "rb") as f:
                wdd_gt_data = pickle.load(f)
                gt_data_df = [(key,) + v for key, v in wdd_gt_data.items()]
        else:
            gt_data_df = gt_data_path

        gt_data_df = pandas.DataFrame(
            gt_data_df, columns=["waggle_id", "label", "gt_angle", "path"]
        )
        paths = list(gt_data_df.path.values)

        if remap_paths_to:

            def rewrite(p):
                p = str(p).replace("/mnt/curta/storage/beesbook/wdd/", remap_paths_to)
                p = pathlib.Path(p)
                return p

            paths = [rewrite(p) for p in paths]

        return gt_data_df, paths

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        batch_images, vectors, durations = self.dataset.__getitem__(
            i, normalize_to_float=True, return_center_images=True
        )

        if not self.return_indices:
            return batch_images
        return i, batch_images

    def get_images_and_embeddings(
        self,
        model,
        use_last_state=False,
        show_progress=True,
        get_sample_images=True,
        augment_images=False,
    ):

        augmentations = [None]
        if augment_images:
            augmentations = [
                None,
                iaa.Fliplr(1.0),
                iaa.Flipud(1.0),
                iaa.Sequential([iaa.Fliplr(1.0), iaa.Flipud(1.0)]),
            ]

        model.eval()
        with torch.no_grad():
            embeddings = []
            sample_images = []
            labels = []

            trange = range(len(self.dataset))
            if show_progress:
                trange = tqdm.auto.tqdm(trange)

            for i in trange:
                original_batch_images = self[i]

                for aug in augmentations:
                    # Collapse batch dimension.
                    batch_images = original_batch_images.copy()

                    if aug is not None:
                        batch_images = aug.augment_images(batch_images)

                    # Add batch dimension.
                    batch_images = batch_images[None, :, :, :]

                    if get_sample_images:
                        temp_dimension = batch_images.shape[0]
                        sample_images.append(batch_images[0, temp_dimension // 2])

                    batch_images = torch.from_numpy(batch_images).cuda()

                    _, embedding = model.embed_sequence(
                        batch_images,
                        return_full_state=not use_last_state,
                        check_length=False,
                    )

                    if not use_last_state and model.lstm is not None:
                        embedding = torch.mean(embedding[:, 0], dim=0)

                    embedding = embedding.detach().cpu().numpy().flatten()

                    embeddings.append(embedding)
                    labels.append(self.gt_data_df.label.iloc[i])

        embeddings = np.array(embeddings)

        return sample_images, embeddings, labels

    def plot_embeddings(self, sample_images, embeddings, labels, **kwargs):

        from bb_wdd_filter.visualization import plot_embeddings

        return plot_embeddings(
            embeddings=embeddings,
            indices=np.arange(len(self.dataset)),
            dataset=self.dataset,
            images=sample_images,
            labels=labels,
            **kwargs,
        )

    def calculate_scores(self, embeddings, labels):

        import sklearn.linear_model
        import sklearn.preprocessing
        import sklearn.dummy

        unique_labels = list(sorted(np.unique(labels)))
        label_encoder = lambda l: np.array([unique_labels.index(x) for x in l])
        reg_model = sklearn.linear_model.LogisticRegression()

        X = embeddings
        Y = label_encoder(labels)

        scores = dict()

        from sklearn.metrics import make_scorer
        import sklearn.metrics

        scorers = dict(
            accuracy=make_scorer(sklearn.metrics.accuracy_score),
            f1=make_scorer(sklearn.metrics.f1_score, average="macro"),
            roc_auc_score=make_scorer(
                sklearn.metrics.roc_auc_score, multi_class="ovr", needs_proba=True
            ),
        )

        for label in ("all", "waggle"):
            _Y = Y
            if label != "all":
                target_label = unique_labels.index(label)
                _Y = (Y == target_label).astype(int)

            cv_results = sklearn.model_selection.cross_validate(
                reg_model, X, _Y, scoring=scorers, cv=10
            )

            for metric_name, metric_results in cv_results.items():
                if not metric_name.startswith("test_"):
                    continue
                scores[f"{label}_{metric_name}"] = np.mean(metric_results)

        return scores

    def evaluate(self, model, embed_kwargs={}, plot_kwargs={}):

        images, embeddings, labels = self.get_images_and_embeddings(
            model, **embed_kwargs
        )
        scores = self.calculate_scores(embeddings, labels)
        plot = self.plot_embeddings(images, embeddings, labels, **plot_kwargs)

        return scores, plot


class SupervisedDataset:
    def __init__(
        self,
        gt_paths,
        image_size=32,
        temporal_dimension=40,
        remap_paths_to="/mnt/thekla/",
        images_in_archives=False,
        **kwargs,
    ):

        self.gt_data_df, self.paths = ValidationDatasetEvaluator.load_ground_truth_data(
            gt_paths, remap_paths_to=remap_paths_to
        )
        self.dataset = WDDDataset(
            self.paths,
            images_in_archives=images_in_archives,
            temporal_dimension=temporal_dimension,
            image_size=image_size,
            n_targets=0,
            silently_skip_invalid=False,
            wdd_angles_for_samples=self.gt_data_df.gt_angle.values,
            **kwargs,
        )

        labels = self.gt_data_df.label.copy()
        labels[labels == "trembling"] = "other"
        self.all_labels = ["other", "waggle", "ventilating", "activating"]
        label_mapper = {s: i for i, s in enumerate(self.all_labels)}
        self.Y = np.array([label_mapper[l] for l in labels])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i, **kwargs):
        images, vector, duration = self.dataset.__getitem__(i, **kwargs)
        label = self.Y[i]

        # Add empty channel dimension.
        images = np.expand_dims(images, 0)

        return images, vector, duration, label


class SupervisedValidationDatasetEvaluator:
    def __init__(
        self,
        gt_data_path,
        remap_paths_to="/mnt/thekla/",
        images_in_archives=False,
        image_size=128,
        temporal_dimension=None,
        return_indices=False,
        default_image_scale=0.25,
        class_labels=["other", "waggle", "ventilating", "activating"]
    ):

        self.dataset = SupervisedDataset(
            gt_data_path,
            images_in_archives=images_in_archives,
            image_size=image_size,
            load_wdd_vectors=True,
            load_wdd_durations=True,
            remap_paths_to=remap_paths_to,
            default_image_scale=default_image_scale,
        )

        self.return_indices = return_indices
        self.class_labels = class_labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):

        item = self.dataset.__getitem__(
            i, normalize_to_float=True, return_center_images=True
        )

        if self.return_indices:
            return i, item
        return item

    def evaluate(self, model, plot_kwargs=dict()):

        dataloader = torch.utils.data.DataLoader(
            self, num_workers=0, batch_size=16, shuffle=False, drop_last=False
        )

        all_classes_hat = []
        all_vectors_hat = []
        all_durations_hat = []
        all_classes = self.dataset.Y
        all_vectors = []
        all_durations = []

        for (images, vectors, durations, _) in dataloader:
            predictions = model(images.cuda())
            assert predictions.shape[2] == 1
            assert predictions.shape[3] == 1
            assert predictions.shape[4] == 1
            predictions = predictions[:, :, 0, 0, 0]

            n_classes = 4
            classes_hat = predictions[:, :n_classes]
            vectors_hat = predictions[:, n_classes : (n_classes + 2)]
            durations_hat = predictions[:, (n_classes + 2) : (n_classes + 3)]

            vectors_hat = torch.tanh(vectors_hat)
            durations_hat = torch.relu(durations_hat)

            classes_hat = torch.nn.functional.softmax(classes_hat, dim=1)

            classes_hat = classes_hat.detach().cpu().numpy()
            vectors_hat = vectors_hat.detach().cpu().numpy()
            durations_hat = durations_hat.detach().cpu().numpy()

            all_classes_hat.append(classes_hat)
            all_vectors_hat.append(vectors_hat)
            all_durations_hat.append(durations_hat)
            all_vectors.append(vectors)
            all_durations.append(durations)

        all_classes_hat = np.concatenate(all_classes_hat, axis=0)
        all_classes_hat_argmax = np.argmax(all_classes_hat, axis=1)
        all_vectors_hat = np.concatenate(all_vectors_hat, axis=0)
        all_durations_hat = np.concatenate(all_durations_hat, axis=0)
        all_vectors = np.concatenate(all_vectors, axis=0)
        all_durations = np.concatenate(all_durations, axis=0)

        metrics = dict()
        metrics["test_balanced_accuracy"] = sklearn.metrics.balanced_accuracy_score(
            all_classes, all_classes_hat_argmax, adjusted=True
        )
        try:
            metrics["test_roc_auc_score"] = sklearn.metrics.roc_auc_score(
                all_classes, all_classes_hat, multi_class="ovr"
            )
        except ValueError as e:
            metrics["test_roc_auc_score"] = np.nan
            
        metrics["test_matthews"] = sklearn.metrics.matthews_corrcoef(
            all_classes, all_classes_hat_argmax
        )
        metrics["test_f1_weighted"] = sklearn.metrics.f1_score(
            all_classes, all_classes_hat_argmax, average="weighted"
        )

        metrics["test_angle_cosine"] = 1.0 - np.mean(
            [
                scipy.spatial.distance.cosine(a, b)
                for (a, b) in zip(all_vectors, all_vectors_hat)
            ]
        )

        for i in range(1, len(self.class_labels)):
            label = self.class_labels[i]
            Y_hat = all_classes_hat_argmax == i
            Y = all_classes == i

            metrics[f"test_precision_{label}"] = sklearn.metrics.precision_score(Y, Y_hat)
            metrics[f"test_recall_{label}"] = sklearn.metrics.recall_score(Y, Y_hat)

            

        idx = ~pandas.isnull(all_durations)
        all_durations = all_durations[idx]
        all_durations_hat = all_durations_hat[idx]
        metrics["test_duration_mse"] = sklearn.metrics.mean_squared_error(
            all_durations, all_durations_hat
        )

        return metrics


class WDDDatasetWithIndicesAndNormalized:
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        item = self.dataset.__getitem__(
            i,
            return_just_one=False,
            normalize_to_float=True,
            return_center_images=True,
        )
        return i, item


### Loss

In [None]:
# loss.py

import torch


def calculate_cpc_loss(encodings, predictions, detach_accuracies=True):

    assert isinstance(encodings, list)
    assert isinstance(predictions, list)

    n_timesteps = len(encodings)
    batch_size = encodings[0].shape[0]

    nce_loss = 0.0
    accuracies = []

    for i in range(n_timesteps):
        encoding = encodings[i]
        prediction = predictions[i]
        prediction = torch.swapaxes(prediction, 0, 1)

        projections = torch.mm(encoding, prediction)
        assert projections.shape[0] == batch_size
        assert projections.shape[1] == batch_size

        logs_projections = torch.nn.functional.log_softmax(projections, dim=1)

        # Count the number of times the highest element is on the diagonal.
        hits = logs_projections.argmax(dim=0) == torch.arange(
            batch_size, device=logs_projections.device
        )
        hits = hits.float().mean()
        if detach_accuracies:
            hits = hits.detach()

        accuracies.append(hits)

        # Now the InfoNCE loss.
        nce = torch.diag(logs_projections).mean()

        nce_loss += -1.0 * nce / n_timesteps

    losses = {f"acc_t{i}": acc for (i, acc) in enumerate(accuracies)}
    losses["nce_loss"] = nce_loss

    return losses


### Models Supervised

In [None]:
# models_supervised.py

import numpy as np
import sklearn.metrics
import torch
import torch.nn
import torch.utils
import torchvision.transforms.functional

DEFAULT_CLASS_LABELS = ["other", "waggle", "ventilating", "activating"]

class TensorView(torch.nn.Module):
    def __init__(self, *shape):
        self.shape = shape

    def forward(self, t):
        return t.view(*self.shape)


class WDDClassificationModel(torch.nn.Module):
    def __init__(
        self,
        n_outputs=7,
        temporal_dimension=40,
        image_size=32,
        scaledown_factor=4,
        inplace=False,
    ):

        super().__init__()

        center_stride = image_size // 32
        center_padding = 2 if image_size == 32 else 0

        if temporal_dimension == 60:
            center_temporal_stride = 2
            center_temporal_kernel_size = 3
        else:
            assert temporal_dimension == 40
            center_temporal_stride = 1
            center_temporal_kernel_size = 5

        s = scaledown_factor

        self.seq = [
            torch.nn.Conv3d(1, 128 // s, kernel_size=5, stride=1, padding=0),
            torch.nn.BatchNorm3d(128 // s),
            torch.nn.Mish(inplace=inplace),
            # 36/56 x 28 - 56 x 60
            torch.nn.Conv3d(
                128 // s, 64 // s, kernel_size=3, stride=1, padding=0, dilation=2
            ),
            torch.nn.BatchNorm3d(64 // s),
            torch.nn.Mish(inplace=inplace),
            # 32/52 x 24 - 52 x 56
            torch.nn.Conv3d(
                64 // s,
                64 // s,
                kernel_size=(5, 3, 3),
                stride=2,
                padding=(3, 1, 1),
                dilation=(2, 1, 1),
            ),
            torch.nn.BatchNorm3d(64 // s),
            torch.nn.Mish(inplace=inplace),
            # 15/25 x 12 - 25 x 28
            torch.nn.Conv3d(
                64 // s,
                64 // s,
                kernel_size=5,
                stride=(1, center_stride, center_stride),
                padding=(2, center_padding, center_padding),
                dilation=1,
            ),
            torch.nn.BatchNorm3d(64 // s),
            torch.nn.Mish(inplace=inplace),
            # 15/25 x 12 - 25 x 12
            torch.nn.Conv3d(
                64 // s,
                128 // s,
                kernel_size=(center_temporal_kernel_size, 3, 3),
                stride=(center_temporal_stride, 2, 2),
                padding=(0, 1, 1),
                dilation=1,
            ),
            torch.nn.BatchNorm3d(128 // s),
            torch.nn.FeatureAlphaDropout(),
            torch.nn.Mish(inplace=inplace),
            # 12 x 6 - 12 x 6
            torch.nn.Conv3d(
                128 // s,
                128 // s,
                kernel_size=3,
                stride=(2, 1, 1),
                padding=(0, 1, 1),
                dilation=(1, 2, 2),
            ),
            torch.nn.BatchNorm3d(128 // s),
            torch.nn.GLU(dim=1),
            torch.nn.Mish(inplace=inplace),
            # 5 x 4 - 5 x 4
            torch.nn.Conv3d(64 // s, n_outputs, kernel_size=(5, 4, 4)),
        ]

        self.seq = torch.nn.Sequential(*self.seq)

    def postprocess_predictions(self, all_outputs, return_raw=False, as_numpy=False):
        
        n_classes = 4

        classes_hat = all_outputs[:, :n_classes]
        vectors_hat = all_outputs[:, n_classes : (n_classes + 2)]
        durations_hat = all_outputs[:, (n_classes + 2)]

        confidences = None

        if not return_raw:
            probabilities = torch.nn.functional.softmax(classes_hat, 1)
            classes_hat = torch.argmax(probabilities, 1)
            confidences = probabilities[np.arange(probabilities.shape[0]), classes_hat]
            vectors_hat = torch.tanh(vectors_hat)
            durations_hat = torch.relu(durations_hat)
        
        if as_numpy:
            classes_hat = classes_hat.detach().cpu().numpy()
            vectors_hat = vectors_hat.detach().cpu().numpy()
            durations_hat = durations_hat.detach().cpu().numpy()
            if confidences is not None:
                confidences = confidences.detach().cpu().numpy()

        return classes_hat, vectors_hat, durations_hat, confidences

    def forward(self, images):
        if self.training:
            images.requires_grad = True
            output = torch.utils.checkpoint.checkpoint_sequential(self.seq, 4, images)
        else:
            output = self.seq(images)

        if self.training:
            shape_correct = (
                output.shape[2] == 1 and output.shape[3] == 1 and output.shape[4] == 1
            )
            if not shape_correct:
                raise ValueError(
                    "Incorrect output shape: {} [input shape was {}]".format(
                        output.shape, images.shape
                    )
                )
            output = output[:, :, 0, 0, 0]

        return output

    def load_state_dict(self, d):
        try:
            return super().load_state_dict(d)
        except Exception as e:
            print("Failed to load. Trying without DataParallel prefix.")
        # Strip off Wrapper & DataParallel prefix.
        d = {key.replace("model.module.", ""): v for key, v in d.items()}
        return super().load_state_dict(d)


# To support DataParallel.
class SupervisedModelTrainWrapper(torch.nn.Module):
    def __init__(
        self, model, class_labels=DEFAULT_CLASS_LABELS
    ):

        super().__init__()

        self.vector_similarity = torch.nn.CosineSimilarity(dim=1)
        self.mse = torch.nn.MSELoss()
        self.classification_loss = torch.nn.CrossEntropyLoss()
        self.class_labels = class_labels

        self.model = model
    
    def forward(self, x):
        return self.model(x)

    def calc_additional_metrics(
        self, predictions, labels, vectors_hat, vectors, durations_hat, durations
    ):

        import wandb

        results = dict()

        predictions = torch.nn.functional.softmax(predictions, dim=1)

        predictions = predictions.detach().cpu().numpy()
        predicted_labels = np.argmax(predictions, axis=1)

        labels = labels.detach().cpu().numpy()

        results["train_conf"] = wandb.plot.confusion_matrix(
            probs=None,
            y_true=labels,
            preds=predicted_labels,
            class_names=self.class_labels,
        )

        try:
            results["train_roc_auc"] = sklearn.metrics.roc_auc_score(
                labels,
                predictions,
                multi_class="ovr",
                labels=np.arange(len(self.class_labels)),
            )
        except:
            pass

        results["train_matthews"] = sklearn.metrics.matthews_corrcoef(
            labels, predicted_labels
        )

        results["train_balanced_accuracy"] = sklearn.metrics.balanced_accuracy_score(
            labels, predicted_labels, adjusted=True
        )

        results["train_f1_weighted"] = sklearn.metrics.f1_score(
            labels, predicted_labels, average="weighted"
        )

        if vectors_hat is not None:
            divergence = self.vector_similarity(vectors_hat, vectors)
            results["vector_cossim"] = torch.mean(divergence)

            vectors_hat = torch.tanh(vectors_hat)
            divergence = self.mse(vectors_hat, vectors)
            results["vector_mse"] = torch.mean(divergence)

        return results

    def run_batch(self, images, vectors, durations, labels):
        # print(images.dtype, vectors.dtype, durations.dtype)
        batch_size, temp_dimension = images.shape[:2]
        model = self.model

        all_outputs = model(images)
        classes_hat, vectors_hat, durations_hat, _ = model.postprocess_predictions(all_outputs, return_raw=True)        

        losses = dict()

        losses["classification_loss"] = self.classification_loss(classes_hat, labels)
        # losses["classification_loss"] = self.mse(classes_hat, labels)

        if vectors is not None:
            other_target = 0
            valid_indices = labels != other_target

            if torch.any(valid_indices):
                vectors_hat = vectors_hat[valid_indices]
                vectors = vectors[valid_indices]

                vectors_hat = torch.tanh(vectors_hat)
                # divergence = 1.0 - self.vector_similarity(vectors, vectors_hat)
                divergence = self.mse(vectors, vectors_hat)
                losses["vector_loss"] = 1.5 * torch.mean(divergence)
            else:
                vectors_hat = None
                vectors = None

        if durations is not None:
            waggle_target = 1
            valid_indices = (labels == waggle_target) & (~torch.isnan(durations))

            if torch.any(valid_indices):
                durations_hat = durations_hat[valid_indices]
                durations = durations[valid_indices]

                durations_hat = torch.relu(durations_hat)
                divergence = self.mse(durations, durations_hat)
                losses["duration_loss"] = 1.0 * torch.mean(divergence)
            else:
                durations_hat = None
                durations = None

        with torch.no_grad():
            losses["additional"] = self.calc_additional_metrics(
                classes_hat, labels, vectors_hat, vectors, durations_hat, durations
            )

        return losses


### Models

In [None]:
# models.py

import numpy as np
import torch
import torch.nn
import torch.utils
import torchvision.transforms.functional

# from .loss import calculate_cpc_loss


class SubsampleBlock(torch.nn.Module):
    def __init__(
        self, n_channels, n_mid_channels=64, n_out_channels=64, subsample=True
    ):
        super().__init__()

        stride = 1 if not subsample else 2

        norm = torch.nn.utils.spectral_norm
        self.seq = torch.nn.Sequential(
            norm(
                torch.nn.Conv2d(
                    n_channels, n_mid_channels, kernel_size=3, padding=1, dilation=1
                )
            ),
            torch.nn.GroupNorm(8, n_mid_channels),
            torch.nn.GLU(dim=1),
            torch.nn.Mish(),
            # torch.nn.BatchNorm2d(n_mid_channels // 2),
            norm(
                torch.nn.Conv2d(
                    n_mid_channels // 2,
                    n_out_channels,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                )
            ),
            torch.nn.GroupNorm(8, n_out_channels),
            torch.nn.Mish(),
            # torch.nn.BatchNorm2d(n_out_channels),
        )

    def forward(self, x):
        return self.seq(x)


class EmbeddingModel(torch.nn.Module):
    def __init__(self, n_channels=1, temporal_length=15, n_targets=3, image_size=128):
        super().__init__()

        self.temporal_length = temporal_length
        self.n_targets = n_targets

        n_mid_channels = 64
        norm = torch.nn.utils.spectral_norm

        embedding_size = 256
        hidden_state_size = 64
        f = 2
        self.embedding = torch.nn.Sequential(
            # 128
            SubsampleBlock(
                n_channels, 32, 96 // 2 // f, subsample=image_size >= 128
            ),  # 64
            SubsampleBlock(
                96 // 2 // f, 128 // f, 128 // f, subsample=image_size >= 64
            ),  # 32
            SubsampleBlock(
                128 // f, 128 // f, 256 // f, subsample=image_size >= 32
            ),  # 16
            SubsampleBlock(256 // f, 256 // f, 512 // f),  # 8
            SubsampleBlock(512 // f, embedding_size, 2 * embedding_size),  # 4
            norm(
                torch.nn.Conv2d(
                    2 * embedding_size,
                    embedding_size,
                    kernel_size=4,
                )
            ),
        )

        # Input size: b, embedding_size,  temporal_length
        if False:
            self.lstm = None
            self.sequential_embedding = []
            current_length = self.temporal_length
            current_hidden_size = embedding_size
            while current_length >= 4:
                self.sequential_embedding += [
                    torch.nn.Conv1d(
                        current_hidden_size,
                        current_hidden_size * 2,
                        kernel_size=3,
                        stride=1,
                    ),
                    torch.nn.GroupNorm(8, current_hidden_size * 2),
                    torch.nn.Mish(),
                ]
                current_length -= 2
                current_length //= 2

                out_size = (
                    current_hidden_size * 2
                    if current_length >= 4
                    else hidden_state_size
                )
                self.sequential_embedding += [
                    torch.nn.Conv1d(
                        current_hidden_size * 2,
                        out_size,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                    ),
                    torch.nn.GroupNorm(8, out_size),
                    torch.nn.Mish(),
                ]

                current_hidden_size = out_size

            print(current_length)
            self.sequential_embedding += [torch.nn.AvgPool1d(current_length)]
            self.sequential_embedding = torch.nn.Sequential(*self.sequential_embedding)
        else:
            self.lstm = torch.nn.LSTM(
                input_size=embedding_size,
                hidden_size=hidden_state_size,
                batch_first=False,
            )

        self.predictors = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    # torch.nn.Linear(hidden_state_size, hidden_state_size // 2),
                    # torch.nn.GroupNorm(1, hidden_state_size // 2),
                    # torch.nn.Mish(),
                    torch.nn.Linear(hidden_state_size, embedding_size),
                    # torch.nn.LeakyReLU(),
                )
                for _ in range(self.n_targets)
            ]
        )

        self.direction_vector_regressor = torch.nn.Sequential(
            torch.nn.Linear(hidden_state_size, hidden_state_size * 2),
            torch.nn.GroupNorm(8, hidden_state_size * 2),
            torch.nn.Mish(),
            torch.nn.Linear(hidden_state_size * 2, 2),
            torch.nn.Tanh(),
        )

        self.vector_similarity = torch.nn.CosineSimilarity(dim=1)

    def predict_waggle_direction(self, hidden_state):
        directions = self.direction_vector_regressor(hidden_state)
        lengths = torch.linalg.vector_norm(directions, dim=1) + 1e-3
        directions = directions / lengths.unsqueeze(1)
        return directions

    def embed(self, images):
        if self.training:
            embedding = torch.utils.checkpoint.checkpoint(self.embedding, images)
        else:
            embedding = self.embedding(images)
        return embedding

    def calculate_image_embeddings_for_image_sequences(self, images):

        temporal_length = self.temporal_length or images.shape[1]

        embeddings = []
        for i in range(temporal_length):
            e = self.embed(images[:, i : (i + 1), :, :])
            embedding_size = e.shape[1]
            assert e.shape[2] == 1 and e.shape[3] == 1
            e = e[:, :, 0, 0]
            embeddings.append(e)

        embeddings = torch.stack(embeddings, dim=0)

        return embeddings

    def embed_sequence(self, images, return_full_state=False, check_length=True):

        assert (
            (not check_length)
            or (self.temporal_length is None)
            or (images.shape[1] == self.temporal_length)
        )

        embeddings = self.calculate_image_embeddings_for_image_sequences(images)

        if self.lstm is not None:
            out, hidden_states = self.lstm(embeddings)

            if not return_full_state:
                out = out[-1]  # Last sequence state.

        else:
            e = torch.transpose(embeddings, 0, 1)
            e = torch.transpose(e, 1, 2)

            out = self.sequential_embedding(e)

            if not return_full_state:
                out = out[:, :, -1]

        return embeddings, out

    def forward(self, images):

        image_embeddings, sequential_embeddings = self.embed_sequence(images)
        predictions = [
            predictor(sequential_embeddings) for predictor in self.predictors
        ]

        return image_embeddings, sequential_embeddings, predictions

    def run_batch(self, images, vectors, durations=None, labels=None):
        batch_size, L = images.shape[:2]
        base = images[:, : -len(self.predictors), :, :]

        assert base.shape[1] == self.temporal_length

        target_embeddings = []
        targets = images[:, -len(self.predictors) :, :, :]
        for idx, predictor in enumerate(self.predictors):

            target_embedding = self.embed(targets[:, idx : (idx + 1)])
            target_embedding = target_embedding[
                :, :, 0, 0
            ]  # Collapse dimensions of length 1.
            target_embeddings.append(target_embedding)

        image_embeddings, sequential_embeddings, predictions = self(base)
        losses = calculate_cpc_loss(target_embeddings, predictions)

        # Add rotation invariance losses.
        angles = np.arange(0, 360 - 45, 15) / 180 * np.pi
        angles = np.random.choice(angles, 2, replace=False)

        n_angles = angles.shape[0]
        rotation_loss = 0.0

        for angle in angles:
            rotated = torchvision.transforms.functional.rotate(images, angle=angle)
            rotated_embeddings = self.calculate_image_embeddings_for_image_sequences(
                rotated
            )

            difference = 1.0 - self.vector_similarity(
                image_embeddings, rotated_embeddings
            )
            # difference = torch.abs(image_embeddings - rotated_embeddings).mean()
            rotation_loss += difference.mean()

        losses["rotation_inv_loss"] = 100 * rotation_loss / n_angles

        if vectors is not None:
            valid_indices = ~torch.all(torch.abs(vectors) < 1e-4, dim=1)
            if torch.any(valid_indices):
                vectors_hat = self.predict_waggle_direction(
                    sequential_embeddings[valid_indices]
                )
                divergence = 1.0 - self.vector_similarity(
                    vectors[valid_indices], vectors_hat
                )
                losses["vector_loss"] = torch.mean(divergence)

        return losses


### Trainer

In [None]:
### trainer.py

import madgrad
import numpy as np
import pandas
import pathlib
import shutil
import torch
import tqdm.auto

#from .dataset import BatchSampler
#from .visualization import plot_embeddings, sample_embeddings


class Trainer:
    def __init__(
        self,
        dataset,
        model,
        batch_size=32,
        use_wandb=None,
        wandb_config=dict(),
        save_path="warn",
        save_every_n_batches=None,
        save_every_n_samples=25000,
        eval_test_set_every_n_samples=None,
        num_workers=16,
        continue_training=True,
        image_size=128,
        test_set_evaluator=None,
        batch_sampler_kwargs=dict(),
        max_lr=0.001,
        batches_to_reach_maximum_augmentation=2000,
        run_batch_fn=None,
    ):
        def init_worker(ID):

            import torch
            import numpy as np

            np.random.seed(torch.initial_seed() % 2 ** 32)

            import imgaug

            imgaug.seed((torch.initial_seed() + 1) % 2 ** 32)

        self.dataset = dataset
        self.batch_sampler = BatchSampler(
            dataset, batch_size, image_size=image_size, **batch_sampler_kwargs
        )
        self.dataloader = torch.utils.data.DataLoader(
            self.batch_sampler,
            num_workers=num_workers,
            batch_size=None,
            batch_sampler=None,
            pin_memory=True,
            shuffle=True,
            worker_init_fn=init_worker,
        )
        self.model = model

        self.optimizer = madgrad.MADGRAD(self.model.parameters(), lr=0.001)
        self.max_lr = max_lr
        self.batches_to_reach_maximum_augmentation = (
            batches_to_reach_maximum_augmentation
        )
        self.run_batch_fn = run_batch_fn

        self.use_wandb = use_wandb
        self.wandb_config = wandb_config
        if self.use_wandb is None:
            self.use_wandb = len(wandb_config) > 0

        if self.use_wandb:
            import wandb

            self.id = wandb.util.generate_id()
            self.wandb_initialized = False
        else:
            self.id = None

        if save_path == "warn":
            print("Warning: No model save path given. Model will not be saved.")
            save_path = None

        if save_every_n_batches is None:
            save_every_n_batches = save_every_n_samples // batch_size
        if eval_test_set_every_n_samples is None:
            self.eval_test_set_every_n_batches = save_every_n_batches
        else:
            self.eval_test_set_every_n_batches = (
                eval_test_set_every_n_samples // batch_size
            )

        self.save_path = save_path
        self.save_every_n_batches = save_every_n_batches
        self.test_set_evaluator = test_set_evaluator
        self.total_batches = 0
        self.total_epochs = 0

        self.continue_training = continue_training

        if continue_training:
            self.load_checkpoint()

    def run_batch(self, images, vectors, durations=None, labels=None):

        current_state = dict()

        self.model.train()
        images = images.cuda(non_blocking=True)
        if vectors is not None and not np.any(pandas.isnull(vectors)):
            vectors = vectors.cuda(non_blocking=True)
        else:
            vectors = None

        if durations is not None and not np.all(pandas.isnull(durations)):
            durations = durations.cuda(non_blocking=True)
        else:
            durations = None

        if labels is not None:
            labels = labels.cuda(non_blocking=True)

        self.optimizer.zero_grad()
        if self.run_batch_fn is not None:
            losses = self.run_batch_fn(self.model, images, vectors, durations, labels)
        else:
            losses = self.model.run_batch(images, vectors, durations, labels)

        total_loss = 0.0
        for loss_name, value in losses.items():
            if isinstance(value, dict):
                current_state = {**current_state, **value}
                continue

            if value.requires_grad:
                total_loss += value

            current_state[loss_name] = float(value.detach().cpu().numpy())

        total_loss.backward()
        self.optimizer.step()

        return current_state

    def check_init_wandb(self):
        if self.use_wandb:
            import wandb

            if not self.wandb_initialized:
                self.wandb_initialized = True
                wandb.init(
                    id=self.id,
                    resume="allow" if self.continue_training else False,
                    **self.wandb_config
                )

                config = wandb.config
                config["optimizer"] = type(self.optimizer).__name__

    def check_scale_augmenters(self):
        if self.total_batches % 100 == 0:
            # Scale augmentation.
            self.batch_sampler.init_augmenters(
                current_epoch=self.total_batches,
                total_epochs=self.batches_to_reach_maximum_augmentation,
            )

    def save_at_n_batches(self):
        if self.save_path is not None:
            tqdm.auto.tqdm.write(
                "Saving model state at batch {}..".format(self.total_batches)
            )
            self.save_state()

    def sample_and_save_embedding(self):
        import wandb

        self.model.eval()

        loss_info = dict()

        if self.test_set_evaluator is not None:
            scores, plot = self.test_set_evaluator.evaluate(
                self.model, plot_kwargs=dict(display=False)
            )
            loss_info = {**loss_info, **scores}
            loss_info["embedding"] = wandb.Image(plot)

        else:
            e, idx = sample_embeddings(self.model, self.dataset)
            img = plot_embeddings(
                e, idx, self.dataset, scatterplot=False, display=False
            )
            loss_info["embedding"] = wandb.Image(img)

        self.model.train()

        return loss_info

    def run_epoch(self):

        if self.use_wandb:
            import wandb

        self.check_init_wandb()

        n_batches = len(self.batch_sampler)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer, self.max_lr, total_steps=n_batches
        )

        for _, batch in enumerate(tqdm.auto.tqdm(self.dataloader, leave=False)):

            self.check_scale_augmenters()

            loss_info = self.run_batch(*batch)

            self.total_batches += 1

            if (self.total_batches + 1) % (self.save_every_n_batches + 1) == 0:
                self.save_at_n_batches()

            if (self.total_batches + 1) % (self.eval_test_set_every_n_batches + 1) == 0:
                additional_vars = None

                if self.use_wandb:
                    with torch.no_grad():
                        additional_vars = self.sample_and_save_embedding()

                if additional_vars is not None:
                    loss_info = {**loss_info, **additional_vars}

            scheduler.step()

            if self.use_wandb:
                loss_info["learning_rate"] = scheduler._last_lr
                wandb.log(loss_info)

    def run_epochs(self, n):
        for i in range(n):

            self.run_epoch()
            self.total_epochs += 1

            if self.save_path is not None:
                print("Saving model state after epoch {}..".format(i), flush=True)
                self.save_state(copy_suffix="_epoch{:03d}".format(self.total_epochs))

    def save_state(self, copy_suffix=None):

        model_state_dict = self.model.state_dict()

        state = dict(
            model=model_state_dict,
            wandb_id=self.id,
            total_batches=self.total_batches,
            total_epochs=self.total_epochs,
        )
        torch.save(state, self.save_path)

        if copy_suffix:
            ext = pathlib.Path(self.save_path).suffix
            copy_path = self.save_path[: -len(ext)] + str(copy_suffix) + ext
            shutil.copy(self.save_path, copy_path)

    def load_checkpoint(self):
        print("Loading last checkpoint...")
        state = torch.load(self.save_path)

        self.id = state["wandb_id"]
        self.total_batches = state["total_batches"]
        self.total_epochs = state["total_epochs"]
        self.model.load_state_dict(state["model"])


### Trainer Supervised

In [None]:
# trainer_supervised.py
import madgrad
import numpy as np
import pandas
import pathlib
import shutil
import torch
import tqdm.auto

# from .trainer import Trainer
# from .dataset import BatchSampler
# from .visualization import plot_embeddings, sample_embeddings
# from .models_supervised import SupervisedModelTrainWrapper


class SupervisedTrainer(Trainer):
    def __init__(self, dataset, model, *args, batch_sampler_kwargs=dict(), **kwargs):

        model = SupervisedModelTrainWrapper(model)

        super().__init__(
            dataset,
            model,
            *args,
            batch_sampler_kwargs={
                **dict(inflate_dataset_factor=50),
                **batch_sampler_kwargs,
            },
            **kwargs
        )

    def sample_and_save_embedding(self):

        self.model.eval()

        loss_info = dict()

        if self.test_set_evaluator is not None:
            scores = self.test_set_evaluator.evaluate(
                self.model, plot_kwargs=dict(display=False)
            )
            loss_info = {**loss_info, **scores}

        self.model.train()

        return loss_info


In [None]:
from matplotlib import get_data_path
### train_model.py

import argparse

import pickle
import numpy as np
import os
import torch.nn

#import bb_wdd_filter.dataset
#import bb_wdd_filter.models_supervised
#import bb_wdd_filter.trainer_supervised
#import bb_wdd_filter.visualization


def run(
    gt_data_path,
    checkpoint_path=None,
    continue_training=True,
    epochs=1000,
    remap_wdd_dir=None,
    image_size=32,
    images_in_archives=True,
    multi_gpu=False,
    image_scale=0.5,
    batch_size="auto",
    max_lr=0.002 * 8,
    wandb_entity=None,
    wandb_project="wdd-image-classification",
):
    """
    Arguments:
        gt_data_path (string)
            Path to the .pickle file containing the ground-truth labels and paths.
        remap_wdd_dir (string, optional)
            Prefix of the path where the image data is saved. The paths in gt_data_path
            will be changed to point to this directory instead.
        images_in_archives (bool)
            Whether the images of the single waggle frames are saved withing an images.zip
            file in each WDD subdirectory.
        checkpoint_path (string, optional)
            Filename to which the model will be saved regularly during training.
            The model will be saved on every epoch AND every X batches.
        continue_training (bool)
            Whether to try to continue training from last checkpoint. Will use the same
            wandb run ID. Auto set to "false" in case no checkpoint is found.
        epochs (int)
            Number of epochs to train for.
            As the model is saved after every epoch in 'checkpoint_path' and as the logs are
            streamed live to wandb.ai, it's save to interrupt the training after any epoch.
        image_size (int)
            Width and height of images that are passed to the model.
        image_scale (float)
            Scale factor for the data. E.g. 0.5 will scale the images to half resolution.
            That allows for a wider FoV for the model by sacrificing some resolution.
        max_lr (float)
            The training uses a learning rate scheduler (OneCycleLR) for each epoch
            where max_lr constitutes the peak learning rate.
        wandb_entity (string, optional)
            User name for wandb.ai that the training will log data to.
        wandb_project (string)
            Project name for wandb.ai.

    """

    with open(gt_data_path, "rb") as f:
        wdd_gt_data = pickle.load(f)
        gt_data_df = [(key,) + v for key, v in wdd_gt_data.items()]

    all_indices = np.arange(len(gt_data_df))
    test_indices = all_indices[::10]
    train_indices = [idx for idx in all_indices if not (idx in test_indices)]

    print("Train set:")
    dataset = SupervisedDataset(
        [gt_data_df[idx] for idx in train_indices],
        images_in_archives=images_in_archives,
        image_size=image_size,
        load_wdd_vectors=True,
        load_wdd_durations=True,
        remap_paths_to=remap_wdd_dir,
    )

    print("Test set:")
    # The evaluator's job is to regularly evaluate the training progress on the test dataset.
    # It will calculate additional statistics that are logged over the wandb connection.
    evaluator = SupervisedValidationDatasetEvaluator(
        [gt_data_df[idx] for idx in test_indices],
        images_in_archives=images_in_archives,
        image_size=image_size,
        remap_paths_to=remap_wdd_dir,
        default_image_scale=image_scale,
    )

    model = WDDClassificationModel(
        image_size=image_size
    )

    if multi_gpu:
        model = torch.nn.DataParallel(model)

    model = model.cuda()

    if batch_size == "auto":
        # The batch size here is calculated so that it fits on two RTX 2080 Ti in multi-GPU mode.
        # Note that a smaller batch size might also need a smaller learning rate.
        factor = 1
        if multi_gpu:
            factor = 2
        batch_size = int((64 * 7 * factor) / ((image_size * image_size) / (32 * 32)))
    else:
        batch_size = int(batch_size)

    print(
        "N pars: ",
        str(sum(p.numel() for p in model.parameters() if p.requires_grad)),
        "batch size: ",
        batch_size,
    )

    wandb_config = None
    if wandb_entity:
        # Project name is fixed so far.
        # This provides a logging interface to wandb.ai.
        wandb_config = (dict(project=wandb_project, entity=wandb_entity),)

    trainer = SupervisedTrainer(
        dataset,
        model,
        wandb_config=wandb_config,
        save_path=checkpoint_path,
        batch_size=batch_size,
        num_workers=8,
        continue_training=continue_training,
        image_size=image_size,
        batch_sampler_kwargs=dict(
            image_scale_factor=image_scale,
            inflate_dataset_factor=1000,
            augmentation_per_image=False,
        ),
        test_set_evaluator=evaluator,
        eval_test_set_every_n_samples=2000,
        save_every_n_samples=200000,
        max_lr=max_lr,
        batches_to_reach_maximum_augmentation=1000,
    )

    trainer.run_epochs(epochs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--index-path",
        type=str,
        default="./ground_truth_wdd_angles.pickle",
    )
    parser.add_argument(
        "--checkpoint-path",
        type=str,
        default="./wdd_filtering_supervised_model.pt",
    )
    parser.add_argument("--remap-wdd-dir", type=str, default="")
    parser.add_argument("--continue-training", action="store_true")
    parser.add_argument("--images-in-archives", action="store_true")
    parser.add_argument("--multi-gpu", action="store_true")
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--batch-size", type=float, default=0.002 * 8)
    parser.add_argument("--max-lr", default="auto")
    parser.add_argument("--wandb-entity", type=str, default="d_d")
    args = parser.parse_args()

    continue_training = args.continue_training
    if continue_training and args.checkpoint_path:
        if not os.path.exists(args.checkpoint_path):
            print("Can not continue training, as no file found at checkpoint location.")
            continue_training = False

    run(gt_data_path='./ground_truth_wdd_angles.pickle', epochs=2)

    #run(
    #    gt_data_path=args.index_path,
    #    checkpoint_path=args.checkpoint_path,
    #    epochs=args.epochs,
    #    continue_training=continue_training,
    #    remap_wdd_dir=args.remap_wdd_dir,
    #    images_in_archives=args.images_in_archives,
    #    multi_gpu=args.multi_gpu,
    #    batch_size=args.batch_size,
    #    max_lr=args.max_lr,
    #    wandb_entity=args.wandb_entity,
    #)

   # run(
   #     # gt_data_path (string) - Path to the .pickle file containing the ground-truth labels and paths.
   #     gt_data_path = "./ground_truth_wdd_angles.pickle",
#
   #     # remap_wdd_dir (string, optional) -- Prefix of the path where the image data is saved. The paths in gt_data_path will be changed to point to this directory instead.
   #     remap_wdd_dir = "",
   #     
   #     # images_in_archives (bool) - Whether the images of the single waggle frames are saved withing an images.zip file in each WDD subdirectory.
   #     images_in_archives = True,
#
   #     # checkpoint_path (string, optional) - Filename to which the model will be saved regularly during training. The model will be saved on every epoch AND every X batches.
   #     
   #     
   #     # continue_training (bool) - Whether to try to continue training from last checkpoint. Will use the same wandb run ID. Auto set to "false" in case no checkpoint is found.
   #     continue_training = False,
   #     
   #     # epochs (int) - Number of epochs to train for. As the model is saved after every epoch in 'checkpoint_path' and as the logs are streamed live to wandb.ai, it's save to interrupt the training after any epoch.
   #     epochs = 2,
#
   #     batch_size = 0.002 * 8,
   #     
   #     # image_size (int) - Width and height of images that are passed to the model.
   #     # image_size = ,
   #     
   #     # image_scale (float) - Scale factor for the data. E.g. 0.5 will scale the images to half resolution. That allows for a wider FoV for the model by sacrificing some resolution.
   #     image_scale = 0.5 ,
#
   #     # max_lr (float) - The training uses a learning rate scheduler (OneCycleLR) for each epoch where max_lr constitutes the peak learning rate.
   #     max_lr = 0.1
#
   #     # wandb_entity (string, optional) - User name for wandb.ai that the training will log data to.
   #     
   #     # wandb_project (string) - Project name for wandb.ai.
   # )


usage: ipykernel_launcher.py [-h] [--index-path INDEX_PATH]
                             [--checkpoint-path CHECKPOINT_PATH]
                             [--remap-wdd-dir REMAP_WDD_DIR]
                             [--continue-training] [--images-in-archives]
                             [--multi-gpu] [--epochs EPOCHS]
                             [--batch-size BATCH_SIZE] [--max-lr MAX_LR]
                             [--wandb-entity WANDB_ENTITY]
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-4bfd9151-c6d2-4797-bcf2-3d9c0e0f0843.json


SystemExit: ignored

In [None]:
!pip install git+https://github.com/BioroboticsLab/bb_wdd_filter.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/BioroboticsLab/bb_wdd_filter.git
  Cloning https://github.com/BioroboticsLab/bb_wdd_filter.git to /tmp/pip-req-build-zoypgyqp
  Running command git clone --filter=blob:none --quiet https://github.com/BioroboticsLab/bb_wdd_filter.git /tmp/pip-req-build-zoypgyqp
  Resolved https://github.com/BioroboticsLab/bb_wdd_filter.git to commit 1010371f08ee8608469c60e2ffcb6ba7a463d3c7
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pre-commit
  Downloading pre_commit-3.3.1-py2.py3-none-any.whl (202 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m202.5/202.5 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting black
  Downloading black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0

In [None]:
pip list

Package                       Version
----------------------------- --------------------
absl-py                       1.4.0
alabaster                     0.7.13
albumentations                1.2.1
altair                        4.2.2
anyio                         3.6.2
appdirs                       1.4.4
argon2-cffi                   21.3.0
argon2-cffi-bindings          21.2.0
arviz                         0.15.1
astropy                       5.2.2
astunparse                    1.6.3
attrs                         23.1.0
audioread                     3.0.0
autograd                      1.5
Babel                         2.12.1
backcall                      0.2.0
beautifulsoup4                4.11.2
bleach                        6.0.0
blis                          0.7.9
blosc2                        2.0.0
bokeh                         2.4.3
branca                        0.6.0
CacheControl                  0.12.11
cached-property               1.5.2
cachetools                    5.3.0
cata

In [None]:
# load data from onedrive (ALTERNATIVE: CALL wget so that jupyter downloads it automatically)
# remap_wdd_dir = '/content/drive/MyDrive/wdd_ground_truth'


In [None]:
from matplotlib import get_data_path
### train_model.py

import argparse

import pickle
import numpy as np
import os
import torch.nn

import bb_wdd_filter.dataset
import bb_wdd_filter.models_supervised
import bb_wdd_filter.trainer_supervised
import bb_wdd_filter.visualization


def run(
    gt_data_path,
    checkpoint_path=None,
    continue_training=True,
    epochs=1000,
    remap_wdd_dir=None,
    image_size=32,
    images_in_archives=True,
    multi_gpu=False,
    image_scale=0.5,
    batch_size="auto",
    max_lr=0.002 * 8,
    wandb_entity=None,
    wandb_project="wdd-image-classification",
):
    """
    Arguments:
        gt_data_path (string)
            Path to the .pickle file containing the ground-truth labels and paths.
        remap_wdd_dir (string, optional)
            Prefix of the path where the image data is saved. The paths in gt_data_path
            will be changed to point to this directory instead.
        images_in_archives (bool)
            Whether the images of the single waggle frames are saved withing an images.zip
            file in each WDD subdirectory.
        checkpoint_path (string, optional)
            Filename to which the model will be saved regularly during training.
            The model will be saved on every epoch AND every X batches.
        continue_training (bool)
            Whether to try to continue training from last checkpoint. Will use the same
            wandb run ID. Auto set to "false" in case no checkpoint is found.
        epochs (int)
            Number of epochs to train for.
            As the model is saved after every epoch in 'checkpoint_path' and as the logs are
            streamed live to wandb.ai, it's save to interrupt the training after any epoch.
        image_size (int)
            Width and height of images that are passed to the model.
        image_scale (float)
            Scale factor for the data. E.g. 0.5 will scale the images to half resolution.
            That allows for a wider FoV for the model by sacrificing some resolution.
        max_lr (float)
            The training uses a learning rate scheduler (OneCycleLR) for each epoch
            where max_lr constitutes the peak learning rate.
        wandb_entity (string, optional)
            User name for wandb.ai that the training will log data to.
        wandb_project (string)
            Project name for wandb.ai.

    """

    with open(gt_data_path, "rb") as f:
        wdd_gt_data = pickle.load(f)
        gt_data_df = [(key,) + v for key, v in wdd_gt_data.items()]

    all_indices = np.arange(len(gt_data_df))
    test_indices = all_indices[::10]
    train_indices = [idx for idx in all_indices if not (idx in test_indices)]

    print("Train set:")
    dataset = SupervisedDataset(
        [gt_data_df[idx] for idx in train_indices],
        images_in_archives=images_in_archives,
        image_size=image_size,
        load_wdd_vectors=True,
        load_wdd_durations=True,
        remap_paths_to=remap_wdd_dir,
    )

    print("Test set:")
    # The evaluator's job is to regularly evaluate the training progress on the test dataset.
    # It will calculate additional statistics that are logged over the wandb connection.
    evaluator = SupervisedValidationDatasetEvaluator(
        [gt_data_df[idx] for idx in test_indices],
        images_in_archives=images_in_archives,
        image_size=image_size,
        remap_paths_to=remap_wdd_dir,
        default_image_scale=image_scale,
    )

    model = WDDClassificationModel(
        image_size=image_size
    )

    if multi_gpu:
        model = torch.nn.DataParallel(model)

    model = model.cuda()

    if batch_size == "auto":
        # The batch size here is calculated so that it fits on two RTX 2080 Ti in multi-GPU mode.
        # Note that a smaller batch size might also need a smaller learning rate.
        factor = 1
        if multi_gpu:
            factor = 2
        batch_size = int((64 * 7 * factor) / ((image_size * image_size) / (32 * 32)))
    else:
        batch_size = int(batch_size)

    print(
        "N pars: ",
        str(sum(p.numel() for p in model.parameters() if p.requires_grad)),
        "batch size: ",
        batch_size,
    )

    wandb_config = None
    if wandb_entity:
        # Project name is fixed so far.
        # This provides a logging interface to wandb.ai.
        wandb_config = (dict(project=wandb_project, entity=wandb_entity),)

    trainer = SupervisedTrainer(
        dataset,
        model,
        wandb_config=wandb_config,
        save_path=checkpoint_path,
        batch_size=batch_size,
        num_workers=8,
        continue_training=continue_training,
        image_size=image_size,
        batch_sampler_kwargs=dict(
            image_scale_factor=image_scale,
            inflate_dataset_factor=1000,
            augmentation_per_image=False,
        ),
        test_set_evaluator=evaluator,
        eval_test_set_every_n_samples=2000,
        save_every_n_samples=200000,
        max_lr=max_lr,
        batches_to_reach_maximum_augmentation=1000,
    )

    trainer.run_epochs(epochs)


In [None]:
# Einfach: Github fork und von da installieren und die prints dort reinhauen 
# oder
# in Colab das Modell tunen

run(
  gt_data_path=    './ground_truth_wdd_angles.pickle', 
  epochs=2, 
  wandb_entity=None,
  remap_wdd_dir =  '/content/drive/wdd_ground_truth',
  checkpoint_path = '/content/drive/',
  images_in_archives = False
)


    #run(
    #    gt_data_path=args.index_path,
    #    checkpoint_path=args.checkpoint_path,
    #    epochs=args.epochs,
    #    continue_training=continue_training,
    #    remap_wdd_dir=args.remap_wdd_dir,
    #    images_in_archives=args.images_in_archives,
    #    multi_gpu=args.multi_gpu,
    #    batch_size=args.batch_size,
    #    max_lr=args.max_lr,
    #    wandb_entity=args.wandb_entity,
    #)
