In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from dataset import train_dataloader, test_dataloader
import torchvision.utils as vutils
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda")
ngf = 64
ndf = 64
nc = 3
lr = 1
beta1 = 0.5

In [30]:
x = torch.rand(32,3,32)
y = torch.rand(32,3,32)

z = torch.cat((x,y),0)
z = torch.mean(z,axis=-1)
print(z.shape)

torch.Size([64, 3])


In [35]:
def multilossD(pred_concat, target_concat):
    mean = torch.mean(torch.square(pred_concat - target_concat), axis = 0)
    return torch.exp(torch.mean(torch.log(mean), axis = -1))

In [56]:
def multilossG(pred_nor,pred_disp,pred_rough,data):
    mse_nor_mean = torch.mean(torch.square(pred_nor - data["nor"].to(device)), axis = 0)
    mse_disp_mean = torch.mean(torch.square(pred_disp - data["disp"].to(device)), axis = 0)
    mse_rough_mean = torch.mean(torch.square(pred_rough - data["rough"].to(device)), axis = 0)
    mse_concat = torch.cat((mse_nor_mean, mse_disp_mean, mse_rough_mean),0)
    mse_concat = torch.mean(torch.log(mse_concat),axis = -1)
    mse_concat = torch.mean(mse_concat,axis = -1)
    return torch.exp(torch.mean(mse_concat,axis=-1))

In [41]:
print(multilossD(x,y))

tensor([0.1601, 0.1741, 0.1610])


In [63]:
#DCGAN - Disp
from network import GeneratorSkipMultitask, Discriminator,BasicBlock, weights_init

netG = GeneratorSkipMultitask(ngf,BasicBlock).to(device)
netG.apply(weights_init)
netD_nor = Discriminator(ndf,3).to(device)
netD_disp = Discriminator(ndf,1).to(device)
netD_rough = Discriminator(ndf,1).to(device)
netD_nor.apply(weights_init)
netD_disp.apply(weights_init)
netD_rough.apply(weights_init)

criterion = nn.BCELoss()
mse = nn.MSELoss()
def lr_schedule(epoch):
    lr = 0
    if(epoch < 8):
        lr = 0.001 * ((epoch+1)/8)
    else:
        lr = 0.001*(8/epoch)
    return lr

optimizerD_nor = optim.Adam(netD_nor.parameters(), lr=lr, betas=(beta1, 0.999))
DScheduler_nor = torch.optim.lr_scheduler.LambdaLR(optimizerD_nor, lr_schedule)
optimizerD_disp = optim.Adam(netD_disp.parameters(), lr=lr, betas=(beta1, 0.999))
DScheduler_disp = torch.optim.lr_scheduler.LambdaLR(optimizerD_disp, lr_schedule)
optimizerD_rough = optim.Adam(netD_rough.parameters(), lr=lr, betas=(beta1, 0.999))
DScheduler_rough = torch.optim.lr_scheduler.LambdaLR(optimizerD_rough, lr_schedule)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
GScheduler = torch.optim.lr_scheduler.LambdaLR(optimizerG, lr_schedule)

In [None]:
# Training Loop
real_label = 1.0
fake_label = 0.0
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
test_MSE_nor = []
test_MSE_disp = []
test_MSE_rough = []
iters = 0
num_epochs = 30
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(train_dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD_nor.zero_grad()
        netD_disp.zero_grad()
        netD_rough.zero_grad()
        # Format batch
        real_nor = data["nor"].to(device)
        real_disp = data["disp"].to(device)
        real_rough = data["rough"].to(device)
        b_size = real_nor.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output_nor = netD_nor(real_nor).view(-1)
        output_disp = netD_disp(real_disp).view(-1)
        output_rough = netD_rough(real_rough).view(-1)
        # Calculate loss on all-real batch
        errD_real_nor = criterion(output_nor, label)
        errD_real_disp = criterion(output_disp, label)
        errD_real_rough = criterion(output_rough, label)
        errD_real = errD_real_nor + errD_real_disp + errD_real_rough
        # Calculate gradients for D in backward pass
        errD_real_nor.backward()
        errD_real_disp.backward()
        errD_real_rough.backward()
        D_x_nor = output_nor.mean().item()
        D_x_disp = output_disp.mean().item()
        D_x_rough = output_rough.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = data["diff"].to(device)
        # Generate fake image batch with G
        fake_nor, fake_disp, fake_rough = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output_nor = netD_nor(fake_nor.detach()).view(-1)
        output_disp = netD_disp(fake_disp.detach()).view(-1)
        output_rough = netD_rough(fake_rough.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake_nor = criterion(output_nor, label)
        errD_fake_disp = criterion(output_disp, label)
        errD_fake_rough = criterion(output_rough, label)
        errD_fake = errD_fake_nor + errD_fake_disp + errD_fake_rough
        # Calculate the gradients for this batch
        errD_fake_nor.backward()
        errD_fake_disp.backward()
        errD_fake_rough.backward()
        D_G_z1_nor = output_nor.mean().item()
        D_G_z1_disp = output_disp.mean().item()
        D_G_z1_rough = output_rough.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD_nor.step()
        optimizerD_disp.step()
        optimizerD_rough.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output_nor = netD_nor(fake_nor).view(-1)
        output_disp = netD_disp(fake_disp).view(-1)
        output_rough = netD_rough(fake_rough).view(-1)  
        output = torch.cat((output_nor, output_disp, output_rough),0)
        # Calculate G's loss based on this output
        errG = (multilossG(fake_nor,fake_disp,fake_rough, data)  + multilossD(output,torch.cat((label,label,label),0)))/2
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        loss_fn = nn.MSELoss()
        data_test = next(iter(test_dataloader))
        fake_nor, fake_disp, fake_rough = netG(data_test["diff"].to(device))
        test_MSE_nor.append(mse(fake_nor.detach().cpu(),data_test["nor"]))
        test_MSE_disp.append(mse(fake_disp.detach().cpu(),data_test["disp"]))
        test_MSE_rough.append(mse(fake_rough.detach().cpu(),data_test["rough"]))
        # Output training stats
        if i % 50 == 0:
            print(epoch,"errD: ",errD.item(), " errG: ",errG.item())
#             print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tMSE_LossG: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
#                     % (epoch, num_epochs, i, len(train_dataloader),
#                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            print("TEST MSE_nor: ", test_MSE_nor[-1])
            print("TEST MSE_disp: ", test_MSE_disp[-1])
            print("TEST MSE_rough: ", test_MSE_rough[-1])


        
        # Check how the generator is doing by saving G's output on fixed_noise
#         if iters % 10 == 0:
#             with torch.no_grad():
#                 loss_fn = nn.MSELoss()
#                 data = next(iter(test_dataloader))
#                 fake = netG(data["diff"].to(device)).detach().cpu()
#                 test_MSE_losses.append(mse(fake,data["nor"]))
#                 print("TEST MSE: ", mse(fake,data["nor"]))
#             img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

#         iters += 1
    DScheduler_nor.step()
    DScheduler_disp.step()
    DScheduler_rough.step()
    GScheduler.step()

Starting Training Loop...
0 errD:  5.751616954803467  errG:  0.4495849609375
TEST MSE_nor:  tensor(0.0499)
TEST MSE_disp:  tensor(0.0673)
TEST MSE_rough:  tensor(0.1205)


In [57]:
errG = (multilossG(fake_nor,fake_disp,fake_rough, data)  + multilossD(output,torch.cat((label,label,label),0)))/2

In [62]:
multilossD(output,torch.cat((label,label,label),0))

tensor(1.2233e-07, device='cuda:0', grad_fn=<ExpBackward>)

In [None]:
plt.figure(figsize=(10,5))
plt.title("Multitask MSE Loss")
plt.plot(test_MSE_nor,label="nor_MSE")
plt.plot(test_MSE_disp,label="disp_MSE")
plt.plot(test_MSE_rough,label="rough_MSE")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()