In [None]:
import torch; torch.manual_seed(0)
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import DataLoader
from function.VAE import VAE
from function.Dir import Dir
from function.Dataset import ImageDataset
from function.Loss import Custom_criterion
from function.Log import log

NUM_TO_LEARN = 5000 #训练集放入图片对数量
EPOCHS = 1000 #参数1
BATCH_SIZE = 128 #参数2
LATENTDIM = 256 #参数3
LR_MAX = 5e-4
LR_MIN = 5e-6
mode = 1 #0代表STED_HC文件训练，1代表使用STED，对应ImageDataset里的 mode 参数。（STED出的模型对泛化能力弱，STED_HC对训练集的还原会有点失真）

DEVICE = 'cuda'
LOSS_PLOT = []
EPOCH_PLOT = []

In [None]:
name = f'{EPOCHS}epo_{BATCH_SIZE}bth_{LATENTDIM}latn'

#加载数据集
dataset = ImageDataset(NUM_TO_LEARN, mode)
dataloader = DataLoader(dataset, BATCH_SIZE, True)

#初始化VAE网络
vae = VAE(LATENTDIM).to(DEVICE)
vae = nn.DataParallel(vae) #将 VAE 包装成一个并行化模型，以便在多个 GPU 上并行地进行训练

#定义LOSS函数与优化器
criterion1 = nn.MSELoss()
criterion2 = Custom_criterion().cuda()
optimizer = torch.optim.AdamW(vae.parameters(), lr = LR_MAX)

In [None]:
def train(dataloader, num_epochs):
    with open('training.log', 'w') as nothing: # 清空原log
        pass
    for epoch in range(num_epochs):
        vae.train() # 切换成训练模式
        total_loss = 0.0
        current_lr = LR_MIN + 0.5 * (LR_MAX - LR_MIN) * (1 + np.cos(np.pi * epoch / EPOCHS)) #定义loss
        optimizer = torch.optim.AdamW(vae.parameters(), lr = current_lr)

        for _, (img_LR, img_HR) in enumerate(dataloader):
            img_LR = torch.squeeze(img_LR,dim = 1).to(DEVICE)
            img_HR = torch.squeeze(img_HR,dim = 1).to(DEVICE)
            img_SR, _, _ = vae(img_LR)
            img_SR = img_SR.to(DEVICE)
            # 这步为止，img_LR,img_HR,img_SR均是[batchsize,不知道是什么,宽，高]
            if epoch <= 500:
                loss = criterion1(img_SR, img_HR)
            if epoch > 500:
                loss = criterion2(img_SR, img_HR) # 每个BATCH的loss，64张图平均
            optimizer.zero_grad()
            loss.backward() # 最耗算力的一步
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader) # 每个EPOCH的loss，全部数据集的平均
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.6f}, Current_LR:{current_lr:.8f}")
        log(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.6f}, Current_LR:{current_lr:.8f}")

        LOSS_PLOT.append(total_loss)
        EPOCH_PLOT.append(epoch)
        if epoch % 300 == 0:
            torch.save(vae.state_dict(), Dir.TEMP()+'/checkpoint.pth')

In [None]:
print(DEVICE)
print('start!')
train(dataloader, EPOCHS)
print('succsessfully done!')

In [None]:
fig,ax = plt.subplots()
ax.plot(EPOCH_PLOT,LOSS_PLOT)
# plt.show() # jupyter不show也显示图片
fig.savefig(f'{Dir.models()}/lossfig_{name}.png', dpi = 300)

LOSS_DATA = np.stack((np.array(EPOCH_PLOT),np.array(LOSS_PLOT)),axis=0)
np.save(f'{Dir.models()}/lossdata_{name}.npy',LOSS_DATA)

torch.save(vae.state_dict(), f'{Dir.models()}/model_{name}.pth')

import torch; torch.manual_seed(0)
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import DataLoader
from function.VAE import VAE
from function.Dir import Dir
from function.Dataset import ImageDataset
from function.Loss import Custom_criterion
from function.Log import log

# 第0步：定义网络参数
NUM_TO_LEARN = 5000 #训练集放入图片对数量
EPOCHS = 2000 #参数1
BATCH_SIZE = 128 #参数2
LATENTDIM = 256 #参数3
LR_MAX = 5e-4
LR_MIN = 5e-6
mode = 0 #0代表STED_HC文件训练，1代表使用STED，对应ImageDataset里的 mode 参数。哪个训练效果好？

DEVICE = 'cuda'
LOSS_PLOT = []
EPOCH_PLOT = []

# 第一步：加载数据集
dataset = ImageDataset(NUM_TO_LEARN, 0) # 0代表使用STED_HC文件训练，1代表使用STED文件训练 对应ImageDataset里的 mode 参数。哪个训练效果好？
dataloader = DataLoader(dataset, BATCH_SIZE, True)

# 第二步：初始化VAE网络
vae = VAE(LATENTDIM).to(DEVICE)
vae = nn.DataParallel(vae) #将 VAE 包装成一个并行化模型，以便在多个 GPU 上并行地进行训练
# 第二步：定义LOSS函数与优化器
criterion1 = nn.MSELoss()
criterion2 = Custom_criterion().cuda()
optimizer = torch.optim.AdamW(vae.parameters(), lr = LR_MAX)

# 第三步：定义Train函数
def train(dataloader, num_epochs):

# 第四步：开始训练，输出可视化，保存训练log，保存训练后网络
print(DEVICE)
print('start!')
train(dataloader, EPOCHS)
print('succsessfully done!')

fig,ax = plt.subplots()
ax.plot(EPOCH_PLOT,LOSS_PLOT)
fig.savefig(f'{Dir.models()}/lossfig_{name}.png', dpi = 300) # 保存LOSS图片
LOSS_DATA = np.stack((np.array(EPOCH_PLOT),np.array(LOSS_PLOT)),axis=0)
np.save(f'{Dir.models()}/lossdata_{name}.npy',LOSS_DATA) # 保存LOSS数据
torch.save(vae.state_dict(), f'{Dir.models()}/model_{name}.pth') #保存模型pth