In [None]:
import pickle
import zipfile
from google.colab import drive
drive.mount('/content/drive')

# update path to import from Drive
import sys
sys.path.append('content/drive/MyDrive')

import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
import torch.optim as optim
import time
from time import sleep
from tqdm import tqdm
import os

seed = 42
torch.manual_seed(seed)

device = 'cuda:0' # "cpu" 

Mounted at /content/drive


In [None]:
folder_path = './'
data_path = folder_path + 'FFPN-Ellipse-TrainingData-0.015IndividualNoise.pkl'

if os.path.isfile(data_path):
    print("FFPN data .pkl file already exists.")
else:
    print("Extracting data from .pkl file.")
    with zipfile.ZipFile('/content/drive/MyDrive/FixedPointNetworks/FFPN-Ellipse-TrainingData-IndividualNoise.zip', 'r') as zip_ref:
        zip_ref.extractall('./')
    print("Extraction complete.")

    sys.path.append('content/drive/MyDrive/FixedPointNetworks')
    sys.path.insert(0,'/content/drive/My Drive/FixedPointNetworks')  

state = torch.load(data_path)
A = state['A'].to(device)
u_train = state['u_true_train']
u_test = state['u_true_test']
data_obs_train = state['data_obs_train']
data_obs_test = state['data_obs_test']

Extracting data from .pkl file.
Extraction complete.


In [None]:
plt.figure()
plt.subplot(1,2,1)
plt.imshow(u_train[0,0,:,:])
plt.subplot(1,2,2)
plt.imshow(data_obs_train[0,0,:,:])
plt.colorbar()

In [None]:
print("A size = ", A.size())
S = torch.diag(torch.count_nonzero(A, dim=0) ** -1.0).float()
S = S.to(device)
print(S)
print('S size = ', S.size())

Create training datasets

In [None]:
batch_size  = 15
data_train  = TensorDataset(u_train[0:10000,:,:,:], data_obs_train[0:10000,:,:,:])
data_loader = DataLoader(dataset=data_train, batch_size=batch_size, shuffle=True)
n_batches   = int(len(data_loader.dataset)/batch_size)

print()
print(f'u_train.min(): {u_train.min()}')
print(f'u_train.max(): {u_train.max()}')
print("data_obs_train.shape = ", data_obs_train.shape)
print('n_batches = ', n_batches)

Define Network

In [None]:
class Regularizer_Net(nn.Module):
    def __init__(self, D, M, res_net_contraction=0.99, res_layers=4,
                 num_channels = 42):
        super().__init__()
        self.relu = nn.ReLU()
        self.leaky_relu = nn.LeakyReLU(0.05)
        self.gamma = res_net_contraction  
        self.D = D 
        self.M = M
        self.Mt = M.t()        
        in_channels = lambda i: 1 if i == 0 else num_channels
        out_channels = lambda i: 1 if i == res_layers-1 else num_channels
        self.convs = nn.ModuleList([nn.Sequential(nn.Conv2d(
                                            in_channels=in_channels(i), 
                                            out_channels=num_channels, 
                                            kernel_size=3, stride=1, 
                                            padding=(1,1)),
                                            self.leaky_relu,
                                            nn.Conv2d(in_channels=num_channels, 
                                            out_channels=out_channels(i), 
                                            kernel_size=3, stride=1, 
                                            padding=(1,1)),
                                            self.leaky_relu)
                                    for i in range(res_layers)]) 

    def name(self) -> str:
        return "Regularizer_Net"

    def device(self):
        return next(self.parameters()).data.device

    def _T(self, u, d):
        batch_size = u.shape[0]

        # Learned Regularization Operator
        for idx, conv in enumerate(self.convs):
            u_ref = u if idx + 1 < len(self.convs) \
                    else u[:,0,:,:].view(batch_size,1,128,128)
            Du = torch.roll(u, 1, dims=-1) - u if idx%2 == 0 \
                 else torch.roll(u, 1, dims=-2) - u
            u = u_ref + conv(Du)
        u = torch.clamp(u, min=-1.0e1, max=1.0e1)

        # Constraints Projection
        u_vec = u.view(batch_size, -1).to(self.device())
        u_vec = u_vec.permute(1,0).to(self.device())   
        d = d.view(batch_size,-1).to(self.device())
        d = d.permute(1,0)
        res = torch.matmul(self.Mt, self.M.matmul(u_vec) - d)
        res = 1.99 * torch.matmul(self.D.to(self.device()), res)
        res = res.permute(1,0)
        res = res.view(batch_size, 1, 128, 128).to(self.device())
        return u - res

    def normalize_lip_const(self, u, d):
        ''' Scale convolutions in R to make it gamma Lipschitz

            It should hold that |R(u,v) - R(w,v)| <= gamma * |u-w| for all u
            and w. If this doesn't hold, then we must rescale the convolution.
            Consider R = I + Conv. To rescale, ideally we multiply R by

                norm_fact = gamma * |u-w| / |R(u,v) - R(w,v)|,
            
            averaged over a batch of samples, i.e. R <-- norm_fact * R. The 
            issue is that ResNets include an identity operation, which we don't 
            wish to rescale. So, instead we use
                
                R <-- I + norm_fact * Conv,
            
            which is accurate up to an identity term scaled by (norm_fact - 1).
            If we do this often enough, then norm_fact ~ 1.0 and the identity 
            term is negligible.

            Note: BatchNorm and ReLUs are nonexpansive when...???
        '''
        noise_u = 0.05 * torch.randn(u.size(), device=self.device()) 
        w = u.clone() + noise_u
        w = w.to(self.device())
        Twd = self._T(w, d)
        Tud = self._T(u, d)
        T_diff_norm = torch.mean(torch.norm(Twd - Tud, dim=1))
        u_diff_norm = torch.mean(torch.norm(w - u, dim=1))
        R_is_gamma_lip = T_diff_norm <= self.gamma * u_diff_norm
        if not R_is_gamma_lip:
            normalize_factor = (self.gamma * u_diff_norm / T_diff_norm) ** (1.0 / len(self.convs))
            print("normalizing!")
            for i in range(len(self.convs)):
                self.convs[i][0].weight.data *= normalize_factor
                self.convs[i][0].bias.data *= normalize_factor
                self.convs[i][2].weight.data *= normalize_factor
                self.convs[i][2].bias.data *= normalize_factor                

    def forward(self, d, eps=1.0e-3, max_depth=100, 
                depth_warning=False):
        ''' FPN forward prop

            With gradients detached, find fixed point. During forward iteration,
            u is updated via R(u,Q(d)) and Lipschitz constant estimates are
            refined. Gradient are attached performing one final step.
        '''         
        with torch.no_grad():
            self.depth = 0.0
            u = torch.zeros((d.size()[0], 1, 128, 128), 
                            device=self.device())
            u_prev = np.Inf*torch.ones(u.shape, device=self.device())            
            all_samp_conv = False
            while not all_samp_conv and self.depth < max_depth:
                u_prev = u.clone()   
                u = self._T(u, d)
                res_norm = torch.max(torch.norm(u - u_prev, dim=1)) 
                self.depth += 1.0
                all_samp_conv = res_norm <= eps
            
            if self.training:
                self.normalize_lip_const(u, d)

        if self.depth >= max_depth and depth_warning:
            print("\nWarning: Max Depth Reached - Break Forward Loop\n")

        return self._T(u, d)                

Set up training parameters

In [None]:
Phi = Regularizer_Net(S.to(device), A.to(device))
Phi = Phi.to(device)

pytorch_total_params = sum(p.numel() for p in Phi.parameters() if p.requires_grad)
print(f'Number of trainable parameters: {pytorch_total_params}')

Number of trainable parameters: 96307


In [None]:
max_epochs    = 25 
max_depth     = 100
eps           = 5.0e-3

criterion = torch.nn.MSELoss()  
learning_rate = 2.5e-5
optimizer = optim.Adam(Phi.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

fmt        = '[{:2d}/{:2d}]: train_loss = {:7.3e} | ' 
fmt       += 'depth = {:5.1f} | lr = {:5.1e} | time = {:4.1f} sec'

load_weights = True
if load_weights:
    # FOR RELOADING
    state = torch.load('./drive/MyDrive/FixedPointNetworks/Feasible_FPN_Ellipses_weights.pth', map_location=torch.device(device)) 
    Phi.load_state_dict(state['Phi_state_dict'])
    print('Loaded Phi from file.')

Loaded Phi from file.


Execute Training

In [None]:
best_loss = 1.0e10 

for epoch in range(max_epochs): 
  sleep(0.5)  # slows progress bar so it won't print on multiple lines
  tot = len(data_loader)
  loss_ave = 0.0
  start_time_epoch = time.time() 
  with tqdm(total=tot, unit=" batch", leave=False, ascii=True) as tepoch:
      
    for idx, (u_batch, d) in enumerate(data_loader): 
      u_batch    = u_batch.to(device) 
      batch_size = u_batch.shape[0]
      train_batch_size = d.shape[0] # re-define if batch size changes
      Phi.train()
      optimizer.zero_grad()
      u = Phi(d.to(device), max_depth=max_depth, eps=eps) # add snippet for hiding
      output = criterion(u, u_batch)
      train_loss = output.detach().cpu().numpy()
      loss_ave += train_loss * train_batch_size
      output.backward()
      optimizer.step()

      tepoch.update(1)
      tepoch.set_postfix(train_loss="{:5.2e}".format(train_loss),
                            depth="{:5.1f}".format(Phi.depth))
    
    if epoch%1 == 0:
        # compute test image 
        Phi.eval()
        u_test_approx = Phi(data_obs_test[0,:,:,:], max_depth=max_depth, eps=eps)

        plt.figure()
        plt.subplot(2,2,1)
        plt.imshow(u_batch[0,0,:,:].cpu(), vmin=0, vmax=1)
        plt.title('u true train')
        plt.subplot(2,2,2)
        plt.imshow(u[0,0,:,:].detach().cpu(), vmin=0, vmax=1)
        plt.title('u approx train')
        plt.subplot(2,2,3)
        plt.imshow(u_test[0,0,:,:].cpu(), vmin=0, vmax=1)
        plt.title('u true test')
        plt.subplot(2,2,4)
        plt.imshow(u_test_approx[0,0,:,:].detach().cpu(), vmin=0, vmax=1)
        plt.title('u approx test')
        plt.show() 
        Phi.train() 

    # ---------------------------------------------------------------------
    # Save weights
    # ---------------------------------------------------------------------
    if loss_ave < best_loss:
        best_loss = loss_ave
        state = {
            'Phi_state_dict': Phi.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler
        }
        file_name = './drive/MyDrive/FixedPointNetworks/Feasible_FPN_Ellipses_weights.pth'
        torch.save(state, file_name)
        print('\nModel weights saved to ' + file_name)        

  loss_ave = loss_ave/len(data_loader.dataset)
  end_time_epoch = time.time()
  time_epoch = end_time_epoch - start_time_epoch
  lr_scheduler.step()
  print(fmt.format(epoch+1, max_epochs, loss_ave, Phi.depth, 
                   optimizer.param_groups[0]['lr'],
                   time_epoch))

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

n_samples = u_test.shape[0]

data_test     = TensorDataset(u_test, data_obs_test)
test_data_loader     = DataLoader(dataset=data_test, batch_size=batch_size, shuffle=True)


def compute_avg_SSIM_PSNR(u_true, u_gen, n_mesh, data_range):
    # assumes images are size n_samples x n_features**2 and are detached
    n_samples = u_true.shape[0]
    u_true = u_true.reshape(n_samples, n_mesh, n_mesh).cpu().numpy()
    u_gen  = u_gen.reshape(n_samples, n_mesh, n_mesh).cpu().numpy()
    ssim_val = 0
    psnr_val = 0
    for j in range(n_samples):
        ssim_val = ssim_val + ssim(u_true[j,:,:], u_gen[j,:,:], data_range=data_range)
        psnr_val = psnr_val + psnr(u_true[j,:,:], u_gen[j,:,:], data_range=data_range)
    return ssim_val/n_samples, psnr_val/n_samples

In [None]:
test_loss_ave = 0
test_PSNR_ave = 0
test_SSIM_ave = 0
with torch.no_grad():
    for idx, (u_batch, d) in enumerate(test_data_loader): 

        u_batch    = u_batch.to(device) 
        batch_size = u_batch.shape[0]
        temp       = u_batch.view(batch_size, -1)
        temp       = temp.permute(1,0)        
        test_batch_size = d.shape[0] 
        Phi.eval()
        u = Phi(d, max_depth=max_depth, eps=eps) 
        output = criterion(u, u_batch)
        test_loss = output.detach().cpu().numpy()
        test_SSIM, test_PSNR = compute_avg_SSIM_PSNR(u_batch, u, 128, 1)
        test_PSNR_ave += test_PSNR * test_batch_size
        test_loss_ave += test_loss * test_batch_size
        test_SSIM_ave += test_SSIM * test_batch_size

        print('test_PSNR = {:7.3e}'.format(test_PSNR))
        print('test_SSIM = {:7.3e}'.format(test_SSIM))
        print('test_loss = {:7.3e}'.format(test_loss))
        if idx%1 == 0:
            # compute test image 
            plt.figure()
            plt.subplot(1,2,1)
            plt.imshow(u_batch[0,0,:,:].cpu(), vmin=0, vmax=1)
            plt.title('u true')
            plt.subplot(1,2,2)
            plt.imshow(u[0,0,:,:].detach().cpu(), vmin=0, vmax=1)
            plt.title('u approx')
            plt.show()  

print('\n\nSUMMARY')
print('test_loss_ave =  {:7.3e}'.format(test_loss_ave / 1000))
print('test_PSNR_ave =  {:7.3e}'.format(test_PSNR_ave / 1000))
print('test_SSIM_ave =  {:7.3e}'.format(test_SSIM_ave / 1000))

In [None]:
ind_val = 0

u = Phi(data_obs_test[ind_val,:,:,:]).view(128,128)
u_true = u_test[ind_val,0,:,:]
def string_ind(index):
    if index < 10:
        return '000' + str(index)
    elif index < 100:
        return '00' + str(index)
    elif index < 1000:
        return '0' + str(index)
    else:
        return str(index)

cmap = 'gray'
fig = plt.figure()
plt.imshow(np.rot90(u.detach().cpu().numpy()),cmap=cmap, vmin=0, vmax=1)
plt.axis('off')


save_loc = './drive/MyDrive/FixedPointNetworks/Learned_Feasibility_Ellipse_FFPN_ind_' + string_ind(ind_val) + '.pdf'
plt.savefig(save_loc,bbox_inches='tight')
plt.show()

print("SSIM: ", compute_avg_SSIM_PSNR(u_true.view(1,128,128), u.view(1,128,128).detach(), 128, 1))

#------------------------------------------------------------
# TRUE
#------------------------------------------------------------

cmap = 'gray'
fig = plt.figure()
plt.imshow(np.rot90(u_true.detach().cpu().numpy()),cmap=cmap, vmin=0, vmax=1)
plt.axis('off')


save_loc = './drive/MyDrive/FixedPointNetworks/Learned_Feasibility_Ellipse_GT_ind_' + string_ind(ind_val) + '.pdf'
plt.savefig(save_loc,bbox_inches='tight')
plt.show()
