# Data Augmentation

In [None]:
# Data Augmentation
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode


class DataAugmentation:
    def __init__(self,global_crops_scale=(0.4,1),local_crops_scale=(0.05,0.4),n_local_crops=2,output_size=112):

        self.n_local_crops=n_local_crops
        RandomGaussianBlur=lambda p: transforms.RandomApply([transforms.GaussianBlur(kernel_size=1,sigma=(0.1,2))],p=p)
        flip_and_rotation=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomRotation(degrees=(10)),])
        normalize=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,)),])


        self.global_1=transforms.Compose([
            transforms.RandomResizedCrop(output_size,scale=global_crops_scale,interpolation=InterpolationMode.BICUBIC),
            flip_and_rotation,
            RandomGaussianBlur(1.0),
            normalize
        ])
        self.global_2=transforms.Compose([
            transforms.RandomResizedCrop(output_size,scale=global_crops_scale,interpolation=InterpolationMode.BICUBIC),
            flip_and_rotation,
            RandomGaussianBlur(0.1),
            transforms.RandomSolarize(170,p=0.2),
            normalize
        ])
        self.local=transforms.Compose([
            transforms.RandomResizedCrop(output_size,scale=local_crops_scale,interpolation=InterpolationMode.BICUBIC),
            flip_and_rotation,
            RandomGaussianBlur(0.5),
            normalize
        ])

    
    def __call__(self,image):
        '''
        all_crops:list of torch.Tensor
        represent different version of input img
        '''
        all_crops=[]
        all_crops.append(self.global_1(image))
        all_crops.append(self.global_2(image))
        all_crops.extend([self.local(image) for _ in range(self.n_local_crops)])
        return all_crops

# ViT

In [None]:
"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn

def drop_path(x, drop_prob: float = 0., training: bool = False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

class Attention(nn.Module):
    def __init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop_ratio=0.,proj_drop_ratio=0.):
        super(Attention,self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x,attn

class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    def __init__(self,dim,num_heads, mlp_ratio=4.,qkv_bias=False,qk_scale=None, drop_ratio=0.,attn_drop_ratio=0., drop_path_ratio=0.,
                 act_layer=nn.GELU,norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)

        #  drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x,return_attention=False):
        y,attn=self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class VisionTransformer(nn.Module):
    def __init__(self, img_size=112, patch_size=7, in_c=3, num_classes=0,
                 embed_dim=588, depth=6, num_heads=7, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_vit_weights)
    
    
    def _init_vit_weights(self,m):
        """
        ViT weight initialization
        :param m: module
        """
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.01)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out")
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

#     def forward_features(self, x):
#         # [B, C, H, W] -> [B, num_patches, embed_dim]
#         x = self.patch_embed(x)  # [B, 196, 768]
#         # [1, 1, 768] -> [B, 1, 768]
#         cls_token = self.cls_token.expand(x.shape[0], -1, -1)
#         if self.dist_token is None:
#             x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
#         else:
#             x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

#         x = self.pos_drop(x + self.pos_embed)
#         x = self.blocks(x)
#         x = self.norm(x)
#         if self.dist_token is None:
#             return self.pre_logits(x[:, 0])
#         else:
#             return x[:, 0], x[:, 1]

#     def forward(self, x):
#         x = self.forward_features(x)
#         if self.head_dist is not None:
#             x, x_dist = self.head(x[0]), self.head_dist(x[1])
#             if self.training and not torch.jit.is_scripting():
#                 # during inference, return the average of both classifier predictions
#                 return x, x_dist
#             else:
#                 return (x + x_dist) / 2
#         else:
#             x = self.head(x)
#         return x
    def interpolate_pos_encoding(self, x, w, h):
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        class_pos_embed = self.pos_embed[:, 0]
        patch_pos_embed = self.pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + 0.1, h0 + 0.1
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
            mode='bicubic',
        )
        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

    def prepare_tokens(self, x):
        B, nc, w, h = x.shape
        x = self.patch_embed(x)  # patch linear embedding

        # add the [CLS] token to the embed patch tokens
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # add positional encoding to each token
        x = x + self.interpolate_pos_encoding(x, w, h)

        return self.pos_drop(x)

    def forward(self, x):
        x = self.prepare_tokens(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0]

    def get_last_selfattention(self, x):
        x = self.prepare_tokens(x)
        for i, blk in enumerate(self.blocks):
            if i < len(self.blocks) - 1:
                x = blk(x)
            else:
                # return attention of the last block
                return blk(x, return_attention=True)
            
    def get_intermediate_layers(self, x, n=1):
        x = self.prepare_tokens(x)
        # we return the output tokens from the `n` last blocks
        output = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if len(self.blocks) - i <= n:
                output.append(self.norm(x))
        return output





# new Head

In [None]:
class DINOHead(nn.Module):
    """Network hooked up to the CLS token embedding.
    Just a MLP with the last layer being normalized in a particular way.
    
    Parameters:
    in_dim : int
        The dimensionality of the token embedding.
    out_dim : int
        The dimensionality of the final layer (we compute the softmax over).
    hidden_dim : int
        Dimensionality of the hidden layers.
    bottleneck_dim : int
        Dimensionality of the second last layer.
    n_layers : int
        The number of layers.
    norm_last_layer : bool
        If True, then we freeze the norm of the weight of the last linear layer
        to 1.
        
        
    Attributes:
    mlp : nn.Sequential
        Vanilla multi-layer perceptron.
    last_layer : nn.Linear
        Reparametrized linear layer with weight normalization. That means
        that that it will have `weight_g` and `weight_v` as learnable
        parameters instead of a single `weight`.
    """
    def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
            for _ in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)
        self.apply(self._init_weights)
        self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """Run forward pass.
        
        Parameters:
        x : torch.Tensor
            Of shape `(n_samples, in_dim)`.
        
        return: torch.Tensor
            Of shape `(n_samples, out_dim)`.
        """
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x

# Multicropwrapper

In [None]:
class MultiCropWrapper(nn.Module):
    """
    Perform forward pass separately on each resolution input.
    The inputs corresponding to a single resolution are clubbed and single
    forward is run on the same resolution inputs. Hence we do several
    forward passes = number of different resolutions used. We then
    concatenate all the output features and run the head forward on these
    concatenated features.

    Parameters:
    backbone : vision transformer
        Instantiated Vision Transformer. Note that we will take the `head` attribute and replace it with `nn.Identity`.
    head : DINOHead
        New head that is going to be put on top of the `backbone`.
    """
    def __init__(self, backbone, head):
        super(MultiCropWrapper, self).__init__()
        # disable layers dedicated to ImageNet labels classification
        backbone.fc, backbone.head = nn.Identity(), nn.Identity()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        '''
        The different crops are concatenated along the batch dimension and then a single forward pass is fun. The resulting tensor
        is then chunked back to per crop tensors.
        return: list of len=n_crops, each of shape (batch,out_dim)
        '''
        # convert to list
        if not isinstance(x, list):
            print('multicrop',x.shape)
            x = [x]
        n_crops=len(x)
        concatenated=torch.cat(x,dim=0)
        cls_embedding=self.backbone(concatenated)
        logits=self.head(cls_embedding)
        chunks=logits.chunk(n_crops)
        return chunks

# simCLR loss

In [None]:
import torch.nn.functional as F

class Loss(nn.Module):
    def __init__(self, tau=0.07,out_dim=1024,center_momentum=0.995):
        super().__init__()
        self.tau=tau
        self.center_momentum=center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student_output_ori, teacher_output_ori):
        """
        NCELoss of the teacher and student networks.
        student_output_ori: list of len=n_crops, each of shape (batch,out_dim)
        """
        teacher_output=torch.cat(teacher_output_ori,dim=1)
        student_output=torch.cat(student_output_ori,dim=1)
        n_examples,_=student_output.size()
        teacher=F.normalize(teacher_output,dim=-1)
        student=F.normalize(student_output,dim=-1)
        scores=torch.mm(teacher,student.t()).div_(self.tau)
        target=torch.arange(n_examples,dtype=torch.long).to(scores.device)
        loss=F.cross_entropy(scores,target)
        self.update_center(teacher_output_ori)

        return loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """Update center used for teacher output.
        Compute the exponential moving average.
        Parameters
        ----------
        teacher_output : tuple
            Tuple of tensors of shape `(n_samples, out_dim)` where each
            tensor represents a different crop.
        """
        batch_center = torch.cat(teacher_output).mean(
            dim=0, keepdim=True
        )  # (1, out_dim)
        self.center = self.center * self.center_momentum + batch_center * (
            1 - self.center_momentum
        )

    
def clip_gradients(model, clip=2.0):
    """Rescale norm of computed gradients. Used to avoid gradient exponential
    Parameters
    ----------
    model : nn.Module
        Module.
    clip : float
        Maximum norm.
    """
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            clip_coef = clip / (param_norm + 1e-6)
            if clip_coef < 1:
                p.grad.data.mul_(clip_coef)

In [None]:
def get_params_groups(model):
    regularized = []
    not_regularized = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # we do not regularize biases nor Norm parameters
        if name.endswith(".bias") or len(param.shape) == 1:
            not_regularized.append(param)
        else:
            regularized.append(param)
    return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]

# MetricLogger

In [None]:
from collections import defaultdict,deque
import datetime
import time
class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.6f} ({global_avg:.6f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)

class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.6f}')
        data_time = SmoothedValue(fmt='{avg:.6f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.6f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))

# others

In [None]:
import numpy as np
import torch.distributed as dist

def get_params_groups(model):
    regularized = []
    not_regularized = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # we do not regularize biases nor Norm parameters
        if name.endswith(".bias") or len(param.shape) == 1:
            not_regularized.append(param)
        else:
            regularized.append(param)
    return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()
    
def is_main_process():
    return get_rank() == 0

def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


# training process

In [None]:
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
    if epoch >= freeze_last_layer:
        return
    for n, p in model.named_parameters():
        if "last_layer" in n:
            p.grad = None

In [None]:
def train_one_epoch(student,teacher,dino_loss,data_loader,optimizer,lr_schedule,wd_schedule,momentum_schedule,epoch,output_dir
                    ,total_epochs,clip_grad,freeze_last_layer):
    metric_logger = MetricLogger(delimiter="  ")
    header = 'Epoch: [{}/{}]'.format(epoch, total_epochs)
    for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 1000, header)):
        # update weight decay and learning rate according to their schedule
        it = len(data_loader) * epoch + it  # global training iteration
        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedule[it]
            if i == 0:  # only the first group is regularized
                param_group["weight_decay"] = wd_schedule[it]

        # move images to gpu
        images = [im.cuda() for im in images]
        teacher_output=teacher(images)
        student_output=student(images)
        loss=dino_loss(student_output,teacher_output)
        #print('loss:{:.4f}, stopping training'.format(loss.item()))

        optimizer.zero_grad()
        loss.backward()
        param_norms=clip_gradients(student,clip=clip_grad)
        cancel_gradients_last_layer(epoch,student,freeze_last_layer)
        optimizer.step()


        #EMA update teacher
        with torch.no_grad():
            m = momentum_schedule[it]  # momentum parameter
            for param_q, param_k in zip(student.parameters(), teacher.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

        #torch.cuda.synchronize()
        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}



        


In [None]:
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
    """
    Re-start from checkpoint
    """
    if not os.path.isfile(ckp_path):
        return
    print("Found checkpoint at {}".format(ckp_path))

    # open checkpoint file
    checkpoint = torch.load(ckp_path, map_location="cpu")

    # key is what to look for in the checkpoint file
    # value is the object to load
    # example: {'state_dict': model}
    for key, value in kwargs.items():
        if key in checkpoint and value is not None:
            try:
                msg = value.load_state_dict(checkpoint[key], strict=False)
                print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
            except TypeError:
                try:
                    msg = value.load_state_dict(checkpoint[key])
                    print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
                except ValueError:
                    print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
        else:
            print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))

    # re load variable important for the run
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in checkpoint:
                run_variables[var_name] = checkpoint[var_name]

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DistributedSampler,DataLoader
import os
import time
from pathlib import Path
import json

def train_dino(data_path,batch_size,lr,weight_decay,weight_decay_end,min_lr,out_dim,tau,total_epochs,warmup_epochs,momentum_teacher,output_dir,saveckp_freq,
            clip_grad,freeze_last_layer):

    transform = DataAugmentation()
    dataset = ImageFolder(data_path, transform=transform)
    data_loader = DataLoader(dataset,batch_size=batch_size,drop_last=True,)
    print(f"Data loaded: there are {len(dataset)} images.")

    student =VisionTransformer(patch_size=7,drop_path_ratio=0.1,)  # stochastic depth
    teacher = VisionTransformer(patch_size=7)
    embed_dim = student.embed_dim

    student = MultiCropWrapper(student, DINOHead(embed_dim,out_dim=out_dim,use_bn=False,norm_last_layer=True,))
    teacher = MultiCropWrapper(teacher,DINOHead(embed_dim, out_dim=out_dim, use_bn=False),)
    # move networks to gpu
    student=student.cuda()
    teacher=teacher.cuda()
    params_groups = get_params_groups(student)
    optimizer = torch.optim.AdamW(params_groups)
    # teacher and student start with the same weights
    teacher.load_state_dict(student.state_dict())
    dino_loss = Loss(tau=tau,out_dim=out_dim).cuda()
    #there is no backpropagation through the teacher, so no need for gradients
    for p in teacher.parameters():
        p.requires_grad = False
    print(f"Student and Teacher are built: they are both vit network.")


    lr_schedule = cosine_scheduler(lr * (batch_size* get_world_size()) / 256.,min_lr, total_epochs, len(data_loader),warmup_epochs=warmup_epochs,)
    wd_schedule = cosine_scheduler(weight_decay,weight_decay_end,total_epochs, len(data_loader),)
    # momentum parameter is increased to 1. during training with a cosine schedule
    momentum_schedule = cosine_scheduler(momentum_teacher, 1,total_epochs, len(data_loader))
    print(f"Loss, optimizer and schedulers ready.")

    #start to training
    to_restore = {"epoch": 21}
    restart_from_checkpoint(
        os.path.join(output_dir, "checkpoint20.pth"),
        run_variables=to_restore,
        student=student,
        teacher=teacher,
        optimizer=optimizer,
        dino_loss=dino_loss,
    )
    
    start_epoch = to_restore["epoch"]
    start_time = time.time()
    print("Starting DINO training !")
    
    for epoch in range(start_epoch,total_epochs):
        train_stats=train_one_epoch(student,teacher,dino_loss,data_loader,optimizer,lr_schedule,wd_schedule,momentum_schedule,epoch,output_dir,
                                    total_epochs,clip_grad,freeze_last_layer)
        
        save_dict = {'student': student.state_dict(),'teacher': teacher.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch + 1,'loss': dino_loss.state_dict()}
        if epoch%saveckp_freq==0:
            save_on_master(save_dict,os.path.join(output_dir,'checkpoint{}.pth'.format(epoch)))
        
        log_stats={**{f'train_{k}': v for k, v in train_stats.items()},'epoch': epoch}
        with (Path(output_dir)/'log.txt').open("a") as f:
            f.write(json.dumps(log_stats)+'\n')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

    





In [None]:
data_path='data/MNIST - JPG - training'
#data_path='../input/mnistasjpg'
batch_size=32
lr=0.0005
weight_decay=0.04
weight_decay_end=0.4
min_lr=1e-6

total_epochs=100
warmup_epochs=10
momentum_teacher=0.995
out_dim=16384
tau=0.1
saveckp_freq=5

clip_grad=3.0
freeze_last_layer=1
output_dir='log'

train_dino(data_path,batch_size,lr,weight_decay,weight_decay_end,min_lr,out_dim,tau,total_epochs,warmup_epochs,momentum_teacher,output_dir,saveckp_freq,
            clip_grad,freeze_last_layer)

Data loaded: there are 60000 images.
Student and Teacher are built: they are both vit network.
Loss, optimizer and schedulers ready.
Found checkpoint at /content/drive/MyDrive/MNIST/log/checkpoint10.pth
=> loaded 'student' from checkpoint '/content/drive/MyDrive/MNIST/log/checkpoint10.pth' with msg <All keys matched successfully>
=> loaded 'teacher' from checkpoint '/content/drive/MyDrive/MNIST/log/checkpoint10.pth' with msg <All keys matched successfully>
=> loaded 'optimizer' from checkpoint: '/content/drive/MyDrive/MNIST/log/checkpoint10.pth'
=> key 'dino_loss' not found in checkpoint: '/content/drive/MyDrive/MNIST/log/checkpoint10.pth'
Starting DINO training !
Epoch: [11/100]  [   0/1875]  eta: 0:44:27  loss: 0.002687 (0.002687)  lr: 0.000062 (0.000062)  wd: 0.050641 (0.050641)  time: 1.422853  data: 0.187907  max mem: 10421
Epoch: [11/100]  [1000/1875]  eta: 0:20:23  loss: 0.002478 (0.002676)  lr: 0.000062 (0.000062)  wd: 0.051676 (0.051160)  time: 1.403273  data: 0.172565  max me

KeyboardInterrupt: ignored

# visualize attention map

In [None]:
import ipywidgets
import matplotlib.pyplot as plt

import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import torch.nn.functional as F

device='cuda:0' if torch.cuda.is_available() else 'cpu'
patch_size=7

In [None]:
#models = torch.load("/content/drive/MyDrive/MNIST/log/checkpoint10.pth", map_location="cpu")
model=VisionTransformer(patch_size=patch_size)
for p in model.parameters():
    p.requires_grad=False
    
model.eval()
model.to(device)

state_dict=torch.load('log/checkpoint20.pth',map_location='cpu')
state_dict=state_dict['teacher']
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
dataset = ImageFolder("data/MNIST - JPG - testing")

transform = transforms.Compose([
        transforms.Resize(112),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])


In [None]:
@ipywidgets.interact
def _(
    i=ipywidgets.IntSlider(min=0, max=len(dataset) - 1, continuous_update=False),
    k=ipywidgets.IntSlider(min=0, max=195, value=10, continuous_update=False),
    
):
    img0 = dataset[i][0]
    img_ori_tr=transforms.Resize(112)
    img_ori=img_ori_tr(img0) #original image
    img=transform(img0)
        

    # make the image divisible by the patch size
    w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - img.shape[2] % patch_size
    img = img[:, :w, :h].unsqueeze(0) #[1,3,img_size,img_size]
    #how many token attention across width & height dim
    w_featmap = img.shape[-2] // patch_size 
    h_featmap = img.shape[-1] // patch_size 

    #take from the last layer, return the attention coefficents
    #dim: [1,num_head,w_featmap*h_featmap+1,w_featmap*h_geatmap+1] 
    attentions = model.get_last_selfattention(img.to(device))

    nh = attentions.shape[1] # number of head

    # we keep only the output patch attention map -> [1,take_all_head,take_cls_token,attention coefficent without cross cls_token]
    #dim: [num_head,w_featmap*h_featmap]
    attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
    attentions = attentions.reshape(nh, w_featmap, h_featmap) #[num_head,w_featmap,h_featmap]
    attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
    

    # original image
    plt.imshow(img_ori)
    plt.axis("off")
    plt.show()

    #kwargs = {"vmin": 0, "vmax": 0.24}
    # Attentions
    n_heads = 7

    fig, axs = plt.subplots(2, 3)
    
    for i in range(n_heads):
        #ax = axs[i // 3, i % 3]
        #ax.imshow(attentions[i], **kwargs)
        ax.imshow(attentions[i])
        ax.axis("off")
        
    plt.tight_layout()
        
    plt.show()

interactive(children=(IntSlider(value=0, continuous_update=False, description='i', max=9999), IntSlider(value=…