In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from aikfm.dataset import AikfmDataset
from aikfm.models import CAN8, UCAN64, discriminator
from torchsummary import summary

In [3]:
max_epoch_num = 30
mini_batch_size = 2
lambda1 = 100
lambda2 = 10

In [4]:
dataset = AikfmDataset("~/DKLabs/AI-KFM/AI-KFM/data")
dataloader = DataLoader(dataset, batch_size=mini_batch_size, shuffle=True)

# device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# Generater 1
g1 = CAN8()
g1.to(device)

# # Generator 2
# g2 = UCAN64()
# g2.to(device)

# # Discriminator
# dis = discriminator()
# dis.to(device)

CAN8(
  (leakyrelu1): LeakyReLU(negative_slope=0.2)
  (leakyrelu2): LeakyReLU(negative_slope=0.2)
  (leakyrelu3): LeakyReLU(negative_slope=0.2)
  (leakyrelu4): LeakyReLU(negative_slope=0.2)
  (leakyrelu5): LeakyReLU(negative_slope=0.2)
  (leakyrelu6): LeakyReLU(negative_slope=0.2)
  (leakyrelu7): LeakyReLU(negative_slope=0.2)
  (leakyrelu8): LeakyReLU(negative_slope=0.2)
  (g1_conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (g1_conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (g1_conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  (g1_conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  (g1_conv5): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))
  (g1_conv6): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
  (g1_conv7): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), paddin

In [6]:
summary(g1, (3, 200, 600))

RuntimeError: CUDA out of memory. Tried to allocate 470.00 MiB (GPU 0; 3.82 GiB total capacity; 1.39 GiB already allocated; 109.25 MiB free; 1.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [8]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 2         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    1423 MB |    1423 MB |    1601 MB |  182553 KB |
|       from large pool |    1422 MB |    1422 MB |    1599 MB |  181441 KB |
|       from small pool |       0 MB |       1 MB |       1 MB |    1112 KB |
|---------------------------------------------------------------------------|
| Active memory         |    1423 MB |    1423 MB |    1601 MB |  182553 KB |
|       from large pool |    1422 MB |    1422 MB |    1599 MB |  181441 KB |
|       from small pool |       0 MB |       1 MB |       1 MB |    1112 KB |
|---------------------------------------------------------------

In [6]:
# Define optimizers
optim_g1 = optim.AdamW(g1.parameters(), lr=1e-4, weight_decay=1e-5)
optim_g2 = optim.AdamW(g2.parameters(), lr=1e-4, weight_decay=1e-5)
optim_dis = optim.AdamW(dis.parameters(), lr=1e-5, weight_decay=1e-6)

In [7]:
# Loss function
loss1 = nn.BCEWithLogitsLoss()

In [8]:
imgs, masks = iter(dataloader).next()

In [9]:
imgs, masks = imgs.to(device), masks.to(device) # Move data to compute Device

### Discriminator training

In [None]:
###############################
# Train the discriminator first
dis.train()
g1.eval()
g2.eval()
optim_g1.zero_grad()
optim_g2.zero_grad()
optim_dis.zero_grad()

# Get generator outputs
g1_out = g1(imgs) # [B, 1, 1200, 1600]
g1_out = torch.clamp(g1_out, 0.0, 1.0)

g2_out = g2(imgs) # [B, 1, 1200, 1600]
g2_out = torch.clamp(g2_out, 0.0, 1.0)

pos1 = torch.cat([imgs, 2 * masks - 1], dim = 1) # [B, 4, H, W]
neg1 = torch.cat([imgs, 2 * g1_out - 1], dim = 1) # [B, 4, H, W]
neg2 = torch.cat([imgs, 2 * g2_out - 1], dim = 1) # [B, 4, H, W]

dis_input = torch.cat([pos1, neg1, neg2], dim=0) # # [3*B, 4, H, W]

# Get discriminator output
logits_real, logits_fake1, logits_fake2, Lgc = dis(dis_input)

const1 = torch.ones(imgs.size(0), 1, device=device, dtype=torch.float32)
const0 = torch.zeros(imgs.size(0), 1, device=device, dtype=torch.float32)

gen_gt = torch.cat([const1, const0, const0], dim=1)
gen_gt1 = torch.cat([const0, const1, const0], dim=1)
gen_gt2 = torch.cat([const0, const0, const1], dim=1)

ES0 = torch.mean(loss1(logits_real, gen_gt))
ES1 = torch.mean(loss1(logits_fake1, gen_gt1))
ES2 = torch.mean(loss1(logits_fake2, gen_gt2))

dis_loss = ES0 + ES1 + ES2 # Discriminator loss
print(f'Discriminator loss : {dis_loss}')

dis_loss.backward() # Compute gradients
optim_dis.step() # Apply gradients