The aim of this notebook is to perform training for few steps and inference given a pretrained model.

# Setup

Imports

In [None]:
from pathlib import Path
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from IPython.display import clear_output
from tqdm import tqdm

from utils.img import unnormalize
from utils.training import get_batch, get_dataloaders
from utils.visualization import plot_feats

Load config

In [None]:
project_root = str(Path().absolute())

# Initialize Hydra manually for Jupyter Notebook
if not GlobalHydra.instance().is_initialized():
    initialize(config_path="config", version_base=None)

# Load configuration and overrides elements
overrides = ["val_dataloader.batch_size=1", f"project_root={project_root}"]
cfg = compose(config_name="base", overrides=overrides)

# Set seed
seed = 0
print(f"Seed: {seed}")
torch.manual_seed(seed)

Backbone

In [None]:
# Load Backbones
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone = instantiate(cfg.backbone)
backbone.to(device)
clear_output()

Dataloader

In [None]:
train_dataloader, val_dataloader = get_dataloaders(
    cfg, backbone, is_evaluation=False
)

Model

In [None]:
model = instantiate(cfg.model)
model.cuda()

optimizer = instantiate(cfg.optimizer, params=list(model.parameters()))

# Training

In [None]:
epochs = 1
model.train()
criterion = instantiate(cfg.loss, dim=backbone.embed_dim)

for epoch in range(epochs):
    # Initialize progress bar with additional formatting
    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}", leave=True, postfix={"loss": "..."})

    for step, batch in enumerate(pbar):
        batch = get_batch(batch, device)
        image_batch = batch["image"]
        hr_feats, _ = backbone(image_batch)

        low_res_batch = F.interpolate(image_batch, scale_factor=0.5, mode="area")
        lr_feats, _ = backbone(low_res_batch)

        lr_img_batch = F.interpolate(image_batch, hr_feats.shape[-2:])
        pred = model(lr_img_batch, lr_feats, (hr_feats.shape[2], hr_feats.shape[3]))

        loss_hr = criterion(hr_feats, pred)["total"]
        loss = loss_hr

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar with current loss
        pbar.set_postfix({"loss_hr": f"{loss_hr.item():.4f}"})

        if step % 500 == 0:
            unorm_img_batch = unnormalize(image_batch, backbone.config["mean"], backbone.config["std"])
            plot_feats(
                unorm_img_batch[0].to(torch.float32),
                hr_feats[0].to(torch.float32),
                pred[0].to(torch.float32),
            )
        if step == 2000:
            break

    pbar.close()

torch.cuda.empty_cache()

# Inference

Simple inference

In [None]:
model.load_state_dict(
    torch.load(
        "./output/jafar/vit_small_patch14_dinov2.lvd142m/model.pth"
    )["jafar"]
)

In [None]:
model.eval()
SIZE = 448
with torch.no_grad():
    for step, batch in enumerate(
        tqdm(val_dataloader)
    ):
        batch = get_batch(batch, device)
        image_batch = batch["image"].cuda()
        bs = image_batch.shape[0]

        hr_feats, _ = backbone(image_batch)
        pred = model(image_batch, hr_feats, (SIZE,SIZE))

        unorm_img_batch = unnormalize(
            image_batch, backbone.config["mean"], backbone.config["std"]
        )
        plot_feats(
            unorm_img_batch[0].to(torch.float32),
            torch.nn.functional.interpolate(hr_feats, SIZE, mode="bilinear")[0].to(torch.float32),
            pred[0].to(torch.float32),
        )
        break
    
torch.cuda.empty_cache()