In [10]:
from custom_dataset import CustomDataset
import matplotlib.pyplot as plt
import argparse
import os
import time 
import numpy as np

import torch
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.utils import save_image
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms.functional as TF

from model_vae_8_dim import Autoencoder

In [11]:
# def normalize_images(images):
#     """Normalizes images to the range [-1, 1]."""
#     images = images.clone()
#     images -= images.min()
#     images /= images.max()
#     images *= 2.0
#     images -= 1.0
#     return images


In [15]:

def example_images(original_images, reconstructed_images, epoch, batch_idx, log_interval):
    n_samples = 5  # Number of samples to visualize

    original_images = original_images[:n_samples].cpu()
    reconstructed_images = reconstructed_images[:n_samples].cpu().detach() 


    fig, axes = plt.subplots(n_samples, 2, figsize=(10, 10))
    for i in range(n_samples):
        axes[i, 0].imshow(original_images[i].permute(1, 2, 0))  # Rearrange channels
        axes[i, 0].set_title('Original')
        axes[i, 0].axis('off')
    
        axes[i, 1].imshow(reconstructed_images[i].permute(1, 2, 0))  # Rearrange channels
        axes[i, 1].set_title('Reconstructed')
        axes[i, 1].axis('off')

    plt.tight_layout()
    os.makedirs('TRAINING_IMAGES_64_TEST')
    plt.savefig(f'TRAINING_IMAGES_64_TEST/reconstruction_examples_epoch_{epoch}_batch_{batch_idx * log_interval}.png')
    plt.close()



In [14]:
""" This script is an example of VAE training in PyTorch. The code was adapted from:
https://github.com/pytorch/examples/blob/master/vae/main.py """

## Arguments
args = argparse.Namespace(
    batch_size=64,
    input_size=64,
    epochs=15,
    no_cuda=False,
    log_interval=500,
    model='mse_vae',
    log_dir='mse_vae'
)


## Cuda
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda:3" if args.cuda else "cpu")

# ## Data Parallelism
# if args.cuda and torch.cuda.device_count() > 1:
#     print("Using", torch.cuda.device_count(), "GPUs for Data Parallelism!")
#     model = nn.DataParallel(Autoencoder(args.input_size)).to(device)
# # else:
model = Autoencoder(args.input_size).to(device)


kwargs = {'num_workers': 48, 'pin_memory': True} if args.cuda else {}

mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
transform = transforms.Compose([
    transforms.Resize((args.input_size, args.input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

def random_mask(img_data, device):
    # Generate a binary mask with a random percentage of values set to 1
    mask_percentage = np.random.uniform(0.05, 0.2)
    mask = torch.rand(img_data.shape, device=device) < mask_percentage
    # Apply the mask to the random_image to replace values with -1000
    dummy_value = -1000
    masked_image = torch.where(mask, dummy_value, img_data)
    return masked_image
        
# ## ---------------------------------------------------
train_dataset = CustomDataset(data_path='/mayo_atlas/home/m296984/RESULTS_40x/Liver/test_64x64_patches', transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

## Logging
os.makedirs('vae_logs/{}'.format(args.log_dir), exist_ok=True)
summary_writer = SummaryWriter(log_dir='vae_logs/' + args.log_dir, purge_step=0)

## Build Model
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [16]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        original_shape = data.size()
        data = data.view(data.size(0), -1)
        data = data.to(device)
        masked_image = random_mask(data, device)
        optimizer.zero_grad()
        
        # Run VAE
        recon_data, mu, logvar = model(data)
        data = data.view(original_shape)
        recon_data = recon_data.view(original_shape)
        
        # Compute loss
        rec, kl = model.loss_function(recon_data, data, mu, logvar)
        
        total_loss = rec + kl * 0.1
        # total_loss = rec + kl
        total_loss.backward()
        train_loss += total_loss.item()
        
        optimizer.step()
        
        
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tMSE: {:.6f}\tKL: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                rec.item() / len(data),
                kl.item() / len(data)))
        # Print image reconstruction examples every few batches
        if batch_idx % (args.log_interval * 5) == 0:
            example_images(data, recon_data, epoch, batch_idx, args.log_interval)

                
    train_loss /=  len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss))
   

start_time = time.time()
for epoch in range(1, args.epochs + 1):
    train(epoch)
    
    
    os.makedirs(f'/mayo_atlas/home/m296984/MAIN_CHAIN_LIVER_RESULTS/vae_logs_masked_DELETE/{args.log_dir}', exist_ok=True)
    torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(),
            '/mayo_atlas/home/m296984/MAIN_CHAIN_LIVER_RESULTS/vae_logs_masked_DELETE/{}/checkpoint_{}.pt'.format(args.log_dir, str(epoch)))
    
end_time = time.time()
execution_time = end_time - start_time
print(f"Training model time ({args.epochs}): {execution_time:.2f} seconds")




Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).




Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


====> Epoch: 1 Average loss: 0.0093


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


