In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn

from torch.utils.tensorboard import SummaryWriter
from dataset import MyDataset
from dataloader import PatchDataLoader
from torchvision import transforms
from PIL import Image
from model import CNN
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

Image.MAX_IMAGE_PIXELS = 933120000

In [None]:
torch.cuda.empty_cache()
save_log_dir = "2_images_training"
writer = SummaryWriter(log_dir=save_log_dir)

In [None]:
def save_checkpoint(state, checkpoint, filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)

In [None]:
def draw_graph(loss, psnr, ssim, epoch):
    
    writer.add_scalar("Train Loss", loss, epoch)
    writer.add_scalar("PSNR/Epoch", psnr, epoch)
    writer.add_scalar("SSIM/Epoch", ssim, epoch)
    
    del psnr, loss

In [None]:
""" def plot_feature(image, model):
    no_of_layers= 0
    layers = []
    weights = []
    outputs = []
    processed = []
    names = []
    type_list = [nn.Conv2d, nn.AvgPool2d, nn.ConvTranspose2d, nn.Upsample]
    model_children=list(model.children())

    for child in model_children:
        if type(child) in type_list:
            no_of_layers+=1
            layers.append(child)
            if type(child)==nn.Conv2d:
                weights.append(child.weight)

    for layer in layers[0:]:
        image = layer(image)
        outputs.append(image)

    for feature_map in outputs:
        feature_map = feature_map.squeeze(0)
        gray_scale = torch.sum(feature_map,0)
        gray_scale = gray_scale / feature_map.shape[0]
        #for gray_scale in feature_map:
        processed.append(gray_scale.detach().cpu().numpy())
        names.append(str(gray_scale.shape))
        
    fig = plt.figure(figsize=(20, 700))
    for i in range(len(processed)):
        a = fig.add_subplot(1000, 10, i+1)
        imgplot = plt.imshow(processed[i], cmap="gray")
        a.axis("on")
        a.set_title(names[i], fontsize=8)
    plt.savefig(str(f'{save_log_dir}/feature_maps.jpg'), bbox_inches='tight')

    return fig """

In [None]:
transform =transforms.Compose([transforms.ToTensor()])
dataset = MyDataset("image")
loader = PatchDataLoader(dataset=dataset, transform=transform, kernel_size=64, stride=64, batch_size=80)

In [None]:
""" dataiter = iter(loader)
images = next(dataiter)
img_grid = torchvision.utils.make_grid(images)
writer.add_image('Example input image.', img_grid)
del img_grid """

In [None]:
model = CNN()
model.to("cuda")
model.train()
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
#writer.add_graph(model, images)

In [None]:
epochs = 2000
summary = { 
        "BSIZE": loader.batch_size,
        "EPOCH": epochs, 
        "LFUNC": criterion, 
        "OPTIM": optimizer,
        "SIZE" : loader.size,
        "STRIDE": loader.stride,
        "BOTTLENECK": "NONE",
        "RELU": model.encoder[1],
        "DOWNSAMPLE": model.encoder[2],
        "UPSAMPLE": model.decoder[2]
        }

writer.add_text("Summary", str(summary))

In [None]:
for epoch in range(epochs):
    for batch_idx, inputs in enumerate(loader,0):
        if inputs == None:
            break
        inputImage = inputs.to("cuda")
        outputImage = model(inputImage)
        optimizer.zero_grad()
        loss = criterion(outputImage, inputImage)
        loss.backward()
        optimizer.step()
        
        if batch_idx == 0:
            one_img_grid = torchvision.utils.make_grid(inputImage)
            writer.add_image('First batch input image.', one_img_grid, global_step=epoch) 
                       
            one_out_img_grid = torchvision.utils.make_grid(outputImage)
            writer.add_image('First batch output image.', one_out_img_grid, global_step=epoch)
            
        if batch_idx == 100:
            in_img_grid = torchvision.utils.make_grid(inputImage)
            writer.add_image('6000. batch input image.', in_img_grid, global_step=epoch) 
                       
            out_img_grid = torchvision.utils.make_grid(outputImage)
            writer.add_image('6000. batch output image.', out_img_grid, global_step=epoch)
              
    save_checkpoint({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, save_log_dir)

    psnr_val = psnr(inputImage[0].transpose(0,2).detach().cpu().numpy(), outputImage[0].transpose(0,2).detach().cpu().numpy())
    ssim_val = ssim(inputImage[0].transpose(0,2).detach().cpu().numpy(), outputImage[0].transpose(0,2).detach().cpu().numpy(), channel_axis=2)
    print(f"Epoch: {epoch}/{epochs}, Loss: {loss}, PSNR: {psnr_val}, SSIM: {ssim_val}")
    draw_graph(loss=loss, psnr=psnr_val, ssim=ssim_val, epoch=epoch)

In [None]:
""" predict_model = CNN()
checkpoint = torch.load("deneme/3-1296/90*/checkpoint.pth.tar")
predict_model.load_state_dict(checkpoint['state_dict'])
predict_model.to("cuda")
predict_model.eval() """

In [None]:
""" predict_log_dir = save_log_dir + "/predict"
predict_dataset = MyDataset("predict_image")
predict_loader = PatchDataLoader(predict_dataset, transform=transform, kernel_size=64, stride=64, batch_size=80) """

In [None]:
""" for batch_idx, predict_input in enumerate(predict_loader,0):
    with torch.no_grad():
        predict_image = predict_input.to("cuda")
        predict_output = predict_model(predict_image)
    pred_grid = torchvision.utils.make_grid(predict_output)
    writer.add_image('First batch predict image.', pred_grid)
    break """