In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
from image_captioner import ImageEncoder, ImageDecoder
from coco_loader import get_coco_loader
from paths import paths
import yaml
import torch
import torch.nn as nn
from torch.optim import AdamW
from image_transforms import image_transform_index
from tqdm import tqdm

In [13]:
with open(paths["config"]) as f:
    config = yaml.safe_load(f)

IMAGE_SIZE = config["image_size"]
PATCH_SIZE = config["patch_size"]
MASKING_RATIO = config["masking_ratio"]
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
NUM_MASKED_PATCHES = int(NUM_PATCHES * MASKING_RATIO)
NUM_WORKERS = config["num_workers"]
BATCH_SIZE = config["batch_size"]
image_encoder_config = config["image_encoder"]
image_decoder_config = config["image_decoder"]

In [14]:
# Set device.
if "device" in config:
    device = config["device"]
else:
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
print(f"You are using {device}.")

You are using cuda.


In [None]:
image_encoder = ImageEncoder(IMAGE_SIZE, PATCH_SIZE, image_encoder_config).to(device)
image_encoder.train()
image_decoder = ImageDecoder(PATCH_SIZE, NUM_PATCHES, image_decoder_config).to(device)
image_decoder.train()
criterion = nn.MSELoss()
optimizer = AdamW(
    list(image_encoder.parameters()) + list(image_decoder.parameters()),
    lr=1e-4,
    weight_decay=0.05,
)

train_batches = get_coco_loader(
    "train", BATCH_SIZE, 100, image_transform_index["train"], 0, NUM_WORKERS
)
pbar = tqdm(train_batches, desc=f"Training epoch {1}:", leave=True)
patch_extracter = nn.Unfold(kernel_size=PATCH_SIZE, stride=PATCH_SIZE)

for image, _ in pbar:
    optimizer.zero_grad()

    image = image.to(device)
    positions = torch.randint(
        0,
        NUM_PATCHES,
        (
            image.shape[0],
            NUM_PATCHES,
        ),
        device=image.device,
    )
    masked_positions = positions[:, :NUM_MASKED_PATCHES]
    unmasked_positions = positions[:, NUM_MASKED_PATCHES:]

    image_patches = patch_extracter(image).transpose(-1, -2)
    ground_inds = masked_positions.unsqueeze(-1).expand(-1, -1, image_patches.shape[-1])
    ground_masked_patches = torch.gather(image_patches, dim=1, index=ground_inds)

    encoded_unmasked_patches = image_encoder(image, positions=unmasked_positions)
    reconstructed_image_patches = image_decoder(
        encoded_unmasked_patches, positions=unmasked_positions
    )
    pred_masked_patches = torch.gather(
        reconstructed_image_patches, dim=1, index=ground_inds
    )
    loss = criterion(pred_masked_patches, ground_masked_patches)
    loss.backward()
    optimizer.step()
    pbar.set_postfix({"loss": loss.item()})

Training epoch 1::  18%|█▊        | 811/4624 [02:33<12:01,  5.29it/s, loss=0.681]


KeyboardInterrupt: 