<a href="https://colab.research.google.com/github/nir6760/DIP_DeepKSVD/blob/main/DeepKSVD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Blocks for extracting images from zip folder after uploading to colab.

In [None]:
import tensorflow as tf
tf.test.gpu_device_name()

'/device:GPU:0'

In [5]:
from zipfile import ZipFile
file_name = "/content/gray.zip"

with ZipFile(file_name, 'r') as zip:
  zip.extractall()
  print('Done')

Done


# **Model Training**

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import time
import Deep_KSVD_NirOren
from scipy import linalg

In [None]:
# List of the test image names BSD68:
file_test = open("test_gray.txt", "r")
onlyfiles_test = []
for e in file_test:
    onlyfiles_test.append(e[:-1])

In [None]:
# List of the train image names:
file_train = open("train_gray.txt", "r")
onlyfiles_train = []
for e in file_train:
    onlyfiles_train.append(e[:-1])

In [None]:
# Rescaling in [-1, 1]:
mean = 255 / 2
std = 255 / 2
data_transform = transforms.Compose(
    [Deep_KSVD.Normalize(mean=mean, std=std), Deep_KSVD.ToTensor()]
)
# Noise level:
sigma = 25
# Sub Image Size:
sub_image_size = 128

In [None]:
# Training Dataset:
my_Data_train = Deep_KSVD.SubImagesDataset(
    root_dir="gray",
    image_names=onlyfiles_train,
    sub_image_size=sub_image_size,
    sigma=sigma,
    transform=data_transform,
)

In [None]:
# Training Dataset:
my_Data_train = Deep_KSVD.SubImagesDataset(
    root_dir="gray",
    image_names=onlyfiles_train,
    sub_image_size=sub_image_size,
    sigma=sigma,
    transform=data_transform,
)

In [None]:
# Test Dataset:
my_Data_test = Deep_KSVD.FullImagesDataset(
    root_dir="gray", image_names=onlyfiles_test, sigma=sigma, transform=data_transform
)

# Dataloader of the test set:
num_images_test = 5
indices_test = np.random.randint(0, 68, num_images_test).tolist()
my_Data_test_sub = torch.utils.data.Subset(my_Data_test, indices_test)
dataloader_test = DataLoader(
    my_Data_test_sub, batch_size=1, shuffle=False, num_workers=0
)


In [None]:
# Dataloader of the training set:
batch_size = 1
dataloader_train = DataLoader(
    my_Data_train, batch_size=batch_size, shuffle=True, num_workers=0
)

# Create a file to see the output during the training:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
file_to_print = open("results_training.csv", "w")
file_to_print.write(str(device) + "\n")
file_to_print.flush()

In [None]:
# Initialization:
patch_size = 8
m = 16
Dict_init = Deep_KSVD.init_dct(patch_size, m)
Dict_init = Dict_init.to(device)

c_init = (linalg.norm(Dict_init.cpu(), ord=2)) ** 2
c_init = torch.FloatTensor((c_init,))
c_init = c_init.to(device)

w_init = torch.normal(mean=1, std=1 / 10 * torch.ones(patch_size ** 2)).float()
w_init = w_init.to(device)

D_in, H_1, H_2, H_3, H_4, D_out_lam, T, min_v, max_v = patch_size ** 2, 128, 64, 32, 16, 1, 5, -1, 1

In [None]:
model = Deep_KSVD.DenoisingNet_MLP(
    patch_size,
    D_in,
    H_1,
    H_2,
    H_3,
    H_4,
    D_out_lam,
    T,
    min_v,
    max_v,
    Dict_init,
    c_init,
    w_init,
    device,
)
model.to(device)

DenoisingNet_MLP(
  (unfold): Unfold(kernel_size=(8, 8), dilation=1, padding=0, stride=1)
  (linear1): Linear(in_features=64, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=32, bias=True)
  (linear4): Linear(in_features=32, out_features=16, bias=True)
  (linear5): Linear(in_features=16, out_features=1, bias=True)
)

In [None]:
# Construct our loss function and an Optimizer:
criterion = torch.nn.MSELoss(reduction="mean")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
start = time.time()
epochs = 3
running_loss = 0.0
print_every = 1
train_losses, test_losses = [], []

for epoch in range(epochs):  # loop over the dataset multiple times
    for i, (sub_images, sub_images_noise) in enumerate(dataloader_train, 0):
        # get the inputs
        sub_images, sub_images_noise = (
            sub_images.to(device),
            sub_images_noise.to(device),
        )

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(sub_images_noise)
        loss = criterion(outputs, sub_images)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % print_every == print_every - 1:  # print every x mini-batches
            train_losses.append(running_loss / print_every)

            end = time.time()
            time_curr = end - start
            file_to_print.write("time:" + " " + str(time_curr) + "\n")
            start = time.time()

            with torch.no_grad():
                test_loss = 0

                for patches_t, patches_noise_t in dataloader_test:
                    patches, patches_noise = (
                        patches_t.to(device),
                        patches_noise_t.to(device),
                    )
                    outputs = model(patches_noise)
                    loss = criterion(outputs, patches)
                    test_loss += loss.item()

                test_loss = test_loss / len(dataloader_test)

            end = time.time()
            time_curr = end - start
            file_to_print.write("time:" + " " + str(time_curr) + "\n")
            start = time.time()

            # break if the loss is closer to the minimum loss at the paper results
            if test_loss < 0.0085:
                break

            test_losses.append(test_loss)
            s = "[%d, %d] loss_train: %f, loss_test: %f" % (
                epoch + 1,
                (i + 1) * batch_size,
                running_loss / print_every,
                test_loss,
            )
            s = s + "\n"
            file_to_print.write(s)
            file_to_print.flush()
            running_loss = 0.0

        if i % (10 * print_every) == (10 * print_every) - 1:
            torch.save(model.state_dict(), "model.pth")
            np.savez(
                "losses.npz", train=np.array(test_losses), test=np.array(train_losses)
            )


file_to_print.write("Finished Training")

17

# **Model loading and Testing**

In [31]:
import numpy as np
from scipy import linalg
import pickle
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import Deep_KSVD_NirOren

from matplotlib.pyplot import imsave
from matplotlib.pyplot import cm

In [32]:
device = torch.device("cpu")

In [33]:
# Overcomplete Discrete Cosinus Transform:
patch_size = 8
m = 16
Dict_init = Deep_KSVD.init_dct(patch_size, m)
Dict_init = Dict_init.to(device)

# Squared Spectral norm:
c_init = linalg.norm(Dict_init, ord=2) ** 2
c_init = torch.FloatTensor((c_init,))
c_init = c_init.to(device)

# Average weight:
w_init = torch.normal(mean=1, std=1 / 10 * torch.ones(patch_size ** 2)).float()
w_init = w_init.to(device)

# Deep-KSVD:
D_in, H_1, H_2, H_3,H_4, D_out_lam, T, min_v, max_v = patch_size ** 2, 128, 64, 32,16, 1, 7, -1, 1
model = Deep_KSVD.DenoisingNet_MLP(
    patch_size,
    D_in,
    H_1,
    H_2,
    H_3,
    H_4,
    D_out_lam,
    T,
    min_v,
    max_v,
    Dict_init,
    c_init,
    w_init,
    device,
)

model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.to(device)

DenoisingNet_MLP(
  (unfold): Unfold(kernel_size=(8, 8), dilation=1, padding=0, stride=1)
  (linear1): Linear(in_features=64, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=32, bias=True)
  (linear4): Linear(in_features=32, out_features=16, bias=True)
  (linear5): Linear(in_features=16, out_features=1, bias=True)
)

In [34]:
# Test image names:
file_test = open("test_gray.txt", "r")
onlyfiles_test = []
for e in file_test:
    onlyfiles_test.append(e[:-1])

# Rescaling in [-1, 1]:
mean = 255 / 2
std = 255 / 2
data_transform = transforms.Compose(
    [Deep_KSVD.Normalize(mean=mean, std=std), Deep_KSVD.ToTensor()]
)
# Noise level:
sigma = 25

# Test Dataset:
my_Data_test = Deep_KSVD.FullImagesDataset(
    root_dir="gray", image_names=onlyfiles_test, sigma=sigma, transform=data_transform
)

dataloader_test = DataLoader(my_Data_test, batch_size=1, shuffle=False, num_workers=0)


In [36]:
# List PSNR:
file_to_print = open("list_test_PSNR.csv", "w")
file_to_print.write(str(device) + "\n")
file_to_print.flush()

with open("list_test_PSNR.txt", "wb") as fp:
    with torch.no_grad():
        list_PSNR = []
        list_PSNR_init = []
        PSNR = 0
        for k, (image_true, image_noise) in enumerate(dataloader_test, 0):

            image_true_t = image_true[0, 0, :, :]
            image_true_t = image_true_t.to(device)

            image_noise_0 = image_noise[0, 0, :, :]
            image_noise_0 = image_noise_0.to(device)

            image_noise_t = image_noise.to(device)
            image_restored_t = model(image_noise_t)
            image_restored_t = image_restored_t[0, 0, :, :]

            PSNR_init = 10 * torch.log10(
                4 / torch.mean((image_true_t - image_noise_0) ** 2)
            )
            file_to_print.write("Init:" + " " + str(PSNR_init) + "\n")
            file_to_print.flush()

            list_PSNR_init.append(PSNR_init)

            PSNR = 10 * torch.log10(
                4 / torch.mean((image_true_t - image_restored_t) ** 2)
            )
            PSNR = PSNR.cpu()
            file_to_print.write("Test:" + " " + str(PSNR) + "\n")
            file_to_print.flush()

            list_PSNR.append(PSNR)

            imsave("im_noisy_"+str(k)+'.pdf',image_noise_0, cmap=cm.gray)
            imsave("im_restored_"+str(k)+'.pdf',image_restored_t, cmap=cm.gray)



    mean = np.mean(list_PSNR)
    file_to_print.write("FINAL" + " " + str(mean) + "\n")
    file_to_print.flush()
    pickle.dump(list_PSNR, fp)

KeyboardInterrupt: ignored