# import

In [1]:
import sys
import os
import os.path as osp
import math
import json

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision.models import vgg19
from torchvision.utils import save_image
import torchvision.transforms as transforms

In [3]:
from glob import glob

In [4]:
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython import display
from IPython.display import clear_output

In [5]:
from PIL import Image

In [6]:
import time
import datetime

# args

In [7]:
class opts():
    pass

In [8]:
opt = opts()

In [9]:
opt.channles = 3
opt.hr_height = 128
opt.residual_blocks = 23
opt.lr = 0.0002
opt.b1 = 0.9
opt.b2 = 0.999
opt.batch_size = 4
opt.n_cpu = 8
opt.n_epoch = 200
# opt.warmup_batches = 500
opt.warmup_batches = 5
opt.lambda_adv = 5e-3
opt.lambda_pixel = 1e-2

opt.pretrained = False

opt.dataset_name = 'cat'
# opt.dataset_name = 'img_align_celeba_resize'
# opt.dataset_name = 'img_align_celeba_resize'

opt.sample_interval = 50
opt.checkpoint_interval = 100

In [10]:
args = [arg for arg in dir(opt) if not arg.startswith('__')]

In [11]:
opt_dict = {arg: getattr(opt, arg) for arg in args}

In [12]:
hr_shape = (opt.hr_height, opt.hr_height)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# model

# Generator

In [14]:
class DenseResidualBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale
        
        def block(in_features, non_linearity=True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
            if non_linearity:
                layers += [nn.LeakyReLU()]
            return nn.Sequential(*layers)
    
        self.b1 = block(in_features=1 * filters)
        self.b2 = block(in_features=2 * filters)
        self.b3 = block(in_features=3 * filters)
        self.b4 = block(in_features=4 * filters)
        self.b5 = block(in_features=5 * filters, non_linearity=False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
    
    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x

In [15]:
class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
        )
    
    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x

In [16]:
class GeneratorPRDB(nn.Module):
    def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
        super(GeneratorPRDB, self).__init__()
        
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        
        upsample_layers = []
        
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.PixelShuffle(upscale_factor=2),
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
        )
    
    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

# Discriminator

In [17]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)

In [18]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
                
        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)
    
        def descriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            print(descriminator_block(in_filters, out_filters, first_block=(i == 0)))
            layers.extend(descriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters
        
        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, img):
        return self.model(img)

## Dataset

In [19]:
def denormalize(tensors):
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return torch.clamp(tensors, 0, 255)

In [20]:
class ImageDataset(Dataset):
    def __init__(self, dataset_dir, hr_shape):
        hr_height, hr_width = hr_shape
        
        self.lr_transform = transforms.Compose([
            transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])

        self.hr_transform = transforms.Compose([
            transforms.Resize((hr_height, hr_height), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])
        
        self.files = sorted(glob(osp.join(dataset_dir, '*')))
    
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)
        
        return {'lr': img_lr, 'hr': img_hr}
    
    def __len__(self):
        return len(self.files)

In [21]:
class TestImageDataset(Dataset):
    def __init__(self, dataset_dir):
        # TODO: 入力に対して1/4
        self.hr_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])
        self.files = sorted(glob(osp.join(dataset_dir, '*')))
    
    def lr_transform(self, img, img_size):
        img_width, img_height = img_size
        self.__lr_transform = transforms.Compose([
            transforms.Resize((img_height // 4, img_width // 4), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])
        img = self.__lr_transform(img)
        return img
            
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_size = img.size
        img_lr = self.lr_transform(img, img_size)
        img_hr = self.hr_transform(img)
        
        return {'lr': img_lr, 'hr': img_hr}
    
    def __len__(self):
        return len(self.files)

In [22]:
def save_json(label, save_path):
    f = open(save_path, "w")
    json.dump(label, f, ensure_ascii=False, indent=4, 
              sort_keys=True, separators=(',', ': '))

# path

In [23]:
ROOT = '../'

In [24]:
input_dir = osp.join(ROOT, 'input')
output_dir = osp.join(ROOT, 'output', str(datetime.datetime.fromtimestamp(time.time())))
weight_dir = osp.join(ROOT, 'weight')

In [25]:
image_train_save_dir = osp.join(output_dir, 'image', 'train')
image_test_save_dir = osp.join(output_dir, 'image', 'test')
weight_save_dir = osp.join(output_dir, 'weight')
plot_save_dir = osp.join(output_dir, 'plot')

save_dirs = [image_train_save_dir, image_test_save_dir, weight_save_dir, plot_save_dir]
for save_dir in save_dirs:
    os.makedirs(save_dir, exist_ok=True)

In [26]:
train_data_dir = osp.join(input_dir, '{}_train'.format(opt.dataset_name))
test_data_dir = osp.join(input_dir, '{}_test_sub2'.format(opt.dataset_name))
g_weight_path = osp.join(weight_dir, 'generator.pth')
d_weight_path = osp.join(weight_dir, 'discriminator.pth')

In [27]:
opt_save_path = osp.join(output_dir, 'opt.json')

In [28]:
save_json(opt_dict, opt_save_path)

## set_model

In [29]:
generator = GeneratorPRDB(opt.channles, filters=64, num_res_blocks=opt.residual_blocks).to(device)
discriminator = Discriminator(input_shape=(opt.channles, *hr_shape)).to(device)

if opt.pretrained:
    generator.load_state_dict(torch.load(g_weight_path))
    discriminator.load_state_dict(torch.load(d_weight_path))

feature_extractor = FeatureExtractor().to(device)
feature_extractor.eval()

[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), LeakyReLU(negative_slope=0.2, inplace=True), Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), LeakyReLU(negative_slope=0.2, inplace=True)]
[Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), LeakyReLU(negative_slope=0.2, inplace=True), Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), LeakyReLU(negative_slope=0.2, inplace=True)]
[Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), LeakyReLU(negative_slope=0.2, inplace=True), Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), BatchNorm2d(256, eps=1e-05, momentum=

FeatureExtractor(
  (vgg19_54): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride

# Loss

In [30]:
criterion_GAN = nn.BCEWithLogitsLoss().to(device)
criterion_content = nn.L1Loss().to(device)
criterion_pixel = nn.L1Loss().to(device)

# Optimizer

In [31]:
optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Tensor

In [32]:
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

# dataset

In [33]:
train_dataloader = DataLoader(
    ImageDataset(train_data_dir, hr_shape=hr_shape),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)

test_dataloader = DataLoader(
    TestImageDataset(test_data_dir),
    batch_size=1,
    shuffle=False,
    num_workers=opt.n_cpu,
)

# main

In [34]:
from copy import copy

In [48]:
loss_names = ['batch_num', 'loss_pixel', 'loss_D', 'loss_G', 'loss_content', 'loss_GAN']
train_infos = []

plt.figure(figsize=(16,9))
low_image_save = False

for epoch in range(1, opt.n_epoch + 1):
    for batch_num, imgs in enumerate(train_dataloader):
        batches_done = (epoch - 1) * len(train_dataloader) + batch_num
        
        # preprocess
        imgs_lr = Variable(imgs['lr'].type(Tensor))
        imgs_hr = Variable(imgs['hr'].type(Tensor))
        
        # ground truth
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), 
                         requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), 
                        requires_grad=False)
        
        # バックプロパゲーションの前に勾配を０にする
        optimizer_G.zero_grad()
        
        # 低解像度の画像から高解像度の画像を生成
        gen_hr = generator(imgs_lr)
        
        loss_pixel = criterion_pixel(gen_hr, imgs_hr)
        
        # 画素単位の損失であるloss_pixelで事前学習を行う
        if batches_done <= opt.warmup_batches:
            loss_pixel.backward()
            optimizer_G.step()
            train_info = {
                'epoch': epoch, 
                'batch_num': batch_num,
                'loss_pixel': loss_pixel.item()
            }
        
            sys.stdout.write('\r{}'.format('\t'*10))
            sys.stdout.write('\r {}'.format(train_info))            
        # loss_pixel以外のLossも含めて学習
        else:
            # prediction
            pred_real = discriminator(imgs_hr).detach()
            pred_fake = discriminator(gen_hr)
            
            # Aeversarial loss
            loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

            # content loss(perceptual loss)
            # 特徴抽出機で抽出した特徴を用いて生成画像と本物画像のL1距離を算出
            gen_feature = feature_extractor(gen_hr)
            real_feature = feature_extractor(imgs_hr).detach()
            loss_content = criterion_content(gen_feature, real_feature)

            # Total generator loss
            loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel
            loss_G.backward()
            optimizer_G.step()
            
            optimizer_D.zero_grad()

            # pred_real = discriminator(imgs_hr)                        
            # pred_fake = discriminator(gen_hr.detach())

            # adversarial loss
            loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)            
            loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

            loss_D = (loss_real + loss_fake) / 2

            loss_D.backward()
            optimizer_D.step()

            train_info = {
                'epoch': epoch,
                'epoch_total': opt.n_epoch,
                'batch_num': batch_num, 
                'batch_total': len(train_dataloader),
                'loss_D': loss_D.item(),
                'loss_G': loss_G.item(),
                'loss_content': loss_content.item(),
                'loss_GAN': loss_GAN.item(),
                'loss_pixel': loss_pixel.item(),
            }

            if batch_num == 1:
                sys.stdout.write('\n{}'.format(train_info))
            else:
                sys.stdout.write('\r{}'.format('\t'*20))
                sys.stdout.write('\r{}'.format(train_info))
            sys.stdout.flush()
        
        train_infos.append(train_info)
        
        if batches_done % opt.sample_interval == 0:
            # Save image grid with upsampled inputs and ESRGAN outputs
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))

            image_batch_save_dir = osp.join(image_train_save_dir, '{:07}'.format(batches_done))
            os.makedirs(osp.join(image_batch_save_dir, "hr_image"), exist_ok=True)
            save_image(img_grid, osp.join(image_batch_save_dir, "hr_image", "%d.png" % batches_done), nrow=1, normalize=False)

            with torch.no_grad():
                for i, imgs in enumerate(test_dataloader):
                    # Save image grid with upsampled inputs and outputs
                    imgs_lr = Variable(imgs["lr"].type(Tensor))
                    gen_hr = generator(imgs_lr)
                    imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)

                    imgs_lr = denormalize(imgs_lr)
                    gen_hr = denormalize(gen_hr)

                    image_batch_save_dir = osp.join(image_test_save_dir, '{:03}'.format(i))
                    os.makedirs(osp.join(image_batch_save_dir, "hr_image"), exist_ok=True)
                    save_image(gen_hr, osp.join(image_batch_save_dir, "hr_image", "{:09}.png".format(batches_done)), nrow=1, normalize=False)
                    if not low_image_save:
                        save_image(imgs_lr, osp.join(image_batch_save_dir, "lr_image.jpg"), nrow=1, normalize=False)
            low_image_save = True
                    

        if batches_done % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), osp.join(weight_save_dir, "generator_%d.pth" % batches_done))
            torch.save(discriminator.state_dict(), osp.join(weight_save_dir, "discriminator_%d.pth" % batches_done))
        
            log_df = pd.DataFrame(train_infos)
            log_df = log_df.set_index('batch_num')
            cols = log_df.columns[log_df.columns.isin(loss_names)]
            log_df = log_df[cols]

            for num, loss_name in enumerate(log_df.columns, 1):
                plt.subplot(2, 3, num)
                plt.plot(log_df.index.values, log_df[loss_name].values, marker='o', color='b', alpha=0.8)
                plt.title(loss_name)

            plt.savefig(osp.join(plot_save_dir, "plot.png"))

            # display.clear_output(wait=True)
            # display.display(plt.gcf())

{'epoch': 1, 'epoch_total': 200, 'batch_num': 16, 'batch_total': 2474, 'loss_D': 0.0037132117431610823, 'loss_G': 2.423321485519409, 'loss_content': 2.3848416805267334, 'loss_GAN': 6.364528656005859, 'loss_pixel': 0.6657178997993469}}}

KeyboardInterrupt: 

Error in callback <function flush_figures at 0x7f13a8ed1bf8> (for post_execute):


KeyboardInterrupt: 

In [63]:
pred_real - pred_fake.mean(0, keepdim=True)

tensor([[[[5.1151, 4.7412, 4.8801, 4.3034, 5.1104, 5.0542, 4.8478, 3.2613],
          [5.6587, 5.5451, 6.4305, 5.4977, 5.7008, 5.3491, 5.1457, 4.6151],
          [5.8661, 6.2379, 5.5242, 4.6496, 5.3447, 5.4010, 4.1703, 4.8150],
          [5.9776, 7.0561, 5.0861, 4.3053, 5.5784, 5.7994, 5.2240, 5.6542],
          [5.5899, 5.7084, 4.3085, 4.6929, 5.2552, 5.1549, 6.1259, 5.6568],
          [5.6813, 5.4188, 4.2845, 5.3099, 5.4558, 5.5985, 6.7705, 5.2581],
          [5.5976, 5.0727, 4.0672, 4.7075, 4.0101, 3.5757, 5.6195, 4.3422],
          [4.1934, 5.0714, 5.2192, 5.9648, 5.6372, 5.0232, 5.3218, 4.0180]]],


        [[[4.1441, 4.5443, 5.1429, 3.9682, 5.1021, 5.5031, 5.1086, 2.9576],
          [3.9886, 3.9338, 4.6858, 4.3956, 6.2974, 5.8514, 5.5533, 3.5103],
          [5.0201, 5.8700, 5.7982, 2.0678, 1.5862, 2.8274, 4.1654, 4.0825],
          [4.3586, 5.1218, 5.1490, 3.7467, 3.8828, 3.3920, 2.7131, 3.3123],
          [3.1030, 4.4346, 3.7951, 3.4932, 4.0567, 5.9183, 5.1662, 5.2851],
        

In [62]:
valid

tensor([[[[1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1

In [55]:
criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)

tensor(0.0129, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [53]:
loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)            

In [54]:
loss_real

tensor(0.0129, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [51]:
pred_real

tensor([[[[ 2.3010,  1.5029,  1.7281,  1.3221,  1.8486,  1.8571,  1.5488,
            1.0924],
          [ 2.5043,  1.8834,  2.6218,  1.3253,  0.8288,  0.6705,  0.1675,
            1.3442],
          [ 2.6464,  2.4710,  2.0465,  1.4035,  1.7311,  1.9205,  0.1164,
            1.7478],
          [ 2.7813,  2.9469,  1.6732,  1.1989,  2.4951,  3.0790,  1.3781,
            2.5669],
          [ 2.6280,  2.0827,  1.4554,  1.2345,  1.0891,  0.9773,  1.7779,
            2.5194],
          [ 2.6146,  1.6435,  0.3558,  0.5772,  0.8507,  1.3171,  2.4758,
            2.1537],
          [ 2.2080,  1.6525,  0.4850,  0.5251, -0.1229, -0.5465,  1.4804,
            1.9481],
          [ 1.3893,  1.5334,  0.7116,  1.0888,  0.8325,  0.2879,  0.7735,
            1.0936]]],


        [[[ 1.3300,  1.3060,  1.9909,  0.9868,  1.8403,  2.3059,  1.8097,
            0.7887],
          [ 0.8342,  0.2722,  0.8772,  0.2232,  1.4254,  1.1728,  0.5751,
            0.2394],
          [ 1.8005,  2.1031,  2.3206, -1.1782,

In [52]:
pred_fake.mean(0, keepdim=True)

tensor([[[[-2.8141, -3.2383, -3.1520, -2.9813, -3.2618, -3.1971, -3.2989,
           -2.1689],
          [-3.1544, -3.6617, -3.8086, -4.1725, -4.8720, -4.6786, -4.9782,
           -3.2709],
          [-3.2196, -3.7669, -3.4777, -3.2460, -3.6136, -3.4805, -4.0538,
           -3.0672],
          [-3.1963, -4.1092, -3.4129, -3.1064, -3.0832, -2.7204, -3.8459,
           -3.0873],
          [-2.9618, -3.6257, -2.8532, -3.4584, -4.1661, -4.1775, -4.3479,
           -3.1374],
          [-3.0667, -3.7753, -3.9288, -4.7327, -4.6051, -4.2814, -4.2947,
           -3.1045],
          [-3.3896, -3.4202, -3.5822, -4.1823, -4.1330, -4.1221, -4.1392,
           -2.3941],
          [-2.8041, -3.5379, -4.5076, -4.8759, -4.8047, -4.7353, -4.5484,
           -2.9243]]]], device='cuda:0', grad_fn=<MeanBackward1>)

In [50]:
loss_real

tensor(0.0039, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [49]:
loss_fake

tensor(0.0035, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [45]:
assert (pred_real_1 == pred_real_2).all()

tensor([[[ 0.1627,  0.0858, -0.1988,  0.0166,  0.2525,  0.1275,  0.0035,
          -0.0720],
         [ 0.2663,  0.0456,  0.5502,  0.1471,  0.5895,  0.3599,  0.1650,
           0.1863],
         [ 0.2412,  0.3566,  0.4380,  0.0029, -0.2379, -0.3139,  0.2480,
          -0.1008],
         [ 0.0772,  0.3343, -0.7597, -0.2592,  0.1534,  0.0226,  0.3417,
           0.5543],
         [ 0.6315,  0.2903,  0.0907,  0.4565,  0.0854,  0.3103, -0.2989,
           0.2373],
         [ 0.2643, -0.3885, -0.3389,  0.0766,  0.1265,  0.4819,  0.1319,
          -0.2840],
         [-0.2802, -0.4577, -0.3373,  0.2746, -0.2266,  0.1437, -0.2373,
          -0.5435],
         [-0.4513, -0.7434, -0.6575, -0.0604, -0.0369, -0.4630, -0.2580,
          -0.1887]]], device='cuda:0', grad_fn=<SelectBackward>)

In [None]:
pre