In [None]:
import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
import torch.nn as nn
from data import ImgDataset, get_dataloader
from data import get_default_transforms, get_default_training_transforms
from timm import scheduler as timm_schedulers
from timm.models.vision_transformer import LayerScale
from timm.optim import optim_factory
from tqdm.notebook import tqdm

class TokenReconstructor(nn.Module):
    """
    Input: Images -> Pixel space corruption -> Embed -> Patch corruption -> Predict all tokens of original image
    """
    def __init__(self, original_encoder, model, corruption_ratio: float = 1 / 3,
                 sample_corruption_ratio: bool = True, project: bool = True, last_sample_clean: bool = True,
                 pixel_space_corruption: bool = True, pixel_space_corruption_scale: float = 0.2):
        super().__init__()
        self.original_encoder = original_encoder

        self.model = model

        self.corruption_ratio = corruption_ratio
        self.corruption_token = nn.Parameter(torch.zeros(1, 1, self.model.embed_dim))

        self.sample_corruption_ratio = sample_corruption_ratio

        self.project = project
        if self.project:
            self.model.projector = nn.Sequential(
                nn.LayerNorm(self.model.embed_dim),
                nn.Linear(self.model.embed_dim, self.model.embed_dim),
                nn.GELU(),
                nn.Linear(self.model.embed_dim, self.model.embed_dim)
            )
            self.model.ls_projector = LayerScale(self.model.embed_dim, init_values=1e-5)

        self.last_sample_clean = last_sample_clean

        self.pixel_space_corruption = pixel_space_corruption
        self.pixel_space_corruption_scale = pixel_space_corruption_scale

    def forward(self, imgs):
        with torch.inference_mode():
            y = self.original_encoder.forward_features(imgs)

        if self.pixel_space_corruption:
            # corrupt random pixels by adding noise, scaled by pixel_space_corruption_scale * std
            noise = torch.randn_like(imgs) * self.pixel_space_corruption_scale * imgs.std(dim=(2, 3), keepdim=True)
            if self.last_sample_clean:
                noise[-1] = 0
            imgs = imgs + noise

        x = self.model.patch_embed(imgs)
        # corrupt random patches by replacing them with a corruption_token
        # x - B, N, D; mask - B, N
        if self.sample_corruption_ratio:
            # corruption_ratio = torch.rand(1, device=x.device, dtype=x.dtype) * self.corruption_ratio
            # B
            corruption_ratio = torch.rand(x.shape[0], device=x.device, dtype=x.dtype) * self.corruption_ratio
            corruption_ratio = corruption_ratio.unsqueeze(1)
            mask = torch.rand(x.shape[0], x.shape[1], device=x.device, dtype=x.dtype) < corruption_ratio
        else:
            mask = torch.rand(x.shape[0], x.shape[1], device=x.device, dtype=x.dtype) < self.corruption_ratio

        if self.last_sample_clean:
            mask[-1] = False

        x[mask] = self.corruption_token.to(x.dtype)

        x = self.model._pos_embed(x)
        x = self.model.norm_pre(x)
        x = self.model.blocks(x)
        x = self.model.norm(x)

        if self.project:
            x = x + self.model.ls_projector(self.model.projector(x))

        loss = (x - y).pow(2).mean()

        return loss


def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w


def plot_losses(losses, iters_per_epoch, max=None, min=None, left=None, as_ep=True):
    if as_ep:
        ep_div = iters_per_epoch
    else:
        ep_div = 1
    plt.figure(dpi=80, figsize=(8, 5))
    x_plot = np.array(list(range(len(losses))))
    plt.plot(x_plot / ep_div, losses, lw=0.05)
    x_plot = np.array(list(range(iters_per_epoch - 1, len(losses))))
    plt.plot(x_plot / ep_div, moving_average(losses, iters_per_epoch), lw=1)
    x_plot = np.array(list(range(iters_per_epoch // 2 - 1, len(losses))))
    plt.plot(x_plot / ep_div, moving_average(losses, iters_per_epoch // 2), lw=0.5)
    x_plot = np.array(list(range(iters_per_epoch * 5 - 1, len(losses))))
    plt.plot(x_plot / ep_div, moving_average(losses, iters_per_epoch * 5), lw=0.1)
    plt.yscale('log')
    plt.ylim(bottom=min, top=max)
    plt.xlim(left=left)
    if as_ep:
        plt.xlabel('Epoch')
    else:
        plt.xlabel('Iter')
    plt.ylabel('Loss (log scale)')
    plt.show()


res = 392
max_grad_norm = 0.1
ep_to_train = 120
cooldown_eps = 20

imgs_to_use = np.load('imgs_all.npy', allow_pickle=True)
dataset = ImgDataset(list(imgs_to_use), get_default_training_transforms(res))
dataloader = get_dataloader(dataset, batch_size=128, num_workers=30)
iters_per_epoch = len(dataloader)
print(len(dataset), iters_per_epoch)

og_encoder = timm.create_model('vit_small_patch14_reg4_dinov2.lvd142m', pretrained=True, img_size=(res, res),
                               num_classes=0).eval()
encoder = timm.create_model('vit_small_patch14_reg4_dinov2.lvd142m', pretrained=True, img_size=(res, res),
                            num_classes=0)
model = TokenReconstructor(og_encoder, encoder,
                           corruption_ratio=1 / 3, sample_corruption_ratio=True,
                           last_sample_clean=True,
                           project=True,
                           pixel_space_corruption_scale=0.2, pixel_space_corruption=True,
                           ).cuda()

name = f'reconstruct_{res}'

torch.backends.cudnn.benchmark = True
scaler = torch.cuda.amp.GradScaler()

param_groups = optim_factory.param_groups_weight_decay(model.model, weight_decay=1e-4)
optim = torch.optim.AdamW(param_groups, lr=5e-5, betas=(0.9, 0.99))
scheduler = timm_schedulers.CosineLRScheduler(optim, t_initial=ep_to_train - cooldown_eps,
                                              warmup_prefix=False, warmup_t=10,
                                              warmup_lr_init=5e-9, lr_min=5e-9)

losses = []
for ep_idx in range(ep_to_train):
    if ep_idx % 5 == 0 and ep_idx > 0:
        recent_losses = np.mean(losses[-iters_per_epoch:])
        torch.save(model, f'/datastorage/justin/foundcheckpoints/{name}_{ep_idx:03}eps_{recent_losses:.6f}.pth')

    pbar = tqdm(dataloader)
    if scheduler:
        scheduler.step(epoch=ep_idx)
    for batch_idx, imgs in enumerate(pbar):
        optim.zero_grad()
        with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
            imgs = imgs.cuda(non_blocking=True)
            loss = model(imgs)
        # loss.backward() equivalent for autocast
        scaler.scale(loss).backward()
        if max_grad_norm > 0:
            scaler.unscale_(optim)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        # equivalent to optim.step()
        scaler.step(optim)
        scaler.update()

        losses.append(loss.item())

        pbar.set_description(f'L:{loss.item():.6f} ({np.mean(losses[-iters_per_epoch:]):.6f}' \
                             f'/ {np.mean(losses[-iters_per_epoch * 2:]):.6f}) - E:{ep_idx}, B:{batch_idx}, lr:{optim.param_groups[0]["lr"]:.2e}')

    torch.cuda.empty_cache()
    try:
        plot_losses(losses, iters_per_epoch)
    except:
        plt.show()
        pass


torch.save(model, f'/datastorage/justin/foundcheckpoints/{name}_{ep_idx:03}eps.pth')

