In [24]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from model import TwoResAutoEncoder
from data import FaceDataset
from train import create_fg_masks
import matplotlib.pyplot as plt
from feather import stitch
import os

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = torch.load('./boiNet250,750.pt')
model.to(device)

TwoResAutoEncoder(
  (fg_ae): FGAutoencoder(
    (encoder): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU()
      (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (9): Flatten(start_dim=1, end_dim=-1)
      (10): Linear(in_features=7040, out_features=250, bias=True)
      (11): ReLU()
    )
    (decoder): Sequential(
      (0): Linear(in_features=250, out_features=7040, bias=True)
      (1): ReLU()
      (2): Unflatten(dim=1, unflattened_size=(64, 11, 10))
      (3): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 

In [26]:
dataset = FaceDataset('test_set_bb.csv', '../celeba/img_align_celeba', return_im_num=True)
# dataset = Subset(dataset, range(2200, 2400))

output_folder = 'compressed-250,750'
os.makedirs(output_folder, exist_ok=True)

In [27]:
loader = DataLoader(dataset, batch_size=32)

for img_num, images, faces, bboxs in loader:
    images, faces = images.to(device), faces.to(device)
    filenames = []
    
    fg_masks = create_fg_masks(bboxs).to(device)
    fg_output, bg_output = model(images * (~fg_masks), faces)
    
    rec = stitch(fg_output, bg_output, bboxs.to(device), feather_size=20, device=device)
    
    # Save images
    for i in range(img_num.shape[0]):
        torchvision.utils.save_image(rec[i], os.path.join(output_folder, f'{img_num[i]}.png'))