In [1]:
import sys

sys.path.insert(0, "..")

from dinov2.data.loaders import make_dataset
import matplotlib.pyplot as plt
from dinov2.data import DataAugmentationDINO

import numpy as np
import torch

from dinov2.data.masking import MaskingGenerator
from dinov2.data.collate import collate_data_and_cast
from dinov2.data import make_data_loader

from functools import partial

from dinov2.models.vision_transformer import DinoVisionTransformer

from dinov2.layers import (
    MemEffAttention,
    Mlp,
    PatchEmbed,
    SwiGLUFFNFused,
)
from dinov2.layers import (
    NestedTensorBlock as Block,
)


root = "/home/jluesch/Documents/data/plankton/nat_lmdb/"
ds_path = f"LMDBDataset:split=TRAIN:root={root}:extra=*"



In [2]:
data_transform_cpu = DataAugmentationDINO(
    [0.32, 0.8],
    [0.05, 0.32],
    local_crops_number=6,
    global_crops_size=224,
    local_crops_size=98,
    use_kornia=True,
    use_native_res=False,
    do_seg_crops=False,
    patch_size=14
)

dataset = make_dataset(
    dataset_str=ds_path,
    transform=data_transform_cpu,
    target_transform=lambda x: (),
    with_targets=True,
    cache_dataset=False,
)

Dataset kwargs {'split': <_Split.TRAIN: 'train'>, 'root': '/home/jluesch/Documents/data/plankton/nat_lmdb/', 'extra': '*'}
extra_path /home/jluesch/Documents/data/plankton/nat_lmdb/*-TRAIN_*
Datasets labels file list:  ['/home/jluesch/Documents/data/plankton/nat_lmdb/2007-TRAIN_labels', '/home/jluesch/Documents/data/plankton/nat_lmdb/2008-TRAIN_labels', '/home/jluesch/Documents/data/plankton/nat_lmdb/2009-TRAIN_labels', '/home/jluesch/Documents/data/plankton/nat_lmdb/2010-TRAIN_labels', '/home/jluesch/Documents/data/plankton/nat_lmdb/2011-TRAIN_labels', '/home/jluesch/Documents/data/plankton/nat_lmdb/2012-TRAIN_labels', '/home/jluesch/Documents/data/plankton/nat_lmdb/2013-TRAIN_labels', '/home/jluesch/Documents/data/plankton/nat_lmdb/2014-TRAIN_labels']
Datasets imgs file list:  ['/home/jluesch/Documents/data/plankton/nat_lmdb/2007-TRAIN_imgs', '/home/jluesch/Documents/data/plankton/nat_lmdb/2008-TRAIN_imgs', '/home/jluesch/Documents/data/plankton/nat_lmdb/2009-TRAIN_imgs', '/home/jlue

In [3]:

def list_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    target = torch.LongTensor(target)
    return [data, target]

data_loader = make_data_loader(
    dataset=dataset,
    batch_size=16,
    num_workers=8,
    shuffle=True,
    seed=0,
    sampler_type=None,
    sampler_advance=0,
    drop_last=True,
    collate_fn=None, #list_collate,
)

img_size = 224
patch_size = 14
mask_generator = MaskingGenerator(
    input_size=(img_size // patch_size, img_size // patch_size),
    max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
)

In [4]:
model = DinoVisionTransformer(
    patch_size=14,
    embed_dim=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4,
    block_fn=partial(Block, attn_class=MemEffAttention),
    num_register_tokens=0,
    img_size=224,
    in_chans=3,
    drop_path_rate=0.0,
    drop_path_uniform=True,
    init_values=1.0e-05,  # for layerscale: None or 0 => no layerscale
    embed_layer=PatchEmbed,
    ffn_layer="mlp",
    block_chunks=1,
    interpolate_antialias=False,
    interpolate_offset=0.1,
)

In [5]:
out = None
torch.cuda.empty_cache()
for i, el in enumerate(data_loader):
    if i > 2:
        break
    el = el[0]
    print(el["global_crops"].shape)
    print(el["local_crops"].shape)
    # list of len batch size containing dicts
    collated_dict = collate_data_and_cast(
        el,
        mask_ratio_tuple=(0.1, 0.5),
        mask_probability=0.5,
        dtype=torch.half,
        n_tokens=200,
        mask_generator=mask_generator,
        free_shapes=True,
    )

    model.cuda()
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            out = model(
                [
                    collated_dict["collated_global_crops"].cuda(),
                    collated_dict["collated_local_crops"].cuda(),
                ],
                masks=[collated_dict["collated_masks"].cuda(), None],
                is_training=True,
                # attn_masks=[collated_dict["attn_mask_gc"].cuda(),collated_dict["attn_mask_lc"].cuda()]
            )
            print(out[0]["x_norm_clstoken"].shape)

torch.Size([16, 2, 3, 224, 224])
torch.Size([16, 6, 3, 92, 92])
collated_masks torch.Size([32, 256])


AssertionError: Input image height 92 is not a multiple of patch height 14

In [None]:
for i in range(len(out)):
    for k in out[i].keys():
        if out[i][k] is not None:
            print(k, out[i][k].shape)

x_norm_clstoken torch.Size([32, 384])
x_norm_regtokens torch.Size([32, 0, 384])
x_norm_patchtokens torch.Size([32, 288, 384])
x_prenorm torch.Size([32, 289, 384])
masks torch.Size([32, 288])
x_norm_clstoken torch.Size([16, 384])
x_norm_regtokens torch.Size([16, 0, 384])
x_norm_patchtokens torch.Size([16, 288, 384])
x_prenorm torch.Size([16, 289, 384])
