In [None]:
!pip install h5py
!pip install typing-extensions
!pip install wheel

In [None]:
!pip install fastai

In [None]:
!pip install --upgrade torch==1.8.0           

In [None]:
import torch
import torchvision
print(torch.__version__)

In [None]:
import sys
print(sys.version)

In [None]:
import torch
torch.cuda.empty_cache()
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import numpy as np 
from matplotlib import pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.optim as optim
from torch.utils.data import random_split
from torchvision import models,datasets
import os
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from PIL import Image
from tqdm import tqdm
from torch.optim import lr_scheduler
from torchvision.utils import make_grid

In [None]:
USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    print('using device: cuda')
else:
    print('using device: cpu')

device = torch.device("cuda:0" if USE_GPU else "cpu")

In [None]:
#Hyperparamters
BATCH_SIZE = 4

In [None]:
import torchvision.transforms as transforms
train_transform = transforms.Compose([
        transforms.Resize((256,256)),
        #transforms.RandomResizedCrop(256),
        #transforms.RandomHorizontalFlip(),
        #transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])



In [None]:
class Flare(Dataset):
    def __init__(self, flare_dir, wf_dir,transform = None):
        self.flare_dir = flare_dir
        self.wf_dir = wf_dir
        self.transform = transform
        self.flare_img = os.listdir(flare_dir)
        self.wf_img = os.listdir(wf_dir)
        
    def __len__(self):
         return len(self.flare_img)
    def __getitem__(self, idx):
        f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
        for i in self.wf_img:
            if (self.flare_img[idx].split('.')[0][4:] == i.split('.')[0]):
                wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
                break
        f_img = self.transform(f_img)
        wf_img = self.transform(wf_img)
        
        return f_img, wf_img          

In [None]:
flare_dir = '../input/flaredataset/Flare/Flare_img'
wf_dir = '../input/flaredataset/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])

train_ds = Flare(flare_dir, wf_dir,train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
                                           batch_size=BATCH_SIZE, 
                                           shuffle=True)

print(train_ds)
print(train_loader)

In [None]:
i,l = next(iter(train_loader))
print(i.min())
print(len(train_loader))
print(i.shape)
print(l.shape)

In [None]:
import matplotlib.pyplot as plt
import numpy
samples, labels = iter(train_loader).next()
plt.figure(figsize=(16,24))
grid_imgs = torchvision.utils.make_grid(samples,nrow = 4, normalize = True)
np_grid_imgs = grid_imgs.numpy()
# in tensor, image is (batch, width, height), so you have to transpose it to (width, height, batch) in numpy to show it.
plt.imshow(numpy.transpose(np_grid_imgs, (1,2,0)))
plt.axis('off')
plt.figure(figsize=(16,24))
grid_imgs = torchvision.utils.make_grid(labels,nrow=4, normalize = True)
np_grid_imgs = grid_imgs.numpy()
# in tensor, image is (batch, width, height), so you have to transpose it to (width, height, batch) in numpy to show it.
plt.imshow(numpy.transpose(np_grid_imgs, (1,2,0)))
plt.axis('off')

In [None]:
#Model
#We are going to make a couple ofsmall changes in the implementation, compared to the original unet paper, we are going to use padded convolution and not downsize and we'll use BatchNorm.
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias = False), #Same padding, bias is False because we are using BatchNorm
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2, inplace = True)
        )
        
    def forward(self, x):
        return self.conv(x)
    
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features = [32,64,128,256]):
        super(UNET, self).__init__()
        #Define two lists to store all the Conv layers and also define a pooling layer
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList() 
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        #Define downsampling
        for feature in features:
            self.downs.append(DoubleConv(in_channels,feature)),
            in_channels = feature
                
        #Define Upsampling
        for feature in reversed(features):
            #Set kernel_size and stride to double the image_height and image_width
            self.ups.append(
            nn.ConvTranspose2d(feature*2,feature,kernel_size = 2,stride = 2)
            )
            self.ups.append(DoubleConv(feature*2,feature))
        #This is the layer which is at the bottom of the U shape
        self.bottleneck = DoubleConv(features[-1],features[-1]*2)
        #Set kernel_size to maintain the height and width of the image
        self.final_conv = nn.Conv2d(features[0],out_channels,kernel_size = 1)
        
    def forward(self,x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x) #To remember the Conv we need in the skip connections
            x = self.pool(x)
            
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
            
        for idx in range(0,len(self.ups),2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            #But what if input image dimensions are not divisible by 2^^4?
            if x.shape != skip_connection.shape:
                x = F.resize(x, size=skip_connection.shape[2:])#The [2:] gets rid of the batch size and number of dim
            
            concat_skip = torch.cat((skip_connection, x),dim = 1)
            x = self.ups[idx+1](concat_skip)
            
        return self.final_conv(x)

In [None]:
# Discriminator
class FeatureMapBlock(nn.Module):

    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
  
        x = self.conv(x)
        return x

class ContractingBlock(nn.Module):

    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        super(ContractingBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(input_channels * 2, input_channels * 2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(input_channels * 2)
        self.use_bn = use_bn
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self, x):
        x = self.conv1(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.maxpool(x)
        return x

class Discriminator(nn.Module):

    def __init__(self, input_channels, hidden_channels=8):
        super(Discriminator, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_bn=False)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        #### START CODE HERE ####
        self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)
        #### END CODE HERE ####

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        xn = self.final(x4)
        return xn

In [None]:
fe = models.resnet18(pretrained=True)
for param in fe.parameters():
	param.requires_grad = True
num_ftrs = fe.fc.in_features
class ResNet18(nn.Module):
  def __init__(self):
    super(ResNet18, self).__init__()
    self.features = torch.nn.Sequential(*list(fe.children())[:-2])
    self.conv1 = nn.Conv2d(512,3, 3,1,1)
    self.pool =  nn.AdaptiveAvgPool2d(output_size=(1,1))
    self.drop1 = nn.Dropout(0.3)
    self.fc1 = nn.Linear(3,128)
    self.drop2 = nn.Dropout(0.2)
    self.fc2 = nn.Linear(128,1)
    #We did not add a softmax layer here because the CrossEntropy Loss function contains a softmax, so if you want 
    #to test output, you will have to add a softmax block in addition to the model block
    
  def forward(self,x):
    x = self.features(x)
    x = self.conv1(x)
    x = self.pool(x)
    x = x.view(x.shape[0],3)
    
    x = self.drop1(x)
    x = F.relu(self.fc1(x))
    x = self.drop2(x)
    x = F.sigmoid(self.fc2(x))    
    return x


criterion_cl = nn.MSELoss()    

resnet_model = ResNet18()
resnet_model = resnet_model.to(device) 

c1 = resnet_model.features
c2 = resnet_model.conv1

In [None]:
!pip install torchsummary
from torchsummary import summary
summary(efficientnet_b0,(3,256,256))

In [None]:
    def get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon,c1,c2):
        fake = gen(condition)
        d_fake = disc(fake, condition)
        adv_loss = adv_criterion(d_fake, torch.ones_like(d_fake))

        x_fl = c1(fake)
       # fm_fl = c2(x_fl)
        x_og = c1(real)
       # fm_og = c2(x_og)
        recon_loss = 0.8*recon_criterion(x_og, x_fl)+0.2*recon_criterion(fake,real)
        gen_loss = adv_loss + lambda_recon*recon_loss
        return gen_loss


    def show_tensor_images(image_tensor, num_images=5, size=(1, 28, 28)):
        '''
        Function for visualizing images: Given a tensor of images, number of images, and
        size per image, plots and prints the images in an uniform grid.
        '''
        image_shifted = image_tensor
        image_unflat = image_shifted.detach().cpu().view(-1, *size)
        image_grid = make_grid(image_unflat[:num_images], nrow=5, normalize = True)
        plt.imshow(image_grid.permute(1, 2, 0).squeeze())
        plt.show()

In [None]:
import torch.nn.functional as F
# New parameters
adv_criterion = nn.BCEWithLogitsLoss() 
recon_criterion = nn.L1Loss() 
lambda_recon = 50

n_epochs = 20
input_dim = 3
real_dim = 3
display_step = 2000
batch_size = 4
lrg =1e-5
lrd = 5e-8
image_dim = 256
d_scaler = torch.cuda.amp.GradScaler()
g_scaler = torch.cuda.amp.GradScaler()
gen = UNET(input_dim, real_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lrg)
disc = Discriminator(input_dim + real_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lrd)

In [None]:
#gen.load_state_dict(torch.load('../input/model-1/FM_UNET_P2P_40_epochs_256px.pt'))

#disc.load_state_dict(torch.load('../input/model-2/FM_Disc_P2P_40_epochs_256px.pt'))

#resnet_model.load_state_dict(torch.load('../input/flare-resnet-classifier/Flare_Classifier_25_epochs.pt'))

In [None]:
import matplotlib.pyplot as plt
import numpy
samples, labels = iter(train_loader).next()
samples = samples.to(device)
out = efficientnet_model(samples)
i = c1(samples)
j = c2(i).detach()

plt.figure(figsize=(16,24))
grid_imgs = torchvision.utils.make_grid(samples.cpu(), normalize = True)
# in tensor, image is (batch, width, height), so you have to transpose it to (width, height, batch) in numpy to show it.
plt.imshow(numpy.transpose(grid_imgs, (1,2,0)))

plt.figure(figsize=(16,24))
grid_imgs = torchvision.utils.make_grid(j.cpu(), normalize = True)
# in tensor, image is (batch, width, height), so you have to transpose it to (width, height, batch) in numpy to show it.
plt.imshow(numpy.transpose(grid_imgs, (1,2,0)))

print(out)

In [None]:
from skimage import color
import numpy as np

def train(save_model=False):
    mean_generator_loss = 0
    mean_discriminator_loss = 0
    cur_step = 0

    for epoch in range(n_epochs):
        # Dataloader returns the batches
        for condition, real in tqdm(train_loader):
            cur_batch_size = len(condition)
            condition = condition.to(device)
            real = real.to(device)
            gen.train()
            disc.train()
            #decaying_noise = torch.randn(real_.size(0),1,image_dim,image_dim).to(device)*(0.9**(50*(epoch+1)))
            #decaying_noise_ = torch.randn(condition.size(0),1,image_dim,image_dim).to(device)*(0.9**(50*(epoch+1))) 
            #real = real_+decaying_noise
            #condition_ = condition+decaying_noise_
            ### Update discriminator ###
            disc_opt.zero_grad() # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake = gen(condition)
            disc_fake_hat = disc(fake.detach(), condition) # Detach generator
            disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
            disc_real_hat = disc(real, condition)
            disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True) # Update gradients
            disc_opt.step() # Update optimizer

            ### Update generator ###
            gen_opt.zero_grad()
            gen_loss = get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon,c1,c2)
            gen_loss.backward() # Update gradients
            gen_opt.step() # Update optimizer

            # Keep track of the average discriminator loss
            mean_discriminator_loss += disc_loss.item() / display_step
            # Keep track of the average generator loss
            mean_generator_loss += gen_loss.item() / display_step
            
            ### Visualization code ###
            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(f"Epoch {epoch+1}: Step {cur_step}: Generator (U-Net) loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
                else:
                    print("Pretrained initial state")
                gen.eval()
                disc.eval()
                show_tensor_images(condition, size = condition.shape[1:])
                show_tensor_images(real, size=real.shape[1:])
                show_tensor_images(fake, size=fake.shape[1:])
                mean_generator_loss = 0
                mean_discriminator_loss = 0
    
            cur_step += 1
train()

In [None]:
train_loader_ = torch.utils.data.DataLoader(dataset=train_ds,
                                           batch_size=1, 
                                           shuffle=True)

In [None]:
l,k = next(iter(train_loader_))
l = l.to(device).float()
k = k.to(device).float()
i = gen(l).cpu().detach()

plt.figure(figsize=(5,5))
plt.title("Ground Truth")
img_grid = torchvision.utils.make_grid(k.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(5,5))
plt.title("Images With Flare")
img_grid = torchvision.utils.make_grid(l.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(5,5))
plt.title("Generated_Images")
img_grid = torchvision.utils.make_grid(i.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

In [None]:
l,k = next(iter(train_loader_))
l = l.to(device).float()
k = k.to(device).float()
i = gen(l).cpu().detach()

plt.figure(figsize=(5,5))
plt.title("Ground Truth")
img_grid = torchvision.utils.make_grid(k.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(5,5))
plt.title("Images With Flare")
img_grid = torchvision.utils.make_grid(l.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(5,5))
plt.title("Generated_Images")
img_grid = torchvision.utils.make_grid(i.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

In [None]:
l,k = next(iter(train_loader_))
l = l.to(device).float()
k = k.to(device).float()
i = gen(l).cpu().detach()

plt.figure(figsize=(5,5))
plt.title("Ground Truth")
img_grid = torchvision.utils.make_grid(k.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(5,5))
plt.title("Images With Flare")
img_grid = torchvision.utils.make_grid(l.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(5,5))
plt.title("Generated_Images")
img_grid = torchvision.utils.make_grid(i.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

In [None]:
model_save_name = 'Disc_P2P_70_epochs_FM_256px.pt'
path = F".//{model_save_name}" 
torch.save(disc.state_dict(), path)

In [None]:
model_save_name = 'UNET_P2P_70_epochs_FM_256px.pt'
path = F".//{model_save_name}" 
torch.save(gen.state_dict(), path)

In [None]:
class FlareTest(Dataset):
    def __init__(self, flare_dir, transform = None):
        self.flare_dir = flare_dir

        self.transform = transform
        self.flare_img = os.listdir(flare_dir)
      
        
    def __len__(self):
        return len(self.flare_img)
    def __getitem__(self, idx):
        self.flare_img.sort()
        f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
      
        f_img = self.transform(f_img)

        
        return f_img

In [None]:
test_transform = transforms.Compose([
        transforms.Resize((256,256)),
        #transforms.RandomResizedCrop(256),
        #transforms.RandomHorizontalFlip(),
        #transforms.ColorJitter(),
        transforms.ToTensor(),
        
])

In [None]:
#dir = '../input/flaredetection/Flare/Flare_img'

dir = '../input/flaredataset/Flare/Flare_img'
#wf_dir = '../input/flaredataset/Flare/Without_Flare_'
#'../input/flaredataset/Flare/Flare_img'
ds = FlareTest(dir, test_transform)
loader = torch.utils.data.DataLoader(dataset=ds,
                                           batch_size=32, 
                                           shuffle=True)

In [None]:
images = next(iter(loader))
images =images.to(device)
out = gen(images).cpu().detach()

plt.figure(figsize=(20,20))
plt.title("Ground Truth")
img_grid = torchvision.utils.make_grid(images.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(20,20))
plt.title("Images With Flare")
img_grid = torchvision.utils.make_grid(out,normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()


In [None]:
testf_dir = '../input/testvvvvv/Test_Flare/Test_Flare'
wf_dir = '../input/testvvvvv/Test_/Test_'
test_ds = Flare(testf_dir,wf_dir,test_transform)
test_loader = torch.utils.data.DataLoader(dataset=test_ds,
                                           batch_size=32, 
                                           shuffle=True)

In [None]:
def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr

In [None]:
l,k = next(iter(test_loader))
l = l.to(device).float()
k = k.to(device).float()
i = np.array(gen(l).cpu().detach())

plt.figure(figsize=(20,20))
plt.title("Images With Flare")
img_grid = torchvision.utils.make_grid(l.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(20,20))
plt.title("Generated_Images")
img_grid = torchvision.utils.make_grid(torch.tensor(i),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

In [None]:
from math import log10, sqrt
import cv2
import numpy as np

l,k = next(iter(test_loader))
l = l.to(device).float()
k = k.to(device).float()
i = np.array(gen(l).cpu().detach())
sum = 0
for h in range(32):
    sum += PSNR(np.array(k[h].cpu()), (i[h]))

    
    
plt.figure(figsize=(20,20))
plt.title("Ground Truth")
img_grid = torchvision.utils.make_grid(k.cpu(),normalize =True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(20,20))
plt.title("Images With Flare")
img_grid = torchvision.utils.make_grid(l.cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(20,20))
plt.title("Generated_Images")
img_grid = torchvision.utils.make_grid(torch.tensor(i),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()
print("PSNR: ", sum/32)

In [None]:
l,k = next(iter(test_loader))
l = l.to(device)
k = k.to(device)
i = gen(l)

sum = 0
x = 0
for (m,n) in (test_loader):
    x = x+1
    m = m.to(device)
    n = n.to(device)
    f = np.array(gen(m).detach().cpu())
    for j in range(m.shape[0]):
        sum += PSNR(np.array(n[j].cpu()),f[j])
    if(x%50 == 0):
        print(sum/(32*x))
print("PSNR: ",sum/len(test_ds))

In [None]:
l,k = next(iter(test_loader))
l = l.to(device)
k = k.to(device)
i = gen(l)
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)



In [None]:
print(l.shape)

In [None]:
!pip install piq

In [None]:
test_loader = torch.utils.data.DataLoader(dataset=test_ds,
                                           batch_size=4, 
                                           shuffle=True)
l,k = next(iter(test_loader))
l = l.to(device).float()
k = k.to(device).float()
i = gen(l)
print(l.shape)



In [None]:
import torch
from piq import ssim, SSIMLoss
import piq

ssim_index: torch.Tensor = ssim(l, k, data_range=1.)

loss = SSIMLoss(data_range=1.)
output: torch.Tensor = loss(l, k)
output.backward()



In [None]:
print(f"SSIM index: {ssim_index.item():0.4f}, loss: {output.item():0.4f}")