In [1]:
# cloned from https://github.com/eriklindernoren/PyTorch-GAN

import os
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 
# os.environ['CUDA_VISIBLE_DEVICES']='4,5,6,7'

import time
import itertools
import requests

import numpy as np
import matplotlib.pyplot as plt
import h5py
import random

import torch
from torch.optim import Adam

import torchvision.models.resnet as resnet

from torch.utils.data import DataLoader, TensorDataset
import torchvision.utils as vutils
import torch.autograd as autograd
# from torch.nn.utils import weight_norm

from cyclegan.model import Discriminator, Generator
from cyclegan.model import weights_init

from sklearn.model_selection import train_test_split

import wandb

# import torch.backends.cudnn as cudnn
# cudnn.benchmark = True

from torch.backends import cudnn
cudnn.benchmark = True

device1 = torch.device("cuda:0")
device2 = torch.device("cuda:1")
device3 = torch.device("cuda:2")
device4 = torch.device("cuda:3")

In [2]:
# import sys
# sys.last_value

In [3]:
SEND = 'https://api.telegram.org/bot'+os.environ['TG']+'/'
def send(text):
    return requests.post(SEND+'sendMessage', json={'chat_id': 80968060, 'text': text}).json()['result']['message_id']

def update_msg(text, msg_id):
    resp = ''
    try:
        resp = requests.post(SEND+'editMessageText', json={'chat_id': 80968060, 'text': text, 'message_id': msg_id})
    except:
        pass
    return resp

In [4]:
class DecayLR:
    def __init__(self, epochs, offset, decay_epochs):
        epoch_flag = epochs - decay_epochs
        assert (epoch_flag > 0), "Decay must start before the training session ends!"
        self.epochs = epochs
        self.offset = offset
        self.decay_epochs = decay_epochs

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_epochs) / (
                self.epochs - self.decay_epochs)
    
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert (max_size > 0), "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

In [7]:
# X0_train = np.concatenate([np.load('X0_train_clean_48.npy'), np.load('X0_val_clean_48.npy')])
# X1_train = np.concatenate([np.load('X1_train_clean_48.npy'), np.load('X1_val_clean_48.npy')])

X0_train = np.load('X0_train_clean_48.npy')
X1_train = np.load('X1_train_clean_48.npy')


# n = min(X0_train.shape[0], X1_train.shape[0])

X0_train = torch.from_numpy((X0_train - .5) / .5)
X1_train = torch.from_numpy((X1_train - .5) / .5)



In [8]:
trainloader_A = DataLoader(X0_train, batch_size=128, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
trainloader_B = DataLoader(X1_train, batch_size=128, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

In [9]:
class ResNet(resnet.ResNet):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(self._forward_impl(x))
    
@torch.no_grad()
def eval_G(G, clf, validation_loader, g_device, clf_device):
    G.eval()
    acc = .0
    for i, data in enumerate(validation_loader):
        X = data[0].to(g_device)
        y = data[1].to(clf_device)
        X_g = G(X).to(clf_device)
        predicted = torch.round(clf(0.5 * (X_g + 1.0)))
        acc+=(predicted == y).sum()/float(predicted.shape[0])     
#             acc_g+=(predicted_g == y).sum()/float(predicted_g.shape[0])     
    G.train()
    return (acc/(i+1)).detach().item()

@torch.no_grad()
def eval_Clf(model, validation_loader, device):
    acc = .0
    for i, data in enumerate(validation_loader):
        X = data[0].to(device)
        y = data[1].to(device)
        predicted = torch.round(model(0.5 * (X + 1.0)))
        acc+=(predicted == y).sum()/float(predicted.shape[0])       
    return (acc/(i+1)).detach().item()

In [10]:
X0_test = (np.load('X0_val_clean_48.npy') - .5) / .5
y0_test = np.load('y0_val_clean_48.npy')

X1_test = (np.load('X1_val_clean_48.npy') - .5) / .5
y1_test = np.load('y1_val_clean_48.npy')

X0_test = torch.from_numpy(X0_test)
y0_test = torch.from_numpy(y0_test)
X1_test = torch.from_numpy(X1_test)
y1_test = torch.from_numpy(y1_test)


testloader0 = DataLoader(TensorDataset(X0_test, y0_test), batch_size=128, shuffle=True, num_workers=1, pin_memory=True)
testloader1 = DataLoader(TensorDataset(X1_test, y1_test), batch_size=128, shuffle=True, num_workers=1, pin_memory=True)

In [11]:
ClfA = ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=1)
ClfB = ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=1)
ClfA.load_state_dict(torch.load('results/clf_resnet18_48/best_model.pth'))
ClfB.load_state_dict(torch.load('results/clf_resnet18_48/best_model.pth'))
ClfA = ClfA.to(device3)
ClfB = ClfB.to(device4)
ClfA.eval()
ClfB.eval()

print('Acc for A:', eval_Clf(ClfA, testloader0, device3))
print('Acc for B:', eval_Clf(ClfB, testloader1, device4))

Acc for A: 0.8578712940216064
Acc for B: 0.88215172290802


In [12]:
netG_B2A = Generator().to(device1)
netD_A = Discriminator().to(device1)
netG_A2B = Generator().to(device2)
netD_B = Discriminator().to(device2)

netG_A2B.apply(weights_init)
netG_B2A.apply(weights_init)
netD_A.apply(weights_init)
netD_B.apply(weights_init)

cycle_loss1 = torch.nn.L1Loss().to(device1)
cycle_loss2 = torch.nn.L1Loss().to(device2)
identity_loss1 = torch.nn.L1Loss().to(device1)
identity_loss2 = torch.nn.L1Loss().to(device2)
adversarial_loss1 = torch.nn.MSELoss().to(device1)
adversarial_loss2 = torch.nn.MSELoss().to(device2)

# adversarial_loss1 = torch.nn.BCEWithLogitsLoss().to(device1)
# adversarial_loss2 = torch.nn.BCEWithLogitsLoss().to(device2)

def r1(output, imgs, gamma=0.00001):
    grad_real = torch.autograd.grad(outputs=output, inputs=imgs, create_graph=True)[0]
    grad_penalty_real = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
    return (gamma) * grad_penalty_real

def gradient_penalty(D, real_data, generated_data, device):
    batch_size = real_data.shape[0]

    # Calculate interpolation
    alpha = torch.rand(batch_size, 1, 1, 1)
    alpha = alpha.expand_as(real_data).to(device)
    interpolated = alpha * real_data + (1 - alpha) * generated_data

    # Calculate probability of interpolated examples
    dis_interpolated = D(interpolated)
    grad_outputs = torch.ones(dis_interpolated.shape).to(device)

    # Calculate gradients of probabilities with respect to examples
    gradients = autograd.grad(outputs=dis_interpolated, inputs=interpolated,
                           grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(batch_size, -1)
#         self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data[0])

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = ((torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) - 1) ** 2).mean()
    # Return gradient penalty
    return 10*gradients_norm

# lambda_gp = 10
lr = 0.0002
betas = (0.5, 0.999)
# lr = 1e-4
# betas = (0, 0.99)
# itertools.chain takes a series of iterables and return them as one long iterable.
optimizer_G = Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=lr, betas=betas)
optimizer_D_A = Adam(netD_A.parameters(), lr=lr, betas=betas)
optimizer_D_B = Adam(netD_B.parameters(), lr=lr, betas=betas)

epochs = 200
decay_epochs = 100
lr_lambda = DecayLR(epochs, 0, decay_epochs).step
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

g_losses = []
d_losses_A = []
d_losses_B = []
acc_a = []
acc_b = []

identity_losses = []
gan_losses = []
cycle_losses = []

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

In [None]:
project = 'cyclegan_48'
run_name = 'wgan_r1'
folder=project+'_'+run_name

msg_id = send(folder+': 0')

run = wandb.init(project=project)
wandb.run.name = run_name
wandb.run.save()

path = 'results/'+folder
path_imgs = path +'/samples'
if not os.path.exists('results'):
    os.mkdir('results')
if not os.path.exists(path):
    os.mkdir(path)
if not os.path.exists(path_imgs):
    os.mkdir(path_imgs)
if not os.path.exists(path_imgs+'/A'):
    os.mkdir(path_imgs+'/A')
if not os.path.exists(path_imgs+'/B'):
    os.mkdir(path_imgs+'/B')
    
# print(
#     f" Iter.\t"
#     f"LossD A\t"
#     f"LossD B\t"
#     f" Loss G\t"
#     f"Acc B2A\t"
#     f"Acc A2B")

    
best_acc_A = 0
best_acc_B = 0

early_stop_cnt = 0

iter_A = iter(trainloader_A)
iter_B = iter(trainloader_B)

total_iter = 100000

for i in range(1, total_iter+1):
    try:
        data_A = next(iter_A)
    except:
        iter_A = iter(trainloader_A)
        data_A = next(iter_A)

    try:
        data_B = next(iter_B)
    except:
        iter_B = iter(trainloader_B)
        data_B = next(iter_B)

    # get batch size data
    real_image_A1 = data_A.to(device1)
    real_image_B1 = data_B.to(device1)
    real_image_A2 = data_A.to(device2)
    real_image_B2 = data_B.to(device2)

    real_image_A1.requires_grad_()
    real_image_B2.requires_grad_()
    
    batch_size = real_image_A1.size(0)

    # real data label is 1, fake data label is 0.
    real_label1 = torch.full((batch_size, 1), 1, device=device1, dtype=torch.float32)
    fake_label1 = torch.full((batch_size, 1), 0, device=device1, dtype=torch.float32)
    real_label2 = torch.full((batch_size, 1), 1, device=device2, dtype=torch.float32)
    fake_label2 = torch.full((batch_size, 1), 0, device=device2, dtype=torch.float32)
    
    

    ##############################################
    # (1) Update G network: Generators A2B and B2A
    ##############################################

    # Set G_A and G_B's gradients to zero
    optimizer_G.zero_grad()

    # Identity loss
    # G_B2A(A) should equal A if real A is fed
    identity_image_A = netG_B2A(real_image_A1)
    loss_identity_A = identity_loss1(identity_image_A, real_image_A1) * 5.0
    # G_A2B(B) should equal B if real B is fed
    identity_image_B = netG_A2B(real_image_B2)
    loss_identity_B = identity_loss2(identity_image_B, real_image_B2) * 5.0

    # GAN loss D_A(G_A(A))
    fake_image_A = netG_B2A(real_image_B1)
    fake_output_A = netD_A(fake_image_A)
#     loss_GAN_B2A = adversarial_loss1(fake_output_A, real_label1)
    # WGAN
    loss_GAN_B2A = -fake_output_A.mean()

    # GAN loss D_B(G_B(B))
    fake_image_B = netG_A2B(real_image_A2)
    fake_output_B = netD_B(fake_image_B)
#     loss_GAN_A2B = adversarial_loss2(fake_output_B, real_label2)
#     WGAN
    loss_GAN_A2B = -fake_output_B.mean()

    # Cycle loss
    recovered_image_A = netG_B2A(fake_image_B.to(device1))
    loss_cycle_ABA = cycle_loss1(recovered_image_A, real_image_A1) * 10.0

    recovered_image_B = netG_A2B(fake_image_A.to(device2))
    loss_cycle_BAB = cycle_loss2(recovered_image_B, real_image_B2) * 10.0

    torch.cuda.synchronize()
    # Combined loss and calculate gradients
    errG = loss_identity_A.cpu() + loss_identity_B.cpu() + loss_GAN_A2B.cpu() + loss_GAN_B2A.cpu() + loss_cycle_ABA.cpu() + loss_cycle_BAB.cpu()

    g_losses.append(errG.item())
    # Calculate gradients for G_A and G_B
    errG.backward()
    # Update G_A and G_B's weights
    optimizer_G.step()

    ##############################################
    # (2) Update D network: Discriminator A
    ##############################################

    # Set D_A gradients to zero
    optimizer_D_A.zero_grad()

    # Real A image loss
    real_output_A = netD_A(real_image_A1)
#     errD_real_A = adversarial_loss1(real_output_A, real_label1) + r1(real_output_A.sum(), real_image_A1)
    # WGAN
    errD_real_A = real_output_A.mean() + r1(real_output_A.sum(), real_image_A1)

    # Fake A image loss
    fake_image_A = fake_A_buffer.push_and_pop(fake_image_A)
    fake_output_A = netD_A(fake_image_A.detach().to(device1))
#     errD_fake_A = adversarial_loss1(fake_output_A, fake_label1)
    # WGAN
    errD_fake_A = fake_output_A.mean()

    # Combined loss and calculate gradients
#     errD_A = (errD_real_A + errD_fake_A) / 2
    # WGAN
    gp_A = gradient_penalty(netD_A, real_image_A1, netG_B2A(real_image_B1), device1)
    errD_A = (errD_fake_A - errD_real_A + gp_A)

    d_losses_A.append(errD_A.item())

    # Calculate gradients for D_A
    errD_A.backward()
    # Update D_A weights
    optimizer_D_A.step()

    ##############################################
    # (3) Update D network: Discriminator B
    ##############################################

    # Set D_B gradients to zero
    optimizer_D_B.zero_grad()

    # Real B image loss
    real_output_B = netD_B(real_image_B2)
#     errD_real_B = adversarial_loss2(real_output_B, real_label2) + r1(real_output_B.sum(), real_image_B2)
    # WGAN
    errD_real_B = real_output_B.mean() + r1(real_output_B.sum(), real_image_B2)

    # Fake B image loss
    fake_image_B = fake_B_buffer.push_and_pop(fake_image_B)
    fake_output_B = netD_B(fake_image_B.detach().to(device2))
#     errD_fake_B = adversarial_loss2(fake_output_B, fake_label2)
    # WGAN
    errD_fake_B = fake_output_B.mean()

    # Combined loss and calculate gradients
#     errD_B = (errD_real_B + errD_fake_B) / 2
    # WGAN
    gp_B = gradient_penalty(netD_B, real_image_B2, netG_A2B(real_image_A2), device2)
    errD_B = (errD_fake_B - errD_real_B + gp_B)

    d_losses_B.append(errD_B.item())

    # Calculate gradients for D_B
    errD_B.backward()
    # Update D_B weights
    optimizer_D_B.step()
        
#     torch.cuda.synchronize()
    if i%100 == 0:
        update_msg(folder+': '+str(i/total_iter), msg_id)
        accA = eval_G(netG_B2A, ClfA, testloader1, device1, device3)
        accB = eval_G(netG_A2B, ClfB, testloader0, device2, device4)
        
        acc_a.append(accA)
        acc_b.append(accB)

#         print(
#             f"{i:06d}\t"
#             f"{d_losses_A[-1]:2.4f}\t"
#             f"{d_losses_B[-1]:2.4f}\t"
#             f"{g_losses[-1]:2.4f}\t"
#             f"{accA:2.4f}\t"
#             f"{accB:2.4f}")
        
        wandb.log(
            {
                "d_losses_A": d_losses_A[-1],
                "d_losses_B": d_losses_B[-1],
                "g_losses": g_losses[-1],
                "accA": accA,
                "accB": accB,
            }
        )
    
        if accA > best_acc_A:
            best_acc_A = accA
            # save last check pointing
            torch.save(netG_B2A.state_dict(), path+"/netG_B2A.pth")
            torch.save(netD_A.state_dict(), path+"/netD_A.pth")
            wandb.run.summary["best_acc_A"] = accA
        if accB > best_acc_B:
            best_acc_B = accB
            # save last check pointing
            torch.save(netG_A2B.state_dict(), path+"/netG_A2B.pth")
            torch.save(netD_B.state_dict(), path+"/netD_B.pth")
            wandb.run.summary["best_acc_B"] = accB
    
#     if i%1000 == 0:
#         fake_image_A = 0.5 * (netG_B2A(real_image_B1).data[:16] + 1.0)
#         fake_image_B = 0.5 * (netG_A2B(real_image_A2).data[:16] + 1.0)

#         vutils.save_image(fake_image_A.detach(),
#                         path_imgs+f"/A/{i:06d}_fake.png",
#                         normalize=True)
#         vutils.save_image(fake_image_B.detach(),
#                         path_imgs+f"/B/{i:06d}_fake.png",
#                         normalize=True)

    if i%1000 == 0:
        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

send(folder+' done')

wandb.save(path+"/netG_B2A.pth")
wandb.save(path+"/netD_A.pth")
wandb.save(path+"/netG_A2B.pth")
wandb.save(path+"/netD_B.pth")
    
torch.save(netG_A2B.state_dict(), path+"/netG_A2B_last.pth")
torch.save(netG_B2A.state_dict(), path+"/netG_B2A_last.pth")
torch.save(netD_A.state_dict(), path+"/netD_A_last.pth")
torch.save(netD_B.state_dict(), path+"/netD_B_last.pth")

np.save(path+'/d_losses_A.npy', d_losses_A)
np.save(path+'/d_losses_B.npy', d_losses_B)
np.save(path+'/g_losses.npy', g_losses)
np.save(path+'/acc_a.npy', acc_a)
np.save(path+'/acc_b.npy', acc_b)

run.finish()

[34m[1mwandb[0m: Currently logged in as: [33marray[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.26 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




In [3]:
from torchsummary import summary
from cyclegan.model import Generator as CycleGANGenerator
from cyclegan.model import Discriminator as CycleGANDiscriminator

In [2]:
summary(CycleGANGenerator().cuda(), (3, 48, 48), batch_size=128)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ReflectionPad2d-1           [128, 3, 54, 54]               0
            Conv2d-2          [128, 64, 48, 48]           9,472
    InstanceNorm2d-3          [128, 64, 48, 48]               0
              ReLU-4          [128, 64, 48, 48]               0
            Conv2d-5         [128, 128, 24, 24]          73,856
    InstanceNorm2d-6         [128, 128, 24, 24]               0
              ReLU-7         [128, 128, 24, 24]               0
            Conv2d-8         [128, 256, 12, 12]         295,168
    InstanceNorm2d-9         [128, 256, 12, 12]               0
             ReLU-10         [128, 256, 12, 12]               0
  ReflectionPad2d-11         [128, 256, 14, 14]               0
           Conv2d-12         [128, 256, 12, 12]         590,080
   InstanceNorm2d-13         [128, 256, 12, 12]               0
             ReLU-14         [128, 256,

In [4]:
summary(CycleGANDiscriminator().cuda(), (3, 48, 48), batch_size=128)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [128, 64, 24, 24]           3,136
         LeakyReLU-2          [128, 64, 24, 24]               0
            Conv2d-3         [128, 128, 12, 12]         131,200
    InstanceNorm2d-4         [128, 128, 12, 12]               0
         LeakyReLU-5         [128, 128, 12, 12]               0
            Conv2d-6           [128, 256, 6, 6]         524,544
    InstanceNorm2d-7           [128, 256, 6, 6]               0
         LeakyReLU-8           [128, 256, 6, 6]               0
            Conv2d-9           [128, 512, 5, 5]       2,097,664
   InstanceNorm2d-10           [128, 512, 5, 5]               0
        LeakyReLU-11           [128, 512, 5, 5]               0
           Conv2d-12             [128, 1, 4, 4]           8,193
Total params: 2,764,737
Trainable params: 2,764,737
Non-trainable params: 0
---------------------------