In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split
from tqdm import notebook, tqdm

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchsummary import summary
import torchvision.models as models
from torch.optim import lr_scheduler

#from model_kpn import KPN, LossBasic
#from model_baseline import Unet
from model_gdfn import GDFN

from data import ims, ims_noise

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# préparation des données

In [None]:
batch_size = 16

In [None]:
N_ims, h, w, color = ims.shape
ims = ims[:N_ims].astype(np.float32).transpose(0,3,1,2)
ims_noise = ims_noise[:N_ims].astype(np.float32).transpose(0,3,1,2)

In [None]:
# train test split
test_size = 0.1

train_X, train_Y = ims_noise, ims
train_X, test_X, train_Y, test_Y = train_test_split(train_X, train_Y, test_size=test_size, random_state=42)

#train_X = train_X[:,np.newaxis,...]
#test_X = test_X[:,np.newaxis,...]

print('Training X: ', train_X.shape, train_X.dtype, train_X.max(), train_X.min())
print('Training Y: ', train_Y.shape, train_Y.dtype, train_Y.max(), train_Y.min())
print('Testing X: ', test_X.shape, test_X.dtype, test_X.max(), test_X.min())
print('Testing Y: ', test_Y.shape, test_Y.dtype, test_Y.max(), test_Y.min())

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
        
    def __getitem__(self, index):      #这个方法是必须要有的，用于按照索引读取每个元素的具体内容
        return self.X[index], self.Y[index]
    
    def __len__(self):                 #这个函数也必须要写，它返回的是数据集的长度，也就是多少张图片，要和loader的长度作区分
        return len(self.X)
        
train_set = MyDataset(train_X, train_Y)
test_set = MyDataset(test_X, test_Y)

def collate(batch): 
    inputs = torch.FloatTensor([item[0] for item in batch])
    target = torch.FloatTensor([item[1] for item in batch])
    return inputs, target

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate)

In [None]:
'''please choose the model from 'kpn', 'unet' and 'gdfn' '''
choice = 'gdfn'

if choice == 'kpn':
    model = KPN(color=False, burst_length=1, blind_est=True, kernel_size=[3], sep_conv=False, 
                     channel_att=False, spatial_att=True, core_bias=False).to(device)
    mode = 1
elif choice == 'unet':
    model = Unet(color=False, blind_est=True, channel_att=False, spatial_att=False, core_bias=False).to(device)
    mode = 2
elif choice == 'gdfn':
    model = GDFN(filter_size = (3,3), color=False, blind_est=True, channel_att=False, spatial_att=False, core_bias=False).to(device)
    mode = 2
else:
    assert
    
print('# model parameters:', sum(param.numel() for param in model.parameters()))

In [None]:
if_load = False
if if_load:
    model.load_state_dict(torch.load(r'./model_weights/gdfn.pkl'))
    #model.eval()

# analyse des filtres dynamiques

In [None]:
color = 1 if color == False else 3

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

In [None]:
model.eval()
test_X, test_Y = next(iter(test_loader))
test_X, test_Y = test_X.to(device), test_Y.to(device)
if mode == 1:
    pred_Y,core = model(test_X, test_X)
elif mode == 2:
    pred_Y, core = model(test_X)
else:
    assert

print(core.shape)

## gdfn

In [None]:
plt.figure(figsize = (15,5*num_filters))
for i in range(num_filters):
    cur_core = core[:,i*color**2:(i+1)*color**2,:,:]
    cur_core = cur_core.mean(axis=0)
    
    plt.subplot(num_filters,1,i+1)
    plt.imshow(cur_core.detach().cpu().squeeze(), cmap='gray')
    plt.axis('off')
    
#plt.savefig('./eval/gdfn_'+current_time+'.png')
plt.show()