In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.utils import make_grid
from torchvision.utils import save_image
import torchvision.transforms as transforms
import torch
import numpy as np
from torch.utils.data import DataLoader
import functools
import itertools
from utils import ReplayBuffer

In [None]:
def get_data_loader(image_size=256, batch_size=1, num_workers=0):
    """
    Returns training and test data loaders for a given image
    type, either 'low frequency' or 'high frequency'. 
    These images will be resized to image_sizeximage_sizex1,
    by default, converted into Tensors, and normalized.
    """
    #transforms.RandomCrop((image_size/2,image_size)),
    # crop and normalize the images
    transform = transforms.Compose([transforms.Grayscale(),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.5,], [0.5,])])
    test_transform = transforms.Compose([transforms.Grayscale(),
                                         transforms.ToTensor(),
                                        transforms.Normalize([0.5,], [0.5,])])

    # get training and test directories
    train_A_path = '/home/ahonts/Desktop/US/Low_Images'
    test_A_path = '/home/ahonts/Desktop/US/Low_Images'
    train_B_path = '/home/ahonts/Desktop/US/High_Images'
    test_B_path = '/home/ahonts/Desktop/US/High_Images'

    # define datasets using ImageFolder
    train_A_dataset = datasets.ImageFolder(train_A_path, transform)
    test_A_dataset = datasets.ImageFolder(test_A_path, test_transform)
    train_B_dataset = datasets.ImageFolder(train_B_path, transform)
    test_B_dataset = datasets.ImageFolder(test_B_path, test_transform)

    # create DataLoaders
    train_A_loader = DataLoader(dataset=train_A_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_A_loader = DataLoader(dataset=test_A_dataset, batch_size=1, shuffle=False, num_workers=num_workers)
    train_B_loader = DataLoader(dataset=train_B_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_B_loader = DataLoader(dataset=test_B_dataset, batch_size=1, shuffle=False, num_workers=num_workers)

    #return DataLoaders
    return train_A_loader,test_A_loader,train_B_loader,test_B_loader

In [None]:
train_A_loader,test_A_loader,train_B_loader,test_B_loader = get_data_loader()

In [None]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
class Generator2(nn.Module):
    def __init__(self):
        super(Generator2, self).__init__()
        self.d1 = nn.Conv2d(1, 32, 8, 2, 3, bias=False)
        self.d2 = nn.Conv2d(32, 64, 8, 2, 3, bias=False)
        self.d3 = nn.Conv2d(64, 128, 8, 2, 3, bias=False)
        self.d4 = nn.Conv2d(128, 256, 8, 2, 3, bias=False)
        self.d5 = nn.Conv2d(256, 512, 8, 2, 3, bias=False)
        self.d1_1 = nn.Conv2d(1, 32, 4, 2, 1, bias=False)
        self.d2_1 = nn.Conv2d(32, 64, 4, 2, 1, bias=False)
        self.d3_1 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
        self.d4_1 = nn.Conv2d(128, 256, 4, 2, 1, bias=False)
        #self.d5_1 = nn.Conv2d(256, 512, 4, 2, 1, bias=False)
        self.u1 = nn.ConvTranspose2d(512, 256, 8, 2, 3, bias=False)
        #self.u2 = nn.ConvTranspose2d(512, 128, 8, 2, 3, bias=False)
        self.u2 = nn.ConvTranspose2d(768, 128, 8, 2, 3, bias=False)
        #self.u3 = nn.ConvTranspose2d(256, 64, 8, 2, 3, bias=False)
        self.u3 = nn.ConvTranspose2d(384, 64, 8, 2, 3, bias=False)
        #self.u4 = nn.ConvTranspose2d(128, 32, 8, 2, 3, bias=False)
        self.u4 = nn.ConvTranspose2d(192, 32, 8, 2, 3, bias=False)
        #self.u5 = nn.ConvTranspose2d(64, 64, 8, 2, 3, bias=False)
        self.u5 = nn.ConvTranspose2d(96, 64, 8, 2, 3, bias=False)
        self.out = nn.Conv2d(64, 1, 3, 1, 1, bias=False)
        self.LeakyRelu = nn.LeakyReLU(0.2, inplace=True)
        self.Relu = nn.ReLU(True)
        self.d1Norm = nn.InstanceNorm2d(32)
        self.d2Norm = nn.InstanceNorm2d(64)
        self.d3Norm = nn.InstanceNorm2d(128)
        self.d4Norm = nn.InstanceNorm2d(256)
        self.d5Norm = nn.InstanceNorm2d(512)

    def forward(self, input):
        d1 = self.d1(input)
        d1 = self.LeakyRelu(d1)
        d1 = self.d1Norm(d1)
        
        d1_1 = self.d1_1(input)
        d1_1 = self.LeakyRelu(d1_1)
        d1_1 = self.d1Norm(d1_1)
        
        d2 = self.d2(d1)
        d2 = self.LeakyRelu(d2)
        d2 = self.d2Norm(d2)
        
        d2_1 = self.d2_1(d1_1)
        d2_1 = self.LeakyRelu(d2_1)
        d2_1 = self.d2Norm(d2_1)
        
        d3 = self.d3(d2)
        d3 = self.LeakyRelu(d3)
        d3 = self.d3Norm(d3)
        
        d3_1 = self.d3_1(d2_1)
        d3_1 = self.LeakyRelu(d3_1)
        d3_1 = self.d3Norm(d3_1)
        
        d4 = self.d4(d3)
        d4 = self.LeakyRelu(d4)
        d4 = self.d4Norm(d4)
        
        d4_1 = self.d4_1(d3_1)
        d4_1 = self.LeakyRelu(d4_1)
        d4_1 = self.d4Norm(d4_1)
        
        d5 = self.d5(d4)
        d5 = self.LeakyRelu(d5)
        d5 = self.d5Norm(d5)
        
        u1 = self.u1(d5)
        u1 = self.Relu(u1)
        u1 = self.d4Norm(u1)
        u1 = torch.cat([u1, d4, d4_1],1)
        
        u2 = self.u2(u1)
        u2 = self.Relu(u2)
        u2 = self.d3Norm(u2)
        u2 = torch.cat([u2, d3, d3_1],1)
        
        u3 = self.u3(u2)
        u3 = self.Relu(u3)
        u3 = self.d2Norm(u3)
        u3 = torch.cat([u3, d2, d2_1],1)
        
        u4 = self.u4(u3)
        u4 = self.Relu(u4)
        u4 = self.d1Norm(u4)
        u4 = torch.cat([u4, d1, d1_1],1)
        
        u5 = self.u5(u4)
        output = self.out(u5)
        output = torch.tanh(output)
        
        
        return output

In [None]:
train_A_loader,test_A_loader,train_B_loader,test_B_loader = get_data_loader()
train_A_loader_iterator = iter(train_A_loader)
train_B_loader_iterator = iter(train_B_loader)
test_A_loader_iterator = iter(test_A_loader)
real_A, _ = train_A_loader_iterator.next()
real_B, _ = train_B_loader_iterator.next()
test_A, _ = test_A_loader_iterator.next()

Need To test how to apply the filter to the whole image. The Test image contains the whole image while the train images only contain the filter size

In [None]:
import PIL
import matplotlib.pyplot as plt
import torchvision

to_pil = torchvision.transforms.ToPILImage()
imgs = real_A_imgs.squeeze()
test_A *= 0.5
test_A += 0.5
img = to_pil(test_A)
plt.imshow(img,cmap='gray')

In [None]:
def filter_image(img,G,crop_h = 128,crop_w = 256):
    height,width = img.shape[2],img.shape[3]
    avg_cnt = np.zeros((height,width))
    ones_cnt = np.ones((crop_h,crop_w))
    comp_img = np.zeros((height,width))
    row_index = 0
    while row_index < height-crop_h:
        column_index = 0
        while column_index < width-crop_w:
            tmp = G(img[:,:,
                        row_index:row_index+crop_h,
                        column_index:column_index+crop_w]).detach().cpu().numpy()[0,0,:,:]
            comp_img[row_index:row_index+crop_h,column_index:column_index+crop_w] += tmp
            avg_cnt[row_index:row_index+crop_h,column_index:column_index+crop_w] += ones_cnt
            column_index += int(np.floor(crop_w/19))
        comp_img[row_index:row_index+crop_h,-crop_w:] += G(img[:,:,
                                                    row_index:row_index+crop_h,
                                                    -crop_w:]).detach().cpu().numpy()[0,0,:,:]
        avg_cnt[row_index:row_index+crop_h,-crop_w:] += ones_cnt
        row_index += int(np.floor(crop_h/19))
    column_index = 0
    row_index = height - crop_h
    while column_index < width-crop_w:
        comp_img[row_index:row_index+crop_h, column_index:column_index+crop_w] += G(img[:,:,
                                                    row_index:row_index+crop_h,
                                                     column_index:column_index+crop_w]).detach().cpu().numpy()[0,0,:,:]
        avg_cnt[row_index:row_index+crop_h,column_index:column_index+crop_w] += ones_cnt
        column_index += int(np.floor(crop_w/19))
    comp_img[row_index:row_index+crop_h, -crop_w:] += G(img[:,:,
                                                    row_index:row_index+crop_h,
                                                     -crop_w:]).detach().cpu().numpy()[0,0,:,:]
    avg_cnt[row_index:row_index+crop_h,-crop_w:] += ones_cnt
    return(comp_img/avg_cnt)

    

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 8, 2, 3, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 64*2, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # input is 64 x 64
            nn.Conv2d(64*2, 64*4, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64) x 32 x 32
            nn.Conv2d(64*4, 64 * 4, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64*2) x 16 x 16
            nn.Conv2d(64*4, 64 * 8, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64 * 8, 1, 3, 1, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output

In [None]:
class Discriminator2(nn.Module):
    def __init__(self):
        super(Discriminator2, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 8, 2, 3, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 64*2, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # input is 64 x 64
            nn.Conv2d(64*2, 64*4, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64) x 32 x 32
            nn.Conv2d(64*4, 64 * 4, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64*2) x 16 x 16
            nn.Conv2d(64*4, 64 * 8, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            #nn.Conv2d(64 * 8, 1, 3, 1, 1, bias=False),
            #nn.Sigmoid()
        )
        self.fc1 = nn.Linear(16384,1000)
        self.fc2 = nn.Linear(1000,1)
        self.Lrelu = nn.LeakyReLU(0.2, inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        output = self.main(input)
        output = output.view(output.size(0), -1)
        output = self.Lrelu(self.fc1(output))
        output = self.sigmoid(self.fc2(output))
        return output

In [None]:
#Load the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Create the Generator Network
G_AB = Generator2()
G_BA = Generator2()
#Load Generator Network onto the GPU
G_AB.to(device)
G_BA.to(device)
#Apply weights to the Generator
G_AB.apply(weights_init)
G_BA.apply(weights_init)
#Create the Discriminator Network
D_A = Discriminator2()
D_B = Discriminator2()
#Load the Disriminator Network onto the GPU
D_A.to(device)
D_B.to(device)
#Apply weights to the Discriminator
D_A.apply(weights_init)
D_B.apply(weights_init)

In [None]:
import torch.optim as optim
#Create Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()),
                                lr=0.0002, betas=(0.5, 0.999))
#Optimizer_G_AB = optim.Adam(G_AB.parameters(), lr=0.0002, betas=(0.5, 0.999))
#Optimizer_G_BA = optim.Adam(G_BA.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss()
cycle_consis_loss = nn.L1Loss()
identity_loss = nn.MSELoss()

In [None]:
train_A_loader,test_A_loader,train_B_loader,test_B_loader = get_data_loader()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import random
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
criterion_GAN = torch.nn.BCELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
test_A_loader_iterator = iter(test_A_loader)
test_B_loader_iterator = iter(test_B_loader)
test_real_A, _ = test_A_loader_iterator.next()
test_real_B, _ = test_B_loader_iterator.next()
test_real_A = test_real_A.to(device)
test_real_B = test_real_B.to(device)
real_label = 1
fake_label = 0
target_real = torch.full((1,), real_label, device=device)
target_fake = torch.full((1,), fake_label, device=device)
for epoch in range(11,50):
    train_A_loader_iterator = iter(train_A_loader)
    train_B_loader_iterator = iter(train_B_loader)
    brake = False
    print(epoch)
    for cnt in range(1000):
        #target_real = torch.full((1,1,2,8), np.random.uniform(.95,1.05), device=device)
        #target_fake = torch.full((1,1,2,8), np.random.uniform(-0.05,.05), device=device)
        try:
            real_A, _ = train_A_loader_iterator.next()
            real_B, _ = train_B_loader_iterator.next()
            h_A = real_A.shape[2]
            h_B = real_B.shape[2]
            w_A = real_A.shape[3]
            w_B = real_B.shape[3]
            rand_h_A = random.randint(0,h_A - 129)
            rand_w_A = random.randint(0,w_A - 257)
            real_A = real_A[:,:,rand_h_A:rand_h_A+128,rand_w_A:rand_w_A+256]
            rand_h_B = random.randint(0,int(h_B/6))
            rand_w_B = random.randint(0,w_B - 257)
            real_B = real_B[:,:,rand_h_B:rand_h_B+128,rand_w_B:rand_w_B+256]
            #real_A = real_A[:,:,64:192,:]
            #real_B = real_B[:,:,64:192,:]
            real_A = real_A.to(device)
            real_B = real_B.to(device)
        except:
            brake = True
            break
        if cnt%100 == 0:
            print(cnt)
        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        # G_A2B(B) should equal B if real B is fed
        same_B = G_AB(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*10.0
        # G_B2A(A) should equal A if real A is fed
        same_A = G_BA(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*10.0

        # GAN loss
        fake_B = G_AB(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

        fake_A = G_BA(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

        # Cycle loss
        recovered_A = G_BA(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0

        recovered_B = G_AB(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0
        
        #MINE Loss
        retAB = torch.mean(real_B) - torch.log(torch.mean(torch.exp(fake_B)))
        retBA = torch.mean(real_A) - torch.log(torch.mean(torch.exp(fake_A)))

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB - retAB - retBA
        loss_G.backward()
        
        optimizer_G.step()
        ###################################
        for i in range(2):
            try:
                real_A, _ = train_A_loader_iterator.next()
                real_B, _ = train_B_loader_iterator.next()
                h_A = real_A.shape[2]
                h_B = real_B.shape[2]
                w_A = real_A.shape[3]
                w_B = real_B.shape[3]
                rand_h_A = random.randint(0,h_A - 129)
                rand_w_A = random.randint(0,w_A - 257)
                real_A = real_A[:,:,rand_h_A:rand_h_A+128,rand_w_A:rand_w_A+256]
                rand_h_B = random.randint(0,int(h_B/6))
                rand_w_B = random.randint(0,w_B - 257)
                real_B = real_B[:,:,rand_h_B:rand_h_B+128,rand_w_B:rand_w_B+256]
                #real_A = real_A[:,:,64:192,:]
                #real_B = real_B[:,:,64:192,:]
                real_A = real_A.to(device)
                real_B = real_B.to(device)
            except:
                brake = True
                break
            ###### Discriminator A ######
            #target_real = torch.full((1,1,2,8), np.random.uniform(.95,1.05), device=device)
            #target_fake = torch.full((1,1,2,8), np.random.uniform(-0.05,.05), device=device)
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = D_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = D_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = D_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = D_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()

            optimizer_D_B.step()
        ###################################
    
    torch.save(G_AB.state_dict(), './paths3/netG_A2B_%d.pth' % epoch)
    torch.save(G_BA.state_dict(), './paths3/netG_B2A_%d.pth' % epoch)
    torch.save(D_A.state_dict(), './paths3/netD_A_%d.pth' % epoch)
    torch.save(D_B.state_dict(), './paths3/netD_B_%d.pth' % epoch)
    
    fake_imgs_B = filter_image(test_real_A,G_AB)
    fake_imgs_B *= 0.5
    fake_imgs_B += 0.5
    tmp = torch.ones_like(test_real_A)
    tmp *= test_real_A
    tmp *= 0.5
    tmp += 0.5
    save_image(torch.tensor(fake_imgs_B),'./fake_imgs3/test%d.png' % epoch)
    save_image(tmp,'./real_imgs3/test%d.png' % epoch)

In [None]:
for i in range(30):
    test_real_A, _ = test_A_loader_iterator.next()
    test_real_A = test_real_A.to(device)
    epoch = 60
    fake_imgs_B = filter_image(test_real_A,G_AB)
    fake_imgs_B *= 0.5
    fake_imgs_B += 0.5
    tmp = torch.ones_like(test_real_A)
    tmp *= test_real_A
    tmp *= 0.5
    tmp += 0.5
    tmp2 = torch.ones_like(test_real_A)
    tmp2 = 0.1*tmp.detach().cpu() + 0.9*torch.tensor(fake_imgs_B)
    save_image(torch.tensor(tmp2),'./video2/test%d.png' % i)
    #save_image(torch.tensor(fake_imgs_B),'./fake_imgs3/test%d.png' % epoch)
    save_image(tmp,'./real_video2/test%d.png' % i)

In [None]:
fake_B.shape

In [None]:
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
test_A_loader_iterator = iter(test_A_loader)
test_B_loader_iterator = iter(test_B_loader)
test_real_A, _ = test_A_loader_iterator.next()
test_real_B, _ = test_B_loader_iterator.next()
test_real_A = test_real_A.to(device)
test_real_B = test_real_B.to(device)
for epoch in range(12):
    train_A_loader_iterator = iter(train_A_loader)
    train_B_loader_iterator = iter(train_B_loader)
    brake = False
    print(epoch)
    for cnt in range(1000):
        target_real = torch.full((1,1,2,8), np.random.uniform(.95,1.05), device=device)
        target_fake = torch.full((1,1,2,8), np.random.uniform(-0.05,.05), device=device)
        try:
            real_A, _ = train_A_loader_iterator.next()
            real_B, _ = train_B_loader_iterator.next()
            real_A = real_A[:,:,64:192,:]
            real_B = real_B[:,:,64:192,:]
            real_A = real_A.to(device)
            real_B = real_B.to(device)
        except:
            brake = True
            break
        if cnt%100 == 0:
            print(cnt)
        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        # G_A2B(B) should equal B if real B is fed
        same_B = G_AB(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*10.0
        # G_B2A(A) should equal A if real A is fed
        same_A = G_BA(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*10.0

        # GAN loss
        fake_B = G_AB(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

        fake_A = G_BA(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

        # Cycle loss
        recovered_A = G_BA(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0

        recovered_B = G_AB(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0
        
        #MINE Loss
        retAB = torch.mean(real_B) - torch.log(torch.mean(torch.exp(fake_B)))
        retBA = torch.mean(real_A) - torch.log(torch.mean(torch.exp(fake_A)))

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB - retAB - retBA
        loss_G.backward()
        
        optimizer_G.step()
        ###################################
        for i in range(2):
            try:
                real_A, _ = train_A_loader_iterator.next()
                real_B, _ = train_B_loader_iterator.next()
                #real_A = real_A[:,:,64:192,:]
                #real_B = real_B[:,:,64:192,:]
                real_A = real_A.to(device)
                real_B = real_B.to(device)
            except:
                brake = True
                break
            ###### Discriminator A ######
            target_real = torch.full((1,1,2,8), np.random.uniform(.95,1.05), device=device)
            target_fake = torch.full((1,1,2,8), np.random.uniform(-0.05,.05), device=device)
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = D_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = D_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = D_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = D_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()

            optimizer_D_B.step()
        ###################################
    
    torch.save(G_AB.state_dict(), './paths/netG_A2B_%d.pth' % epoch)
    torch.save(G_BA.state_dict(), './paths/netG_B2A_%d.pth' % epoch)
    torch.save(D_A.state_dict(), './paths/netD_A_%d.pth' % epoch)
    torch.save(D_B.state_dict(), './paths/netD_B_%d.pth' % epoch)
    
    fake_imgs_B = filter_image(test_real_A,G_AB)
    fake_imgs_B *= 0.5
    fake_imgs_B += 0.5
    tmp = torch.ones_like(test_real_A)
    tmp *= test_real_A
    tmp *= 0.5
    tmp += 0.5
    save_image(torch.tensor(fake_imgs_B),'./fake_imgs/test%d.png' % epoch)
    save_image(tmp,'./real_imgs/test%d.png' % epoch)
    #save_image(fake_imgs_A.detach().cpu(),'C:/Users/ajhon/Desktop/Industrial Math Presentation/Low_freq_256_2/test%d.png' % epoch)