In [1]:
%load_ext autoreload

%autoreload 2
import sys
%cd noise2self

/home/aruba19th/noise2self


In [None]:
!git clone https://github.com/deepskies/noise2self.git

In [2]:
from glob import glob
import os
from collections import defaultdict, Counter
from skimage.io import imread
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from astropy.io import fits
from torch import Tensor
import math

In [3]:
from util import show, plot_images, plot_tensors, plot_grid
from util import clean_two_channel_data, clean_three_channel_data, clamp_data
from models.babyunet import BabyUnet
from models.dncnn import DnCNN
from models.dnfcn8 import DnFCN8
from models.singleconv import SingleConvolution
from util import getbestgpu

In [4]:
cd ..

/home/aruba19th


In [8]:
!git clone https://github.com/sksq96/pytorch-summary.git

Cloning into 'pytorch-summary'...
remote: Enumerating objects: 44, done.[K
remote: Counting objects: 100% (44/44), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 186 (delta 13), reused 33 (delta 7), pack-reused 142[K
Receiving objects: 100% (186/186), 38.50 KiB | 0 bytes/s, done.
Resolving deltas: 100% (62/62), done.


In [9]:
%cd pytorch-summary
from torchsummary import summary

/home/aruba19th/pytorch-summary


In [10]:
cd ..

/home/aruba19th


In [11]:
class ThreeChannelGalaxyDataset(Dataset):
    def __init__(self, type):
        self.fnames = glob("galfit_final_sims/*snr100.0*.gz") #5.0 10.0 20.0 50.0 100.0
        #self.fnames = glob("cutouts_v1/*.gz")
        self.files_length = len(self.fnames)

        self.train_len = math.floor(self.files_length * .8)
        self.val_len = math.floor(self.files_length * .1)
        self.test_len = math.floor(self.files_length* .1 )

        self.data = self._load_data(type)
        self.data = self._clean()
        self.type = type

    
    def __len__(self):
        return len(self.data)

    def _load_data(self, type):
        galaxies = []
        if type == "train":
            for i in range(self.train_len):
                galaxies.append(fits.open(self.fnames[i])[0].data[:3])

                if i % 50 == 0:
                    print(f'{i/self.train_len} Done')

        elif type == "validate":
            for i in range(self.train_len, self.train_len + self.val_len):
                galaxies.append(fits.open(self.fnames[i])[0].data[:3])
                if i % 50 == 0:
                    print(f'{(i - self.train_len)/self.val_len} Done')

        elif type == "test":
            for i in range(self.train_len + self.val_len, self.train_len + self.val_len+self.test_len):
                galaxies.append(fits.open(self.fnames[i])[0].data[:3])
                if i % 50 == 0:
                     print(f'{(i - (self.train_len + self.val_len))/self.test_len} Done')

        galaxies = np.stack(galaxies)
        print(galaxies.shape)
        return galaxies

    def _clean(self):
        return clamp_data(clean_three_channel_data(self.data))

    def __getitem__(self, idx):
        channels = self.data[idx]
        idx1 = np.random.randint(0, 3)
        idx2 = idx1
        while idx2 == idx1:
              idx2 = np.random.randint(0, 3)

        pair = np.stack([channels[idx1], channels[idx2]])

        return np.array(pair[:,:128,:128]/255)

    def get_full_batch(self, idx):
        return self.data[idx]  

In [12]:
train_data = ThreeChannelGalaxyDataset("train")
val_data = ThreeChannelGalaxyDataset("validate")
test_data = ThreeChannelGalaxyDataset("test")

0.0 Done
0.06944444444444445 Done
0.1388888888888889 Done
0.20833333333333334 Done
0.2777777777777778 Done
0.3472222222222222 Done
0.4166666666666667 Done
0.4861111111111111 Done
0.5555555555555556 Done
0.625 Done
0.6944444444444444 Done
0.7638888888888888 Done
0.8333333333333334 Done
0.9027777777777778 Done
0.9722222222222222 Done
(720, 3, 128, 128)
0.3333333333333333 Done
0.8888888888888888 Done
(90, 3, 128, 128)
0.4444444444444444 Done
(90, 3, 128, 128)


In [13]:
import torch
# device = getbestgpu()
device = 'cuda'
from torch.nn import MSELoss, L1Loss, SmoothL1Loss
from torch.optim import Adam
from torch.utils.data import DataLoader

In [14]:
#model = BabyUnet()
#model = DnCNN(1)
model = DnFCN8(1)
#model = SingleConvolution()
model.to(device)
#loss_function = MSELoss()
loss_function = L1Loss()
optimizer = Adam(model.parameters(), lr= .00005)

train_loader = DataLoader(train_data, batch_size=10, shuffle=False)
val_loader = DataLoader(val_data, batch_size=10, shuffle=False)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

In [15]:
summary(model, input_size=(1, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm2d-1          [-1, 1, 128, 128]               2
            Conv2d-2         [-1, 64, 128, 128]             640
            Conv2d-3         [-1, 64, 128, 128]          36,928
            Conv2d-4          [-1, 1, 128, 128]              65
         MaxPool2d-5           [-1, 64, 64, 64]               0
       BatchNorm2d-6           [-1, 64, 64, 64]             128
            Conv2d-7          [-1, 128, 64, 64]          73,856
            Conv2d-8          [-1, 128, 64, 64]         147,584
         MaxPool2d-9          [-1, 128, 32, 32]               0
      BatchNorm2d-10          [-1, 128, 32, 32]             256
           Conv2d-11          [-1, 256, 32, 32]         295,168
           Conv2d-12          [-1, 256, 32, 32]         590,080
           Conv2d-13          [-1, 256, 32, 32]         590,080
        MaxPool2d-14          [-1, 256,

(tensor(34613607), tensor(34613607))

In [None]:
n_epochs = 20 #1200
best_loss = 100000

# Keeps track of losses
train_losses = []
val_losses = []

best_model={}

for epoch in range(n_epochs):
  
    train_loss = 0
    model = model.train()
    for i, batch in enumerate(train_loader):
        batch = batch.to(torch.float)
        noisy_images_1, noisy_images_2 = batch[:, 0:1], batch[:, 1:2]
        noisy_images_1 = noisy_images_1.to(device)
        noisy_images_2 = noisy_images_2.to(device)

        net_output = model(noisy_images_1)

        loss = loss_function(net_output, noisy_images_2)
        train_loss += loss.cpu().item()
        
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()
        
    train_losses.append(train_loss/len(train_loader))
    # Keeps track of loss over 10 epochs
    if epoch % 1 == 0:
        print("Loss (", epoch, "): \t", round(train_loss/len(train_loader), 6))
        
    val_loss = 0
    with torch.no_grad():
        model = model.eval()
    
    for i, batch in enumerate(val_loader):
        batch = batch.to(torch.float)
        noisy_images_1, noisy_images_2 = batch[:, 0:1], batch[:, 1:2]
        noisy_images_1 = noisy_images_1.to(device)
        noisy_images_2 = noisy_images_2.to(device)

        net_output = model(noisy_images_1)

        loss = loss_function(net_output, noisy_images_2)  
        val_loss += loss.cpu().item()
        
    val_losses.append(val_loss/len(val_loader))
    if val_loss < best_loss:
        best_loss = val_loss
        best_model = model.state_dict()
        
    optimizer.zero_grad()
    #torch.save(model,'Track_training/DnCnn/epoch'+str(epoch)+'SNR_100.pt')
    torch.save(model,'Track_training/DnFcn8/epoch'+str(epoch)+'SNR_100.pt')

Loss ( 0 ): 	 0.124152
Loss ( 1 ): 	 0.012248
Loss ( 2 ): 	 0.009631
Loss ( 3 ): 	 0.008939
Loss ( 4 ): 	 0.008623
Loss ( 5 ): 	 0.008485
Loss ( 6 ): 	 0.008461
Loss ( 7 ): 	 0.008327
Loss ( 8 ): 	 0.008291


In [None]:
# Plot losses
import matplotlib.pyplot as plt

plt.plot(train_losses,color='b',label='train_losses')
plt.plot(val_losses,color='r',label='val_losses')
plt.xlabel('training_epoch')
plt.ylabel('Loss')
#plt.title('loss history of DnCnn')
plt.title('loss history of DnFcn8')
plt.legend()

In [None]:
test_data_real = ThreeChannelGalaxyDataset("test")
test_loader = DataLoader(test_data_real, batch_size=1, shuffle=False)

In [None]:
files = glob("galfit_final_sims/*snr100.0*.gz")
files_length = len(files)
train_len = math.floor(files_length * .8)
val_len = math.floor(files_length* .1 )

    
for i, batch in enumerate(test_loader):
    frame = np.zeros((6,128,128))
    for j in range(0,18,3):  
        #model = torch.load('Track_training/DnCnn/epoch'+str(j)+'SNR_100.pt')
        model = torch.load('Track_training/DnFcn8/epoch'+str(j)+'SNR_100.pt')
        with torch.no_grad():
            model = model.eval()
        batch = batch.to(torch.float)
        noisy_images_1, noisy_images_2 = batch[:, 0:1], batch[:, 1:2]
        noisy_images_1 = noisy_images_1.to(device)
        noisy_images_2 = noisy_images_2.to(device)

        output = model(noisy_images_1)
        
        img = fits.open(files[train_len+val_len+i])
        gtr = img[2].data #ground truth
        img = img[0].data #noisy image
        
        np_array = output.cpu().detach().numpy()
        
        frame[int(j/3)] = img[0] - np_array[0,0]*255 
    
    fig, axes = plt.subplots(nrows=1, ncols=6)
    for u, ax in enumerate(axes):
        im = ax.imshow(frame[u])

    fig.subplots_adjust(right=4)
    cbar_ax = fig.add_axes([0.8, 0.1, 0.1, 0.8])
    fig.colorbar(im, cax=cbar_ax)
    plt.show()
        