In [None]:

import itertools
from PIL import Image
from torch_snippets import *
from torchvision.utils import make_grid


In [None]:
from glob import glob
val_set = glob('maps/val/*.jpg')

In [None]:
print(f'Val set size: {val_set.__len__()}')

In [None]:
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
both_transform = A.Compose(
    [A.Resize(width=256, height=256)], additional_targets={"image0": "image"},
)
transform_only_input = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)
class MapDataset(Dataset):
    def __init__(self, images):

        self.list_files = images

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_path = self.list_files[index]
        image = np.array(Image.open(img_path))
        input_image = image[:, :600, :]
        target_image = image[:, 600:, :]

        augmentations = both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = transform_only_input(image=input_image)["image"]
        target_image = transform_only_mask(image=target_image)["image"]

        return input_image, target_image

In [None]:
val_dataset = MapDataset(val_set)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [None]:
images = []
for i_ in range(16):
  samples_input, samples_target = next(iter(val_loader))
  samples = torch.cat([samples_input, samples_target],axis=0)
  images.append(samples)
images = torch.cat(images,axis=0)
images = make_grid((images * 127.5) + 127.5).permute(1,2,0)
show(images)

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
generator = torch.load('BCE+L1+StyleLoss/generator.pt').to(device)

In [None]:
discrimintor_bce = torch.load('BCE/discriminator.pt').to(device)
discrimintor_bce_l1 = torch.load('BCE + L1/discriminator.pt').to(device)
discrimintor_bce_l1_st = torch.load('BCE+L1+StyleLoss/discriminator.pt').to(device)
discrimintor_l2_l1 = torch.load('L2+L1/discriminator.pt').to(device)
discriminators = [discrimintor_bce, discrimintor_bce_l1, discrimintor_bce_l1_st, discrimintor_l2_l1]

In [None]:
BCE = nn.BCEWithLogitsLoss() 

In [None]:
def discriminator_step(real_src, fake_trg, discriminator):
    #discriminator.train()
    prediction_fake = discriminator(real_src, fake_trg.detach())
    error_fake = BCE(prediction_fake, torch.zeros_like(prediction_fake))
    D_loss = error_fake

    return D_loss

In [None]:
epochs = 1
log = Report(epochs)

N = len(val_loader)
generator.eval()
for discriminator in discriminators:
    discriminator.eval()
with torch.no_grad():
    images = []
    errors = []
    for bx, batch in enumerate(val_loader):
        real_src, real_trg = batch
        real_src, real_trg = real_src.to(device), real_trg.to(device)
        fake_trg = generator(real_src)
        
        samples = torch.cat([real_src, fake_trg],axis=0)
        images.append(samples)
        
        errD=0
        for discriminator in discriminators:
            errD += discriminator_step(real_src, fake_trg, discriminator)
        
        log.record(pos=epochs+(1+bx)/N, errD=errD.item()/len(discriminators), end='\r')
        errors.append(errD.item()/len(discriminators))
        if len(images) == 16:
            images = torch.cat(images,axis=0)
            images = make_grid((images * 127.5) + 127.5).permute(1,2,0)
            show(images)
            images = list()
print(f'Average loss: {sum(errors)/len(errors)}')
