## BitViT - Example Training
---

In [None]:
# load training script
%run -i train.py

In [None]:
num_classes = 1000
config = {
    "project": "distill-vit-pretrain",
    "experiment": "bit-vit",
    "epochs": 50,
    "batch_size": 256,
    "accum_steps": 4, # gradient accumulation steps
    "warmup": 5, # in epochs
    "distill": {
        "enabled": True,
        "args": {
            "teacher": efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1),
            "hard": True,  # use hard or soft labels
            "alpha": 0.5,  # trade-off between main loss and distillation loss
            "temperature": 1.0,  # only used for soft distillation
        },
    },
    "architecture": {
        "model": DistillableBitViT,
        "args": { # based on DeiT-Small
            "image_size": 224,
            "patch_size": 16,
            "num_classes": num_classes,
            "dim": 384,
            "depth": 12,
            "heads": 6,
            "mlp_dim": 1536,
            "spt": True, # from vits for small datasets
            "sincos2d": True, # from SimpleViT
        },
    },
    "optimizer": {
        "type": torch.optim.AdamW,
        "args": {  # based on SimpleViT
            "lr": 1e-3,
            "weight_decay": 0.05,
        },
    },
    "scheduler": {
        "type": torch.optim.lr_scheduler.CosineAnnealingLR,
        "args": {
            "T_max": 90,
        },
    },
    "criterion": { # disabled for distillation
        "type": nn.CrossEntropyLoss,
        "args": {},
    },
    "data": {
        "path": f"{os.environ['TMPDIR']}/datasets/imagenet-1k-rgb-256",
        "num_classes": num_classes,
        "preprocess": {
            "train": transforms.Compose(
                [
                    transforms.RandomCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandAugment(2, 10),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                    ),
                ]
            ),
            "validation": transforms.Compose(
                [
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                    ),
                ]
            ),
        },
        "cutmix_or_mixup": {
            "enabled": True,
            "mixup": 0.2, # alpha
            "cutmix": 1.0, # alpha 
            "p": [0.0, 1.0], # only mixup
        },
        "loader": {
            "shuffle": False,
            "num_workers": 16,
            "prefetch_factor": 4,
            "pin_memory": True,
            "drop_last": True,
            "persistent_workers": True,
        }
    },
    "wandb": {
        "enabled": True,
        "log_interval": 100,
    },
}