In [None]:
# for TPU
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import numpy as np 
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import random_split
from torchvision import models,datasets
import os
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from PIL import Image

In [None]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"# Using device: {device}")

In [None]:
torch.cuda.empty_cache()

In [None]:
#Hyperparamters
BATCH_SIZE = 8

In [None]:
train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        #transforms.RandomResizedCrop(256),
        transforms.RandomHorizontalFlip(),
        #transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])

test_transform = transforms.Compose([
        transforms.Resize((256,256)),
        #transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

In [None]:
class Flare(Dataset):
    def __init__(self, flare_dir, wf_dir, flare_img_, wf_img_, transform = None):
        self.flare_dir = flare_dir
        self.wf_dir = wf_dir
        self.transform = transform
        self.flare_l = flare_img_
        self.wf_l = wf_img_
        
    def __len__(self):
        return len(self.flare_l)
    def __getitem__(self, idx):
        f_img = Image.open(os.path.join(self.flare_dir, self.flare_l[idx])).convert("RGB")
        for i in self.wf_l:
            if (self.flare_l[idx].split('.')[0][4:] == i.split('.')[0]):
                wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
                break
        f_img = self.transform(f_img)
        wf_img = self.transform(wf_img)
        
        
        return f_img, wf_img        

In [None]:
flare_dir = '../input/dehazer/Flare/Flare_img'
wf_dir = '../input/dehazer/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])
train_ds = Flare(flare_dir, wf_dir, flare_img, wf_img, train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
                                           batch_size=BATCH_SIZE, 
                                           shuffle=True)
print(len(train_ds))

In [None]:
fn = os.listdir(flare_dir)[:5000]
f = os.listdir(wf_dir)[:5000]
k = 0
for i in f:
    if (fn[34].split('.')[0][4:] == i.split('.')[0]):
        print(i.split('.')[0])
    k+=1
    #print(fn[4].split('.')[0][4:])
    #print(i.split('.')[0])
    #if (k==10):
       # break

In [None]:
i,l = next(iter(train_loader))
print(i.shape)

In [None]:
import matplotlib.pyplot as plt
import numpy
samples, labels = iter(train_loader).next()
plt.figure(figsize=(16,24))
grid_imgs = torchvision.utils.make_grid(samples.cpu(),normalize = True)
np_grid_imgs = grid_imgs.numpy()
# in tensor, image is (batch, width, height), so you have to transpose it to (width, height, batch) in numpy to show it.
plt.imshow(numpy.transpose(np_grid_imgs, (1,2,0)))



In [None]:
def model_eval(dataloader,model):
    total = 0
    correct = 0
    for data in dataloader:
      images, l = data
      
      images = images.to(device)
      l = l.to(device)
      
      out = model(images)
      max_val, preds = torch.max(out,dim=1)
      
      total += l.shape[0]                   
      correct += (preds == l).sum().item()  
      accuracy = (100 * correct)/total
    return accuracy

In [None]:
correct = 0
a = torch.randn(2,3)
b = torch.randn(2,3)
correct += (a == b).sum()
print(correct)

In [None]:
class ResidualBlock(nn.Module):

    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.res_b1 = self.get_res_block(2)
        self.res_b2 = self.get_res_block(2)
        self.res_b3 = self.get_res_block(3)
        self.res_b4 = self.get_res_block(4)
        self.relu = nn.ReLU(inplace = True)
    
    def get_res_block(self, block_size = 1, in_dim = 128, out_dim = 128):
        layers = []
        for i in range(block_size + 1):
            layers.append(nn.Conv2d(in_dim, out_dim, 3, padding = 1))
            if i != block_size:
                layers.append(nn.ReLU(inplace = True))
        return nn.Sequential(*layers)
        
        
    
    def forward(self, image):
        output = self.res_b1(image)
        res_b1_image = self.relu(image + output)
        
        output = self.res_b2(res_b1_image)
        res_b2_image = self.relu(res_b1_image + output)
        
        output = self.res_b3(res_b2_image)
        res_b3_image = self.relu(res_b2_image + output)
        
        output = self.res_b4(res_b3_image)
        res_b4_image = res_b3_image + output
        
        return res_b4_image

In [None]:
class GMAN(nn.Module):
    def __init__(self, in_dim = 3, hidden_dim = 64):
        super(GMAN, self).__init__()
        self.relu = nn.ReLU(inplace = True)
        self.gman = nn.Sequential(
            nn.Conv2d(in_dim, hidden_dim, 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(hidden_dim, hidden_dim * 2, 3, padding = 1, stride = 2),
            nn.ReLU(inplace = True),
            nn.Conv2d(hidden_dim * 2, hidden_dim * 2, 3, padding = 1, stride = 2),
            nn.ReLU(inplace = True),
            ResidualBlock(),
            nn.ReLU(inplace = True),
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, 2, stride = 2),
            nn.ConvTranspose2d(hidden_dim, hidden_dim, 2, stride = 2),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.Conv2d(hidden_dim, in_dim, 3, padding = 1),
        )
    
    
    def forward(self, image):
        return self.relu(image + self.gman(image))
    


In [None]:
def init_weights(m):
    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.008)
        m.bias.data.fill_(0.01)



In [None]:
net = GMAN().to(device)
net.apply(init_weights)

In [None]:
def show_image(hazy_image, gt_image, predicted_image):
    
    title = ['Flare Image', 'Ground Truth Image', 'Predicted']
    
    plt.figure(figsize=(15, 15))
    
    
    display_list = [
                        hazy_image.cpu().permute(1, 2, 0).numpy(),
                        gt_image.cpu().permute(1, 2, 0).numpy(),
                        predicted_image.detach().cpu().permute(1, 2, 0).numpy()
                   ]
    
    
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
        
    plt.show()

In [None]:
from tqdm import tqdm
def train(net, criterion, optimizer,num_epochs,scaler,sch):
    for epoch in range(num_epochs):
        print("Epoch:",epoch+1)
        print("LR: ", sch.get_last_lr())
        loss_h = []
        loop = tqdm(train_loader, leave = True)
        for i, (flare,gt) in enumerate(loop):
            #flare,gt = next(iter(train_loader))
            with torch.cuda.amp.autocast():  
                net.to(device)
                flare = flare.to(device).to(torch.float16)
                gt = gt.to(device).to(torch.float16)

                output = net(flare)
                #output = F.sigmoid(output)
                output = output.to(device) 
            optimizer.zero_grad()
            loss_i = criterion(output, gt)
            #loss_f = criterion(output- gt, torch.zeros_like(gt))
            loss = loss_i 
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loop.set_postfix(loss = loss.item())  
            loss_h.append(loss.item())
            
        f, w = next(iter(train_loader))
        sch.step()
        with torch.no_grad():
            #if (epoch%4 == 0):
                print("LR:", sch.get_last_lr())
                grid_imgs = torchvision.utils.make_grid(f[:32].cpu(), normalize = True).numpy()
                plt.figure(figsize=(16,16))
                plt.imshow(np.transpose((grid_imgs),(1,2,0)))
                plt.show()

                o = net(f.to(device)).cpu()
                grid_imgs1 = torchvision.utils.make_grid(o[:32], normalize = True).numpy()
                plt.figure(figsize=(16,16))
                plt.imshow(np.transpose((grid_imgs),(1,2,0)))
                plt.show()

                plt.plot(loss_h)
                plt.show()

In [None]:
f,w = next(iter(train_loader))
f = f.to(device).float()
w = w.to(device).float()
i = (f).detach()

plt.figure(figsize=(16,16))
img_grid = torchvision.utils.make_grid(f[:32].cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(16,16))
plt.title("Fake Images")
img_grid = torchvision.utils.make_grid(i[:32].cpu(), normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

plt.figure(figsize=(16,16))
img_grid = torchvision.utils.make_grid(w[:32].cpu(),normalize = True)
plt.imshow(np.transpose(img_grid,(1,2,0)))
plt.show()

In [None]:
num_epochs = 30
#criterion = nn.BCEWithLogitsLoss()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-5)
scaler = torch.cuda.amp.GradScaler()
# Decay LR by a factor of 0.1 every several epochs
sch = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma= 0.5)

train(net, criterion, optimizer, num_epochs, scaler,sch)


In [None]:
net_save_name = 'flare_UNET_40epochs.pt'
path = F".//{net_save_name}" 
torch.save(net.state_dict(), path)