In [None]:
# 模型基本参数
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch;torch.manual_seed(0)
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F


BATCH_SIZE = 32
EPOCHS = 1500
LATENTDIM = 192
LR_MAX = 1e-4
LR_MIN = 8e-5

path_Confocal = f'./datasets/Confocal/'
path_STED = f'./datasets/STED'
path_STED_HC = f'./datasets/STED_HC/'

savepath_model = f'./models/vae_model_{EPOCHS}epo_{BATCH_SIZE}bth_{LATENTDIM}latn.pth'
savepath_fig = f'./models/vae_lossfig_{EPOCHS}epo_{BATCH_SIZE}bth_{LATENTDIM}latn.png'
savepath_data = f'./models/vae_lossdata_{EPOCHS}epo_{BATCH_SIZE}bth_{LATENTDIM}latn.npy'

In [None]:
# 数据集
# 32*160
num_to_learn = 2560

class ImageDataset(Dataset):
    def __init__(self, path_LR, path_HR):
        self.path_LR= path_LR
        self.data_HR = path_HR
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.data = []
        for i in range(0,num_to_learn):
            img_LR_to_dataset = Image.open(f"{path_LR}/{i}_Confocal.png")
            img_LR_to_dataset = self.transform(img_LR_to_dataset).unsqueeze(0)
            img_HR_to_dataset = Image.open(f"{path_HR}/{i}_STED.png")
            img_HR_to_dataset = self.transform(img_HR_to_dataset).unsqueeze(0)
            self.data.append((img_LR_to_dataset, img_HR_to_dataset))
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        return self.data[index]

dataset = ImageDataset(path_Confocal,path_STED)
dataloader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)

In [None]:
# 检查dataset:
num = len(dataset)
img_LR_0, img_HR_0 = dataset[0]
img_shape = img_LR_0.shape
print("检查dataset参数：")
print("图像数量:%s"%num)
print("图像规格:[%d, %d, %d, %d] [批次图像数，通道数，高度，宽度]"%(img_shape[0], img_shape[1], img_shape[2], img_shape[3]))

for i in range(3):
    img_LR_i, img_HR_i = dataset[i]
    img_LR_i, img_HR_i = img_LR_i.squeeze(), img_HR_i.squeeze()
    fig, ax = plt.subplots(1,2)
    ax[0].imshow(img_LR_i,cmap='hot')
    ax[0].set_title('Low_Resolution[%s]'%(i+1))
    ax[0].axis('off')
    ax[1].imshow(img_HR_i,cmap='hot')
    ax[1].set_title('High_Resolution[%s]'%(i+1))
    ax[1].axis('off')
plt.show()

In [None]:
# 检查dataloader:
num_batches = len(dataloader)
batch_size = dataloader.batch_size
first_batch = next(iter(dataloader))
num_images_in_first_batch = len(first_batch[0])  # 或 len(first_batch[1])，两者应该相等
print("检查dataloader参数：")
print("批次数量:", num_batches)
print("批次大小:", batch_size)
print("第一个批次的图像数量:", num_images_in_first_batch)
#从dataloader角度检查第一batch图片：
img_LR_batch = first_batch[0]
img_HR_batch = first_batch[1]

for i in range(batch_size):
    img_LR_i, img_HR_i = img_LR_batch[i], img_HR_batch[i]
    img_LR_i, img_HR_i = img_LR_i.squeeze(), img_HR_i.squeeze()
    fig,ax = plt.subplots(1,2)
    ax[0].imshow(img_LR_i,cmap='hot')
    ax[0].set_title(i+1)
    ax[0].axis('off')
    ax[1].imshow(img_HR_i,cmap='hot')
    ax[1].set_title(i+1)
    ax[1].axis('off')
plt.show()

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(4, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 64, kernel_size=3, stride=2, padding=1)
        self.fc_mu = nn.Linear(64 * 32 * 32, latent_dim) 
        self.fc_logvar = nn.Linear(64 * 32 * 32, latent_dim) 

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 64 * 32 * 32)
        self.conv3 = nn.ConvTranspose2d(64, 16, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.conv2 = nn.ConvTranspose2d(16, 4, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.conv1 = nn.ConvTranspose2d(4, 1, kernel_size=4, stride=2, padding=1, output_padding=0)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64, 32, 32)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv1(x))
        return x

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

In [None]:
# 循环训练
LOSS_PLOT = []
EPOCH_PLOT = []
LR_PLOT = []
def train(vae, dataloader, criterion, optimizer, num_epochs, device):
    vae.to(device)
    for epoch in range(num_epochs):
        vae.train()
        total_loss = 0.0
        # 余弦退火法动态学习率
        current_lr = LR_MIN + 0.5 * (LR_MAX - LR_MIN) * (1 + np.cos(np.pi * epoch / EPOCHS))
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr
        for i, (img_LR, img_HR) in enumerate(dataloader):
            img_LR = torch.squeeze(img_LR,dim = 1)
            img_HR = torch.squeeze(img_HR,dim = 1)
            img_LR = img_LR.to(device)
            img_HR = img_HR.to(device)
            # 前向传播
            img_SR, mu, logvar = vae(img_LR)
            # 计算重建误差
            loss_mse = criterion(img_SR, img_HR)
            #loss_ssim = 1 - pytorch_ssim.ssim(img_SR, img_HR).item()
            loss = loss_mse
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.6f},Current LR: {current_lr:.8f}")
        LOSS_PLOT.append(total_loss)
        EPOCH_PLOT.append(epoch)
        LR_PLOT.append(current_lr)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = VAE(LATENTDIM).to(device)
vae = nn.DataParallel(vae)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(vae.parameters())

In [None]:
print(device)
print('start!')
train(vae, dataloader, criterion, optimizer, EPOCHS, device)
print('succsessfully done!')

In [None]:
torch.save(vae.state_dict(), savepath_model)

In [None]:
# 保存loss曲线和loss表
fig,ax = plt.subplots()
ax.plot(EPOCH_PLOT,LOSS_PLOT)
plt.show()
fig.savefig(savepath_fig,dpi = 300)
EPOCH_C = np.array(EPOCH_PLOT)
LOSS_C = np.array(LOSS_PLOT)
LOSS_DATA = np.stack((EPOCH_C,LOSS_C),axis=0)
np.save(savepath_data,LOSS_DATA)

In [None]:
# 绘制learn_rate曲线
fig,ax = plt.subplots()
ax.plot(EPOCH_PLOT,LR_PLOT)
plt.show()

In [None]:
def reconstruct(vae,dataloader):
    for i, (x, y) in enumerate(dataloader):
        x = x.to(device)
        x_rec = vae(x.squeeze(1)) #squeeze:移除维度
        y = y.to(device)
    return x, x_rec, y
x,x_rec,y= reconstruct(vae,dataloader)
print(x.shape)
print("一个批次包含",x.shape[0],"个样本，每个样本的形状为",x.shape[1:5])
x = x.squeeze(1)
print(x_rec[0].shape)
print("2.x.shape"+"运用sqeeze除去第二个维度，放入vae模型进行训练")
print(x.shape[2:])
print("3.要imshow除去x需要除去第一个维度化为",x.shape[2:])

print(y.shape)
y = y.squeeze(1)

In [None]:
img_LR_test = x
img_SR_test = x_rec[0]
img_HR_test = y
img_LR_test = img_LR_test.detach().cpu().numpy()
img_SR_test = img_SR_test.detach().cpu().numpy()
img_HR_test = img_HR_test.detach().cpu().numpy()
batch_size = dataloader.batch_size
fig,ax = plt.subplots(1,3)

number = 2

ax[0].imshow(img_LR_test[number,0,:,:],cmap='hot')
ax[0].set_xlabel("Confocal")

ax[1].imshow(img_SR_test[number,0,:,:],cmap='hot')
ax[1].set_xlabel("Super-resolution")

ax[2].imshow(img_HR_test[number,0,:,:],cmap='hot')
ax[2].set_xlabel("STED_HC")
plt.show()