In [None]:
from torch.backends import cudnn
import torch
from transformers.utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.datasets import ImageFolder
import torchvision.transforms.v2 as transforms
import torchvision
import sys
import os

sys.path.append('../..')

normalize = transforms.Normalize(
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
)

train_augs = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToImageTensor(),
        transforms.ConvertImageDtype(),
        normalize,
    ]
)


test_augs = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToImageTensor(),
        transforms.ConvertImageDtype(),
        normalize,
    ]
)
train_set = ImageFolder('/path/to/dataset/train', transform=train_augs)
test_set = ImageFolder('/path/to/dataset/val', transform=test_augs)
train_set.class_to_idx

In [None]:
from ic.train import utils


config = {
    "n_epochs": 100,
    "data_loader": {"batch_size": 200, "num_workers": 16, "pin_memory": True},
    "optimizer": "Adam",
    "optim_hparas": {"lr": 1e-2, "weight_decay": 1e-4},
    "weights": "google/vit-base-patch16-224-in21k",
}

config["optim_hparas"]["lr"] = (
    1e-4 if config["weights"] else config["optim_hparas"]["lr"]
)
train_loader, test_loader = utils.create_data_loaders(config, train_set, test_set)
len(train_set), len(test_set)

In [None]:
from ic.models import *
import gpytorch


fe = ResNetFeatureExtractor(101, True)
# fe = HuggingFaceViTFeatureExtractor(config["weights"])
model = DKLModel(fe, utils.get_feature_dim(fe))
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(
    num_features=model.num_dim, num_classes=3
)
config["criterion"] = gpytorch.mlls.VariationalELBO(
    likelihood, model.gp_layer, num_data=len(train_loader.dataset)
)

In [None]:
cudnn.benchmark = True
if __name__ == "__main__":
    utils.train_distributed_dkl(
        "tcp://localhost:23456", config, model, likelihood, train_set, test_set, "cuda"
    )