In [1]:
from dataclasses import dataclass, field
from typing import List

@dataclass
class Config:
    # Data configuration
    data_dir: str = '/mnt/d/datasets/mvtec'
    categories: List[str] = field(default_factory=lambda: ['bottle'])
    img_size: int = 256
    batch_size: int = 32
    valid_ratio: float = 0.2
    seed: int = 42

config = Config()
config

Config(data_dir='/mnt/d/datasets/mvtec', categories=['bottle'], img_size=256, batch_size=32, valid_ratio=0.2, seed=42)

In [2]:
from data import MVTecDataset, get_transforms, get_dataloader, split_train_valid

# Get transforms
train_transform, test_transform = get_transforms(img_size=config.img_size)

# Create datasets
train_dataset = MVTecDataset(config.data_dir, config.categories, 'train', transform=train_transform)
valid_dataset = MVTecDataset(config.data_dir, config.categories, 'train', transform=test_transform)
test_dataset = MVTecDataset(config.data_dir, config.categories, 'test', transform=test_transform)

# Split train/valid
train_dataset, valid_dataset = split_train_valid(
    train_dataset, valid_dataset, 
    valid_ratio=config.valid_ratio, 
    seed=config.seed
)

# Create dataloaders
train_loader = get_dataloader(train_dataset, config.batch_size, 'train')
valid_loader = get_dataloader(valid_dataset, 16, 'valid')
test_loader = get_dataloader(test_dataset, 16, 'test')

print(f'Train size: {len(train_loader.dataset)}')
print(f'Valid size: {len(valid_loader.dataset)}')
print(f'Test  size: {len(test_loader.dataset)}')

Train size: 168
Valid size: 41
Test  size: 83


In [3]:
import torch
from modeler import Modeler, get_model, get_loss_fn, get_metric
from trainer import Trainer, get_optimizer, get_scheduler, get_logger
from stopper import get_stopper

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model = get_model('vanilla_ae', latent_dim=512, img_size=256)
loss_fn = get_loss_fn('combined')
metrics = {"ssim": get_metric('ssim'), "psnr": get_metric('psnr')}
modeler = Modeler(model, loss_fn, metrics, device)

optimizer = get_optimizer(model, "adamw", lr=0.001, weight_decay=1e-5)
scheduler = get_scheduler(optimizer, "plateau")
logger = get_logger("./experments")
stopper = get_stopper("stop", max_epoch=20)
trainer = Trainer(modeler, optimizer, scheduler, logger, stopper)

history = trainer.fit(train_loader, num_epochs=100, valid_loader=None)

Train [1/100]: 100%|██████████| 5/5 [00:03<00:00,  1.33it/s, loss=0.472, ssim=0.191, psnr=8.711]
Train [2/100]: 100%|██████████| 5/5 [00:03<00:00,  1.38it/s, loss=0.368, ssim=0.390, psnr=9.043]
Train [3/100]: 100%|██████████| 5/5 [00:03<00:00,  1.38it/s, loss=0.354, ssim=0.412, psnr=9.180]
Train [4/100]: 100%|██████████| 5/5 [00:03<00:00,  1.38it/s, loss=0.345, ssim=0.428, psnr=9.287]
Train [5/100]: 100%|██████████| 5/5 [00:03<00:00,  1.38it/s, loss=0.346, ssim=0.428, psnr=9.249]
Train [6/100]: 100%|██████████| 5/5 [00:03<00:00,  1.38it/s, loss=0.345, ssim=0.429, psnr=9.228]
Train [7/100]: 100%|██████████| 5/5 [00:03<00:00,  1.40it/s, loss=0.344, ssim=0.432, psnr=9.210]
Train [8/100]: 100%|██████████| 5/5 [00:03<00:00,  1.39it/s, loss=0.340, ssim=0.437, psnr=9.334]
Train [9/100]: 100%|██████████| 5/5 [00:03<00:00,  1.40it/s, loss=0.343, ssim=0.432, psnr=9.293]
Train [10/100]: 100%|██████████| 5/5 [00:03<00:00,  1.37it/s, loss=0.340, ssim=0.435, psnr=9.399]
Train [11/100]: 100%|████████

Training stopped by stopper at epoch 20
Training completed!





In [8]:
model = get_model('vae', latent_dim=512, img_size=256)
loss_fn = get_loss_fn('vae')
metrics = {"ssim": get_metric('ssim'), "psnr": get_metric('psnr')}
modeler = Modeler(model, loss_fn, metrics, device)

optimizer = get_optimizer(model, "adamw", lr=0.001, weight_decay=1e-5)
scheduler = get_scheduler(optimizer, "plateau")
logger = get_logger("./experments")
stopper = get_stopper("stop", max_epoch=5)
trainer = Trainer(modeler, optimizer, scheduler, logger, stopper)

history = trainer.fit(train_loader, num_epochs=100, valid_loader=None)

Train [1/100]: 100%|██████████| 5/5 [00:03<00:00,  1.37it/s, loss=5.784, ssim=0.060, psnr=9.017] 
Train [2/100]: 100%|██████████| 5/5 [00:03<00:00,  1.40it/s, loss=2.053, ssim=0.160, psnr=9.674]
Train [3/100]: 100%|██████████| 5/5 [00:03<00:00,  1.38it/s, loss=0.944, ssim=0.242, psnr=9.444]
Train [4/100]: 100%|██████████| 5/5 [00:03<00:00,  1.39it/s, loss=0.607, ssim=0.261, psnr=9.514]
Train [5/100]: 100%|██████████| 5/5 [00:03<00:00,  1.40it/s, loss=0.383, ssim=0.292, psnr=9.911]

Training stopped by stopper at epoch 5
Training completed!



