In [None]:
import os
import sys

import pandas as pd
import numpy as np

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms, utils
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

## Данные и модель

In [None]:
from dataset import ArchDataset, GS_NEGATIVE_PATHS, GS_POSITIVE_PATHS, GS_NEGATIVE_COORDS, GS_POSITIVE_COORDS

In [None]:
from model import VQVAE

# Параметры модели и обучения

In [None]:
EMBED_DIM = 64
NUM_EMBEDS = 4

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
results_dir = f'../vq_vae_training_sasgis_results/v2_256pix_gray_{str(EMBED_DIM).zfill(2)}embdim_{str(NUM_EMBEDS).zfill(2)}embeds'

In [None]:
os.mkdir(results_dir)

In [None]:
params = {
    'batch_size': 32,
    'img_size': 256,
    'channels': 1,
    'embedding_dim': EMBED_DIM,
    'num_embeddings': NUM_EMBEDS,
    'beta': 0.25,
    'n_epochs': 10,
}

# Датасет

In [None]:
transform_nocrop_256_grayscale = transforms.Compose([
    # transforms.CenterCrop((256, 256)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
torch.random.manual_seed(42);

In [None]:
train_dataset = ArchDataset(
    img_paths=GS_NEGATIVE_PATHS, 
    coords=GS_NEGATIVE_COORDS, 
    anomalies=False, 
    transform=transform_nocrop_256_grayscale
)
train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)

anomalies_dataset = ArchDataset(
    img_paths=GS_POSITIVE_PATHS, 
    coords=GS_POSITIVE_COORDS, 
    anomalies=True, 
    transform=transform_nocrop_256_grayscale
)
anomalies_loader = DataLoader(anomalies_dataset, batch_size=params['batch_size'], shuffle=False)

In [None]:
print(f'train: {len(train_dataset)} images, {len(train_loader)} batches')
print(f'anomalies: {len(anomalies_dataset)} images, {len(anomalies_loader)} batches')

In [None]:
ncols = 6

fig, axs = plt.subplots(nrows=2, ncols=ncols, figsize=(16, 8))
for i in range(ncols):
    axs[0, i].axis('off')
    axs[1, i].axis('off')
    axs[0, i].imshow(anomalies_dataset[i][0].permute(1, 2, 0), cmap='gray')
    axs[1, i].imshow(train_dataset[i][0].permute(1, 2, 0), cmap='gray')
fig.tight_layout()
fig.show()

# Создание модели

In [None]:
model = VQVAE(
    in_channels=params['channels'],
    img_size=params['img_size'],
    
    embedding_dim=params['embedding_dim'],
    num_embeddings=params['num_embeddings'],
    beta=params['beta']
)

In [None]:
# model

In [None]:
total_params = sum(param.numel() for param in model.parameters())
total_params

# Цикл обучения

In [None]:
def train_vqvae(pbar, model, dataloader, optimizer, scheduler, n_epochs, device, logged_indices):
    
    model = model.to(device)
    model.train()

    for epoch in range(n_epochs):
        
        running_loss = 0.0
        current_lr = optimizer.param_groups[0]['lr']
        
        for batch_idx, (data, data_dataset_indices, _, _) in enumerate(dataloader):

            data = data.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            reconstructions, original_input, vq_loss = outputs
            
            loss_dict = model.loss_function(reconstructions, original_input, vq_loss)
            total_loss = loss_dict['loss']
            total_loss.backward()
            optimizer.step()
            running_loss += total_loss.item()
            curr_loss = running_loss / (1 + batch_idx)
            
            pbar.set_description(
                f'Epoch [{epoch + 1}/{n_epochs}] Batch [{batch_idx + 1}/{len(dataloader)}] Loss [{curr_loss:.4f}] [LR {current_lr:.6f}]'
            )
            pbar.update(1)
            
            if batch_idx % 1000 == 0: 
                
                logged_indices[(epoch, batch_idx)] = data_dataset_indices[: 8].detach().cpu().numpy()
                
                with torch.no_grad():
                    log_imgs = torch.cat([original_input[: 8], reconstructions[: 8]])
                    utils.save_image(
                        log_imgs.cpu().data,
                        f'{results_dir}/sample_{str(epoch).zfill(2)}_{str(batch_idx).zfill(4)}.png',
                        normalize=True,
                        nrow=8
                    )

        avg_loss = running_loss / len(dataloader)
        scheduler.step(avg_loss)
                
        torch.save(model.state_dict(), f'{results_dir}/checkpoint_{epoch}.pth')
        # torch.save(optimizer.state_dict(), f'{results_dir}/optimizer_{epoch}.opt')


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, min_lr=1e-6)

In [None]:
logged_indices = {}

with tqdm(
    total=params['n_epochs'] * len(train_loader), 
    desc=f'[Epoch ?] [Batch ?] [Loss ?] [LR ?]',
    leave=False
) as pbar:
    train_vqvae(pbar, model, train_loader, optimizer, scheduler, params['n_epochs'], device, logged_indices)

In [None]:
for (epoch, batch_idx), indices in logged_indices.items():
    print(f'epoch {epoch} batch {batch_idx}: indices {indices}')

In [None]:
idx = 84195
train_dataset[idx]

In [None]:
plt.imshow(train_dataset[idx][0][0], cmap='gray');

In [None]:
import glob, shutil

In [None]:
shutil.copytree('v2_256pix_gray_64embdim_04embeds', '../vq_vae_training_sasgis_results/v2_256pix_gray_64embdim_04embeds')