In [3]:
import torch
import torchvision
import numpy as np
import cv2
import matplotlib.pyplot as plt 
import matplotlib.animation as manim
import os
import rawpy
import skimage.measure as sK_measure
from torch.utils.data import DataLoader
from torchvision import transforms

cuda = torch.device('cuda')

In [4]:
batch_size = 8
ds_folder = './Sony/'
try: 
    os.makedirs(ds_folder)
except Exception as e:
    pass
ds_exists = False
ds_train_sub = 'sony'
# for folder in os.listdir(ds_folder):
#     if 

In [9]:
low_exp_img = 'lei'
high_exp_img = 'hei'

def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out


class Dataset(torch.utils.data.Dataset):
    def __init__(self, text_path, main_dir='Sony/', res=(512,512), tp='bayer', transforms=None):
        self.main_dir = main_dir
        self.Xs = []
        with open(text_path) as f:
            lines = f.readlines()
            for line in lines:
                idx = line.find(' ')
                # self.Xs[line[:idx]] = line[idx+1:]
                self.Xs.append((line[:idx], line[idx+1:]))
        self.transforms = transforms
        self.ids = np.arange(len(self.Xs))
    def __len__(self):
        return len(self.Xs)
    
    def __getitem__(self, idx):
        if idx == 0: 
            self.ids = np.random.permutation(self.ids)
        sample = {low_exp_img: np.random.rand(1, 1024, 1024,4), high_exp_img: np.random.rand(1, 2048, 2048 ,3)}
        if self.transforms != None:
            sample = self.transforms(sample)
        return sample
    
        # return torch.rand(1024), torch.rand(1024)
        x_img_info, gt_img_info = self.Xs[self.ids[idx]]
        x_img_path = os.path.join(self.main_dir, x_img_info)
        gt_img_path = gt_img_info.split(' ')[0]
        gt_img_path = os.path.join(self.main_dir, gt_img_path)
        # print(x_img_path, gt_img_path)
        base_x, base_gt = x_img_path.split('/')[-1], gt_img_path.split('/')[-1]
        in_exposure = float(base_x[9:-5])
        gt_exposure = float(base_gt[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)
        x, gt = rawpy.imread(x_img_path), rawpy.imread(gt_img_path)
        x = np.expand_dims(pack_raw(x), axis=0) * ratio
        
        gt = gt.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        gt = np.expand_dims(np.float32(gt / 65535.0), axis=0)
        sample = {low_exp_img: x, high_exp_img: gt}
        if self.transforms != None:
            sample = self.transforms(sample)
        return sample
        return str(np.random.randint(100, 1000)) + '--' + str(idx) + ' -- ' + str(self.Xs[idx])
        
lki = None
class RandomCrop(object):
    def __init__(self, output_size) -> None:
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.h, self.w = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.h, self.w = output_size
    def __call__(self, sample):
        global lki
        lli, hli = sample[low_exp_img], sample[high_exp_img]
        lki = lli
        # print(lli.shape, hli.shape)
        or_h, or_w = lli.shape[1:3]
        # print('height width: ', or_h, or_w)
        start_height = torch.randint(0, or_h-self.h, size=(1,))
        start_width = torch.randint(0, or_w-self.w, size=(1,))
        # print(start_height, start_width)
        lli = lli[:, start_height:start_height+self.h, start_width:start_width+self.w]
        hli = hli[:, start_height*2:start_height*2+self.h*2, start_width*2:start_width*2+self.w*2]
        # print(lli.shape, hli.shape)
        return {low_exp_img: lli, high_exp_img:hli}

class RandomFlip(object):
    def __init__(self, probabilty=.3):
        self.probabilty = probabilty
    def __call__(self, sample):
        lli, hli = sample[low_exp_img], sample[high_exp_img]
        hor_prob = torch.rand(1)[0]
        ver_prob = torch.rand(1)[0]
        if hor_prob > self.probabilty:
            lli = np.flip(lli, axis=1)
            hli = np.flip(hli, axis=1)
        if ver_prob > self.probabilty:
            lli = np.flip(lli, axis=2)
            hli = np.flip(hli, axis=2)
        return {low_exp_img: lli, high_exp_img:hli}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        lli, hli = sample[low_exp_img].copy(), sample[high_exp_img].copy()
        # print('coming_shape: ', lli.shape, hli.shape)
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        lli = lli.transpose((0, 3, 1, 2))
        hli = hli.transpose((0, 3, 1, 2))
        lli = lli[0]
        hhi = hli[0]
        return {low_exp_img: torch.from_numpy(lli),
                high_exp_img: torch.from_numpy(hli)}


# im_size = (512, 512)
# composer = transforms.Compose([
#     RandomCrop(im_size),
#     # torchvision.transforms.RandomEqualize(.2),
#     RandomFlip(.3),
#     # torchvision.transforms.RandomHorizontalFlip(),
#     # torchvision.transforms.RandomVerticalFlip(),
#     ToTensor(),
# ])

# im_size = (512, 512)
# ds = Dataset('./Sony/Sony_train_list.txt', res=im_size, transforms=composer)
# for _, batch in enumerate(ds):
#     sample = batch
#     break
# print(sample)

In [49]:
def collate(batch):
    print(batch.shape)
    
    return torch.utils.data.dataloader.default_collate(batch)

In [10]:
batch_size = 4
workers = 0
im_size = (512, 512)

composer = transforms.Compose([
    RandomCrop(im_size),
    # torchvision.transforms.RandomEqualize(.2),
    RandomFlip(.3),
    # torchvision.transforms.RandomHorizontalFlip(),
    # torchvision.transforms.RandomVerticalFlip(),
    ToTensor(),
])
ds = Dataset('./Sony/Sony_train_list.txt', res=im_size, transforms=composer)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=workers)

In [8]:
import math
class OriginalModel(torch.nn.Module):
    def __init__(self, block_size, in_channel=4, kernel_size=3, dialation=1) -> None:
        super().__init__()
        self.block_size = block_size 
        
        self.activation = torch.nn.LeakyReLU(.02, True)
        self.max_pool = torch.nn.MaxPool2d(2, 2, 0, ceil_mode=True)
        
        self.convF1 = torch.nn.Conv2d(in_channel, 32, kernel_size, 1, 1, dialation, bias=True)
        self.convF2 = torch.nn.Conv2d(32, 32, kernel_size, 1, 1, dialation, bias=True)
        
        self.convF3 = torch.nn.Conv2d(32, 64, kernel_size, 1, 1, dialation, bias=True)
        self.convF4 = torch.nn.Conv2d(64, 64, kernel_size, 1, 1, dialation, bias=True)
        
        self.convF5 = torch.nn.Conv2d(64, 128, kernel_size, 1, 1, dialation, bias=True)
        self.convF6 = torch.nn.Conv2d(128, 128, kernel_size, 1, 1, dialation, bias=True)
        
        self.convF7 = torch.nn.Conv2d(128, 256, kernel_size, 1, 1, dialation, bias=True)
        self.convF8 = torch.nn.Conv2d(256, 256, kernel_size, 1, 1, dialation, bias=True)
        
        self.convF9 = torch.nn.Conv2d(256, 512, kernel_size, 1, 1, dialation, bias=True)
        self.convF10 = torch.nn.Conv2d(512, 512, kernel_size, 1, 1, dialation, bias=True)
        
        self.conv_upB10 = torch.nn.ConvTranspose2d(512, 256, 2, 2, bias=False)
        self.convB10 = torch.nn.Conv2d(512, 256, kernel_size, 1, 1, dialation, bias=True)
        self.convB9 = torch.nn.Conv2d(256, 256, kernel_size, 1, 1, dialation, bias=True)
        
        self.conv_upB8 = torch.nn.ConvTranspose2d(256, 128, 2, 2, bias=False)
        self.convB8 = torch.nn.Conv2d(256, 128, kernel_size, 1, 1, dialation, bias=True)
        self.convB7 = torch.nn.Conv2d(128, 128, kernel_size, 1, 1, dialation, bias=True)
        
        self.conv_upB6 = torch.nn.ConvTranspose2d(128, 64, 2, 2, bias=False)
        self.convB6 = torch.nn.Conv2d(128, 64, kernel_size, 1, 1, dialation, bias=True)
        self.convB5 = torch.nn.Conv2d(64, 64, kernel_size, 1, 1, dialation, bias=True)
        
        self.conv_upB4 = torch.nn.ConvTranspose2d(64, 32, 2, 2, bias=False)
        self.convB4 = torch.nn.Conv2d(64, 32, kernel_size, 1, 1, dialation, bias=True)
        self.convB3 = torch.nn.Conv2d(32, 32, kernel_size, 1, 1, dialation, bias=True)
        
        self.convB = torch.nn.Conv2d(32, 3 * self.block_size * self.block_size, 1, 1, 0, bias=True)
        
    def init_weights(self, seed=42):
        np.random.seed(seed)
        torch.manual_seed(seed)
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()
            elif isinstance(m, torch.nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
    def forward(self, x):
        x = self.convF1(x)
        x = self.activation(x)
        x = self.convF2(x)
        x = self.activation(x)
        up2 = x
        x = self.max_pool(x)
        
        x = self.convF3(x)
        x = self.activation(x)
        x = self.convF4(x)
        x = self.activation(x)
        up4 = x
        x = self.max_pool(x)
        
        x = self.convF5(x)
        x = self.activation(x)
        x = self.convF6(x)
        x = self.activation(x)
        up6 = x
        x = self.max_pool(x)
        
        x = self.convF7(x)
        x = self.activation(x)
        x = self.convF8(x)
        x = self.activation(x)
        up8 = x
        x = self.max_pool(x)
        
        x = self.convF9(x)
        x = self.activation(x)
        x = self.convF10(x)
        x = self.activation(x)
        
        x = self.conv_upB10(x)
        x = torch.cat((x[:, :, :up8.size(2), :up8.size(3)], up8), 1)
        x = self.convB10(x)
        x = self.activation(x)
        x = self.convB9(x)
        x = self.activation(x)
        
        x = self.conv_upB8(x)
        x = torch.cat((x[:, :, :up6.size(2), :up6.size(3)], up6), 1)
        x = self.convB8(x)
        x = self.activation(x)
        x = self.convB7(x)
        x = self.activation(x)
        
        x = self.conv_upB6(x)
        x = torch.cat((x[:, :, :up4.size(2), :up4.size(3)], up4), 1)
        x = self.convB6(x)
        x = self.activation(x)
        x = self.convB5(x)
        x = self.activation(x)
        
        
        x = self.conv_upB4(x)
        x = torch.cat((x[:, :, :up2.size(2), :up2.size(3)], up2), 1)
        x = self.convB4(x)
        x = self.activation(x)
        x = self.convB3(x)
        x = self.activation(x)
        
        x = self.convB(x)
        x = torch.nn.PixelShuffle(2)(x)
        return x
        
    def pixel_shuffle(x, upscale_factor, depth_first=False):
        pass
model = OriginalModel(2, 4, 3).to(device=cuda)

In [53]:
optim = torch.optim.NAdam(model.parameters(),lr=.003)
mse = torch.nn.MSELoss()
l1 = torch.nn.L1Loss()
epochs = 50
print_every = 500 # steps
iters = 0
running_loss = 0.0
for epoch in range(epochs): 
    for idx, batched in enumerate(dl, 0):
        # model.zero_grad()
        optim.zero_grad()
        llis, hlis = batched[low_exp_img], batched[high_exp_img]
        outputs = model(llis.to(device=cuda))
        outputs = outputs.to('cpu')
        # print(hlis.shape, outputs.shape)
        
        err = l1(hlis, outputs)
        err.backward()
        optim.step()
        running_loss += err.item()
        iters+=1
        if (iters % print_every == 0) or ((epoch == epochs-1) and (idx == len(dl)-1)):
            print('[%d/%d][%d/%d]\tloss::%f\trunning_loss::%f' % (epoch, epochs, idx, len(dl), err, running_loss / (idx+1)))
    running_loss = 0.0
    

  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)


[1/50][32/467]	loss::3613104.000000	running_loss::19.375000
[2/50][65/467]	loss::4018040.750000	running_loss::18.671875
[3/50][98/467]	loss::17288642.000000	running_loss::53.812500
[4/50][131/467]	loss::15243571.000000	running_loss::83.617188
[5/50][164/467]	loss::184733.187500	running_loss::126.988281
[6/50][197/467]	loss::27397.541016	running_loss::110.121094
[7/50][230/467]	loss::488240.125000	running_loss::202.833008
[8/50][263/467]	loss::72211.601562	running_loss::237.835938
[9/50][296/467]	loss::11608.757812	running_loss::233.623535
[10/50][329/467]	loss::114759.156250	running_loss::241.376709
[11/50][362/467]	loss::2668.846191	running_loss::348.927490
[12/50][395/467]	loss::2342.109131	running_loss::336.605835
[13/50][428/467]	loss::2828.647217	running_loss::288.806030
[14/50][461/467]	loss::742392.312500	running_loss::198.055298
[16/50][27/467]	loss::177769.750000	running_loss::2.937500
[17/50][60/467]	loss::251716.171875	running_loss::5.432617
[18/50][93/467]	loss::9956.503906

KeyboardInterrupt: 

In [56]:
import time
torch.save(model.state_dict(), './lsid__'+str(time.time())[:-10])

In [44]:
# outputs = model(llis.to(cuda))
outputs.to('cpu')

tensor([[[[ 0.0986,  0.1367,  0.0976,  ...,  0.1384,  0.0972,  0.1366],
          [ 0.0895, -0.1348,  0.0930,  ..., -0.1365,  0.0916, -0.1350],
          [ 0.0977,  0.1377,  0.0962,  ...,  0.1367,  0.0988,  0.1370],
          ...,
          [ 0.0902, -0.1346,  0.0916,  ..., -0.1365,  0.0915, -0.1353],
          [ 0.0939,  0.1359,  0.0949,  ...,  0.1352,  0.1003,  0.1364],
          [ 0.0892, -0.1375,  0.0879,  ..., -0.1381,  0.0912, -0.1351]],

         [[-0.1606, -0.1566, -0.1600,  ..., -0.1598, -0.1591, -0.1572],
          [-0.1528,  0.0398, -0.1522,  ...,  0.0350, -0.1534,  0.0360],
          [-0.1615, -0.1575, -0.1610,  ..., -0.1607, -0.1593, -0.1576],
          ...,
          [-0.1520,  0.0409, -0.1532,  ...,  0.0366, -0.1544,  0.0376],
          [-0.1619, -0.1581, -0.1620,  ..., -0.1593, -0.1621, -0.1555],
          [-0.1541,  0.0406, -0.1563,  ...,  0.0405, -0.1549,  0.0413]],

         [[ 0.1573, -0.0244,  0.1572,  ..., -0.0239,  0.1585, -0.0254],
          [ 0.0619, -0.1601,  

In [106]:
import glob
input_dir = './Sony/Sony/short/'
gt_dir = './Sony/Sony/long/'
train_fns = glob.glob(gt_dir + '0*.ARW')
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]

ps = 512  # patch size for training
save_freq = 500

gt_images = [None] * 6000
input_images = {}
input_images['300'] = [None] * len(train_ids)
input_images['250'] = [None] * len(train_ids)
input_images['100'] = [None] * len(train_ids)
for ind in np.random.permutation(len(train_ids)):
        # get the path from image id
        train_id = train_ids[ind]
        in_files = glob.glob(input_dir + '%05d_00*.ARW' % train_id)
        in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
        in_fn = os.path.basename(in_path)

        gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % train_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)

        if input_images[str(ratio)[0:3]][ind] is None:
            raw = rawpy.imread(in_path)
            print(raw.sizes)
            input_images[str(ratio)[0:3]][ind] = np.expand_dims(pack_raw(raw), axis=0) * ratio

            gt_raw = rawpy.imread(gt_path)
            print(gt_raw.sizes)
            im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            print(im.shape)
            gt_images[ind] = np.expand_dims(np.float32(im / 65535.0), axis=0)
            print(gt_images[ind].shape)

        # crop
        H = input_images[str(ratio)[0:3]][ind].shape[1]
        W = input_images[str(ratio)[0:3]][ind].shape[2]

        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)
        input_patch = input_images[str(ratio)[0:3]][ind][:, yy:yy + ps, xx:xx + ps, :]
        # gt_patch = gt_images[ind][:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]
        gt_patch = gt_images[ind][:, yy:yy + ps, xx:xx + ps, :]
        print(input_patch.shape, gt_patch.shape)

        if np.random.randint(2, size=1)[0] == 1:  # random flip
            input_patch = np.flip(input_patch, axis=1)
            gt_patch = np.flip(gt_patch, axis=1)
        if np.random.randint(2, size=1)[0] == 1:
            input_patch = np.flip(input_patch, axis=2)
            gt_patch = np.flip(gt_patch, axis=2)
        if np.random.randint(2, size=1)[0] == 1:  # random transpose
            input_patch = np.transpose(input_patch, (0, 2, 1, 3))
            gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))

        # input_patch = np.minimum(input_patch, 1.0)
        break

  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]


ImageSizes(raw_height=2848, raw_width=4288, height=2848, width=4256, top_margin=0, left_margin=0, iheight=2848, iwidth=4256, pixel_aspect=1.0, flip=0)
ImageSizes(raw_height=2848, raw_width=4288, height=2848, width=4256, top_margin=0, left_margin=0, iheight=2848, iwidth=4256, pixel_aspect=1.0, flip=0)
(2848, 4256, 3)
(1, 2848, 4256, 3)
(1, 512, 512, 4) (1, 512, 512, 3)


In [107]:
cv2.imshow('a', np.concatenate((input_patch[0, :, :, 1:], np.minimum(input_patch, 1.0)[0, :, :, 1:]), axis=1))
cv2.waitKey(0)
cv2.imshow('b', gt_patch[0])
cv2.waitKey(0)
cv2.destroyAllWindows()

In [89]:
cv2.destroyAllWindows()

In [None]:
with open('./Sony/Sony_train_list.txt') as f:
    lines = f.readlines()b
    for line in lines:
        idx = line.find(' ')
        print(line[:idx], '--', line[idx+1:])
        img_idx = line[:idx].split('/')[-1][:5]
        
        break

./Sony/short/00001_00_0.04s.ARW -- ./Sony/long/00001_00_10s.ARW ISO200 F8

00001


In [None]:
np.expand_dims(np.float32(im / 65535.0), axis=0).shape

(1, 2848, 4256, 3)

In [None]:
(np.expand_dims(pack_raw(raw), axis=0)).shape

(1, 1424, 2128, 4)

In [None]:
ratio

300.0

In [31]:
torch.rand(1)[0]

tensor(0.6056)

In [59]:
class Dataset2(torch.utils.data.Dataset):
    def __init__(self, text_path, main_dir='Sony/', res=(512,512), tp='bayer', transforms=None):
        self.main_dir = main_dir
        self.Xs = []
        self.X_images = {}
        self.GT_images = {}
        with open(text_path) as f:
            lines = f.readlines()
            for line in lines:
                idx = line.find(' ')
                # self.Xs[line[:idx]] = line[idx+1:]
                self.Xs.append((line[:idx], line[idx+1:]))
        self.transforms = transforms
        self.ids = np.arange(len(self.Xs))
    def __len__(self):
        return len(self.Xs)
    
    def __getitem__(self, idx):
        # if idx == 0: 
        #     self.ids = np.random.permutation(self.ids)
        # return torch.rand(1024), torch.rand(1024)
        x_img_info, gt_img_info = self.Xs[self.ids[idx]]
        x_img_path = os.path.join(self.main_dir, x_img_info)
        gt_img_path = gt_img_info.split(' ')[0]
        gt_img_path = os.path.join(self.main_dir, gt_img_path)
        x, gt = None, None 
        if x_img_path in self.X_images.keys():
            x = self.X_images[x_img_path]
        else:
            base_x, base_gt = x_img_path.split('/')[-1], gt_img_path.split('/')[-1]
            in_exposure = float(base_x[9:-5])
            gt_exposure = float(base_gt[9:-5])
            ratio = min(gt_exposure / in_exposure, 300)
            x = rawpy.imread(x_img_path)
            x = np.expand_dims(pack_raw(x), axis=0) * ratio
            self.X_images[x_img_path] = x
        if gt_img_path in self.GT_images.keys():
            gt = self.GT_images[gt_img_path]
        else: 
            gt = rawpy.imread(gt_img_path)
            gt = gt.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt = np.expand_dims(np.float32(gt / 65535.0), axis=0)
            self.GT_images[gt_img_path] = gt
            
        # print(x_img_path, gt_img_path)
        
        # x, gt = rawpy.imread(x_img_path), rawpy.imread(gt_img_path)
        # gt = np.expand_dims(np.float32(gt / 65535.0), axis=0)
        sample = {low_exp_img: x, high_exp_img: gt}
        if self.transforms != None:
            sample = self.transforms(sample)
        return sample

In [71]:
ds = Dataset2('./Sony/Sony_train_list.txt', res=im_size, transforms=composer)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=workers)

In [72]:
optim = torch.optim.NAdam(model.parameters(),lr=.003)
mse = torch.nn.MSELoss()
l1 = torch.nn.L1Loss()
epochs = 50
print_every = 250 # steps
iters = 0
running_loss = 0.0
for epoch in range(epochs): 
    for idx, batched in enumerate(dl, 0):
        # model.zero_grad()
        optim.zero_grad()
        llis, hlis = batched[low_exp_img], batched[high_exp_img]
        outputs = model(llis.to(device=cuda))
        outputs = outputs.to('cpu')
        # print(hlis.shape, outputs.shape)
        
        err = l1(hlis, outputs)
        err.backward()
        optim.step()
        running_loss += err.item()
        iters+=1
        if (iters % print_every == 0) or ((epoch == epochs-1) or (idx == len(dl)-1)):
            print('[%d/%d][%d/%d]\tloss::%f\trunning_loss::%f' % (epoch, epochs, idx, len(dl), err, running_loss / (idx+1)))
    running_loss = 0.0
    

  return F.l1_loss(input, target, reduction=self.reduction)


KeyboardInterrupt: 

In [77]:
for epoch in range(epochs): 
    for idx, batched in enumerate(dl, 0):
        iters+=1
        
        if (iters % print_every == 0) or ((epoch == epochs-1) or (idx == len(dl)-1)):
            print("ckck")
        
    break

ckck


MemoryError: Unable to allocate 46.2 MiB for an array with shape (2848, 4256) and data type float32

In [None]:
optim = torch.optim.NAdam(model.parameters(),lr=.003)
sch = torch.optim.lr_scheduler.StepLR(
        optim, step_size  = 10 , gamma = 0.5)
lr = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'min', .5, patience=10)

In [22]:
low_exp_img = 'lei'
high_exp_img = 'hei'

def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out


class Dataset(torch.utils.data.Dataset):
    def __init__(self, text_path, main_dir='Sony/', res=(512,512), tp='bayer', transforms=None):
        self.main_dir = main_dir
        self.Xs = []
        with open(text_path) as f:
            lines = f.readlines()
            for line in lines:
                idx = line.find(' ')
                # self.Xs[line[:idx]] = line[idx+1:]
                self.Xs.append((line[:idx], line[idx+1:]))
        self.transforms = transforms
        self.ids = np.arange(len(self.Xs))
    def __len__(self):
        return len(self.Xs)
    
    def __getitem__(self, idx):
        if idx == 0: 
            self.ids = np.random.permutation(self.ids)
        # sample = {low_exp_img: np.random.rand(1, 1024, 1024,4), high_exp_img: np.random.rand(1, 2048, 2048 ,3)}
        # if self.transforms != None:
        #     sample = self.transforms(sample)
        # return sample
    
        # return torch.rand(1024), torch.rand(1024)
        x_img_info, gt_img_info = self.Xs[self.ids[idx]]
        x_img_path = os.path.join(self.main_dir, x_img_info)
        gt_img_path = gt_img_info.split(' ')[0]
        gt_img_path = os.path.join(self.main_dir, gt_img_path)
        # print(x_img_path, gt_img_path)
        base_x, base_gt = x_img_path.split('/')[-1], gt_img_path.split('/')[-1]
        in_exposure = float(base_x[9:-5])
        gt_exposure = float(base_gt[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)
        x, gt = rawpy.imread(x_img_path), rawpy.imread(gt_img_path)
        x = np.expand_dims(pack_raw(x), axis=0) * ratio
        
        gt = gt.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        gt = np.expand_dims(np.float32(gt / 65535.0), axis=0)
        print(x.shape, gt.shape)
        sample = {low_exp_img: x, high_exp_img: gt}
        if self.transforms != None:
            sample = self.transforms(sample)
        print(x.dtype)
        return sample
        return str(np.random.randint(100, 1000)) + '--' + str(idx) + ' -- ' + str(self.Xs[idx])
        
lki = None
class RandomCrop(object):
    def __init__(self, output_size) -> None:
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.h, self.w = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.h, self.w = output_size
    def __call__(self, sample):
        global lki
        lli, hli = sample[low_exp_img], sample[high_exp_img]
        lki = lli
        # print(lli.shape, hli.shape)
        or_h, or_w = lli.shape[1:3]
        # print('height width: ', or_h, or_w)
        start_height = torch.randint(0, or_h-self.h, size=(1,))
        start_width = torch.randint(0, or_w-self.w, size=(1,))
        # print(start_height, start_width)
        lli = lli[:, start_height:start_height+self.h, start_width:start_width+self.w]
        hli = hli[:, start_height*2:start_height*2+self.h*2, start_width*2:start_width*2+self.w*2]
        # print(lli.shape, hli.shape)
        return {low_exp_img: lli, high_exp_img:hli}

class RandomFlip(object):
    def __init__(self, probabilty=.3):
        self.probabilty = probabilty
    def __call__(self, sample):
        lli, hli = sample[low_exp_img], sample[high_exp_img]
        hor_prob = torch.rand(1)[0]
        ver_prob = torch.rand(1)[0]
        if hor_prob > self.probabilty:
            lli = np.flip(lli, axis=1)
            hli = np.flip(hli, axis=1)
        if ver_prob > self.probabilty:
            lli = np.flip(lli, axis=2)
            hli = np.flip(hli, axis=2)
        return {low_exp_img: lli, high_exp_img:hli}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        lli, hli = sample[low_exp_img].copy(), sample[high_exp_img].copy()
        # print('coming_shape: ', lli.shape, hli.shape)
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        lli = lli.transpose((0, 3, 1, 2))
        hli = hli.transpose((0, 3, 1, 2))
        lli = lli[0]
        hhi = hli[0]
        return {low_exp_img: torch.from_numpy(lli),
                high_exp_img: torch.from_numpy(hli)}


batch_size = 4
workers = 0
im_size = (512, 512)

composer = transforms.Compose([
    RandomCrop(im_size),
    # torchvision.transforms.RandomEqualize(.2),
    RandomFlip(.3),
    # torchvision.transforms.RandomHorizontalFlip(),
    # torchvision.transforms.RandomVerticalFlip(),
    ToTensor(),
])
ds = Dataset('./Sony/Sony_train_list.txt', res=im_size, transforms=composer)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=workers)

In [23]:
model.to(cuda)
optim = torch.optim.NAdam(model.parameters(),lr=.003)
sch = torch.optim.lr_scheduler.StepLR(optim, step_size  = 10 , gamma = 0.5)
lr = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'min', .5, patience=10)
l1 = torch.nn.L1Loss()
epochs = 50
print_every = 500 # steps
iters = 0
running_loss = 0.0

for epoch in range(epochs): 
    for idx, batched in enumerate(dl, 0):
        # model.zero_grad()
        optim.zero_grad()
        llis, hlis = batched[low_exp_img], batched[high_exp_img]
        print(llis.shape)
        outputs = model(llis.to(device=cuda))
        outputs = outputs.to('cpu')
        # print(hlis.shape, outputs.shape)
        
        err = l1(hlis, outputs)
        err.backward()
        optim.step()
        running_loss += err.item()
        iters+=1
        if (iters % print_every == 0) or ((epoch == epochs-1) and (idx == len(dl)-1)):
            print('[%d/%d][%d/%d]\tloss::%f\trunning_loss::%f' % (epoch, epochs, idx, len(dl), err, running_loss / (idx+1)))
        break
    lr.step()
    running_loss = 0.0
    

(1, 1424, 2128, 4) (1, 2848, 4256, 3)
float32
(1, 1424, 2128, 4) (1, 2848, 4256, 3)
float32
(1, 1424, 2128, 4) (1, 2848, 4256, 3)
float32
(1, 1424, 2128, 4) (1, 2848, 4256, 3)
float32
torch.Size([4, 4, 512, 512])


  return F.l1_loss(input, target, reduction=self.reduction)


TypeError: step() missing 1 required positional argument: 'metrics'