In [1]:
import pytorch_lightning as pl
import wandb

# Import custom modules
from data.cifar100 import CIFAR100DataModule
from vision_transformer.models.pl_model import ViTModel

## 1. Set Up Hyperparameter Sweep

In [None]:
sweep_config = {
    'method': 'random'
    }

In [None]:
metric = {
    'name': 'val_loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

In [None]:
parameters_dict = {
    'batch_size': {
        'values': [64, 128, 256]
        },
    'num_encoders': {
        'values': [12, 24, 36]
        },
    'patch_size': {
        'values': [4, 8, 16]
        },
    'learning_rate': {
        'values': [0.01, 0.003, 0.001, 0.0003, 0.0001]
        }
    }

sweep_config['parameters'] = parameters_dict

In [2]:
model_kwargs = {
    "embed_size":256, 
    "hidden_size":512,
    "hidden_class_size":512, 
    "num_encoders":24,
    "num_heads":8,
    "patch_size":4,
    "num_patches":64,
    "dropout":0.1,
    "batch_size":256,
    "learning_rate":0.001
}


In [3]:
CIFAR = "/media/curttigges/project-files/datasets/cifar-100/"
cifar100 = CIFAR100DataModule(batch_size=model_kwargs["batch_size"], num_workers=12,data_dir=CIFAR)

In [4]:
wandb.init(
        project="vit-classifier", 
        entity="ascendant",
        config = {
        "batch_size":model_kwargs["batch_size"],
        "embed_size":model_kwargs["embed_size"], 
        "hidden_size":model_kwargs["hidden_size"],
        "hidden_class_size":model_kwargs["hidden_class_size"], 
        "num_encoders":model_kwargs["num_encoders"],
        "num_heads":model_kwargs["num_heads"],
        "patch_size":model_kwargs["patch_size"],
        "num_patches":model_kwargs["num_patches"],
        "dropout":model_kwargs["dropout"],
        "learning_rate":model_kwargs["learning_rate"],
        "scheduler":"OneCycleLR",
        "loss":"CrossEntropy"
        })

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33mascendant[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
from pytorch_lightning.callbacks import TQDMProgressBar
pl.seed_everything(42)
model = ViTModel(**model_kwargs)
trainer = pl.Trainer(max_epochs=60,accelerator='gpu', devices=1, callbacks=[TQDMProgressBar(refresh_rate=10)])
trainer.fit(model, datamodule=cifar100)
wandb.finish()

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


Missing logger folder: /home/curttigges/projects/vit/lightning_logs
Global seed set to 42
Global seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type          | Params
----------------------------------------
0 | model | ViTClassifier | 12.9 M
----------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.451    Total estimated model params size (MB)


Epoch 10:   5%|▌         | 10/196 [05:47<1:47:49, 34.78s/it, loss=3.26, v_num=0, val_loss=3.040, val_acc=0.253]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▇▇▆▆▆▅▆▆▅▅▅▅▄▅▅▄▅▄▅▄▄▄▄▃▄▄▄▃▂▃▃▂▁▂▃▂▁▂▁
val_acc,▁▃▃▂▂▃▃▂▃▄▄▃▃▄▄▃▅▄▄▅▆▅▄▅▆▆▅▆▆▆▆▇▆▇▇█▇█▇▇
val_loss,█▆▆▆▆▅▅▅▆▄▄▄▅▄▄▄▃▄▄▄▃▃▃▃▂▂▂▂▂▂▂▁▂▂▂▁▂▁▁▂

0,1
loss,3.3175
val_acc,0.23529
val_loss,3.3047
