# Denoising Autoencoder Training

In [None]:
import sys
sys.path.insert(0, 'F:/Thesis')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

from src.config import DEVICE, MODELS_DIR, FIGURES_DIR, ensure_dirs
from src.data import MVTecDataset
from src.data.transforms import denormalize
from src.models import create_denoising_ae
from src.training import get_optimizer, get_scheduler, EarlyStopping

ensure_dirs()

In [None]:
CONFIG = {'category': 'bottle', 'batch_size': 16, 'num_epochs': 100, 'learning_rate': 1e-3, 'noise_factor': 0.3}
experiment_name = f'denoising_ae_{CONFIG["category"]}'

train_dataset = MVTecDataset(category=CONFIG['category'], split='train')
test_dataset = MVTecDataset(category=CONFIG['category'], split='test', return_mask=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=CONFIG['batch_size'])

model = create_denoising_ae(noise_factor=CONFIG['noise_factor']).to(DEVICE)
optimizer = get_optimizer(model, lr=CONFIG['learning_rate'])
criterion = nn.MSELoss()

In [None]:
history = []
for epoch in tqdm(range(1, CONFIG['num_epochs'] + 1)):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        clean = batch[0].to(DEVICE)
        optimizer.zero_grad()
        recon, _ = model(clean)
        loss = criterion(recon, clean)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    history.append(epoch_loss / len(train_loader))
print(f'Final loss: {history[-1]:.6f}')

In [None]:
model.eval()
scores, labels = [], []
with torch.no_grad():
    for img, mask, label in test_loader:
        error = model.get_reconstruction_error(img.to(DEVICE), reduction='mean')
        scores.extend(error.cpu().numpy())
        labels.extend(label.numpy())

auc = roc_auc_score(labels, scores)
print(f'ROC-AUC: {auc:.4f}')

torch.save({'model_state_dict': model.state_dict(), 'config': CONFIG, 'auc': auc}, MODELS_DIR / f'{experiment_name}.pth')