In [1]:
import os
import random
import torch
import numpy as np


def seed_everything(seed):
    """
    Seeds basic parameters for reproductibility of results
    
    Arguments:
        seed {int} -- Number of the seed
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(8888)

In [2]:
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from omegaconf import DictConfig

from dataset.cub import CUB200
from model.vit import VisionTransformer

In [3]:
config = DictConfig({
    "patch_size": 16,
    "split": "non-overlap",
    "slide_step": 12,
    "hidden_size": 768,
    "classifier": "token",
    "dropout": 0.1,
    "max_len": 100,
    "transformer": {
        "mlp_dim": 3072,
        "num_heads": 12,
        "num_layers": 12,
        "attention_dropout_rate": 0.0,
    },
    "batch_size": 16,
    "image_size": 448,
    "lr": 2e-5,
    "momentum": 0.9,
    "epoch": 10,
})

In [4]:
def scale_keep_ar_min_fixed(img, fixed_min):
    ow, oh = img.size

    if ow < oh:
        nw = fixed_min
        nh = nw * oh // ow
    else:
        nh = fixed_min
        nw = nh * ow // oh
    return img.resize((nw, nh), InterpolationMode.BICUBIC)

train_transform=transforms.Compose([
    transforms.Resize((600, 600), InterpolationMode.BILINEAR),
    transforms.RandomCrop((448, 448)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_transform=transforms.Compose([
    transforms.Resize((600, 600), InterpolationMode.BILINEAR),
    transforms.CenterCrop((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [5]:
model = VisionTransformer(config)

img, _, target = CUB200(root="./data", train=True, transform=train_transform)[0]
img = img.unsqueeze(0)
model(img).shape

torch.Size([1, 200])