In [None]:
from dataclasses import asdict

import pytorch_lightning as pl
from lightning.pytorch import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor
from timm.data import create_dataset, create_loader

from model import DeMansia
from model_config import DeMansia_tiny_config
from modules.data import create_token_label_dataset, create_token_label_loader
from modules.ema import EMA, EMAModelCheckpoint

In [None]:
imagenet_root = "datasets/ImageNet 1k"
token_label_root = "datasets/ImageNet 1k token label"

In [None]:
config = asdict(DeMansia_tiny_config())
model = DeMansia(**config)

In [None]:
class dataset(pl.LightningDataModule):
    def __init__(self, batch_size: int):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.train_set = create_token_label_dataset(
            imagenet_root + "/train", token_label_root
        )
        self.valid_set = create_dataset(
            name="", root=imagenet_root + "/val", batch_size=self.batch_size
        )

    def train_dataloader(self):
        return create_token_label_loader(
            self.train_set,
            input_size=config["img_size"],
            batch_size=self.batch_size,
            num_workers=16,
            use_prefetcher=False, # some weird bug with prefetcher, as long as you have a normal ssd the loading speed won't be a issue
            pin_memory=True,
        )

    def val_dataloader(self):
        return create_loader(
            self.valid_set,
            input_size=config["img_size"],
            batch_size=self.batch_size,
            is_training=False,
            num_workers=16,
            use_prefetcher=False,
            pin_memory=True,
        )


data = dataset(batch_size=768)

In [None]:
trainer = pl.Trainer(
    callbacks=[
        EMA(decay=0.9999),
        EMAModelCheckpoint(
            dirpath="models/",
            save_top_k=-1,
        ),
        LearningRateMonitor(logging_interval="step"),
    ],
    logger=pl_loggers.WandbLogger(project="DeMansia Tiny", name="Pretrain"),
    precision="bf16-mixed",
    max_epochs=310,
)

trainer.fit(model, data)
# trainer.fit(model, data, ckpt_path="ckpt to resume training")