In [None]:
!git clone https://github.com/jankrepl/mildlyoverfitted.git

Cloning into 'mildlyoverfitted'...
remote: Enumerating objects: 356, done.[K
remote: Counting objects: 100% (97/97), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 356 (delta 71), reused 52 (delta 46), pack-reused 259[K
Receiving objects: 100% (356/356), 815.38 KiB | 3.85 MiB/s, done.
Resolving deltas: 100% (125/125), done.


In [None]:
!git clone https://github.com/TanyaChutani/DINO_Tf2.x.git

In [None]:
import numpy as np
import torch
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier


def compute_knn(backbone, data_loader_train, data_loader_val):
    """Get CLS embeddings and use KNN classifier on them.

    We load all embeddings in memory and use sklearn. Should
    be doable.

    Parameters
    ----------
    backbone : timm.models.vision_transformer.VisionTransformer
        Vision transformer whose head is just an identity
        mapping.

    data_loader_train, data_loader_val : torch.utils.data.DataLoader
        Training and validation dataloader that does not apply any
        augmentations. Just casting to tensor and then normalizing.

    Returns
    -------
    val_accuracy : float
        Validation accuracy.
    """
    device = next(backbone.parameters()).device

    data_loaders = {
        "train": data_loader_train,
        "val": data_loader_val,
    }
    lists = {
        "X_train": [],
        "y_train": [],
        "X_val": [],
        "y_val": [],
    }

    for name, data_loader in data_loaders.items():
        for imgs, y in data_loader:
            imgs = imgs.to(device)
            lists[f"X_{name}"].append(backbone(imgs).detach().cpu().numpy())
            lists[f"y_{name}"].append(y.detach().cpu().numpy())

    arrays = {k: np.concatenate(l) for k, l in lists.items()}

    estimator = KNeighborsClassifier()
    estimator.fit(arrays["X_train"], arrays["y_train"])
    y_val_pred = estimator.predict(arrays["X_val"])

    acc = accuracy_score(arrays["y_val"], y_val_pred)

    return acc

def compute_embedding(backbone, data_loader):
    """Compute CLS embedding and prepare for TensorBoard.

    Parameters
    ----------
    backbone : timm.models.vision_transformer.VisionTransformer
        Vision transformer. The head should be an identity mapping.

    data_loader : torch.utils.data.DataLoader
        Validation dataloader that does not apply any augmentations. Just
        casting to tensor and then normalizing.

    Returns
    -------
    embs : torch.Tensor
        Embeddings of shape `(n_samples, out_dim)`.

    imgs : torch.Tensor
        Images of shape `(n_samples, 3, height, width)`.

    labels : list
        List of strings representing the classes.
    """
    device = next(backbone.parameters()).device

    embs_l = []
    imgs_l = []
    labels = []

    for img, y in data_loader:
        img = img.to(device)
        embs_l.append(backbone(img).detach().cpu())
        imgs_l.append(((img * 0.224) + 0.45).cpu())  # undo norm
        labels.extend([data_loader.dataset.classes[i] for i in y.tolist()])

    embs = torch.cat(embs_l, dim=0)
    imgs = torch.cat(imgs_l, dim=0)

    return embs, imgs, labels

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image


class DataAugmentation:
    """Create crops of an input image together with additional augmentation.

    It generates 2 global crops and `n_local_crops` local crops.

    Parameters
    ----------
    global_crops_scale : tuple
        Range of sizes for the global crops.

    local_crops_scale : tuple
        Range of sizes for the local crops.

    n_local_crops : int
        Number of local crops to create.

    size : int
        The size of the final image.

    Attributes
    ----------
    global_1, global_2 : transforms.Compose
        Two global transforms.

    local : transforms.Compose
        Local transform. Note that the augmentation is stochastic so one
        instance is enough and will lead to different crops.
    """
    def __init__(
        self,
        global_crops_scale=(0.4, 1),
        local_crops_scale=(0.05, 0.4),
        n_local_crops=8,
        size=224,
    ):
        self.n_local_crops = n_local_crops
        RandomGaussianBlur = lambda p: transforms.RandomApply(  # noqa
            [transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2))],
            p=p,
        )

        flip_and_jitter = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [
                        transforms.ColorJitter(
                            brightness=0.4,
                            contrast=0.4,
                            saturation=0.2,
                            hue=0.1,
                        ),
                    ]
                ),
                transforms.RandomGrayscale(p=0.2),
            ]
        )

        normalize = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )

        self.global_1 = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=global_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(1.0),  # always apply
                normalize,
            ],
        )

        self.global_2 = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=global_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(0.1),
                transforms.RandomSolarize(170, p=0.2),
                normalize,
            ],
        )

        self.local = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=local_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(0.5),
                normalize,
            ],
        )

    def __call__(self, img):
        """Apply transformation.

        Parameters
        ----------
        img : PIL.Image
            Input image.

        Returns
        -------
        all_crops : list
            List of `torch.Tensor` representing different views of
            the input `img`.
        """
        all_crops = []
        all_crops.append(self.global_1(img))
        all_crops.append(self.global_2(img))

        all_crops.extend([self.local(img) for _ in range(self.n_local_crops)])

        return all_crops


class Head(nn.Module):
    """Network hooked up to the CLS token embedding.

    Just a MLP with the last layer being normalized in a particular way.

    Parameters
    ----------
    in_dim : int
        The dimensionality of the token embedding.

    out_dim : int
        The dimensionality of the final layer (we compute the softmax over).

    hidden_dim : int
        Dimensionality of the hidden layers.

    bottleneck_dim : int
        Dimensionality of the second last layer.

    n_layers : int
        The number of layers.

    norm_last_layer : bool
        If True, then we freeze the norm of the weight of the last linear layer
        to 1.

    Attributes
    ----------
    mlp : nn.Sequential
        Vanilla multi-layer perceptron.

    last_layer : nn.Linear
        Reparametrized linear layer with weight normalization. That means
        that that it will have `weight_g` and `weight_v` as learnable
        parameters instead of a single `weight`.
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim=512,
        bottleneck_dim=256,
        n_layers=3,
        norm_last_layer=False,
    ):
        super().__init__()
        if n_layers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            layers.append(nn.GELU())
            for _ in range(n_layers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)

        self.apply(self._init_weights)

        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        """Initialize learnable parameters."""
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Of shape `(n_samples, in_dim)`.

        Returns
        -------
        torch.Tensor
            Of shape `(n_samples, out_dim)`.
        """
        x = self.mlp(x)  # (n_samples, bottleneck_dim)
        x = nn.functional.normalize(x, dim=-1, p=2)  # (n_samples, bottleneck_dim)
        x = self.last_layer(x)  # (n_samples, out_dim)

        return x


class MultiCropWrapper(nn.Module):
    """Convenience class for forward pass of multiple crops.

    Parameters
    ----------
    backbone : timm.models.vision_transformer.VisionTransformer
        Instantiated Vision Transformer. Note that we will take the `head`
        attribute and replace it with `nn.Identity`.

    new_head : Head
        New head that is going to be put on top of the `backbone`.
    """
    def __init__(self, backbone, new_head):
        super().__init__()
        backbone.head = nn.Identity()  # deactivate original head
        self.backbone = backbone
        self.new_head = new_head

    def forward(self, x):
        """Run the forward pass.

        The different crops are concatenated along the batch dimension
        and then a single forward pass is fun. The resulting tensor
        is then chunked back to per crop tensors.

        Parameters
        ----------
        x : list
            List of `torch.Tensor` each of shape `(n_samples, 3, size, size)`.

        Returns
        -------
        tuple
            Tuple of `torch.Tensor` each of shape `(n_samples, out_dim)` where
            `output_dim` is determined by `Head`.
        """
        n_crops = len(x)
        concatenated = torch.cat(x, dim=0)  # (n_samples * n_crops, 3, size, size)
        cls_embedding = self.backbone(concatenated)  # (n_samples * n_crops, in_dim)
        logits = self.new_head(cls_embedding)  # (n_samples * n_crops, out_dim)
        chunks = logits.chunk(n_crops)  # n_crops * (n_samples, out_dim)

        return chunks


class Loss(nn.Module):
    """The loss function.

    We subclass the `nn.Module` becuase we want to create a buffer for the
    logits center of the teacher.

    Parameters
    ----------
    out_dim : int
        The dimensionality of the final layer (we computed the softmax over).

    teacher_temp, student_temp : float
        Softmax temperature of the teacher resp. student.

    center_momentum : float
        Hyperparameter for the exponential moving average that determines
        the center logits. The higher the more the running average matters.
    """
    def __init__(
        self, out_dim, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9
    ):
        super().__init__()
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student_output, teacher_output):
        """Evaluate loss.

        Parameters
        ----------
        student_output, teacher_output : tuple
            Tuple of tensors of shape `(n_samples, out_dim)` representing
            logits. The length is equal to number of crops.
            Note that student processed all crops and that the two initial crops
            are the global ones.

        Returns
        -------
        loss : torch.Tensor
            Scalar representing the average loss.
        """
        student_temp = [s / self.student_temp for s in student_output]
        teacher_temp = [(t - self.center) / self.teacher_temp for t in teacher_output]

        student_sm = [F.log_softmax(s, dim=-1) for s in student_temp]
        teacher_sm = [F.softmax(t, dim=-1).detach() for t in teacher_temp]

        total_loss = 0
        n_loss_terms = 0

        for t_ix, t in enumerate(teacher_sm):
            for s_ix, s in enumerate(student_sm):
                if t_ix == s_ix:
                    continue

                loss = torch.sum(-t * s, dim=-1)  # (n_samples,)
                total_loss += loss.mean()  # scalar
                n_loss_terms += 1

        total_loss /= n_loss_terms
        self.update_center(teacher_output)

        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """Update center used for teacher output.

        Compute the exponential moving average.

        Parameters
        ----------
        teacher_output : tuple
            Tuple of tensors of shape `(n_samples, out_dim)` where each
            tensor represents a different crop.
        """
        batch_center = torch.cat(teacher_output).mean(
            dim=0, keepdim=True
        )  # (1, out_dim)
        self.center = self.center * self.center_momentum + batch_center * (
            1 - self.center_momentum
        )

def clip_gradients(model, clip=2.0):
    """Rescale norm of computed gradients.

    Parameters
    ----------
    model : nn.Module
        Module.

    clip : float
        Maximum norm.
    """
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            clip_coef = clip / (param_norm + 1e-6)
            if clip_coef < 1:
                p.grad.data.mul_(clip_coef)

In [None]:
!pip install timm

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.16


In [None]:
import tensorflow as tf


class DataAugmentationDino:
    def __init__(
        self,
        global_crops_scale,
        local_crops_scale,
        local_crops_number,
        global_image_size=[224, 224],
        local_image_size=[96, 96],
        mean=[0.485, 0.456, 0.406],
        std_dev=[0.229, 0.224, 0.225],
    ):
        self.mean = mean
        self.std_dev = std_dev
        self.local_image_size = local_image_size
        self.global_image_size = global_image_size
        self.local_crops_scale = local_crops_scale
        self.local_crops_number = local_crops_number
        self.global_crops_scale = global_crops_scale

        self.flip_aug = tf.keras.Sequential(
            [tf.keras.layers.RandomFlip(mode="horizontal")]
        )

    def _standardize_normalize(self, image):
        image = image / 255.0
        image -= self.mean
        image /= self.std_dev
        image = tf.cast(image, tf.float32)
        return image

    def _color_jitter(image):
        image = tf.image.random_brightness(image, max_delta=0.4)
        image = tf.image.random_contrast(image, lower=0.0, upper=0.4)
        image = tf.image.random_saturation(image, lower=0.0, upper=0.2)
        image = tf.image.random_hue(image, max_delta=0.1)
        return image

    def _crop_resize(self, image, mode="global"):
        scalee = self.global_crops_scale if mode == "global" else self.local_crops_scale
        final_size = (
            self.global_image_size if mode == "global" else self.local_image_size
        )
        height, width, channels = tf.shape(image)
        scaling_hw = tf.cast(tf.concat([height, width], axis=0), tf.float32)
        scale = tf.multiply(scalee, scaling_hw)
        scale = (
            tf.cast(scale[0].numpy(), tf.int32),
            tf.cast(scale[1].numpy(), tf.int32),
            channels,
        )
        image = tf.image.random_crop(value=image, size=scale)
        image = tf.image.resize(image, final_size, method="bicubic")
        return image

    def _apply_aug(self, image, mode="global"):
        image = self.flip_aug(image)
        image = self._crop_resize(image, mode)
        image = self._standardize_normalize(image)

        return image

    def __call__(self, image):
        crops = []
        crops.append(self._apply_aug(image))
        crops.append(self._apply_aug(image))
        for _ in range(self.local_crops_number):
            crops.append(self._apply_aug(image, mode="local"))
        return crops

In [None]:
import tensorflow as tf
import os

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(
        self,
        mode,
        batch_size,
        dataset_path,
        local_image_size,
        global_image_size,
        shuffle=True,
    ):
        self.mode = mode
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.dataset_path = dataset_path
        self.dataset = os.listdir(dataset_path)
        self.local_image_size = local_image_size
        self.global_image_size = global_image_size
        self.on_epoch_end()

    def _load_image(self, path, data_augmentation):
        image = tf.io.read_file(path)
        image = tf.image.decode_image(image, channels=3)
        image.set_shape([None, None, 3])
        image = data_augmentation(image) if self.mode == "train" else image
        return image

    def on_epoch_end(self):
        self.index = tf.range(len(self.dataset))
        if self.shuffle == True:
            tf.random.shuffle(self.index)

    def __len__(self):
        return len(self.dataset) // self.batch_size

    def __getitem__(self, idx):
        indexes = self.index[idx * self.batch_size : (idx + 1) * self.batch_size]
        datset_keys = [self.dataset[k] for k in indexes]
        (global_images, local_images) = self.__data_generation(datset_keys)
        return global_images, local_images

    def __data_generation(self, index):
        batch_global, batch_local = [], []
        dino = DataAugmentationDino((0.4, 1.0), (0.05, 0.4), 8)
        for idx, i in enumerate(index):
            images = self._load_image(os.path.join(self.dataset_path, i), dino)
            global_images = images[:2]
            # unable to stack varied size input in the dataset
            local_images = images[2:]
            batch_local.append(local_images)
            batch_global.append(global_images)
        return batch_global, batch_local

In [None]:
import tensorflow as tf


class TeacherTemp(tf.keras.callbacks.Callback):
    def __init__(
        self,
        temp,
        nepochs=100,
        teacher_temp=0.04,
        warmup_teacher_temp=0.04,
        warmup_teacher_temp_epochs=30,
    ):
        super(TeacherTemp, self).__init__()
        self.temp = temp
        self.teacher_temp_schedule = tf.concat(
            (
                tf.linspace(
                    warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs
                ),
                tf.ones((nepochs - warmup_teacher_temp_epochs)) * teacher_temp,
            ),
            axis=0,
        )

    @tf.function
    def on_epoch_begin(self, epoch, logs={}):
        self.temp = tf.Variable(self.temp, trainable=True, dtype=tf.float32)
        tf.keras.backend.set_value(self.temp, self.teacher_temp_schedule[epoch])
        logs["temp"] = tf.keras.backend.get_value(self.temp)

In [None]:
import tensorflow as tf

class DinoLoss(tf.keras.losses.Loss):
    def __init__(
        self,
        nepochs=100,
        out_dim=65536,
        ncrops=2,
        warmup_teacher_temp=0.04,
        teacher_temp=0.04,
        warmup_teacher_temp_epochs=30,
        student_temp=0.1,
        center_momentum=0.9,
    ):
        super(DinoLoss, self).__init__()
        self.ncrops = ncrops
        self.student_temp = student_temp
        self.center_momentum = center_momentum

        self.teacher_temp_schedule = tf.concat(
            (
                tf.linspace(
                    warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs
                ),
                tf.ones((nepochs - warmup_teacher_temp_epochs)) * teacher_temp,
            ),
            axis=0,
        )

    def update_center(self, teacher_output):
        batch_center = tf.math.reduce_sum(teacher_output, axis=0)
        batch_center = batch_center / tf.cast(len(teacher_output), tf.float32)
        self.center = tf.stop_gradient(
            self.center * self.center_momentum
            + batch_center * (1 - self.center_momentum)
        )

    def call(self, student_output, teacher_output):
        teacher_output = tf.cast(teacher_output, tf.float32)
        student_output = tf.cast(student_output, tf.float32)

        student_out = student_output / self.student_temp
        student_out = tf.split(student_out, num_or_size_splits=self.ncrops)

        self.center = tf.zeros_like(teacher_output, dtype=tf.float32)
        teacher_out = tf.stop_gradient(
            tf.nn.softmax(
                (teacher_output - self.center) / TeacherTemp(0.04).temp, axis=-1
            )
        )
        teacher_out = tf.split(
            tf.tile(teacher_out, tf.constant([2, 1], tf.int32)), num_or_size_splits=1
        )

        total_loss = 0
        n_loss_terms = 0
        for idx, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                q = tf.stop_gradient(q)
                if v == idx:
                    continue
                loss = tf.reduce_sum(
                    -q * tf.nn.log_softmax(student_out[v], axis=-1), axis=-1
                )
                total_loss += tf.math.reduce_mean(loss)
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

In [None]:
!pip install vit_keras
!pip install tensorflow



In [None]:
!pip install tensorflow-addons

Collecting tensorflow-addons
  Downloading tensorflow_addons-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (611 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m611.8/611.8 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
Collecting typeguard<3.0.0,>=2.7 (from tensorflow-addons)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, tensorflow-addons
Successfully installed tensorflow-addons-0.23.0 typeguard-2.13.3


In [None]:
import tensorflow as tf
from vit_keras import vit, utils
import tensorflow_addons as tfa


class MultiCropWrapper(tf.keras.models.Model):
    def __init__(self, backbone, head, weights=None):
        super(MultiCropWrapper, self).__init__()
        self.head = head
        self.backbone = backbone
        if weights:
            try:
                print("Restoring model weights from: ", weights)
                self.load_weights(weights)
            except Exception:
                raise ValueError

    @staticmethod
    def unique_consecutive(x):
        neq = tf.math.not_equal(x, x)
        neq = tf.cast(neq, tf.int32)
        if neq.shape[0] > 1:
            neq = tf.math.cumsum(tf.cast(neq, tf.int32), axis=0)
        neq = tf.concat([[0], neq], axis=0)
        _, _, count = tf.unique_with_counts(neq)
        return count

    def call(self, x):
        if not isinstance(x, list):
            x = [x]
        unq = tf.constant([inp.shape[0] for inp in x], dtype=tf.int32)
        count = self.unique_consecutive(unq)
        start_idx, output = tf.constant(0), tf.zeros((0, 768), dtype=tf.float32)
        for end_idx in count:
            tf.autograph.experimental.set_loop_options(
                shape_invariants=[(output, tf.TensorShape([None, None]))]
            )
            _out = self.backbone(
                x[tf.get_static_value(start_idx) : tf.get_static_value(end_idx)]
            )
            if isinstance(_out, tuple):
                _out = _out[0]
            output = tf.concat([output, _out], axis=0)
            start_idx = end_idx
        return self.head(output)


def load_base(image_size, include_pretrained=True):
    model = vit.vit_b16(
        image_size=image_size,
        pretrained=include_pretrained,
        pretrained_top=False,
        include_top=False,
    )
    return model


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [None]:
import tensorflow as tf
import tensorflow_addons as tfa


class DinoHead(tf.keras.models.Model):
    def __init__(
        self,
        in_dim=768,
        out_dim=65536,
        use_bn=False,
        norm_last_layer=True,
        nlayers=3,
        hidden_dim=2048,
        bottleneck_dim=256,
    ):
        super(DinoHead, self).__init__()
        self.in_dim = in_dim
        self.use_bn = use_bn
        self.out_dim = out_dim
        self.nlayers = nlayers
        self.hidden_dim = hidden_dim
        self.bottleneck_dim = bottleneck_dim
        self.norm_last_layer = norm_last_layer
        self.last_layer = tf.keras.layers.Dense(self.out_dim)

        self.mlp_block = self.mlp()

    def mlp(self):
        layer = []
        layer.append(tf.keras.layers.Dense(self.hidden_dim, input_shape=(self.in_dim,)))
        if self.use_bn:
            layer.append(tf.keras.layers.BatchNormalization())
        layer.append(tfa.layers.GELU())
        for _ in range(self.nlayers - 2):
            layer.append(tf.keras.layers.Dense(self.hidden_dim))
        if self.use_bn:
            layer.append(tf.keras.layers.BatchNormalization())
        layer.append(tfa.layers.GELU())
        layer.append(tf.keras.layers.Dense(self.bottleneck_dim))
        return tf.keras.Sequential(layer)

    def call(self, input_tensor, training=None):
        x = self.mlp_block(input_tensor, training)
        x = tf.nn.l2_normalize(x, axis=-1)
        x = self.last_layer(x)
        return x

In [None]:
import tensorflow as tf

class Dino(tf.keras.models.Model):
    def __init__(
        self, teacher_model, student_model, student_weights=None, teacher_weights=None
    ):
        super(Dino, self).__init__()
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.student_weights = student_weights
        self.teacher_weights = teacher_weights
        self.dino_loss = DinoLoss()

    def compile(self, optimizer):
        super(Dino, self).compile()
        self.optimizer = optimizer

    def train_step(self, data):
        global_image, local_image = data
        local_image = sum(local_image, ())
        global_image = sum(global_image, ())
        local_image = tf.stack(local_image)
        global_image = tf.stack(global_image)

        with tf.GradientTape() as tape:
            teacher_output = self.teacher_model(global_image)
            student_output = self.student_model(local_image)
            loss = tf.reduce_mean(self.dino_loss(student_output, teacher_output))
            student_gradients = tape.gradient(
                loss, self.student_model.trainable_variables
            )
            self.optimizer.apply_gradients(
                zip(student_gradients, self.student_model.trainable_variables)
            )
            return {"loss": loss}

    def test_step(self, data):
        global_image, local_image = data

        local_image = sum(local_image, ())
        global_image = sum(global_image, ())
        local_image = tf.stack(local_image)
        global_image = tf.stack(global_image)

        teacher_output = self.teacher_model(global_image, training=False)
        student_output = self.student_model(local_image, training=False)

        loss = tf.reduce_mean(self.dino_loss(student_output, teacher_output))

        return {"loss": loss}

    def call(self, image):
        output = self.teacher_model(image, training=False)

In [None]:
import argparse
import tensorflow as tf

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-epoch", "--epochs", type=int, metavar="", default=100)
    parser.add_argument("-b", "--batch_size", type=int, metavar="", default=2)
    parser.add_argument("-ct", "--crop_teacher", type=int, metavar="", default=224)
    parser.add_argument("-cs", "--crop_student", type=int, metavar="", default=96)
    parser.add_argument(
        "-d_train",
        "--dataset_train",
        type=str,
        metavar="",
        default="VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/JPEGImages",
    )
    parser.add_argument(
        "-d_test",
        "--dataset_test",
        type=str,
        metavar="",
        default="VOCtest_06-Nov-2007/VOCdevkit/VOC2007/JPEGImages",
    )
    parser.add_argument(
        "-s_weights",
        "--student_weights_path",
        type=str,
        metavar="",
        default="student_weights",
    )
    parser.add_argument(
        "-t_weights",
        "--teacher_weights_path",
        type=str,
        metavar="",
        default="teacher_weights",
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    head = DinoHead()

    student = load_base(args.crop_student)
    teacher = load_base(args.crop_teacher)

    student = MultiCropWrapper(backbone=student, head=head)
    teacher = MultiCropWrapper(backbone=teacher, head=head)

    model = Dino(teacher, student)

    train_dataset = DataGenerator(
        mode="train",
        dataset_path=args.dataset_train,
        batch_size=args.batch_size,
        local_image_size=args.crop_student,
        global_image_size=args.crop_teacher,
    )

    val_dataset = DataGenerator(
        mode="val",
        dataset_path=args.dataset_test,
        batch_size=args.batch_size,
        local_image_size=args.crop_student,
        global_image_size=args.crop_teacher,
    )

    learning_rate = tf.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=[args.epochs / 2], values=[0.0001, 0.00001]
    )
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            args.teacher_weights_path,
            monitor="loss",
            save_best_only=True,
            save_weights_only=False,
            mode="auto",
        )
    ]
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))
    model.build(input_shape=(1, args.crop_teacher, args.crop_teacher, 3))
    model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=args.epochs,
        callbacks=callbacks,
    )
    model.student_model.save_weights(args.student_weights_path)
    model.teacher_model.save_weights(args.teacher_weights_path)


if __name__ == "__main__":
    main()

usage: colab_kernel_launcher.py [-h] [-epoch] [-b] [-ct] [-cs] [-d_train] [-d_test] [-s_weights]
                                [-t_weights]
colab_kernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-93d5cfa8-e07a-4773-ba1c-f63c6588b3a3.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
