In [1]:
import torch
print(torch.cuda.is_available())
print(torch.__version__)
print(torch.version.cuda)

True
2.3.1
12.1


In [2]:
from torch.utils.data import Dataset
import os
import torch
import torch.nn as nn
import math
from torch.autograd import Function
import scipy.io as scio
from scipy import io

class Imgdataset(Dataset):

    def __init__(self, path):
        super(Imgdataset, self).__init__()
        self.data = []
        if os.path.exists(path):
            groung_truth_path = path + '/gt_gray'
            #groung_truth_path = path + '/gt_gray_flip-20240919T013623Z-001/gt_gray_flip'

            if os.path.exists(groung_truth_path):
                groung_truth = os.listdir(groung_truth_path)
                self.data = [{'groung_truth': groung_truth_path + '/' + groung_truth[i]} for i in
                             range(len(groung_truth))]
            else:
                raise FileNotFoundError('path doesnt exist!')
        else:
            raise FileNotFoundError('path doesnt exist!')

    def __getitem__(self, index):
        groung_truth = self.data[index]["groung_truth"]

        gt = scio.loadmat(groung_truth)
        if "patch_save_gray" in gt:
            gt = torch.from_numpy(gt['patch_save_gray'])
        elif "p1" in gt:
            gt = torch.from_numpy(gt['p1'] / 255)
        elif "p2" in gt:
            gt = torch.from_numpy(gt['p2'] / 255)
        elif "p3" in gt:
            gt = torch.from_numpy(gt['p3'] / 255)

        gt = gt.permute(2, 0, 1)

        return gt

    def __len__(self):

        return len(self.data)


In [3]:
import torch
import torch.nn as nn

class double_conv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(double_conv, self).__init__()
        self.d_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.d_conv(x)
        return x

class Unet(nn.Module):

    def __init__(self,in_ch, out_ch):
        super(Unet, self).__init__()
                
        self.dconv_down1 = double_conv(in_ch, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)       

        self.maxpool = nn.MaxPool2d(2)
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.dconv_up2 = double_conv(64 + 64, 64)
        self.dconv_up1 = double_conv(32 + 32, 32)
        
        self.conv_last = nn.Conv2d(32, out_ch, 1)
        self.afn_last = nn.Tanh()
        
        
    def forward(self, x):
        inputs = x
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)

        
        x = self.upsample2(conv3)        
        x = torch.cat([x, conv2], dim=1)
        
        x = self.dconv_up2(x)
        x = self.upsample1(x)        
        x = torch.cat([x, conv1], dim=1)       

        x = self.dconv_up1(x)  
        
        x = self.conv_last(x)
        x = self.afn_last(x)
        out = x + inputs
        
        return out
    
class double_conv3d(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(double_conv3d, self).__init__()
        self.d_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.d_conv(x)
        return x

class Unet3d(nn.Module):

    def __init__(self,in_ch, out_ch):
        super(Unet3d, self).__init__()
                
        self.dconv_down1 = double_conv3d(in_ch, 32)
        self.dconv_down2 = double_conv3d(32, 64)
        self.dconv_down3 = double_conv3d(64, 128)       

        self.maxpool = nn.MaxPool3d(2)
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.dconv_up2 = double_conv3d(64 + 64, 64)
        self.dconv_up1 = double_conv3d(32 + 32, 32)
        
        self.conv_last = nn.Conv3d(32, out_ch, 1)
        self.afn_last = nn.Tanh()
        
        
    def forward(self, x):
        inputs = x
        #print('inputsize:{}'.format(x.shape))
        conv1 = self.dconv_down1(x)
        #print('conv1size:{}'.format(conv1.shape))
        
        x = self.maxpool(conv1)
        #print('maxpoolsize:{}'.format(x.shape))
        

        conv2 = self.dconv_down2(x)
        #print('conv2size:{}'.format(conv2.shape))
        
        x = self.maxpool(conv2)
        #print('maxpool2size:{}'.format(x.shape))
        
        conv3 = self.dconv_down3(x)
        #print('conv3size:{}'.format(conv3.shape))

        
        x = self.upsample2(conv3)
        #print('x1:{}'.format(x.shape))        
        x = torch.cat([x, conv2], dim=1)
        #print('x2:{}'.format(x.shape))        
        x = self.dconv_up2(x)
        #print('x3:{}'.format(x.shape))        
        x = self.upsample1(x)        
        #print('x4:{}'.format(x.shape))        
        x = torch.cat([x, conv1], dim=1)      
        #print('x5:{}'.format(x.shape))         

        x = self.dconv_up1(x)  
        #print('x6:{}'.format(x.shape))        
        
        x = self.conv_last(x)
        x = self.afn_last(x)
        out = x + inputs
        
        return out

In [4]:
import torch
import scipy.io as scio
import numpy as np
from torch.utils.data import Dataset
import os
import torch.nn as nn
import math
from torch.autograd import Function
from scipy import io


# シャッタパターン読み込み
def generate_masks(mask_path,mask_name):
    mask = scio.loadmat(mask_path + '/' + mask_name)
    mask = mask['ExpPtn']
    #print(mask.shape)
    mask = np.transpose(mask, [2, 0, 1])
    mask_s = np.sum(mask, axis=0)
    index = np.where(mask_s == 0)
    mask_s[index] = 1
    mask_s = mask_s.astype(np.uint8)
    mask = torch.from_numpy(mask)
    mask = mask.float()
    mask = mask.cuda()
    mask_s = torch.from_numpy(mask_s)
    mask_s = mask_s.float()
    mask_s = mask_s.cuda()
    return mask, mask_s

def time2file_name(time):
    year = time[0:4]
    month = time[5:7]
    day = time[8:10]
    hour = time[11:13]
    minute = time[14:16]
    second = time[17:19]
    time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_' + second
    return time_filename


  

In [5]:
class BinarizeHadamardFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight):
        ctx.save_for_backward(input, weight)
        tmp_zero = torch.zeros(weight.shape).to(weight.device)
        tmp_one = torch.ones(weight.shape).to(weight.device)
        weight_b = torch.where(weight>0, tmp_one, tmp_zero)
        output = input * weight_b
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        tmp_zero = torch.zeros(weight.shape).to(weight.device)
        tmp_one = torch.ones(weight.shape).to(weight.device)
        weight_b = torch.where(weight>0, tmp_one, tmp_zero)
        grad_input = grad_weight = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output * weight_b
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output * input
        return grad_input, grad_weight

class LearnableMask(nn.Module):
    def __init__(self, t=16, s=256):
        super().__init__()
        self.t = t
        self.s = s
        self.weight = nn.Parameter(torch.Tensor(t, s, s))
        self.reset_parameters()

    def reset_parameters(self):
        self.stdv = torch.sqrt(torch.tensor(1.5 / (self.s * self.s * self.t)))
        self.weight.data.uniform_(-self.stdv, self.stdv)

    def forward(self, input):
        return BinarizeHadamardFunction.apply(input, self.weight)

    def get_binary_mask(self):
        with torch.no_grad():
            return torch.where(self.weight > 0, 
                             torch.ones_like(self.weight), 
                             torch.zeros_like(self.weight))

In [6]:
import torch
import torch.nn as nn

#再構成モデル　DMM
class ADMM_net(nn.Module):

    def __init__(self, t=16, s=256):
        super(ADMM_net, self).__init__()
        self.mask = LearnableMask(t=t, s=s)
        self.unet1 = Unet(16, 16)
        self.unet2 = Unet(16, 16)
        self.unet3 = Unet(16, 16)
        self.unet4 = Unet(16, 16)
        self.unet5 = Unet(16, 16)
        self.unet6 = Unet(16, 16)
        self.unet7 = Unet(16, 16)
        self.unet8 = Unet(16, 16)
        self.unet9 = Unet(16, 16)   
        self.gamma1 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma2 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma3 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma4 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma5 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma6 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma7 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma8 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma9 = torch.nn.Parameter(torch.Tensor([0]))

    def forward(self, x):
        # Generate measurement using learnable mask
        maskt = self.mask(x)
        y = torch.sum(maskt, dim=1)
        
        # Get binary mask for reconstruction
        binary_mask = self.mask.get_binary_mask()
        Phi = binary_mask.expand([x.shape[0], 16, 256, 256])
        Phi_s = torch.sum(binary_mask, dim=0).expand([x.shape[0], 256, 256])

        # ADMM reconstruction
        x_list = []
        theta = self.At(y, Phi)
        b = torch.zeros_like(Phi)

        # 9 stages of reconstruction
        x_list = []
        theta = self.At(y,Phi)
        b = torch.zeros_like(Phi)
        ### 1-3
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma1),Phi)
        x1 = x-b
        theta = self.unet1(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma2),Phi)
        x1 = x-b
        theta = self.unet2(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma3),Phi)
        x1 = x-b
        theta = self.unet3(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 4-6
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma4),Phi)
        x1 = x-b
        theta = self.unet4(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma5),Phi)
        x1 = x-b
        theta = self.unet5(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma6),Phi)
        x1 = x-b
        theta = self.unet6(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 7-9
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma7),Phi)
        x1 = x-b
        theta = self.unet7(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma8),Phi)
        x1 = x-b
        theta = self.unet8(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = self.A(theta+b,Phi)
        x = theta+b + self.At(torch.div(y-yb,Phi_s+self.gamma9),Phi)
        x1 = x-b
        theta = self.unet9(x1)
        b = b- (x-theta)
        x_list.append(theta)

        return x_list
    
    def A(self, x,Phi):
        temp = x*Phi
        y = torch.sum(temp,1)
        return y

    def At(self, y,Phi):

        temp = torch.unsqueeze(y, 1).repeat(1,Phi.shape[1], 1,1)
        x = temp*Phi
        return x

In [7]:
# maskがbayer_randomで固定バージョン(RGB画像版と対照)
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import torch
import scipy.io as scio
import time
import datetime
import os
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import json

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

if not torch.cuda.is_available():
    raise Exception('NO GPU!')

data_path = r"./data/"

test_path1 = r"./data/ktestdata2_gray"
#test_path1 = r"./data/test_gray"
val_path1 = r"./data/valdata_gray"
result_path1 = r"./result/ADMM_gray"
mask_path = r"./mask"

#mask_name = 'bayer_random256x256'  #完全ランダムパターン
# mask_name = 'bayer_hitomirandom256x256'  #hitomiの制約ありパターン
# mask_name = 'bayer_hamaphotorandom256x256'  #ハマホトカメラ用制約ありパターン
#mask_name ='allOne_gray'
#mask, mask_s = generate_masks(mask_path,mask_name)

last_train = 0
model_save_filename = ''
max_iter = 100
batch_size = 4
learning_rate = 0.0016
stage_num = 9
mode = 'train'  # train or test

dataset = Imgdataset(data_path)
train_data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

criterion = nn.MSELoss()

#if last_train != 0:
#    network = torch.load(
#        './model/' + 'monochrome/' + model_save_filename + "/model_epoch_{}.pth".format(last_train))

def compute_smoothness(tensor):
    diff_h = torch.abs(tensor[:, :, 1:] - tensor[:, :, :-1])
    diff_w = torch.abs(tensor[:, 1:, :] - tensor[:, :-1, :])
    
    # パディングして元のサイズに戻す
    diff_h = torch.nn.functional.pad(diff_h, (0, 1, 0, 0))
    diff_w = torch.nn.functional.pad(diff_w, (0, 0, 0, 1))
    
    return torch.mean(diff_h + diff_w)

def test(network, test_path, epoch, recon_path, psnr_epoch, psnr):
    network.eval()
    test_list = os.listdir(test_path)
    psnr_sample = torch.zeros(len(test_list))
    pred = []
    #compression_rate_checked = False
    
    for i in range(len(test_list)):
        pic = scio.loadmat(test_path + '/' + test_list[i])

        if "patch_save_gray" in pic:
            pic = pic['patch_save_gray']
#         pic = pic / 255
        elif "y" in pic:
            pic = pic['y']
        elif "p1" in pic:
            pic = pic['p1']
        elif "p2" in pic:
            pic = pic['p2']
        elif "p3" in pic:
            pic = pic['p3']
        #print(pic)
        pic = np.transpose(pic, [2,0,1])
        pic0 = torch.from_numpy(pic).cuda().float()
        pic0 = pic0.unsqueeze(0)
        
        with torch.no_grad():
            out_pic_list = network(pic0)
            #print('y.shape:{}'.format(y.shape))
            #print('Phi.shape:{}'.format(Phi.shape))
            #print('Phi_s.shape:{}'.format(Phi_s.shape))
            out_pic = out_pic_list[-1]
            
            #print('out_pic:{}'.format(out_pic.shape))


            psnr_1 = 10 * torch.log10(1 / criterion(out_pic, pic0))

            psnr_sample[i] = psnr_1
            
            if test_path == test_path1:
                psnr[test_list[i]] = float(psnr_1)
        
        pred.append(out_pic.cpu().numpy())
        
    psnr_epoch.append(psnr_sample)
    
    if test_path == test_path1:
        psnr['test_avg'] = float(torch.mean(psnr_sample))
    else:
        psnr['val_avg'] = float(torch.mean(psnr_sample))
    
    return pred, psnr_epoch




def train_with_explicit_mask_optimization(network, train_loader, epoch, mask_lr, recon_lr, device):
    criterion = nn.MSELoss()
    
    # マスク用の最適化器と再構成ネットワーク用の最適化器を分離
    mask_optimizer = optim.Adam([
        {'params': network.mask.parameters(), 'lr': mask_lr}
    ])
    
    recon_optimizer = optim.Adam([
        {'params': network.unet1.parameters()},
        {'params': network.unet2.parameters()},
        {'params': network.unet3.parameters()},
        {'params': network.unet4.parameters()},
        {'params': network.unet5.parameters()},
        {'params': network.unet6.parameters()},
        {'params': network.unet7.parameters()},
        {'params': network.unet8.parameters()},
        {'params': network.unet9.parameters()},
        {'params': network.gamma1},
        {'params': network.gamma2},
        {'params': network.gamma3},
        {'params': network.gamma4},
        {'params': network.gamma5},
        {'params': network.gamma6},
        {'params': network.gamma7},
        {'params': network.gamma8},
        {'params': network.gamma9}
    ], lr=recon_lr)
    
    network.train()
    epoch_loss = 0
    start_time = time.time()
    
    
    for batch_idx, gt in enumerate(train_loader):
        gt = gt.to(device).float()
        
        # マスクの最適化ステップ
        #mask_optimizer.zero_grad()
        recon_optimizer.zero_grad()
        
        outputs = network(gt)
        
        # 再構成誤差
        recon_loss = (torch.sqrt(criterion(outputs[-1], gt)) + 
                    0.5 * torch.sqrt(criterion(outputs[-2], gt)) + 
                    0.5 * torch.sqrt(criterion(outputs[-3], gt)))
        
        # マスクに関する正則化項（オプション）
        #mask_sparsity = 0.01 * torch.mean(torch.abs(network.mask.weight))
        #mask_smoothness = 0.01 * compute_smoothness(network.mask.weight)
        
        # 総損失
        total_loss = recon_loss  #+ mask_sparsity + mask_smoothness
        
        total_loss.backward()
        
        # マスクと再構成ネットワークの更新を分離
        #mask_optimizer.step()
        recon_optimizer.step()
        
        epoch_loss += total_loss.item()
        
        # マスクの制約を適用（オプション）
        # with torch.no_grad():
        #     # マスクの値を0-1の範囲に制限
        #     network.mask.weight.data.clamp_(-1, 1)
            
        #     # 特定の制約条件を適用（例：各時間ステップでの露光時間の合計を制限）
        #     total_exposure = torch.sum(network.mask.get_binary_mask(), dim=(1,2))
        #     if torch.any(total_exposure > 128):  # 例：最大露光時間を128に制限
        #         scale_factor = 128 / total_exposure
        #         network.mask.weight.data *= scale_factor.view(-1, 1, 1)
    avg_epoch_loss = epoch_loss / len(train_loader)
    time_taken = time.time() - start_time
    print("====================================")
    print(f"Epoch {epoch}")
    print(f"Loss: {avg_epoch_loss:.6f}")
    print(f"Time: {time_taken:.2f}s")

    return avg_epoch_loss, time_taken, network.mask.get_binary_mask()
    

def checkpoint(network, epoch, model_path):
    model_out_path = './' + model_path + "/pre-train_epoch_{}.pth".format(epoch)
    torch.save(network, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))
    
def savemask(network, epoch, binary_mask, datetime):
    exposure_sum = torch.sum(binary_mask, dim=(1,2))
    print(f"Average exposure per frame: {exposure_sum.mean().item():.2f}")
        
    # マスクを保存
    mask_path = f'./mask_pre-train/{datetime}'
    if not os.path.exists(mask_path):
        os.makedirs(mask_path)
    scio.savemat(mask_path + f"/mask_epoch_{epoch}.mat", {
        'mask': binary_mask.cpu().numpy(),
        'raw_weights': network.mask.weight.detach().cpu().numpy()
    }) 
    



def main(mask_lr, recon_lr, mask_name):
    date_time = str(datetime.datetime.now())
    date_time = time2file_name(date_time)
    recon_path = 'recon' + '/' + 'monochrome' + '/' + mask_name + '/' + date_time
    model_path = 'model' + '/' + 'monochrome' + '/' + mask_name + '/' + date_time
    result_path = result_path1 + '/' + mask_name + '/' + date_time
    
    if not os.path.exists(recon_path):
        os.makedirs(recon_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    
    psnr_epoch = []
    val_psnr_epoch = []
    result_dict = []
    
    mask = mask = scio.loadmat('./mask/bayer_'+ mask_name + '.mat')
    mask = mask['ExpPtn']
    #mask = np.transpose(mask, [2, 0, 1])  # (T, H, W)の形式に変換
    mask = torch.from_numpy(mask).float()
    
    network = ADMM_net().cuda()
    #network = torch.load('./model/monochrome/2024_10_23_17_21_26_bayer_random256x256/S9_model_epoch_200.pth', map_location=torch.device('cpu'))
    # video_path = "./data/ktestdata2_gray/gray_3_aerobatics00000_3.mat"
    # model_path = "./model/monochrome/2024_10_23_17_21_26_bayer_random256x256/S9_model_epoch_200.pth"
    
    network.mask.weight.data.copy_(mask)
        
    # 再構成部分のパラメータのみ更新可能に
    # マスクのパラメータを凍結
    network.mask.requires_grad_(False)
    # 再構成部分のパラメータを更新可能に
    network.unet1.requires_grad_(True)
    network.unet2.requires_grad_(True)
    network.unet3.requires_grad_(True)
    network.unet4.requires_grad_(True)
    network.unet5.requires_grad_(True)
    network.unet6.requires_grad_(True)
    network.unet7.requires_grad_(True)
    network.unet8.requires_grad_(True)
    network.unet9.requires_grad_(True)
    network.gamma1.requires_grad_(True)
    network.gamma2.requires_grad_(True)
    network.gamma3.requires_grad_(True)
    network.gamma4.requires_grad_(True)
    network.gamma5.requires_grad_(True)
    network.gamma6.requires_grad_(True)
    network.gamma7.requires_grad_(True)
    network.gamma8.requires_grad_(True)
    network.gamma9.requires_grad_(True)
    
    for epoch in range(last_train + 1, last_train + max_iter + 1):
        psnr = {}
        
        loss, time, binary_mask = train_with_explicit_mask_optimization(network, train_data_loader, epoch, mask_lr, recon_lr, device='cuda')
        
        pred, psnr_epoch = test(network, test_path1, epoch, recon_path, psnr_epoch, psnr)
        psnr_mean = torch.mean(psnr_epoch[-1])
        
        val_pred, val_psnr_epoch = test(network, val_path1, epoch, recon_path, val_psnr_epoch, psnr)
        val_psnr_mean = torch.mean(val_psnr_epoch[-1])
        
        print("Test result: {:.4f}".format(psnr_mean))
        print("Validation result: {:.4f}".format(val_psnr_mean))
        
        result_dict.append({
            'epoch' : epoch,
            'lr' : recon_lr,
            'loss' : loss,
            'time' : time, 
            'psnr' : psnr
        })
        
        with open(result_path + '/' + 'admm_gray_pre-train_result.json', 'w') as file:
                json.dump(result_dict, file, indent=4)
        

        if (epoch % 10 == 0):

            name = recon_path + '/S{}'.format(stage_num) + '_pred_' + '{}_{:.4f}'.format(epoch, psnr_mean) + '.mat'
            scio.savemat(name, {'pred': pred})
            #savemask(network, epoch, binary_mask, date_time)
            checkpoint(network, epoch, model_path)
            
            
        
        if (epoch % 10 == 0) and (epoch < 300):
            mask_lr = mask_lr*0.95
            recon_lr = recon_lr*0.95
    

if __name__ == '__main__':
    mask_name_list = [
        "hitomirandom256x256",
        "hitomirandom256x256_rl",
        "hitomirandom256x256_ud",
        "hitomirandom256x256_udrl",
        "hamaphotorandom256x256",
        "hamaphotorandom256x256_rl",
        "hamaphotorandom256x256_ud",
        "hamaphotorandom256x256_udrl",
        "random256x256",
        "random256x256_rl",
        "random256x256_ud",
        "random256x256_udrl"
    ]
    for mask_name in mask_name_list:
        main(0.0001, 0.0016, mask_name)

        

Epoch 1
Loss: 0.109909
Time: 169.80s
Test result: 21.2430
Validation result: 21.0099


KeyboardInterrupt: 