In [1]:
import os
import time
import visdom
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from models.deephic import Generator, Discriminator
from models.loss import GeneratorLoss
from models.ssim import ssim
from math import log10

from all_parser import root_dir

from torchsummary import summary

from matplotlib import pyplot as plt
import seaborn as sns

In [2]:
cs = np.column_stack

# data_dir: directory storing processed data
data_dir = os.path.join(root_dir, 'data')
# out_dir: directory storing checkpoint files
out_dir = os.path.join(root_dir, 'checkpoints')
os.makedirs(out_dir, exist_ok=True)

datestr = time.strftime('%m_%d_%H_%M')
visdom_str=time.strftime('%m%d')

resos = '10kb40kb' 
chunk = 40
stride = 40
bound = 201
pool = 'nonpool'

upscale = 1
#FIXME
num_epochs = 10
batch_size = 64

# whether using GPU for training
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
#数据
# prepare training dataset
train_file = os.path.join(data_dir, f'deephic_{resos}_c{chunk}_s{stride}_b{bound}_{pool}_train.npz')
train = np.load(train_file)
#data和target的size都是40*40的
train_data = torch.tensor(train['data'], dtype=torch.float)
train_target = torch.tensor(train['target'], dtype=torch.float)
train_inds = torch.tensor(train['inds'], dtype=torch.long) #这个是什么意思

train_set = TensorDataset(train_data, train_target, train_inds)

# prepare valid dataset
valid_file = os.path.join(data_dir, f'deephic_{resos}_c{chunk}_s{stride}_b{bound}_{pool}_valid.npz')
valid = np.load(valid_file)

valid_data = torch.tensor(valid['data'], dtype=torch.float)
valid_target = torch.tensor(valid['target'], dtype=torch.float)
valid_inds = torch.tensor(valid['inds'], dtype=torch.long)

valid_set = TensorDataset(valid_data, valid_target, valid_inds)

# DataLoader for batched training
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, drop_last=True)

print(train_data.shape)
print(valid_data.shape)

torch.Size([53238, 1, 40, 40])
torch.Size([18801, 1, 40, 40])


In [4]:
# load network
netG = Generator(upscale, in_channel=1, resblock_num=5).to(device)
netD = Discriminator(in_channel=1).to(device)

# loss function
#非常复杂的损失函数
criterionG = GeneratorLoss().to(device)
criterionD = torch.nn.BCELoss().to(device)

# optimizer
optimizerG = optim.Adam(netG.parameters(), lr=0.0001)
optimizerD = optim.Adam(netD.parameters(), lr=0.0001)




In [5]:
vis = visdom.Visdom(env=f'{visdom_str}-deephic')



In [8]:
# fruitpunch = sns.blend_palette(['white', 'red'], as_cmap=True)
# for data, target, _ in train_loader:
#     print(data.shape)
#     print(target.shape)
#     print(_.shape)
#     idx = 3
#     low_picture = data[idx].numpy().reshape(40,40)
#     high_picture = target[idx].numpy().reshape(40,40)
    
#     #画low picture和high picture的heatmap
#     fig,ax = plt.subplots(figsize=(20,10),ncols=2)
#     im = ax[0].matshow(
#         low_picture,
#         vmin=0,
#         cmap=fruitpunch)
#     plt.colorbar(im, ax=ax[0] ,fraction=0.046, pad=0.04, label='raw counts');
#     im = ax[1].matshow(
#         high_picture,
#         vmin=0,
#         cmap=fruitpunch)
#     plt.colorbar(im, ax=ax[1] ,fraction=0.046, pad=0.04, label='raw counts');
#     plt.show()
    
#     break

In [6]:
best_ssim = 0
torch.autograd.set_detect_anomaly(True)
for epoch in range(1, num_epochs+1):
    run_result = {'nsamples': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()
    train_bar = tqdm(train_loader)
    for data, target, _ in train_bar:
        batch_size = data.size(0)
        run_result['nsamples'] += batch_size
        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = target.to(device)
        z = data.to(device)
        fake_img = netG(z)

        ######### Train discriminator #########
        netD.zero_grad()
        real_out = netD(real_img)
        fake_out = netD(fake_img)
        d_loss_real = criterionD(real_out, torch.ones_like(real_out))
        d_loss_fake = criterionD(fake_out, torch.zeros_like(fake_out))
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward(retain_graph=True)
        optimizerD.step()

        ######### Train generator #########
        netG.zero_grad()
        g_loss = criterionG(fake_out.mean(), fake_img, real_img)
        g_loss.backward()
        optimizerG.step()

        run_result['g_loss'] += g_loss.item() * batch_size
        run_result['d_loss'] += d_loss.item() * batch_size
        run_result['d_score'] += real_out.mean().item() * batch_size
        run_result['g_score'] += fake_out.mean().item() * batch_size

        train_bar.set_description(desc=f"[{epoch}/{num_epochs}] Loss_D: {run_result['d_loss']/run_result['nsamples']:.4f} Loss_G: {run_result['g_loss']/run_result['nsamples']:.4f} D(x): {run_result['d_score']/run_result['nsamples']:.4f} D(G(z)): {run_result['g_score']/run_result['nsamples']:.4f}")
    train_gloss = run_result['g_loss']/run_result['nsamples']
    train_dloss = run_result['d_loss']/run_result['nsamples']
    train_dscore = run_result['d_score']/run_result['nsamples']
    train_gscore = run_result['g_score']/run_result['nsamples']

    valid_result = {'g_loss': 0, 'd_loss': 0, 'g_score': 0, 'd_score': 0, 
                    'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'nsamples': 0}
    netG.eval()
    netD.eval()
    valid_bar = tqdm(valid_loader)

    with torch.no_grad():
        for val_lr, val_hr, inds in valid_bar:
            batch_size = val_lr.size(0)
            valid_result['nsamples'] += batch_size
            lr = val_lr.to(device)
            hr = val_hr.to(device) #真的，target

            #生成的，假的
            sr = netG(lr)

            sr_out = netD(sr)
            hr_out = netD(hr)
            d_loss_real = criterionD(hr_out, torch.ones_like(hr_out))
            d_loss_fake = criterionD(sr_out, torch.zeros_like(sr_out))
            d_loss = d_loss_real + d_loss_fake
            g_loss = criterionG(sr_out.mean(), sr, hr)
            
            valid_result['g_loss'] += g_loss.item() * batch_size
            valid_result['d_loss'] += d_loss.item() * batch_size
            valid_result['g_score'] += sr_out.mean().item() * batch_size
            valid_result['d_score'] += hr_out.mean().item() * batch_size

            batch_mse = ((sr - hr) ** 2).mean()
            valid_result['mse'] += batch_mse * batch_size
            #计算sr和hr的ssim
            batch_ssim = ssim(sr, hr)
            valid_result['ssims'] += batch_ssim * batch_size
            valid_result['psnr'] = 10 * log10(1/(valid_result['mse']/valid_result['nsamples']))
            valid_result['ssim'] = valid_result['ssims'] / valid_result['nsamples']
            valid_bar.set_description(desc=f"[Predicting in Test set] PSNR: {valid_result['psnr']:.4f} dB SSIM: {valid_result['ssim']:.4f}")

    valid_gloss = valid_result['g_loss'] / valid_result['nsamples']
    valid_dloss = valid_result['d_loss'] / valid_result['nsamples']
    valid_gscore = valid_result['g_score'] / valid_result['nsamples']
    valid_dscore = valid_result['d_score'] / valid_result['nsamples']
    now_ssim = valid_result['ssim'].item()
    
    if epoch == 1:
        #初始化
        vis_dloss = vis.line(X=cs((epoch, epoch)), Y=cs((train_dloss, valid_dloss)), opts=dict(title='Discriminator Loss', legend=['Train', 'Test']))
        vis_gloss = vis.line(X=cs((epoch, epoch)), Y=cs((train_gloss, valid_gloss)), opts=dict(title='Generator Loss', legend=['Train', 'Test']))
        vis_dscore = vis.line(X=cs((epoch, epoch)), Y=cs((train_dscore, valid_dscore)), opts=dict(title='Discriminator Score of true images', legend=['Train', 'Test']))
        vis_gscore = vis.line(X=cs((epoch, epoch)), Y=cs((train_gscore, valid_gscore)), opts=dict(title='Generator Score of fake images', legend=['Train', 'Test']))
        vis_ssim = vis.line([now_ssim], X=[epoch], opts=dict(title='SSIM scores in test dataset'))
    else:
        #添加
        vis.line(X=cs((epoch, epoch)), Y=cs((train_dloss, valid_dloss)), update='append', win=vis_dloss, opts=dict(legend=['Train', 'Test']))
        vis.line(X=cs((epoch, epoch)), Y=cs((train_gloss, valid_gloss)), update='append', win=vis_gloss, opts=dict(legend=['Train', 'Test']))
        vis.line(X=cs((epoch, epoch)), Y=cs((train_dscore, valid_dscore)), update='append', win=vis_dscore, opts=dict(legend=['Train', 'Test']))
        vis.line(X=cs((epoch, epoch)), Y=cs((train_gscore, valid_gscore)), update='append', win=vis_gscore, opts=dict(legend=['Train', 'Test']))
        vis.line([now_ssim], X=[epoch], update='append', win=vis_ssim)

    if now_ssim > best_ssim:
        best_ssim = now_ssim
        print(f'Now, Best ssim is {best_ssim:.6f}')
        #并没有保存这个啊
        best_ckpt_file = f'{datestr}_bestg_{resos}_c{chunk}_s{stride}_b{bound}_{pool}_deephic.pytorch'
        torch.save(netG.state_dict(), os.path.join(out_dir, best_ckpt_file))

final_ckpt_g = f'{datestr}_finalg_{resos}_c{chunk}_s{stride}_b{bound}_{pool}_deephic.pytorch'
final_ckpt_d = f'{datestr}_finald_{resos}_c{chunk}_s{stride}_b{bound}_{pool}_deephic.pytorch'


#存储文件
torch.save(netG.state_dict(), os.path.join(out_dir, final_ckpt_g))
torch.save(netD.state_dict(), os.path.join(out_dir, final_ckpt_d))

[1/10] Loss_D: 0.4621 Loss_G: 0.0026 D(x): 0.8051 D(G(z)): 0.1955: 100%|██████████| 831/831 [23:15<00:00,  1.68s/it]
[Predicting in Test set] PSNR: 32.5985 dB SSIM: 0.8439: 100%|██████████| 293/293 [02:11<00:00,  2.23it/s]


Now, Best ssim is 0.843892


[2/10] Loss_D: 0.0789 Loss_G: 0.0006 D(x): 0.9618 D(G(z)): 0.0382: 100%|██████████| 831/831 [21:22<00:00,  1.54s/it]
[Predicting in Test set] PSNR: 33.4522 dB SSIM: 0.8763: 100%|██████████| 293/293 [01:38<00:00,  2.96it/s]


Now, Best ssim is 0.876289


[3/10] Loss_D: 0.0281 Loss_G: 0.0005 D(x): 0.9862 D(G(z)): 0.0138: 100%|██████████| 831/831 [19:59<00:00,  1.44s/it]
[Predicting in Test set] PSNR: 33.6064 dB SSIM: 0.8626: 100%|██████████| 293/293 [01:39<00:00,  2.96it/s]
[4/10] Loss_D: 0.0343 Loss_G: 0.0005 D(x): 0.9870 D(G(z)): 0.0133: 100%|██████████| 831/831 [20:14<00:00,  1.46s/it]
[Predicting in Test set] PSNR: 33.8972 dB SSIM: 0.8605: 100%|██████████| 293/293 [01:43<00:00,  2.83it/s]
[5/10] Loss_D: 0.0078 Loss_G: 0.0005 D(x): 0.9960 D(G(z)): 0.0038: 100%|██████████| 831/831 [20:22<00:00,  1.47s/it]
[Predicting in Test set] PSNR: 33.9638 dB SSIM: 0.8712: 100%|██████████| 293/293 [01:43<00:00,  2.83it/s]
[6/10] Loss_D: 0.1592 Loss_G: 0.0004 D(x): 0.9412 D(G(z)): 0.0583: 100%|██████████| 831/831 [20:21<00:00,  1.47s/it]
[Predicting in Test set] PSNR: 33.9591 dB SSIM: 0.8746: 100%|██████████| 293/293 [01:37<00:00,  3.02it/s]
[7/10] Loss_D: 0.0168 Loss_G: 0.0004 D(x): 0.9929 D(G(z)): 0.0095: 100%|██████████| 831/831 [20:21<00:00,  1

Now, Best ssim is 0.879968


[10/10] Loss_D: 0.0165 Loss_G: 0.0004 D(x): 0.9935 D(G(z)): 0.0070: 100%|██████████| 831/831 [21:06<00:00,  1.52s/it]
[Predicting in Test set] PSNR: 34.3463 dB SSIM: 0.8736: 100%|██████████| 293/293 [01:51<00:00,  2.63it/s]


In [7]:
train_inds[0]

tensor([    1, 22102,     0,     0])