In [1]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.utils import make_grid
from torchvision.utils import save_image
import torchvision.transforms as transforms
import torch
import numpy as np
from torch.utils.data import DataLoader
import functools
import itertools
from utils import ReplayBuffer

In [2]:
def get_data_loader(image_size=256, batch_size=1, num_workers=0):
    """Returns training and test data loaders for a given image type, either 'low frequency' or 'high frequency'. 
       These images will be resized to image_sizeximage_sizex1, by default, converted into Tensors, and normalized.
    """
    
    # crop and normalize the images
    transform = transforms.Compose([transforms.Resize((image_size,image_size)),
                                    transforms.Grayscale(),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.5,], [0.5,])])
    test_transform = transforms.Compose([transforms.Resize((image_size,image_size)),
                                         transforms.Grayscale(),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.5,], [0.5,])])

    # get training and test directories
    train_A_path = 'C:/Users/ajhon/Desktop/US Frames/TrainingImages/lo'
    test_A_path = 'C:/Users/ajhon/Desktop/US Frames/TrainingImages/low_test'
    train_B_path = 'C:/Users/ajhon/Desktop/US Frames/TrainingImages/hi'
    test_B_path = 'C:/Users/ajhon/Desktop/US Frames/TrainingImages/hi_test'

    # define datasets using ImageFolder
    train_A_dataset = datasets.ImageFolder(train_A_path, transform)
    test_A_dataset = datasets.ImageFolder(test_A_path, test_transform)
    train_B_dataset = datasets.ImageFolder(train_B_path, transform)
    test_B_dataset = datasets.ImageFolder(test_B_path, test_transform)

    # create DataLoaders
    train_A_loader = DataLoader(dataset=train_A_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_A_loader = DataLoader(dataset=test_A_dataset, batch_size=8, shuffle=False, num_workers=num_workers)
    train_B_loader = DataLoader(dataset=train_B_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_B_loader = DataLoader(dataset=test_B_dataset, batch_size=8, shuffle=False, num_workers=num_workers)

    #return DataLoaders
    return train_A_loader,test_A_loader,train_B_loader,test_B_loader

In [3]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

In [4]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
class Generator2(nn.Module):
    def __init__(self):
        super(Generator2, self).__init__()
        self.d1 = nn.Conv2d(1, 32, 8, 2, 3, bias=False)
        self.d2 = nn.Conv2d(32, 64, 8, 2, 3, bias=False)
        self.d3 = nn.Conv2d(64, 128, 8, 2, 3, bias=False)
        self.d4 = nn.Conv2d(128, 256, 8, 2, 3, bias=False)
        self.d5 = nn.Conv2d(256, 512, 8, 2, 3, bias=False)
        self.d1_1 = nn.Conv2d(1, 32, 4, 2, 1, bias=False)
        self.d2_1 = nn.Conv2d(32, 64, 4, 2, 1, bias=False)
        self.d3_1 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
        self.d4_1 = nn.Conv2d(128, 256, 4, 2, 1, bias=False)
        #self.d5_1 = nn.Conv2d(256, 512, 4, 2, 1, bias=False)
        self.u1 = nn.ConvTranspose2d(512, 256, 8, 2, 3, bias=False)
        #self.u2 = nn.ConvTranspose2d(512, 128, 8, 2, 3, bias=False)
        self.u2 = nn.ConvTranspose2d(768, 128, 8, 2, 3, bias=False)
        #self.u3 = nn.ConvTranspose2d(256, 64, 8, 2, 3, bias=False)
        self.u3 = nn.ConvTranspose2d(384, 64, 8, 2, 3, bias=False)
        #self.u4 = nn.ConvTranspose2d(128, 32, 8, 2, 3, bias=False)
        self.u4 = nn.ConvTranspose2d(192, 32, 8, 2, 3, bias=False)
        #self.u5 = nn.ConvTranspose2d(64, 64, 8, 2, 3, bias=False)
        self.u5 = nn.ConvTranspose2d(96, 64, 8, 2, 3, bias=False)
        self.out = nn.Conv2d(64, 1, 3, 1, 1, bias=False)
        self.LeakyRelu = nn.LeakyReLU(0.2, inplace=True)
        self.Relu = nn.ReLU(True)
        self.d1Norm = nn.InstanceNorm2d(32)
        self.d2Norm = nn.InstanceNorm2d(64)
        self.d3Norm = nn.InstanceNorm2d(128)
        self.d4Norm = nn.InstanceNorm2d(256)
        self.d5Norm = nn.InstanceNorm2d(512)

    def forward(self, input):
        d1 = self.d1(input)
        d1 = self.LeakyRelu(d1)
        d1 = self.d1Norm(d1)
        
        d1_1 = self.d1_1(input)
        d1_1 = self.LeakyRelu(d1_1)
        d1_1 = self.d1Norm(d1_1)
        
        d2 = self.d2(d1)
        d2 = self.LeakyRelu(d2)
        d2 = self.d2Norm(d2)
        
        d2_1 = self.d2_1(d1_1)
        d2_1 = self.LeakyRelu(d2_1)
        d2_1 = self.d2Norm(d2_1)
        
        d3 = self.d3(d2)
        d3 = self.LeakyRelu(d3)
        d3 = self.d3Norm(d3)
        
        d3_1 = self.d3_1(d2_1)
        d3_1 = self.LeakyRelu(d3_1)
        d3_1 = self.d3Norm(d3_1)
        
        d4 = self.d4(d3)
        d4 = self.LeakyRelu(d4)
        d4 = self.d4Norm(d4)
        
        d4_1 = self.d4_1(d3_1)
        d4_1 = self.LeakyRelu(d4_1)
        d4_1 = self.d4Norm(d4_1)
        
        d5 = self.d5(d4)
        d5 = self.LeakyRelu(d5)
        d5 = self.d5Norm(d5)
        
        u1 = self.u1(d5)
        u1 = self.Relu(u1)
        u1 = self.d4Norm(u1)
        u1 = torch.cat([u1, d4, d4_1],1)
        
        u2 = self.u2(u1)
        u2 = self.Relu(u2)
        u2 = self.d3Norm(u2)
        u2 = torch.cat([u2, d3, d3_1],1)
        
        u3 = self.u3(u2)
        u3 = self.Relu(u3)
        u3 = self.d2Norm(u3)
        u3 = torch.cat([u3, d2, d2_1],1)
        
        u4 = self.u4(u3)
        u4 = self.Relu(u4)
        u4 = self.d1Norm(u4)
        u4 = torch.cat([u4, d1, d1_1],1)
        
        u5 = self.u5(u4)
        output = self.out(u5)
        output = torch.tanh(output)
        
        
        return output

In [13]:
train_A_loader,test_A_loader,train_B_loader,test_B_loader = get_data_loader()
train_A_loader_iterator = iter(train_A_loader)
real_A, _ = train_A_loader_iterator.next()

In [18]:
G = Generator2()
top = G(real_A[:,:,:128,:])
mid = G(real_A[:,:,64:192,:])
bot = G(real_A[:,:,128:256,:])

half_img = torch.zeros_like(top)
half_img[:,:,:64,:] = (top[:,:,64:,:] + mid[:,:,:64,:])/2
half_img[:,:,64:,:] = (bot[:,:,:64,:] + mid[:,:,64:,:])/2

img = torch.zeros_like(real_A)
img[:,:,:64,:] = top[:,:,:64,:]
img[:,:,64:128,:] = half_img[:,:,:64,:]
img[:,:,128:192,:] = half_img[:,:,64:,:]
img[:,:,192:,:] = bot[:,:,64:,:]

img.shape

torch.Size([1, 1, 256, 256])

In [15]:
img = torch.zeros_like(real_A)


torch.Size([1, 1, 256, 256])

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 32, 8, 2, 3, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # input is 64 x 64
            nn.Conv2d(64, 64*2, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64) x 32 x 32
            nn.Conv2d(64*2, 64 * 4, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64*2) x 16 x 16
            nn.Conv2d(64*4, 64 * 8, 8, 2, 3, bias=False),
            nn.InstanceNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64 * 8, 1, 3, 1, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output
        #return output.view(-1, 1).squeeze(1)

In [14]:
D = Discriminator()
pred = D(real_A[:,:,:128,:])
pred.shape

torch.Size([1, 1, 4, 8])

In [7]:
#Load the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Create the Generator Network
G_AB = Generator2()
G_BA = Generator2()
#Load Generator Network onto the GPU
G_AB.to(device)
G_BA.to(device)
#Apply weights to the Generator
G_AB.apply(weights_init)
G_BA.apply(weights_init)
#Create the Discriminator Network
D_A = Discriminator()
D_B = Discriminator()
#Load the Disriminator Network onto the GPU
D_A.to(device)
D_B.to(device)
#Apply weights to the Discriminator
D_A.apply(weights_init)
D_B.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 32, kernel_size=(8, 8), stride=(2, 2), padding=(3, 3), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(32, 64, kernel_size=(8, 8), stride=(2, 2), padding=(3, 3), bias=False)
    (3): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(64, 128, kernel_size=(8, 8), stride=(2, 2), padding=(3, 3), bias=False)
    (6): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(128, 256, kernel_size=(8, 8), stride=(2, 2), padding=(3, 3), bias=False)
    (9): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(256, 512, kernel_size=(8, 8), stride=(2, 2), padding=(3, 3), bias=False)
    (12): InstanceNorm2d(5

In [8]:
import torch.optim as optim
#Create Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()),
                                lr=0.0002, betas=(0.5, 0.999))
#Optimizer_G_AB = optim.Adam(G_AB.parameters(), lr=0.0002, betas=(0.5, 0.999))
#Optimizer_G_BA = optim.Adam(G_BA.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss()
cycle_consis_loss = nn.L1Loss()
identity_loss = nn.MSELoss()

In [11]:
train_A_loader,test_A_loader,train_B_loader,test_B_loader = get_data_loader()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
test_A_loader_iterator = iter(test_A_loader)
test_B_loader_iterator = iter(test_B_loader)
test_real_A, _ = test_A_loader_iterator.next()
test_real_B, _ = test_B_loader_iterator.next()
test_real_A = test_real_A.to(device)
test_real_B = test_real_B.to(device)
for epoch in range(1,20):
    train_A_loader_iterator = iter(train_A_loader)
    train_B_loader_iterator = iter(train_B_loader)
    brake = False
    for cnt in range(1000):
        target_real = torch.full((1,1,4,8), np.random.uniform(.95,1.05), device=device)
        target_fake = torch.full((1,1,4,8), np.random.uniform(-0.05,.05), device=device)
        try:
            real_A, _ = train_A_loader_iterator.next()
            real_B, _ = train_B_loader_iterator.next()
            real_A = real_A[:,:,64:192,:]
            real_B = real_B[:,:,64:192,:]
            real_A = real_A.to(device)
            real_B = real_B.to(device)
        except:
            brake = True
            break
        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        # G_A2B(B) should equal B if real B is fed
        same_B = G_AB(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*10.0
        # G_B2A(A) should equal A if real A is fed
        same_A = G_BA(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*10.0

        # GAN loss
        fake_B = G_AB(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

        fake_A = G_BA(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

        # Cycle loss
        recovered_A = G_BA(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0

        recovered_B = G_AB(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0
        
        #MINE Loss
        retAB = torch.mean(real_B) - torch.log(torch.mean(torch.exp(fake_B)))
        retBA = torch.mean(real_A) - torch.log(torch.mean(torch.exp(fake_A)))

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB - retAB - retBA
        loss_G.backward()
        
        optimizer_G.step()
        ###################################
        for i in range(2):
            try:
                real_A, _ = train_A_loader_iterator.next()
                real_B, _ = train_B_loader_iterator.next()
                real_A = real_A[:,:,64:192,:]
                real_B = real_B[:,:,64:192,:]
                real_A = real_A.to(device)
                real_B = real_B.to(device)
            except:
                brake = True
                break
            ###### Discriminator A ######
            target_real = torch.full((1,1,4,8), np.random.uniform(.95,1.05), device=device)
            target_fake = torch.full((1,1,4,8), np.random.uniform(-0.05,.05), device=device)
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = D_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = D_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = D_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = D_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()

            optimizer_D_B.step()
        ###################################
    
    torch.save(G_AB.state_dict(), 'C:/Users/ajhon/Desktop/Industrial Math Presentation/us_output_half_imgs/netG_A2B_%d.pth' % epoch)
    torch.save(G_BA.state_dict(), 'C:/Users/ajhon/Desktop/Industrial Math Presentation/us_output_half_imgs/netG_B2A_%d.pth' % epoch)
    torch.save(D_A.state_dict(), 'C:/Users/ajhon/Desktop/Industrial Math Presentation/us_output_half_imgs/netD_A_%d.pth' % epoch)
    torch.save(D_B.state_dict(), 'C:/Users/ajhon/Desktop/Industrial Math Presentation/us_output_half_imgs/netD_B_%d.pth' % epoch)
    
    fake_imgs_B_top = G_AB(test_real_A[:,:,:128,:])
    fake_imgs_B_mid = G_AB(test_real_A[:,:,64:192,:])
    fake_imgs_B_bot = G_AB(test_real_A[:,:,128:256,:])
    
    half_img_B = torch.zeros_like(fake_imgs_B_top)
    half_img_B[:,:,:64,:] = (fake_imgs_B_top[:,:,64:,:] + fake_imgs_B_mid[:,:,:64,:])/2
    half_img_B[:,:,64:,:] = (fake_imgs_B_bot[:,:,:64,:] + fake_imgs_B_mid[:,:,64:,:])/2
    
    fake_imgs_B = torch.zeros_like(test_real_A)
    fake_imgs_B[:,:,:64,:] = fake_imgs_B_top[:,:,:64,:]
    fake_imgs_B[:,:,64:128,:] = half_img_B[:,:,:64,:]
    fake_imgs_B[:,:,128:192,:] = half_img_B[:,:,64:,:]
    fake_imgs_B[:,:,192:,:] = fake_imgs_B_bot[:,:,64:,:]
    
    fake_imgs_A_top = G_BA(test_real_B[:,:,:128,:])
    fake_imgs_A_mid = G_BA(test_real_B[:,:,64:192,:])
    fake_imgs_A_bot = G_BA(test_real_B[:,:,128:256,:])
    
    half_img_A = torch.zeros_like(fake_imgs_A_top)
    half_img_A[:,:,:64,:] = (fake_imgs_A_top[:,:,64:,:] + fake_imgs_A_mid[:,:,:64,:])/2
    half_img_A[:,:,64:,:] = (fake_imgs_A_bot[:,:,:64,:] + fake_imgs_A_mid[:,:,64:,:])/2
    
    fake_imgs_A = torch.zeros_like(test_real_A)
    fake_imgs_A[:,:,:64,:] = fake_imgs_A_top[:,:,:64,:]
    fake_imgs_A[:,:,64:128,:] = half_img_B[:,:,:64,:]
    fake_imgs_A[:,:,128:192,:] = half_img_B[:,:,64:,:]
    fake_imgs_A[:,:,192:,:] = fake_imgs_B_bot[:,:,64:,:]
    
    unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    fake_imgs_B = unorm(fake_imgs_B)
    fake_imgs_A = unorm(fake_imgs_A)
    save_image(fake_imgs_B.detach().cpu(),'C:/Users/ajhon/Desktop/Industrial Math Presentation/half_test_series/test%d.png' % epoch)
    #save_image(fake_imgs_A.detach().cpu(),'C:/Users/ajhon/Desktop/Industrial Math Presentation/Low_freq_256_2/test%d.png' % epoch)

In [19]:
    fake_imgs_B_top = G_AB(test_real_A[:,:,:128,:])
    fake_imgs_B_mid = G_AB(test_real_A[:,:,64:192,:])
    fake_imgs_B_bot = G_AB(test_real_A[:,:,128:256,:])
    
    half_img_B = torch.zeros_like(fake_imgs_B_top)
    half_img_B[:,:,:64,:] = (fake_imgs_B_top[:,:,64:,:] + fake_imgs_B_mid[:,:,:64,:])/2
    half_img_B[:,:,64:,:] = (fake_imgs_B_bot[:,:,:64,:] + fake_imgs_B_mid[:,:,64:,:])/2
    
    fake_imgs_B = torch.zeros_like(test_real_A)
    fake_imgs_B[:,:,:64,:] = fake_imgs_B_top[:,:,:64,:]
    fake_imgs_B[:,:,64:128,:] = half_img_B[:,:,:64,:]
    fake_imgs_B[:,:,128:192,:] = half_img_B[:,:,64:,:]
    fake_imgs_B[:,:,192:,:] = fake_imgs_B_bot[:,:,64:,:]
    
    fake_imgs_A_top = G_BA(test_real_B[:,:,:128,:])
    fake_imgs_A_mid = G_BA(test_real_B[:,:,64:192,:])
    fake_imgs_A_bot = G_BA(test_real_B[:,:,128:256,:])
    
    half_img_A = torch.zeros_like(fake_imgs_A_top)
    half_img_A[:,:,:64,:] = (fake_imgs_A_top[:,:,64:,:] + fake_imgs_A_mid[:,:,:64,:])/2
    half_img_A[:,:,64:,:] = (fake_imgs_A_bot[:,:,:64,:] + fake_imgs_A_mid[:,:,64:,:])/2
    
    fake_imgs_A = torch.zeros_like(test_real_A)
    fake_imgs_A[:,:,:64,:] = fake_imgs_A_top[:,:,:64,:]
    fake_imgs_A[:,:,64:128,:] = half_img_B[:,:,:64,:]
    fake_imgs_A[:,:,128:192,:] = half_img_B[:,:,64:,:]
    fake_imgs_A[:,:,192:,:] = fake_imgs_B_bot[:,:,64:,:]
    
    unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    fake_imgs_B = unorm(fake_imgs_B)
    fake_imgs_A = unorm(fake_imgs_A)
    save_image(fake_imgs_B.detach().cpu(),'C:/Users/ajhon/Desktop/Industrial Math Presentation/us_output_half_imgs/test%d.png' % epoch)
    #save_image(fake_imgs_A.detach().cpu(),'C:/Users/ajhon/Desktop/Industrial Math Presentation/Low_freq_256_2/test%d.png' % epoch)

In [16]:
pred_real.shape

torch.Size([1, 1, 8, 8])

In [19]:
G = Generator2()
top = G(real_A[:,:,:128,:])
mid = G(real_A[:,:,64:192,:])
bot = G(real_A[:,:,128:256,:])

half_img = torch.zeros_like(top)
half_img[:,:,:64,:] = (top[:,:,64:,:] + mid[:,:,:64,:])/2
half_img[:,:,64:,:] = (bot[:,:,:64,:] + mid[:,:,64:,:])/2

img = torch.zeros_like(real_A)
img[:,:,:64,:] = top[:,:,:64,:]
img[:,:,64:128,:] = half_img[:,:,:64,:]
img[:,:,128:192,:] = half_img[:,:,64:,:]
img[:,:,192:,:] = bot[:,:,64:,:]

img.shape

torch.Size([1, 1, 256, 256])