In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import torch.nn.functional as F
import copy
from sklearn.utils import shuffle

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
class base_encoder(nn.Module):
  def __init__(self):
    super(base_encoder,self).__init__()
    self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=1,padding=1)
    self.conv2 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1,padding=1)
    self.conv3 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1,padding=1)
    self.convsplit1 = nn.Conv2d(in_channels=16,out_channels=3,kernel_size=2,stride=2,padding=0)
    self.convsplit2 = nn.Conv2d(in_channels=16,out_channels=3,kernel_size=2,stride=2,padding=0)
    self.bn1 = nn.BatchNorm2d(3)
    self.bn2 = nn.BatchNorm2d(3)
  def forward(self,x):
    x = nn.ReLU()(self.conv1(x))
    x = nn.ReLU()(self.conv2(x))
    x = nn.ReLU()(self.conv3(x))
    x1 = nn.Sigmoid()(self.bn1(self.convsplit1(x)))
    x2 = nn.Sigmoid()(self.bn2(self.convsplit2(x)))
    return (x1,x2)

In [None]:
class decoder(nn.Module):
  def __init__(self):
    super(decoder,self).__init__()
    self.decoder1 = nn.Sequential(nn.ConvTranspose2d(in_channels=3,out_channels=16,kernel_size=2,stride=2,padding=0),
                             nn.ReLU(),
                             nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1,padding=1),
                             nn.ReLU(),
                             nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1,padding=1),
                             nn.ReLU(),
                             nn.Conv2d(in_channels=16,out_channels=3,kernel_size=3,stride=1,padding=1),
                             nn.BatchNorm2d(3))
    self.decoder2 = nn.Sequential(nn.ConvTranspose2d(in_channels=3,out_channels=16,kernel_size=2,stride=2,padding=0),
                             nn.ReLU(),
                             nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1,padding=1),
                             nn.ReLU(),
                             nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1,padding=1),
                             nn.ReLU(),
                             nn.Conv2d(in_channels=16,out_channels=3,kernel_size=3,stride=1,padding=1),
                             nn.BatchNorm2d(3))
    
  def forward(self,x):
    out1 = self.decoder1(x[0])
    out2 = self.decoder2(x[1])
    return nn.Sigmoid()(out1+out2)
    



In [None]:
# make objects of these classes. Suppose the objects are enc and dec respectively. Make sure that the architectures are stored in the preferred device.
# enc=enc.to(device)
# dec=dec.to(device)
# finally assign the optimizer of choice to each of the three networks above.
# eoptim = torch.optim.Adam(enc.parameters())
# doptim1 = torch.optim.Adam(dec.decoder1.parameters())
# doptim2 = torch.optim.Adam(dec.decoder2.parameters())

In [None]:
def wavelet_loss(x_true,x,x0):
  l2 = torch.norm(x0,dim=1)
  l2 = torch.mean(l2)
  rec = nn.BCELoss()(x,x_true)
  return (rec,l2)

In [None]:
# Note on using the loss function: Pass the batch of images through the encoder. Assign any one of the tuple generated as x0. This will become the higher wavelet encoding.
# In the loss function, x is the reconstruction which will be generated by passing the tuple through the decoder. Finally x_true is the original image.
# Choose any weighted sum of rec and l2 values returned by the loss function. The weights are hyperparameters that are needed to be tuned. rec+0.01*l2 is what I found worked well for me.