In [1]:
from __future__ import print_function, division
import os
import torchvision
import torch
from skimage import io, transform
import numpy as np
from torchvision import transforms, utils
import torch.nn as nn
import cv2
import utils
import matplotlib.pyplot as plt
import random

from torchvision.utils import save_image

c:\users\adams\appdata\local\programs\python\python39\lib\site-packages\numpy\.libs\libopenblas.EL2C6PLE4ZYW3ECEVIV3OXXGRN2NRFM2.gfortran-win_amd64.dll
c:\users\adams\appdata\local\programs\python\python39\lib\site-packages\numpy\.libs\libopenblas.GK7GX5KEQ4F6UYO3P26ULGBQYHGQO7J4.gfortran-win_amd64.dll


In [2]:
class MFFNet(torch.nn.Module):
    def __init__(self):
        super(MFFNet, self).__init__()
        
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        self.res6 = ResidualBlock(128)
        self.res7 = ResidualBlock(128)
        self.res8 = ResidualBlock(128)
        self.res9 = ResidualBlock(128)
        self.res10 = ResidualBlock(128)
        self.res11 = ResidualBlock(128)
        self.res12 = ResidualBlock(128)
        self.res13 = ResidualBlock(128)
        self.res14 = ResidualBlock(128)
        self.res15 = ResidualBlock(128)
        self.res16 = ResidualBlock(128)
        
        self.deconv1 = UpsampleConvLayer(128*2, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64*2, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32*2, 3, kernel_size=9, stride=1)

        self.relu = torch.nn.ReLU()
        
        # Attentions
        self.att1 = AttentionBlock(128, 128, 64)
        self.att2 = AttentionBlock(64, 64, 32)
        self.att3 = AttentionBlock(32, 32, 16)

    
    def forward(self, X):
        o1 = self.relu(self.conv1(X))
        o2 = self.relu(self.conv2(o1))
        o3 = self.relu(self.conv3(o2))

        y = self.res1(o3)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.res6(y)
        y = self.res7(y)
        y = self.res8(y)
        y = self.res9(y)
        y = self.res10(y)
        y = self.res11(y)
        y = self.res12(y)
        y = self.res13(y)
        y = self.res14(y)
        y = self.res15(y)
        y = self.res16(y)
        
        o3 = self.att1(y, o3)
        in1 = torch.cat((y, o3), 1)
        y = self.relu(self.deconv1(in1))

        o2 = self.att2(y, o2)
        in2 = torch.cat((y, o2), 1)
        y = self.relu(self.deconv2(in2))

        o1 = self.att3(y, o1)
        in3 = torch.cat((y, o1), 1)
        y = self.deconv3(in3)
        
        return y



class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out



class ResidualBlock(torch.nn.Module):
    
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = out + residual
        return out



class UpsampleConvLayer(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out
    


class AttentionBlock(torch.nn.Module):
    
    def __init__(self, Fg, Fl, Fint):
        super(AttentionBlock, self).__init__()               
        self.Wg = nn.Sequential(nn.Conv2d(Fg, Fint, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(Fint))
        self.Wx = nn.Sequential(nn.Conv2d(Fl, Fint, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(Fint))
        self.relu = nn.ReLU()
        self.psi = nn.Sequential(nn.Conv2d(Fint, 1, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(1), nn.Sigmoid())
        
    def forward(self, g, x):
        g1 = self.Wg(g)   
        x1 = self.Wx(x)
        y = self.relu(g1 + x1)  
        y = self.psi(y)
        return torch.mul(y, x)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
imageFilter = MFFNet()
print(device)
model_name = 'MFF-net_all3_a_new'
imageFilter.load_state_dict( torch.load('%s.ckpt'%(model_name)) )
imageFilter = imageFilter.to(device).float()

cuda


In [None]:
data_root = 'training_data'
guide1_root = 'training_data'
guide2_root = '380 train a'
guide3_root = '640 train a'
out_root = 'new model multiple attenuated no modulation'
if not os.path.exists(out_root):
    os.mkdir(out_root)
else:
    print("out_root exists")

for seq in range(1,1305):
    # input rgb image is obtained from demosaicing the raw (no other manipulation)
    # saved in 16-bit TIFF image
    file = ('rgb_%s.png' % (seq) )
    filename = os.path.join( data_root, file )
    inputs = io.imread(filename) / 255
    var = random.randint(100,2000)/10000
    gauss = utils.generateGaussNoise(inputs, 0, var)
    noisy_inputs = utils.validate_im(inputs + gauss)
#     inputs = io.imread(filename)
#     utils.matplotlib_imshow(inputs)
#     plt.imshow(noisy_inputs)
#     print(noisy_inputs.shape)
#     save_image(torch.tensor(np.transpose(noisy_inputs,(2,0,1))), "noisy_1337.png")
#     print(inputs.shape)
#     inputs = np.transpose(inputs,(2,0,1))
#     print(inputs.shape)

#     file = ('guide_%s.bmp' % (seq) )
#     filename = os.path.join( data_root, file )
#     guided = io.imread(filename) / 255

    guidedfile = ('rgb_%s.png' %(seq))
    guidedfilename = os.path.join(guide1_root, guidedfile)
    guide1img = io.imread(guidedfilename)

    guide1 = (guide1img[:,:,0])
    
    guidedfile = ('training_%s_380_a.png' %(seq))
    guidedfilename = os.path.join(guide2_root, guidedfile)
    guide2img = io.imread(guidedfilename)

    guide2 = guide2img[:,:,0]/2 + guide2img[:,:,1]/2
    
    
    guidedfile = ('training_%s_640_a.png' %(seq))
    guidedfilename = os.path.join(guide3_root, guidedfile)
    guide3img = io.imread(guidedfilename)

    guide3 = guide3img[:,:,0]/2 + guide3img[:,:,2]/2
    
    
    
    
    #guided = utils.modulate(guided)
#     plt.imshow(guided, cmap="gray")
#     plt.imshow(guided)
#     save_image(torch.tensor(np.transpose(guided,(2,0,1))), "rgb_1337.png")
    
#     guided_3_channel = np.zeros_like(inputs)
#     guided_3_channel[:,:,0] = guided
#     guided_3_channel[:,:,1] = guided
#     guided_3_channel[:,:,2] = guided
#     plt.imshow(guided_3_channel)
#     save_image(torch.tensor(np.transpose(guided_3_channel,(2,0,1))), "rgb_1337.png")
#     inputs = (inputs*80)**0.4
#     plt.imshow(inputs)
    inputs = np.concatenate((guide1[:,:,None], guide2[:,:,None], guide3[:,:,None]), 2)
    inputs = np.transpose(inputs,(2,0,1))
    inputs = torch.from_numpy(inputs)
    inputs = inputs[None,:,:,:].float()

    with torch.no_grad():
        inputs = inputs.to(device) 
        outputs = imageFilter(inputs)
    outputs[outputs>1] = 1
    outputs[outputs<0] = 0

#     # the parameter for color balance and brightness should be tuned for different scenes
#     outputs[0,0,:,:] = outputs[0,0,:,:]*1.1*1.5
#     outputs[0,1,:,:] = outputs[0,1,:,:]*1*1.5
#     outputs[0,2,:,:] = outputs[0,2,:,:]*1.5*1.5

    save_image(outputs[0,:,:,:], '%s/out_%s.png' % (out_root, seq))

out_root exists
