In [1]:
import os

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode as IMode
from torch.utils.data import DataLoader
from narutils import *
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
import cv2

from utils.imgproc import *
from model import *



どうやらSRGANは最初のp_epochsはgeneratorだけを学習，advはepochsだけらしい

In [2]:
# get file paths
data_dir='/work/dataset/images1024x1024/'
fpaths=[]
for d in os.listdir(data_dir):
	fpaths+=glob(opj(data_dir,d,'*.png'))
path_df=dfc(fpaths).rename(columns={0:'file_path'}).to_csv('path_df.csv',index=False)

In [37]:
class cfg:
	device = torch.device("cuda:0")
	seed=19
	upscale_factor=4
	image_size=512 # hr image size(low res image size=image_size/upscale_factor)
	batch_size=4

	# Train epochs.
	p_epochs=50 # The total number of cycles of the generator training phase.
	epochs=10  # The total number of cycles in the training phase of the adversarial network.

	model_dir='./exp-debug'

	exp_name='test'
	vgg_path='/work/dataset/vgg-pre.pickle'
	train_path_df=pd.read_csv('path_df.csv').iloc[:5000,:]
	valid_path_df=pd.read_csv('path_df.csv').iloc[5000:10000,:]

	pixel_weight          = 0.01
	content_weight        = 1.0
	adversarial_weight    = 0.001

	debug=True


	if debug:
		train_path_df=train_path_df.iloc[:8,:]
		valid_path_df=valid_path_df.iloc[:8,:]


In [38]:
class CustomDataset(Dataset):
	def __init__(self,cfg,mode='train'):
		lr_image_size=(cfg.image_size//cfg.upscale_factor,cfg.image_size//cfg.upscale_factor)
		hr_image_size=(cfg.image_size,cfg.image_size)
		self.hr_transforms=transforms.Compose([
			transforms.Resize(hr_image_size,interpolation=IMode.BICUBIC),
		])
		self.lr_transforms=transforms.Compose([
			transforms.Resize(lr_image_size,interpolation=IMode.BICUBIC),
		])
		if mode=='train':
			self.filenames=cfg.train_path_df['file_path'].values
		elif mode=='valid':
			self.filenames=cfg.valid_path_df['file_path'].values
		else:
			raise NotImplementedError

	def __getitem__(self,index):
		hr=image2tensor(cv2.imread(self.filenames[index],-1))
		lr=self.lr_transforms(hr)
		hr=self.hr_transforms(hr)
		return lr,hr

	def __len__(self):
		return len(self.filenames)


In [43]:
class SRGAN():
	def __init__(self,cfg):
		self.cfg=cfg

	def build_dataloader(self,CustomDataset):
		self.train_dataset=CustomDataset(self.cfg,mode='train')
		self.valid_dataset=CustomDataset(self.cfg,mode='valid')

		self.train_loader=DataLoader(self.train_dataset,self.cfg.batch_size,True,pin_memory=True)
		self.valid_loader=DataLoader(self.valid_dataset,self.cfg.batch_size,True,pin_memory=True)

		decopri('successfully built data loaders!')

	def build_model(self):
		self.discriminator=Discriminator().to(self.cfg.device)
		self.generator=Generator().to(self.cfg.device)

		decopri('successfully built models!')

	def train_generator(self,epoch):
		"""Only train the generative model.
		Args:
			train_dataloader (torch.utils.data.DataLoader): The loader of the training data set.
			epoch (int): number of training cycles.

		"""
		# Calculate how many iterations there are under Epoch.
		batches = len(self.train_loader)
		# Put the generative model in training mode.
		self.generator.train()
		pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), desc=f'Train generator|epoch{epoch}',dynamic_ncols=True)
		for index, (lr, hr) in pbar:
			# Copy the data to the specified device.
			lr = lr.to(cfg.device)
			hr = hr.to(cfg.device)
			# Initialize the gradient of the generated model.
			self.generator.zero_grad()
			# Generate super-resolution images.
			sr = self.generator(lr)
			# Calculate the difference between the super-resolution image and the high-resolution image at the pixel level.
			pixel_loss = self.pixel_criterion(sr, hr)
			# Update the weights of the generated model.
			pixel_loss.backward()
			self.p_optimizer.step()
	

	def train(self):
		self.p_optimizer=optim.Adam(self.generator.parameters(),0.0001,(0.9, 0.999))  # Generate model learning rate during generator training.
		self.d_optimizer=optim.Adam(self.discriminator.parameters(),0.0001,(0.9, 0.999))  # Discriminator learning rate during adversarial network training.
		self.g_optimizer=optim.Adam(self.generator.parameters(),0.0001,(0.9, 0.999))  # The learning rate of the generator during network training.

		# Scheduler.
		self.d_scheduler=StepLR(self.d_optimizer, self.cfg.epochs // 2, 0.1)  # Identify the model scheduler during adversarial training.
		self.g_scheduler=StepLR(self.g_optimizer, self.cfg.epochs // 2, 0.1)
		
		# Loss functions
		self.pixel_criterion=nn.MSELoss().to(cfg.device)               # Pixel loss.
		self.content_criterion=ContentLoss(cfg).to(cfg.device)              # Content loss.
		self.adversarial_criterion=nn.BCELoss().to(cfg.device) 

		# train only generator stage
		decopri('Start train generator stage')
		psnr_best=0.0
		for epoch in range(self.cfg.p_epochs):
			self.train_generator(epoch)
			psnr=self.validate(epoch,stage='generator only')
			if (epoch+1)%5==0:
				torch.save(self.generator.state_dict(),opj(self.cfg.model_dir,f'p-{epoch}.pth'))
			if psnr>psnr_best:
				psnr_best=psnr
				mkdirs(self.cfg.model_dir)
				torch.save(self.generator.state_dict(),opj(self.cfg.model_dir,'p-best.pth'))
		
		# train adversarial stage
		self.generator.load_state_dict(torch.load(opj(cfg.model_dir,'p-best.pth')))


	def validate(self,epoch,stage='adversarial'):
		batches=len(self.valid_loader)
		self.generator.eval()
		total_psnr_value=0.0

		with torch.no_grad():
			pbar=tqdm(enumerate(self.valid_loader), total=len(self.valid_loader), desc=f'valid {stage}|epoch{epoch}',dynamic_ncols=True)
			for index,(lr,hr) in pbar:
				lr = lr.to(self.cfg.device)
				hr = hr.to(self.cfg.device)
				# Generate super-resolution images.
				sr = self.generator(lr)
				# Calculate the PSNR indicator.
				mse_loss = ((sr - hr) ** 2).data.mean()
				psnr_value = 10 * torch.log10(1 / mse_loss).item()
				total_psnr_value += psnr_value
			
			avg_psnr_value=total_psnr_value/batches
			print(f'epoch-{epoch} average psnr:{avg_psnr_value}')

		return avg_psnr_value

		

In [44]:
srgan=SRGAN(cfg)

In [45]:
srgan.build_dataloader(CustomDataset=CustomDataset)
srgan.build_model()

--------------------------------------------------------------------------------
                        successfully built data loaders!                        
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                           successfully built models!                           
--------------------------------------------------------------------------------


In [46]:
srgan.train()

--------------------------------------------------------------------------------
                          Start train generator stage                          
--------------------------------------------------------------------------------


Train generator|epoch0: 100%|██████████| 2/2 [00:00<00:00,  3.64it/s]
valid generator only|epoch0: 100%|██████████| 2/2 [00:00<00:00,  5.34it/s]


epoch-0 average psnr:5.834514796733856


Train generator|epoch1: 100%|██████████| 2/2 [00:00<00:00,  2.49it/s]
valid generator only|epoch1: 100%|██████████| 2/2 [00:00<00:00,  3.73it/s]


epoch-1 average psnr:5.868099629878998


Train generator|epoch2: 100%|██████████| 2/2 [00:00<00:00,  3.78it/s]
valid generator only|epoch2: 100%|██████████| 2/2 [00:00<00:00,  5.48it/s]


epoch-2 average psnr:6.0206955671310425


Train generator|epoch3: 100%|██████████| 2/2 [00:00<00:00,  2.32it/s]
valid generator only|epoch3: 100%|██████████| 2/2 [00:00<00:00,  4.99it/s]


epoch-3 average psnr:6.227478086948395


Train generator|epoch4: 100%|██████████| 2/2 [00:00<00:00,  4.10it/s]
valid generator only|epoch4: 100%|██████████| 2/2 [00:00<00:00,  2.27it/s]


epoch-4 average psnr:6.669838130474091


Train generator|epoch5:  50%|█████     | 1/2 [00:00<00:00,  2.78it/s]


KeyboardInterrupt: 

In [11]:
def train_generator(train_dataloader, epoch) -> None:
	"""Only train the generative model.

	Args:
		train_dataloader (torch.utils.data.DataLoader): The loader of the training data set.
		epoch (int): number of training cycles.

	"""
	# Calculate how many iterations there are under Epoch.
	batches = len(train_dataloader)
	# Put the generative model in training mode.
	generator.train()

	for index, (lr, hr) in enumerate(train_dataloader):
		# Copy the data to the specified device.
		lr = lr.to(cfg.device)
		hr = hr.to(cfg.device)
		# Initialize the gradient of the generated model.
		generator.zero_grad()
		# Generate super-resolution images.
		sr = generator(lr)
		# Calculate the difference between the super-resolution image and the high-resolution image at the pixel level.
		pixel_loss = pixel_criterion(sr, hr)
		# Update the weights of the generated model.
		pixel_loss.backward()
		p_optimizer.step()
		# Write the loss during training into Tensorboard.
		iters = index + epoch * batches + 1
		# writer.add_scalar("Train/Loss", pixel_loss.item(), iters)
		# # Print the loss function every ten iterations and the last iteration in this Epoch.
		# if (index + 1) % 10 == 0 or (index + 1) == batches:
		#     print(f"Train Epoch[{epoch + 1:04d}/{p_epochs:04d}]({index + 1:05d}/{batches:05d}) "
		#           f"Loss: {pixel_loss.item():.6f}.")

In [7]:
discriminator=Discriminator().to(cfg.device)            # Load the discriminator model.
generator=Generator().to(cfg.device)  
p_optimizer=optim.Adam(generator.parameters(),     0.0001, (0.9, 0.999))  # Generate model learning rate during generator training.
d_optimizer=optim.Adam(discriminator.parameters(), 0.0001, (0.9, 0.999))  # Discriminator learning rate during adversarial network training.
g_optimizer=optim.Adam(generator.parameters(),     0.0001, (0.9, 0.999))  # The learning rate of the generator during network training.

# Scheduler.
d_scheduler=StepLR(d_optimizer, cfg.epochs // 2, 0.1)  # Identify the model scheduler during adversarial training.
g_scheduler=StepLR(g_optimizer, cfg.epochs // 2, 0.1) 

In [8]:
pixel_criterion=nn.MSELoss().to(cfg.device)               # Pixel loss.
content_criterion=ContentLoss(cfg).to(cfg.device)              # Content loss.
adversarial_criterion=nn.BCELoss().to(cfg.device) 

In [9]:
train_dataset=CustomDataset(cfg,mode='train')
valid_dataset=CustomDataset(cfg,mode='valid')

In [10]:
train_loader=DataLoader(train_dataset,cfg.batch_size,True,pin_memory=True)
valid_loader=DataLoader(valid_dataset,cfg.batch_size,True,pin_memory=True)

In [31]:
train_generator(train_loader,0)

In [12]:

def train_adversarial(train_dataloader, epoch) -> None:
	"""Training generative models and adversarial models.

	Args:
		train_dataloader (torch.utils.data.DataLoader): The loader of the training data set.
		epoch (int): number of training cycles.

	"""
	# Calculate how many iterations there are under Epoch.
	batches = len(train_dataloader)
	# Put the two models in training mode.
	discriminator.train()
	generator.train()

	for index, (lr, hr) in enumerate(train_dataloader):
		# Copy the data to the specified device.
		lr = lr.to(cfg.device)
		hr = hr.to(cfg.device)
		label_size = lr.size(0)
		# 打label. Set the real sample label to 1, and the false sample label to 0.
		real_label = torch.full([label_size, 1], 1.0, dtype=lr.dtype, device=cfg.device)
		fake_label = torch.full([label_size, 1], 0.0, dtype=lr.dtype, device=cfg.device)

		# Initialize the identification model gradient.
		discriminator.zero_grad()
		# Generate super-resolution images.
		sr = generator(lr)
		# Calculate the loss of the identification model on the high-resolution image.
		hr_output = discriminator(hr)

		sr_output = discriminator(sr.detach())
		diff=hr_output - torch.mean(sr_output)
		diff[diff<0]=0
		d_loss_hr = adversarial_criterion(diff, real_label)
		d_loss_hr.backward()
		d_hr = hr_output.mean().item()
		# Calculate the loss of the identification model on the super-resolution image.
		hr_output = discriminator(hr)

		sr_output = discriminator(sr.detach())
		diff=sr_output - torch.mean(hr_output)
		diff[diff<0]=0
		d_loss_sr = adversarial_criterion(diff, fake_label)
		d_loss_sr.backward()
		d_sr1 = sr_output.mean().item()
		# Update the weights of the authentication model.
		d_loss = d_loss_hr + d_loss_sr
		d_optimizer.step()

		# Initialize the gradient of the generated model.
		generator.zero_grad()
		# Generate super-resolution images.
		sr = generator(lr)
		# Calculate the loss of the identification model on the super-resolution image.
		hr_output = discriminator(hr.detach())
		sr_output = discriminator(sr)
		# Perceptual loss = 0.01 * pixel loss + 1.0 * content loss + 0.005 * counter loss.
		pixel_loss = cfg.pixel_weight * pixel_criterion(sr, hr.detach())
		content_loss = cfg.content_weight * content_criterion(sr, hr.detach())
		diff=sr_output - torch.mean(hr_output)
		diff[diff<0]=0
		adversarial_loss = cfg.adversarial_weight * adversarial_criterion(diff, real_label)
		# Update the weights of the generated model.
		g_loss = pixel_loss + content_loss + adversarial_loss
		g_loss.backward()
		g_optimizer.step()
		# d_sr2 = sr_output.mean().item()

		# # Write the loss during training into Tensorboard.
		# iters = index + epoch * batches + 1

In [13]:
train_adversarial(train_loader,0)

In [23]:
def validate(valid_dataloader, epoch, stage) -> float:
    """Verify the generative model.

    Args:
        valid_dataloader (torch.utils.data.DataLoader): loader for validating data set.
        epoch (int): number of training cycles.
        stage (str): In which stage to verify, one is `generator`, the other is `adversarial`.

    Returns:
        PSNR value(float).

    """
    # Calculate how many iterations there are under Epoch.
    batches = len(valid_dataloader)
    # Put the generated model in verification mode.
    generator.eval()
    # Initialize the evaluation index.
    total_psnr_value = 0.0

    with torch.no_grad():
        for index, (lr, hr) in enumerate(valid_dataloader):
            # Copy the data to the specified device.
            lr = lr.to(cfg.device)
            hr = hr.to(cfg.device)
            # Generate super-resolution images.
            sr = generator(lr)
            # Calculate the PSNR indicator.
            mse_loss = ((sr - hr) ** 2).data.mean()
            psnr_value = 10 * torch.log10(1 / mse_loss).item()
            total_psnr_value += psnr_value

        avg_psnr_value = total_psnr_value / batches
        # Write the value of each round of verification indicators into Tensorboard.
        # Print evaluation indicators.
        print(f"Valid stage: {stage} Epoch[{epoch + 1:04d}] avg PSNR: {avg_psnr_value:.2f}.\n")

    return avg_psnr_value

In [24]:
validate(valid_loader,0,'generator')

Valid stage: generator Epoch[0001] avg PSNR: 5.85.



5.848747342824936