In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy

import os
import sys

In [3]:
from data.datasets.random_dataset import RandomDataset
from data.datasets.golden_panels import GoldenPanelsDataset
from data.augment import get_PIL_image

from networks.panel_encoder.cnn_embedder import CNNEmbedder
from networks.sort_sequence_network import SortSequenceNetwork
from training.sort_sequence_trainer import SortSequenceTrainer
from utils.config_utils import read_config, Config
from utils.logging_utils import *
from utils.plot_utils import *
from utils import pytorch_util as ptu

from configs.base_config import *
from functional.losses.elbo import elbo

In [5]:
ptu.set_gpu_mode(True)
config = read_config(Config.SORT_SEQUENCE)
golden_age_config = read_config(Config.GOLDEN_AGE)

In [6]:
panel_dim = golden_age_config.panel_dim[0]
cnn_embedder = CNNEmbedder("efficientnet-b5", embed_dim=config.embed_dim)
net = SortSequenceNetwork(embedder=cnn_embedder,
                              num_elements_in_sequence=config.seq_size,
                              pairwise_extraction_in_size=(panel_dim ** 2) * 4).cuda()
    
if config.parallel == True:
        net = nn.DataParallel(net)
        
load_path = "/scratch/users/gsoykan20/projects/AF-GAN/playground/sort_sequence/ckpts/sort_sequence_10-06-2021-16-04-30-checkpoint-epoch1.pth"
net.load_state_dict(torch.load(load_path)['model_state_dict'])
net = net.cuda().eval()

Loaded pretrained weights for efficientnet-b5


In [7]:
dataset = GoldenPanelsDataset(golden_age_config.panel_path,
                              golden_age_config.sequence_path, 
                              golden_age_config.panel_dim,
                              config.image_dim, 
                              augment=False, 
                              shuffle=False,
                              mask_val=1, # mask with white color for 1 and black color for 0
                              mask_all=False, # masks faces from all panels and returns all faces
                              return_mask=True,
                              return_mask_coordinates=True,
                              train_test_ratio=golden_age_config.train_test_ratio,
                              train_mode=False,
                              limit_size=-1)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
dl_iter = iter(data_loader)

# Model Result Test

In [None]:
ctr, limit = 0, 6

for i in range(limit):
    x, _, gt = next(dl_iter)
    
    with torch.no_grad():
        x = x.cuda()
        gt = gt.cuda()
        x[:, -1] = gt
         = net(x=x.cuda())   
    print("[INFO] Example:", i)
    generated, gt = create_global_pred_gt_images(x, y, y_recon, c)
    plot_panels_and_faces(x, y, y_recon.cpu(), generated.cpu())