In [5]:
import os
import os.path
import numpy as np
import logging
import argparse
import pycuda.driver as cuda
import matplotlib.pyplot as plt

import torch
import torchvision

from torch.utils.tensorboard import SummaryWriter

import global_v as glv
from network_parser import parse
from datasets import load_dataset_snn, load_dataset_snn2, load_dataset_snn3
from utils import aboutCudaDevices
from utils import AverageMeter
from utils import CountMulAddSNN, CountMulAddANN
import fsvae_models.fsvae as fsvae
import ann_models.ann_vae as ann_vae
from fsvae_models.snn_layers import LIFSpike
import metrics.inception_score as inception_score
import metrics.clean_fid as clean_fid
import metrics.autoencoder_fid as autoencoder_fid

from tqdm.notebook import trange
from ann_models import *


max_accuracy = 0
min_loss = 1000


In [9]:
  

def plot(network, testloader, index):
    n_steps = glv.network_config['n_steps']
    max_epoch = glv.network_config['epochs']

    network = network.eval()
    with torch.no_grad():
        for batch_idx, (real_img, targets) in enumerate(testloader):   
            if batch_idx == index:
                
                real_img = real_img.to(init_device, non_blocking=True)
                targets = targets.to(init_device, non_blocking=True)
                if glv.network_config['spiking']:
                    # direct spike input
                    spike_input = real_img.unsqueeze(-1).repeat(1, 1, 1, 1, n_steps) # (N,C,H,W,T)
                    x_recon, q_z, p_z, sampled_z = network(spike_input, scheduled=network_config['scheduled']) # sampled_z(B,C,1,1,T)
                else:
                    # direct input
                    x_recon, mu, logvar = network(real_img)

                real_img = np.transpose(real_img.cpu().numpy(), (0,2,3,1))
                # reshape x_recon to use matplotlib
                x_recon = np.transpose(x_recon.cpu().numpy(), (0, 2, 3, 1))
                
                real_img = (real_img+1)/2
                x_recon = (x_recon+1)/2
                
                N = 2*x_recon.shape[0]
                if N>4:
                    cols = 4
                    rows = int(np.ceil(N/4))
                else:
                    cols = N
                    rows = 1
                fig, axes = plt.subplots(rows, cols, figsize=(10,20))
                for i, ax in enumerate(axes.flat):
                    if i < N:
                        if i % 2 == 0:
                            ax.imshow(real_img[int(i//2)])
                            ax.axis('off')
                        else:
                            ax.imshow(x_recon[int(i//2)])
                            ax.axis('off')
                    else:
                        ax.axis('off')
                        
                plt.subplots_adjust(wspace=0.1, hspace=0.1)
                #plt.show()
        # Save the figure in high resolution
        plt.savefig('output_plot.png', dpi=600, bbox_inches='tight')
        plt.close(fig)
    return


In [10]:
class Args:
    def __init__(self, checkpoint=None, config="", device=None, name=""):
        self.checkpoint = checkpoint
        self.config = config
        self.device = device
        self.name = name

checkpoint_ann='checkpoint/ann_model_test/best.pth'
args = Args(checkpoint=checkpoint_ann, config='NetworkConfigs/ToyCelebA_ANN.yaml', device=None, name='ann_model_test')

if args.device is None:
    init_device = torch.device("cuda:0")
else:
    init_device = torch.device(f"cuda:{args.device}")
    
os.makedirs(f'checkpoint/{args.name}', exist_ok=True)
writer = SummaryWriter(log_dir=f'checkpoint/{args.name}/tb')
logging.basicConfig(filename=f'checkpoint/{args.name}/{args.name}.log', level=logging.INFO)
    
logging.info("start parsing settings")
    
params = parse(args.config)
network_config = params['Network']

if network_config.get('out_channels') is None:
    network_config['out_channels'] = network_config['in_channels']
    
logging.info("finish parsing settings")
logging.info(network_config)
print(network_config)
        
# Check whether a GPU is available
if torch.cuda.is_available():
    cuda.init()
    c_device = aboutCudaDevices()
    print(c_device.info())
    print("selected device: ", args.device)
else:
    raise Exception("only support gpu")
    
glv.init(network_config, [args.device])

dataset_name = glv.network_config['dataset']

logging.info("dataset loading...")

if dataset_name == "ToyCelebA":
    input_data_path = glv.network_config['input_data_path']
    output_data_path = glv.network_config['output_data_path']
    input_data_path = os.path.expanduser(input_data_path)
    output_data_path = os.path.expanduser(output_data_path)
    train_loader, test_loader = load_dataset_snn2.load_toyceleba(input_data_path, output_data_path)
else:
    data_path = glv.network_config['data_path']

    if dataset_name == "MNIST":
        data_path = os.path.expanduser(data_path)
        train_loader, test_loader = load_dataset_snn.load_mnist(data_path)
    elif dataset_name == "FashionMNIST":
        data_path = os.path.expanduser(data_path)
        train_loader, test_loader = load_dataset_snn.load_fashionmnist(data_path)

    elif dataset_name == "CIFAR10":
        data_path = os.path.expanduser(data_path)
        train_loader, test_loader = load_dataset_snn.load_cifar10(data_path)

    elif dataset_name == "CelebA":
        data_path = os.path.expanduser(data_path)
        train_loader, test_loader = load_dataset_snn.load_celebA(data_path)

    else:
        raise Exception('Unrecognized dataset name.')
logging.info("dataset loaded")

if network_config['model'] == 'FSVAE':
    net = fsvae.FSVAE()
elif network_config['model'] == 'FSVAE_large':
    net = fsvae.FSVAELarge()
elif network_config['model'] == 'FSVAE_small':
    net = fsvae.FSVAESmall()
elif network_config['model'] == 'FSAE_small':
    net = fsvae.FSAESmall()
elif network_config['model'] == 'VanillaVAE_large':
    net = ann_vae.VanillaVAELarge()
elif network_config['model'] == 'AE':
    net = ann_ae.AE()
elif network_config['model'] == 'AE_large':
    net = ann_ae.AELarge()
else:
    raise Exception('not defined model')

net = net.to(init_device)
    
if args.checkpoint is not None:
    checkpoint_path = args.checkpoint
    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint)    


{'epochs': 50, 'batch_size': 40, 'n_steps': 8, 'dataset': 'ToyCelebA', 'in_channels': 3, 'out_channels': 3, 'input_data_path': 'C:/Users/Brhan/OneDrive/Belgeler/KuOnline/Spring24/Elec491/SNN/dataset/img_align_celeba_sketch', 'output_data_path': 'C:/Users/Brhan/OneDrive/Belgeler/KuOnline/Spring24/Elec491/SNN/dataset/img_align_celeba', 'lr': 0.001, 'latent_dim': 128, 'input_size': 32, 'model': 'VanillaVAE_large', 'k': 20, 'scheduled': False, 'loss_func': 'kld', 'spiking': False}
1 device(s) found:
    1) NVIDIA GeForce GTX 1650 (Id: 0)
          Memory: 4.29 GB

selected device:  None
loading ToyCelebA


In [None]:
# Pretrained model

# Set the model to evaluation mode
net.eval()

#sample(net, e, batch_size=2)
#calc_inception_score(net, e)
#calc_autoencoder_frechet_distance(net, e)
#calc_clean_fid(net, e)

# Plot the first batch from the test_loader
plot(net, test_loader, 0)