# cloning git

In [1]:
!git clone 'https://github.com/edwardhan925192/images.git'
%cd images

Cloning into 'images'...
remote: Enumerating objects: 269, done.[K
remote: Counting objects: 100% (125/125), done.[K
remote: Compressing objects: 100% (119/119), done.[K
remote: Total 269 (delta 48), reused 0 (delta 0), pack-reused 144[K
Receiving objects: 100% (269/269), 67.63 KiB | 4.51 MiB/s, done.
Resolving deltas: 100% (107/107), done.
/content/images


# yaml

In [2]:
import yaml

with open('/content/ijepa_image1k.yaml', 'r') as file:
    args = yaml.safe_load(file)

# -- MASK
allow_overlap = args['mask']['allow_overlap']  # whether to allow overlap b/w context and target blocks
patch_size = args['mask']['patch_size']  # patch-size for model training
num_enc_masks = args['mask']['num_enc_masks']  # number of context blocks
min_keep = args['mask']['min_keep']  # min number of patches in context block
enc_mask_scale = args['mask']['enc_mask_scale']  # scale of context blocks
num_pred_masks = args['mask']['num_pred_masks']  # number of target blocks
pred_mask_scale = args['mask']['pred_mask_scale']  # scale of target blocks
aspect_ratio = args['mask']['aspect_ratio']  # aspect ratio of target blocks
# --

# -- OPTIMIZATION
ema = args['optimization']['ema']
ipe_scale = args['optimization']['ipe_scale']  # scheduler scale factor (def: 1.0)
wd = float(args['optimization']['weight_decay'])
num_epochs = args['optimization']['epochs']
lr = args['optimization']['lr']

# loading check points

In [3]:
def load_checkpoint(
    device,
    r_path,
    x_encoder,
    predictor,
    y_encoder,
    x_optimizer,
    pred_optimizer

):
    try:
      # -- saved dir
      checkpoint = torch.load(r_path, map_location=torch.device('cpu'))
      epoch = checkpoint['epoch']

      # -- loading x_encoder
      pretrained_dict = checkpoint['x_encoder']
      x_encoder.load_state_dict(pretrained_dict)

      # -- loading predictor
      pretrained_dict = checkpoint['predictor']
      predictor.load_state_dict(pretrained_dict)

      # -- loading y_encoder
      if y_encoder is not None:
          print(list(checkpoint.keys()))
          pretrained_dict = checkpoint['y_encoder']
          y_encoder.load_state_dict(pretrained_dict)

      # -- loading optimizer
      pred_optimizer.load_state_dict(checkpoint['pred_optimizer'])
      x_optimizer.load_state_dict(checkpoint['x_optimizer'])

    except Exception as e:
      epoch = 0

    return x_encoder, predictor, y_encoder, x_optimizer, pred_optimizer,  epoch

# main

In [9]:
# -- Example usage
use_bfloat16 = False # bool
num_epochs = 110
checkpoint_freq = 5
base_directory = '/content/drive/MyDrive/data/ijepa weights'  # where the check points are stored and loading from

# -- loading path
load_model = True
load_path = '/content/drive/MyDrive/data/ijepa weights/-latest.pth'

In [10]:
import torch
from models.ijepa import VisionTransformerPredictor, VisionTransformer
import torch.nn.functional as F
from utils.masks.mask_application import apply_masks
from utils.tensors import repeat_interleave_batch
import os
# -- 0. datasets
from utils.masks.maskcollator_vit import MaskCollator
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

root_dir = '/content/drive/MyDrive/data/vegi picture'

# -- loading
dataset = datasets.ImageFolder(root = root_dir, transform=transform)

# -- 1. dataloader
collator = MaskCollator()
unsupervised_loader = DataLoader(dataset, batch_size=20, shuffle=True, collate_fn=collator)

In [None]:
from torch.cuda.amp import GradScaler, autocast
from images.scheduler.scheduler import SchedulerManager

# -- Check CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = GradScaler()

# -- model init
x_encoder = VisionTransformer()
y_encoder = VisionTransformer()
num_patches = 196
predictor = VisionTransformerPredictor(num_patches = num_patches)

# -- load models
x_encoder.to(device)
y_encoder.to(device)
predictor.to(device)

# -- optimizer
x_optimizer = torch.optim.Adam(x_encoder.parameters(), lr=lr, weight_decay = wd)
y_optimizer = torch.optim.Adam(y_encoder.parameters(), lr=lr, weight_decay = wd)
pred_optimizer = torch.optim.Adam(predictor.parameters(), lr=lr, weight_decay = wd)

# -- lr scheduler
scheduler_manager = SchedulerManager()
x_scheduler = scheduler_manager.initialize_scheduler(x_optimizer, 'CosineAnnealingWarmRestarts')
pred_scheduler = scheduler_manager.initialize_scheduler(pred_optimizer, 'CosineAnnealingWarmRestarts')

# -- momentum scheduler
ipe = len(unsupervised_loader)
momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale)
                      for i in range(int(ipe*num_epochs*ipe_scale)+1))

def save_checkpoint(epoch, checkpoint_freq, base_directory='/content/drive/MyDrive/data/ijepa weights'):

    # Latest checkpoint path
    latest_path = os.path.join(base_directory, '-latest.pth')

    # Save dictionary
    save_dict = {
        'x_encoder': x_encoder.state_dict(),
        'y_encoder': y_encoder.state_dict(),
        'predictor': predictor.state_dict(),
        'x_optimizer': x_optimizer.state_dict(),
        'pred_optimizer': pred_optimizer.state_dict(),
        'scaler': None if scaler is None else scaler.state_dict(),
        'epoch': epoch
    }

    # -- Always update the latest checkpoint
    torch.save(save_dict, latest_path)

    # -- Checkpoint frequency updates
    if (epoch + 1) % checkpoint_freq == 0:

        # -- Checkpoint path for specific epoch
        checkpoint_path = os.path.join(base_directory, f'-ep{epoch + 1}.pth')
        torch.save(save_dict, checkpoint_path)

# ------------- Begin loading ------------- #
if load_model == True:
    encoder, predictor, target_encoder, optimizer, scaler, start_epoch = load_checkpoint(
        device=device,
        r_path=load_path,
        x_encoder = x_encoder,
        predictor=predictor,
        y_encoder=y_encoder,
        x_optimizer = x_optimizer,
        pred_optimizer = pred_optimizer
        )

    # -- momentum scheduler
    for _ in range(start_epoch*ipe):
      next(momentum_scheduler)

    # -- currently its set to be updated every epoch
    for _ in range(start_epoch):
      x_scheduler.step()
      pred_scheduler.step()
# ------------- End loading ------------- #


# ------------- training step ------------- #
for epoch in range(start_epoch, num_epochs):
  x_encoder.train()
  predictor.train()

  # -- loader
  for itr, (udata, masks_enc, masks_pred) in enumerate(unsupervised_loader):
    def load_imgs():

        # -- unsupervised imgs
        imgs = udata[0].to(device, non_blocking=True)
        masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
        masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
        return (imgs, masks_1, masks_2)
    imgs, masks_enc, masks_pred = load_imgs()

    def train_step():
        # -- return masked target tokens
        def forward_target():
            with torch.no_grad():
                h = y_encoder(imgs)
                h = F.layer_norm(h, (h.size(-1),))  # normalize over feature-dim
                B = len(h)
                # -- create targets (masked regions of h)
                h = apply_masks(h, masks_pred)
                h = repeat_interleave_batch(h, B, repeat=len(masks_enc))
                return h

        # -- return masked encoded tokens
        def forward_context():
            z = x_encoder(imgs, masks_enc)
            z = predictor(z, masks_enc, masks_pred)
            return z

        def loss_fn(z, h):
            loss = F.smooth_l1_loss(z, h)
            return loss

        # -- step 1. Forward
        with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16):
            h = forward_target()
            z = forward_context()
            loss = loss_fn(z, h)

        # -- step 2. Backward & step
        # -- scaler
        if use_bfloat16:
            scaler.scale(loss).backward()
            scaler.step(x_optimizer)
            scaler.step(pred_optimizer)
            scaler.update()
        else:
            loss.backward()
            x_optimizer.step()
            pred_optimizer.step()

        # -- momentum update of y_encoder
        with torch.no_grad():
            m = next(momentum_scheduler)
            for param_q, param_k in zip(x_encoder.parameters(), y_encoder.parameters()):
                param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)

        print(f'LOSS OF {epoch+1}: {loss}')

    train_step()

  # -- lr scheduler step
  x_scheduler.step()
  pred_scheduler.step()

  # -- saving every epoch
  save_checkpoint(epoch + 1, checkpoint_freq, base_directory)