In [1]:
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from omegaconf import DictConfig
from torch.utils.data import DataLoader

import torch
import torchmetrics
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np

from dataset.cub import CUB200
from util import WarmupLinearSchedule
from pl_model.xfg import LitXFGConcat

In [2]:
config = DictConfig({
    "patch_size": 32,
    "split": "overlap",
    "slide_step": 24,
    "hidden_size": 768,
    "dropout": 0.1,
    "max_len": 100,
    "classifier": "token",
    "transformer": {
        "mlp_dim": 3072,
        "num_heads": 12,
        "num_layers": 12,
        "num_layers_fusion": 12,
        "attention_dropout_rate": 0.0,
    },
    "num_classes": 200,
    "batch_size": 16,
    "num_workers": 8,
    "image_size": 448,
    "lr": 3e-2,
    "seed": 42,
    "momentum": 0.9,
    "epoch": 30,
    "gpus": [0,1],
    "logger": False,
    "pretrained_dir": "./pretrained/vit/imagenet21k_ViT-B_32.npz",
})

In [3]:
train_transform=transforms.Compose([
    transforms.Resize((600, 600), InterpolationMode.BILINEAR),
    transforms.RandomCrop((448, 448)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
train_set = CUB200(root="./data", train=True, caption=True, transform=train_transform)

In [4]:
# model = XFG(config)
# model.load_from(np.load(config.pretrained_dir))
# imgs, txts, targets = train_set[0]
# imgs = torch.Tensor(imgs).unsqueeze(0)
# txts = torch.Tensor(txts)
# model(imgs, txts).shape

In [5]:
if config.logger:
    from pytorch_lightning.loggers import WandbLogger
    logger = WandbLogger(
        project="xfg",
        name=f"vit"
    )
else:
    logger = pl.loggers.TestTubeLogger(
        "output", name=f"vit")
    logger.log_hyperparams(config)

pl.seed_everything(config.seed)
trainer = pl.Trainer(
    precision=16,
    deterministic=True,
    check_val_every_n_epoch=1,
    gpus=config.gpus,
    logger=logger,
    max_epochs=config.epoch,
    weights_summary="top",
    accelerator='ddp',
)

model = LitXFGConcat(config)
trainer.fit(model)
trainer.test()

Global seed set to 42


MisconfigurationException: Selected distributed backend ddp is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible backends: dp, ddp_spawn, ddp_sharded_spawn, tpu_spawn