In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [24]:
import os
from PIL import Image
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor

train_transformer = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

class ImageDataset(Dataset):
  def __init__(self, input_dir, clear_img_dir, transform ):
  
    self.filenames = os.listdir(input_dir)
    self.filenames = [os.path.join(input_dir, f) for f in self.filenames if f.endswith('.jpg')]
    self.filenames.sort()
    
    self.clear_img_file = os.listdir(clear_img_dir)
    self.clear_img_file =[os.path.join(clear_img_dir, f) for f in self.clear_img_file if f.endswith('.jpg')]
    self.clear_img_file.sort()
    self.transform = transform


  def __getitem__(self, index):
    input_image = Image.open(self.filenames[index]) 
    input_image = self.transform(input_image)

    clear_image = Image.open(self.clear_img_file[index])
    clear_image = self.transform(clear_image)
    return input_image, clear_image
    
  
  def __len__(self):
    return len(self.filenames)


In [25]:

imagedir = '/content/drive/MyDrive/Colab Notebooks/Paired/underwater_imagenet/trainA'
clearimagedir = '/content/drive/MyDrive/Colab Notebooks/Paired/underwater_imagenet/trainB'

dataset = ImageDataset( imagedir,clearimagedir, train_transformer)



train_loader = DataLoader(dataset=dataset,
                          batch_size=1,
                          shuffle=True,
                          num_workers=0)

input_img, clear_img = next(iter(train_loader))

In [26]:
def PreProcessing(input_img):
  
  temp = input_img
  
  mu_i = torch.mean(temp, dim = (2,3), keepdim=True)

  sigma_i = torch.std(temp, dim = (2,3), keepdim=True)

  I_centered = input_img.sub_(mu_i)
  mu_sigma_cat = torch.cat((mu_i, sigma_i), dim = 1)

  return I_centered, mu_i, mu_sigma_cat

In [27]:
import torch.nn as nn

class TotalNet(nn.Module):
  def __init__(self):
    super(TotalNet, self).__init__()
    self.bias_layer = nn.Linear(3, 16)

    #GLOBAL NET
    input_size =3
    hidden_size=16
    output_size=3
    self.gl1 = nn.Linear(input_size, hidden_size)
    self.fc1 = nn.ReLU()
    self.gl2 = nn.Linear(hidden_size, hidden_size)
    self.fc2 = nn.ReLU()
    self.gl3 = nn.Linear(hidden_size, hidden_size)
    self.fc3 = nn.ReLU()
    self.gl4 = nn.Linear(hidden_size*3, output_size)
    self.fc4 = nn.Sigmoid()

    #LOCAL_NET
    input_channels=3
    num_channels=16
    output_channels=3
    self.ll1 = nn.Conv2d(input_channels, num_channels, kernel_size = 3, padding=1  ,bias = False) 
    self.conv1 = nn.ReLU()
    self.ll2 = nn.Conv2d(num_channels, num_channels, kernel_size = 3, padding=1  ,bias = False)
    self.conv2 = nn.ReLU()
    self.ll3 = nn.Conv2d(num_channels, num_channels, kernel_size = 3, padding=1  ,bias = False)
    self.conv3 = nn.ReLU()
    self.ll4 = nn.Conv2d(num_channels*3, output_channels,kernel_size=3, padding=1)


  def forward(self, x):
    print("======================= Image ====================")
    imshow(torchvision.utils.make_grid(x.cpu().data))
    I_centered, mu_i, mu_sigma_cat = PreProcessing(x)
   
    print("======================= I-centered ====================")
    imshow(torchvision.utils.make_grid(I_centered.cpu().data))
    #print('mu_sigma_cat shape:',mu_sigma_cat.shape, 'mu_i shape :' , mu_i.shape)
    new = torch.reshape(mu_sigma_cat,(1,1,6))
    #trp = torch.transpose(new,1,2)
    mu_i = torch.reshape(mu_i,(1,1,3))


    """
    GLOBAL NET
    """
    # print(trp.shape)
    x=mu_i
    out = self.gl1(x)
    h1 = self.fc1(out)
    out = self.gl2(h1)
    h2 = self.fc2(out)
    out = self.gl3(h2)
    h3 = self.gl3(out)
    # print("h3",h3.shape)
    ct = torch.cat((h1,h2,h3), dim = 2 )
    # print("concat global ",ct.shape)
    delta_mu_i = self.gl4(ct)
    mu_final = self.fc4(delta_mu_i.add(mu_i))


    self.bias = self.bias_layer(delta_mu_i)
    # print("delta_mu :",delta_mu_i.shape)
    # print("mu :",mu_i.shape)
    


    """
    LOCAL NET
    """
    x=I_centered
    out = self.ll1(x) 
    # print(out.shape)
    out +=torch.reshape(self.bias,(16,1,1))
    H1 = self.conv1(out)
    out = self.ll2(H1) + torch.reshape(self.bias,(16,1,1))
    H2 = self.conv2(out)
    out = self.ll3(H2) + torch.reshape(self.bias,(16,1,1))
    H3 = self.conv3(out)
    J_centered = self.ll4(torch.cat((H1,H2,H3), dim = 1))
    print("======================= J_Centered ====================")
    imshow(torchvision.utils.make_grid(J_centered.cpu().data))

    #mu_final = torch.sigmoid(mu_i.add(delta_mu_i))
    # print("mu_final",mu_final.shape)
    J_final = J_centered.add(torch.reshape(mu_final,(3,1,1)))
    print("======================= JFINAL ====================")
    imshow(torchvision.utils.make_grid(J_final.cpu().data))
    
    print("mu_i",mu_i)
    print("delta_mui",delta_mu_i)
    print("mu_final",mu_final)

    

    return J_final

