In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install wandb

In [None]:
!unzip /content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/Data.zip

In [None]:
# !rm -r /content/datasets/SEM

In [None]:
import os

try:
    for a in ["train", "test"]:
        for i in ['A','B']:
            for j in [1,2,3,4]:
                for k in [80, 81,82,83,84]:
                    os.makedirs(f'/content/datasets/SEM/{a}/{i}/Case_{j}/{k}',exist_ok=True)
except:
    pass

In [None]:
import shutil
import os
from glob import glob
from tqdm.auto import tqdm
from collections import defaultdict
os.chdir('/content')
sem_list_iter0 = sorted(glob('./simulation_data/SEM/*/*/*itr0.png'))
depth_list = sorted(glob('./simulation_data/Depth/*/*/*'))
print(len(sem_list_iter0))
print(len(depth_list))
# A : SEM simulator
# B : Depth map
# Stratified 하게 SEM, DEPTH를 나눠야 할 듯 하다.
sem_path_dict = defaultdict(int)
depth_path_dict = defaultdict(int)
for sem_path, depth_path in tqdm(zip(sem_list_iter0, depth_list)): 
    sem_file_name = '/'.join(sem_path.split('/')[-3:])
    depth_file_name = '/'.join(depth_path.split('/')[-3:])
    shutil.copy(sem_path, os.path.join('/content/datasets/SEM/train/A', sem_file_name)) # root/train/A/file_name
    shutil.copy(depth_path, os.path.join('/content/datasets/SEM/train/B', depth_file_name)) # root/train/A/file_name
    sem_path_dict[os.path.join('/content/datasets/SEM/train/A', sem_file_name)] += 1
    depth_path_dict[os.path.join('/content/datasets/SEM/train/B', depth_file_name)] += 1

# for sem_path, depth_path in tqdm(zip(sem_list_iter0[int(len(sem_list_iter0)* 0.8):], depth_list[int(len(depth_list)*0.8):])): 
#     # sem_file_path = sem_path.split('/')[:-1]
#     sem_file_name = '/'.join(sem_path.split('/')[-3:])
#     depth_file_name = '/'.join(depth_path.split('/')[-3:])
#     shutil.copy(sem_path, os.path.join('/content/datasets/SEM/test/A', sem_file_name)) # root/train/A/file_name
#     shutil.copy(depth_path, os.path.join('/content/datasets/SEM/test/B', depth_file_name)) # root/train/A/file_name

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import argparse
import torchvision.transforms as transforms
import torchvision.utils as vutils

import torch.backends.cudnn as cudnn
from PIL import Image

from tqdm.auto import tqdm
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2, ToTensor

import itertools
import os
import random
import time
import numpy as np
import cv2
from glob import glob
import wandb

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1), # channel = 1 or gray-scale * 3 => 3
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, padding=1),
        )
    
    def forward(self, x):
        x = self.main(x)
        x = F.avg_pool2d(x, x.size()[2:]) # TODO : 이거 왜 이런지 생각 및 확인하기
        x = torch.flatten(x, 1)
        return x

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()

        self.res = nn.Sequential(nn.ReflectionPad2d(1),
                                 nn.Conv2d(in_channels, in_channels, 3),
                                 nn.InstanceNorm2d(in_channels),
                                 nn.ReLU(inplace=True),
                                 nn.ReflectionPad2d(1),
                                 nn.Conv2d(in_channels, in_channels, 3),
                                 nn.InstanceNorm2d(in_channels))

    def forward(self, x):
        return x + self.res(x)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Initial convolution block
            nn.ReflectionPad2d(3), # 2d 평면으로 padding을 달아주는 layer
            nn.Conv2d(1, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            # Downsampling
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),

            # Residual blocks
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),

            # Upsampling
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            # Output layer
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 1, 7),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [None]:
class DecayLR:
    def __init__(self, epochs, offset, decay_epochs):
        epoch_flag = epochs - decay_epochs
        assert (epoch_flag > 0), "Decay must start before the training session ends!"
        self.epochs = epochs
        self.offset = offset
        self.decay_epochs = decay_epochs

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_epochs) / (
                self.epochs - self.decay_epochs)

In [None]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert (max_size > 0), "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, transform=None, mode="train", itr=0):
        self.transform = transform

        self.files_A = sorted(glob(os.path.join(root, f"{mode}/A") + f"/*/*/*itr{itr}.*"))
        print(len(self.files_A), "-len(A)")
        self.files_B = sorted(glob(os.path.join(root, f"{mode}/B") + "/*/*/*.*"))
        print(len(self.files_B), "-len(B)")

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index]).convert("L")) # RGB : "RGB", Grayscale : "L"
        item_B = self.transform(Image.open(self.files_B[index]).convert("L"))

        # if self.unaligned:
        #     item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]).convert("L"))
        # else:
        #     item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]).convert("L"))

        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
class TestDataset(Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform
        self.files_test = sorted(glob(root))
    def __getitem__(self, index):
        img = self.transform(Image.open(self.files_test[index]).convert("L"))
        # img = np.expand_dims(img, axis=0)
        file_name = self.files_test[index].split('/')[-1]
        return img, file_name
    
    def __len__(self):
        return len(self.files_test)

In [None]:
parser = argparse.ArgumentParser(
    description="PyTorch implements `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks`")
parser.add_argument("--dataroot", type=str, default="./datasets",
                    help="path to datasets. (default:./data)")
parser.add_argument("--dataset", type=str, default="SEM",
                    help="dataset name. (default:`SEM`)"
                         "Option: [apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, "
                         "cezanne2photo, ukiyoe2photo, vangogh2photo, maps, facades, selfie2anime, "
                         "iphone2dslr_flower, ae_photos, ]")
parser.add_argument("--epochs", default=200, type=int, metavar="N", 
                    help="number of total epochs to run")
parser.add_argument("--decay_epochs", type=int, default=100,
                    help="epoch to start linearly decaying the learning rate to 0. (default:100)")
parser.add_argument("-b", "--batch-size", default=80, type=int,
                    metavar="N",
                    help="mini-batch size (default: 1), this is the total "
                         "batch size of all GPUs on the current node when "
                         "using Data Parallel or Distributed Data Parallel")
parser.add_argument("--lr", type=float, default=0.0002,
                    help="learning rate. (default:0.0002)")
parser.add_argument("-p", "--print-freq", default=100, type=int,
                    metavar="N", help="print frequency. (default:100)")
parser.add_argument("--cuda", action="store_true", help="Enables cuda")
parser.add_argument("--resume_from", default=f'/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/all_(max200_None_Resize)_{55}_all.tar', type=str, help="path to All_state_dict() for continue training")
# parser.add_argument("--netG_A2B", default="", help="path to netG_A2B (to continue training)")
# parser.add_argument("--netG_B2A", default="", help="path to netG_B2A (to continue training)")
# parser.add_argument("--netD_A", default="", help="path to netD_A (to continue training)")
# parser.add_argument("--netD_B", default="", help="path to netD_B (to continue training)")
# parser.add_argument("--netG_A2B", default=f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_A2B_itr0_epoch(max200_None_Resize)_{30}.pth", help="path to netG_A2B (to continue training)")
# parser.add_argument("--netG_B2A", default= f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_itr0_epoch(max200_None_Resize)_{30}.pth", help="path to netG_B2A (to continue training)")
# parser.add_argument("--netD_A", default=f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netD_A_itr0_epoch(max200_None_Resize)_{30}.pth", help="path to netD_A (to continue training)")
# parser.add_argument("--netD_B", default= f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netD_B_itr0_epoch(max200_None_Resize)_{30}.pth", help="path to netD_B (to continue training)")

parser.add_argument("--image-size", type=int, default=128,
                    help="size of the data crop (squared assumed). (default:256)")
parser.add_argument("--outf", default="./outputs",
                    help="folder to output images. (default:`./outputs`).")
parser.add_argument("--manualSeed", type=int,
                    help="Seed for initializing training. (default:none)")

args = parser.parse_args('')
print(args)

In [None]:
try:
    os.makedirs(args.outf)
except OSError:
    pass

try:
    os.makedirs("weights")
except OSError:
    pass

if args.manualSeed is None:
    args.manualSeed = 1627
print("Random Seed: ", args.manualSeed)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)

cudnn.benchmark = True

In [None]:
# Dataset
dataset = ImageDataset(root=os.path.join(args.dataroot, args.dataset),
                       transform=transforms.Compose([
                           transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))

test_dataset = TestDataset('/content/test/SEM/*', transform=transforms.Compose([
                           # transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))

print(len(dataset))
print(len(test_dataset))
print(dataset[0])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True)

In [None]:
try:
    os.makedirs(os.path.join(args.outf, args.dataset, "A"))
    os.makedirs(os.path.join(args.outf, args.dataset, "B"))
except OSError:
    pass

try:
    os.makedirs(os.path.join("weights", args.dataset))
except OSError:
    pass

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# create model
netG_A2B = Generator().to(device)
netG_B2A = Generator().to(device)
netD_A = Discriminator().to(device)
netD_B = Discriminator().to(device)

netG_A2B.apply(weights_init)
netG_B2A.apply(weights_init)
netD_A.apply(weights_init)
netD_B.apply(weights_init)

# define loss function (adversarial_loss) and optimizer
cycle_loss = torch.nn.L1Loss().to(device)
identity_loss = torch.nn.L1Loss().to(device)
adversarial_loss = torch.nn.MSELoss().to(device)

optimizer_G = torch.optim.AdamW(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                lr=args.lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.AdamW(netD_A.parameters(), lr=args.lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.AdamW(netD_B.parameters(), lr=args.lr, betas=(0.5, 0.999))

lr_lambda = DecayLR(args.epochs, 0, args.decay_epochs).step
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

print(args.resume_from)
if args.resume_from != "":
    checkpoint = torch.load(args.resume_from)
    netG_A2B.load_state_dict(checkpoint['netG_A2B'])
    netG_B2A.load_state_dict(checkpoint['netG_B2A'])
    netD_A.load_state_dict(checkpoint['netD_A'])
    netD_B.load_state_dict(checkpoint['netD_B'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G'])
    optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A'])
    optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B'])
    lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
    lr_scheduler_D_A.load_state_dict(checkpoint['lr_scheduler_D_A'])
    lr_scheduler_D_B.load_state_dict(checkpoint['lr_scheduler_D_B'])
# if args.netG_A2B != "":
#     netG_A2B.load_state_dict(torch.load(args.netG_A2B))
# if args.netG_B2A != "":
#     netG_B2A.load_state_dict(torch.load(args.netG_B2A))
# if args.netD_A != "":
#     netD_A.load_state_dict(torch.load(args.netD_A))
# if args.netD_B != "":
#     netD_B.load_state_dict(torch.load(args.netD_B))

In [None]:
g_losses = []
d_losses = []

identity_losses = []
gan_losses = []
cycle_losses = []

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()


In [None]:
wandb.init()
step = 0
for epoch in range(55, args.epochs):
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, data in progress_bar:
        # get batch size data
        with torch.cuda.amp.autocast():
            real_image_A = data["A"].to(device)
            real_image_B = data["B"].to(device)
            batch_size = real_image_A.size(0)

            # real data label is 1, fake data label is 0.
            real_label = torch.full((batch_size, 1), 1, device=device, dtype=torch.float32)
            fake_label = torch.full((batch_size, 1), 0, device=device, dtype=torch.float32)

            ##############################################
            # (1) Update G network: Generators A2B and B2A
            ##############################################

            # Set G_A and G_B's gradients to zero
            optimizer_G.zero_grad()

            # Identity loss
            # G_B2A(A) should equal A if real A is fed
            # print(real_image_A,"의 type은 :", type(real_image_A))
            identity_image_A = netG_B2A(real_image_A)
            loss_identity_A = identity_loss(identity_image_A, real_image_A) * 5.0 # A를 B2A를 통해 나온 A와 실제 A와의 L1 loss
            # G_A2B(B) should equal B if real B is fed
            identity_image_B = netG_A2B(real_image_B)
            loss_identity_B = identity_loss(identity_image_B, real_image_B) * 5.0 # B를 A2B를 통해 나온 B와 실제 B와의 L1 loss

            # GAN loss
            # GAN loss D_A(G_A(B))
            fake_image_A = netG_B2A(real_image_B)
            fake_output_A = netD_A(fake_image_A)
            loss_GAN_B2A = adversarial_loss(fake_output_A, real_label) # cycle_consistency_loss # A -> ^B -> ^A와 실제 A와의 비교를 통해 나오는 loss, D_A((G_A(B)))
            # GAN loss D_B(G_B(A))
            fake_image_B = netG_A2B(real_image_A)
            fake_output_B = netD_B(fake_image_B)
            loss_GAN_A2B = adversarial_loss(fake_output_B, real_label)

            # Cycle loss
            recovered_image_A = netG_B2A(fake_image_B)
            loss_cycle_ABA = cycle_loss(recovered_image_A, real_image_A) * 10.0

            recovered_image_B = netG_A2B(fake_image_A)
            loss_cycle_BAB = cycle_loss(recovered_image_B, real_image_B) * 10.0

            # Combined loss and calculate gradients
            errG = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

            # Calculate gradients for G_A and G_B
            errG.backward()
            # Update G_A and G_B's weights
            optimizer_G.step()

            ##############################################
            # (2) Update D network: Discriminator A
            ##############################################

            # Set D_A gradients to zero
            optimizer_D_A.zero_grad()

            # Real A image loss
            real_output_A = netD_A(real_image_A)
            errD_real_A = adversarial_loss(real_output_A, real_label)

            # Fake A image loss
            fake_image_A = fake_A_buffer.push_and_pop(fake_image_A)
            fake_output_A = netD_A(fake_image_A.detach())
            errD_fake_A = adversarial_loss(fake_output_A, fake_label)

            # Combined loss and calculate gradients
            errD_A = (errD_real_A + errD_fake_A) / 2

            # Calculate gradients for D_A
            errD_A.backward()
            # Update D_A weights
            optimizer_D_A.step()

            ##############################################
            # (3) Update D network: Discriminator B
            ##############################################

            # Set D_B gradients to zero
            optimizer_D_B.zero_grad()
        
            # Real B image loss
            real_output_B = netD_B(real_image_B)
            errD_real_B = adversarial_loss(real_output_B, real_label)

            # Fake B image loss
            fake_image_B = fake_B_buffer.push_and_pop(fake_image_B)
            fake_output_B = netD_B(fake_image_B.detach())
            errD_fake_B = adversarial_loss(fake_output_B, fake_label)

            # Combined loss and calculate gradients
            errD_B = (errD_real_B + errD_fake_B) / 2

            # Calculate gradients for D_B
            errD_B.backward()
            # Update D_B weights
            optimizer_D_B.step()

            progress_bar.set_description(
                f"[{epoch}/{args.epochs - 1}][{i}/{len(dataloader) - 1}] "
                f"Loss_D: {(errD_A + errD_B).item():.4f} "
                f"Loss_G: {errG.item():.4f} "
                f"Loss_G_identity: {(loss_identity_A + loss_identity_B).item():.4f} "
                f"loss_G_GAN: {(loss_GAN_A2B + loss_GAN_B2A).item():.4f} "
                f"loss_G_cycle: {(loss_cycle_ABA + loss_cycle_BAB).item():.4f}")

            if i % args.print_freq == 0:
                vutils.save_image(real_image_A,
                                f"{args.outf}/{args.dataset}/A/real_samples_epoch_{epoch}_{i}.png",
                                normalize=True)
                vutils.save_image(real_image_B,
                                f"{args.outf}/{args.dataset}/B/real_samples_epoch_{epoch}_{i}.png",
                                normalize=True)
                fake_image_A = netG_B2A(real_image_B).data
                fake_image_B = netG_A2B(real_image_A).data

                vutils.save_image(fake_image_A.detach(),
                                f"{args.outf}/{args.dataset}/A/fake_samples_epoch_{epoch}_{i}.png",
                                normalize=True)
                vutils.save_image(fake_image_B.detach(),
                                f"{args.outf}/{args.dataset}/B/fake_samples_epoch_{epoch}_{i}.png",
                                normalize=True)
                wandb.log({
                    "real_image_A" : wandb.Image(real_image_A),
                    "real_image_B" : wandb.Image(real_image_B),
                    "fake_image_A" : wandb.Image(fake_image_A),
                    "fake_image_B" : wandb.Image(fake_image_B),
                    "g_losses": errG.item(),
                    "d_losses": (errD_A + errD_B).item(),
                    "identity_losses": (loss_identity_A + loss_identity_B).item(),
                    "gan_losses":(loss_GAN_A2B + loss_GAN_B2A).item(),
                    "cycle_losses": (loss_cycle_ABA + loss_cycle_BAB).item()
                })
                g_losses.append(errG.item())
                d_losses.append((errD_A + errD_B).item())

                identity_losses.append((loss_identity_A + loss_identity_B).item())
                gan_losses.append((loss_GAN_A2B + loss_GAN_B2A).item())
                cycle_losses.append((loss_cycle_ABA + loss_cycle_BAB).item())


        # do check pointing
    if epoch % 5 == 0:
        torch.save({
            'netG_A2B' : netG_A2B.state_dict(),
            'netG_B2A' : netG_B2A.state_dict(),
            'netD_A' : netD_A.state_dict(),
            'netD_B' : netD_B.state_dict(),
            'optimizer_G' : optimizer_G.state_dict(),
            'optimizer_D_A' : optimizer_D_A.state_dict(),
            'optimizer_D_B' : optimizer_D_B.state_dict(),
            'lr_scheduler_G' : lr_scheduler_G.state_dict(),
            'lr_scheduler_D_A' : lr_scheduler_D_A.state_dict(),
            'lr_scheduler_D_B' : lr_scheduler_D_B.state_dict(),
        },
        f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/all_(max200_None_Resize)_{epoch}_all.tar")
        # torch.save(netG_A2B.state_dict(), f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_A2B_itr0_epoch(max200_None_Resize)_{epoch}.pth")
        # torch.save(netG_B2A.state_dict(), f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_itr0_epoch(max200_None_Resize)_{epoch}.pth")
        # torch.save(netD_A.state_dict(), f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netD_A_itr0_epoch(max200_None_Resize)_{epoch}.pth")
        # torch.save(netD_B.state_dict(), f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netD_B_itr0_epoch(max200_None_Resize)_{epoch}.pth")

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()
    # 필요하면 GT와 A/real_sample와 B/fake_sample을 pair로 matplotlib으로 확인할 수 있도록 logging해야할 듯 함.

# save last check pointing
torch.save(netG_A2B.state_dict(), f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challengehts/CycleGAN/netG_A2B_itr0_(max200_None_Resize).pth")
torch.save(netG_B2A.state_dict(), f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_itr0_(max200_None_Resize).pth")
torch.save(netD_A.state_dict(), f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netD_A_itr0_(max200_None_Resize).pth")
torch.save(netD_B.state_dict(), f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netD_B_itr0_(max200_None_Resize).pth")

# inference

In [None]:
batch_size = 1
real_label = torch.full((batch_size, 1), 1, device=device, dtype=torch.float32)
fake_label = torch.full((batch_size, 1), 0, device=device, dtype=torch.float32)
loss_identity_A = identity_loss(identity_image_A, real_image_A) * 5.0 # A를 B2A를 통해 나온 A와 실제 A와의 L1 loss
loss_identity_B = identity_loss(identity_image_B, real_image_B) * 5.0 # B를 A2B를 통해 나온 B와 실제 B와의 L1 loss
loss_GAN_B2A = adversarial_loss(fake_output_A, real_label) # cycle_consistency_loss # A -> ^B -> ^A와 실제 A와의 비교를 통해 나오는 loss, D_A((G_A(B)))
loss_GAN_A2B = adversarial_loss(fake_output_B, real_label)

fake_image_A = netG_B2A(real_image_B)
fake_output_A = netD_A(fake_image_A)
loss_GAN_B2A = adversarial_loss(fake_output_A, real_label) # cycle_consistency_loss # A -> ^B -> ^A와 실제 A와의 비교를 통해 나오는 loss, D_A((G_A(B)))
# GAN loss D_B(G_B(A))
fake_image_B = netG_A2B(real_image_A)
fake_output_B = netD_B(fake_image_B)
loss_GAN_A2B = adversarial_loss(fake_output_B, real_label)

In [None]:
 # Cycle loss
recovered_image_A = netG_B2A(fake_image_B)
loss_cycle_ABA = cycle_loss(recovered_image_A, real_image_A) * 10.0

recovered_image_B = netG_A2B(fake_image_A)
loss_cycle_BAB = cycle_loss(recovered_image_B, real_image_B) * 10.0

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
max_epoch = 30
y = max_epoch // 5 + 1
img_idx = 20000
fig, axes = plt.subplots(y, 4, figsize=(10, y * 5))
# model_A2b = loaded_netG_A2B
colors = ['red', 'brown', 'yellow', 'green', 'blue']
cmap = LinearSegmentedColormap.from_list('name', colors)
norm = plt.Normalize(0, 255)
for idx, epoch in enumerate(range(0, max_epoch + 1, 5)):
    # loaded_netG_B2A.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_epoch_{epoch}.pth"))
    loaded_netG_A2B.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_A2B_itr0_epoch(max200_None_Resize)_{epoch}.pth"))
    loaded_netG_B2A.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_itr0_epoch(max200_None_Resize)_{epoch}.pth"))
    
    real_label = torch.full((1, 1), 1, device=device, dtype=torch.float32)
    fake_label = torch.full((1, 1), 0, device=device, dtype=torch.float32)

    test_image = cv2.imread(sem_list_iter0[img_idx],cv2.IMREAD_GRAYSCALE)
    depth_image = cv2.imread(depth_list[img_idx], cv2.IMREAD_GRAYSCALE)
    
    identity_image_A = loaded_netG_B2A(torch.from_numpy(cv2.imread(sem_list_iter0[img_idx],cv2.IMREAD_GRAYSCALE)).unsqueeze(0).float().to(device))
    identity_image_B = loaded_netG_A2B(torch.from_numpy(cv2.imread(depth_list[img_idx],cv2.IMREAD_GRAYSCALE)).unsqueeze(0).float().to(device))

    fake_image_A = netG_B2A(torch.from_numpy(depth_image).unsqueeze(0).float().to(device))
    fake_output_A = netD_A(fake_image_A)
    loss_GAN_B2A = adversarial_loss(fake_output_A, real_label) # cycle_consistency_loss # A -> ^B -> ^A와 실제 A와의 비교를 통해 나오는 loss, D_A((G_A(B)))

    fake_image_B = netG_A2B(torch.from_numpy(test_image).unsqueeze(0).float().to(device))
    fake_output_B = netD_B(fake_image_B)
    loss_GAN_A2B = adversarial_loss(fake_output_B, real_label)

    recovered_image_A = netG_B2A(fake_image_B)
    loss_cycle_ABA = cycle_loss(recovered_image_A, torch.from_numpy(test_image).unsqueeze(0).float().to(device)) * 10.0

    
    recovered_image_B = netG_A2B(fake_image_A)
    loss_cycle_BAB = cycle_loss(recovered_image_B, torch.from_numpy(depth_image).unsqueeze(0).float().to(device)) * 10.0

    print(test_image[0][0].item(), end=", ")
    axes[idx][0].imshow(test_image, norm=norm, cmap='rainbow')
    # predicted_image = 0.5 * (loaded_netG_A2B(torch.from_numpy(test_image).unsqueeze(0).float().to(device)).data + 1.0)
    predicted_image_A2B = loaded_netG_A2B(torch.from_numpy(cv2.imread(sem_list_iter0[img_idx],cv2.IMREAD_GRAYSCALE)).unsqueeze(0).float().to(device))
    predicted_image_A2B = predicted_image_A2B.unsqueeze(0)
    # predicted_image_A2B = F.interpolate(predicted_image_A2B, size=(72,48), mode='bicubic')
    # print(predicted_image_A2B.shape)
    print(predicted_image_A2B[0,0,0, 0].item(), end=", ")

    predicted_image_B2A = loaded_netG_B2A(torch.from_numpy(cv2.imread(depth_list[img_idx],cv2.IMREAD_GRAYSCALE)).unsqueeze(0).float().to(device))
    predicted_image_B2A = predicted_image_B2A.unsqueeze(0)
    #predicted_image_B2A = F.interpolate(predicted_image_B2A, size=(72,48), mode='bicubic')
    # print(predicted_image_B2A.shape)
    print(predicted_image_B2A[0, 0, 0, 0].item(), end=", ")

    axes[idx][1].imshow(predicted_image_A2B.squeeze().detach().cpu() * 255, norm=norm, cmap='rainbow')
    axes[idx][2].imshow(predicted_image_B2A.squeeze().detach().cpu() * 255, norm=norm, cmap='rainbow')
    axes[idx][3].imshow(depth_image.astype(int), norm=norm, cmap='rainbow')
    print(depth_image[0][0].item())
    print(type(identity_loss))
    print(f"loss_identity_A : {identity_loss(identity_image_A,torch.from_numpy(test_image).unsqueeze(0).float().to(device)) * 5.0}")
    print(f"loss_identity_B : {identity_loss(identity_image_B,torch.from_numpy(depth_image).unsqueeze(0).float().to(device)) * 5.0},")
    print(f"loss_GAN_B2A : {loss_GAN_B2A}")
    print(f"loss_GAN_A2B : {loss_GAN_B2A}")
    print(f"loss_cycle_ABA : {loss_cycle_ABA}")
    print(f"loss_cycle_BAB : {loss_cycle_BAB}")
# fig.colorbar(orientation='horizontal')


In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
max_epoch = 45
y = max_epoch // 5 + 1
img_idx = 200
fig, axes = plt.subplots(y, 4, figsize=(10, y * 5))
# model_A2b = loaded_netG_A2B
colors = ['red', 'brown', 'yellow', 'green', 'blue']
cmap = LinearSegmentedColormap.from_list('name', colors)
norm = plt.Normalize(0, 255)
for idx, epoch in enumerate(range(0, max_epoch + 1, 5)):
    # loaded_netG_B2A.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_epoch_{epoch}.pth"))
    loaded_netG_A2B.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_A2B_itr0_epoch_{epoch}.pth"))
    loaded_netG_B2A.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_itr0_epoch_{epoch}.pth"))
    test_image = cv2.imread(sem_list_iter0[img_idx],cv2.IMREAD_GRAYSCALE)
    print(test_image[0][0].item(), end=", ")
    axes[idx][0].imshow(test_image, cmap='gray')
    # predicted_image = 0.5 * (loaded_netG_A2B(torch.from_numpy(test_image).unsqueeze(0).float().to(device)).data + 1.0)
    predicted_image_A2B = loaded_netG_A2B(torch.from_numpy(cv2.resize(cv2.imread(sem_list_iter0[img_idx],cv2.IMREAD_GRAYSCALE), (256, 256), interpolation=cv2.INTER_CUBIC)).unsqueeze(0).float().to(device))
    predicted_image_A2B = predicted_image_A2B.unsqueeze(0)
    predicted_image_A2B = F.interpolate(predicted_image_A2B, size=(72,48), mode='bicubic')
    # print(predicted_image_A2B.shape)
    print(predicted_image_A2B[0,0,0, 0].item() * 255.0, end=", ")

    predicted_image_B2A = loaded_netG_B2A(torch.from_numpy(cv2.resize(cv2.imread(depth_list[img_idx],cv2.IMREAD_GRAYSCALE), (256, 256), interpolation=cv2.INTER_CUBIC)).unsqueeze(0).float().to(device))
    predicted_image_B2A = predicted_image_B2A.unsqueeze(0)
    predicted_image_B2A = F.interpolate(predicted_image_B2A, size=(72,48), mode='bicubic')
    # print(predicted_image_B2A.shape)
    print(predicted_image_B2A[0, 0, 0, 0].item() * 255.0, end=", ")

    depth_image = cv2.imread(depth_list[img_idx], cv2.IMREAD_GRAYSCALE)
    print(depth_image[0][0].item())
    axes[idx][1].imshow(predicted_image_A2B.squeeze().detach().cpu(), cmap='gray')
    axes[idx][2].imshow(predicted_image_B2A.squeeze().detach().cpu())
    axes[idx][3].imshow(depth_image)


In [None]:
import matplotlib.pyplot as plt
max_epoch = 25
y = max_epoch // 5 + 1
img_idx = 100
fig, axes = plt.subplots(y, 3, figsize=(15, y * 5))
# model_A2b = loaded_netG_A2B
for idx, epoch in enumerate(range(0, max_epoch + 1, 5)):
    loaded_netG_B2A.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_epoch_{epoch}.pth"))
    loaded_netG_A2B.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_A2B_epoch_{epoch}.pth"))
    test_image = cv2.resize(cv2.imread(train_sem_list[100],cv2.IMREAD_GRAYSCALE)/255.0, (256,256), interpolation=cv2.INTER_CUBIC)
    
    axes[idx][0].imshow(cv2.resize(test_image, (48,72), interpolation=cv2.INTER_CUBIC), 'gray')
    # axes[idx][0].imshow(test_image, 'gray')
    predicted_image_B2A = loaded_netG_B2A(torch.from_numpy(test_image).unsqueeze(0).float().to(device))

    predicted_image_A2B = loaded_netG_A2B(torch.from_numpy(test_image).unsqueeze(0).float().to(device))
    predicted_image_A2B = predicted_image_A2B.unsqueeze(0)
    predicted_image_A2B = F.interpolate(predicted_image_A2B, size=(72,48), mode='bicubic')
    axes[idx][1].imshow(predicted_image_B2A.squeeze().detach().cpu(), 'gray')
    axes[idx][2].imshow(predicted_image_A2B.squeeze().detach().cpu(), 'gray')
    # axes[idx][3].imshow(cv2.imread(depth_list[idx], cv2.IMREAD_GRAYSCALE)/255.0, 'gray')

In [None]:
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, pin_memory=True)

In [None]:
print(iter(test_dataloader))

In [None]:
inference_model = Generator().to(device)
inference_model.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_A2B_itr0_epoch(max200_None_Resize)_{30}.pth"))

In [None]:
import zipfile
def inference(model, test_loader, device):
    model.to(device)
    model.eval()
    
    result_name_list = []
    result_list = []
    with torch.no_grad():
        for sem, name in tqdm(iter(test_loader)):
            sem = sem.float().to(device)
            model_pred = model(sem)
            
            for pred, img_name in zip(model_pred, name):
                pred = pred.unsqueeze(0)
                # plt.imshow(pred.squeeze().to('cpu'),'gray')
                pred = F.interpolate(pred, size=(72,48), mode='bicubic')
                pred = pred.squeeze(0).cpu().numpy().transpose(1,2,0)*255.0
                # print(pred.shape)
                save_img_path = f'{img_name}'
                #cv2.imwrite(save_img_path, pred)
                result_name_list.append(save_img_path)
                result_list.append(pred)
                
    os.makedirs('./submission', exist_ok=True)

    return result_name_list, result_list


In [None]:
print(len(glob('/content/submission/submission/submission/*')))

In [None]:
!rm -r /content/submission
%cd /content

In [None]:
loaded_netG_B2A.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_B2A_itr0_epoch(max200_None_Resize)_{30}.pth"))
loaded_netG_B2A.to(device)
loaded_netG_B2A.eval()
loaded_netG_A2B.load_state_dict(torch.load(f"/content/drive/MyDrive/Colab_Notebooks/Dacon/Samsung_AI_Challenge/CycleGAN/netG_A2B_itr0_epoch(max200_None_Resize)_{30}.pth"))
loaded_netG_A2B.to(device)
loaded_netG_A2B.eval()

result_name_list, result_list = inference(loaded_netG_A2B, test_dataloader, device)

In [None]:
!pwd

In [None]:
os.chdir('/content/submission')
sub_imgs = []
for name, pred_img in zip(result_name_list, result_list):
    cv2.imwrite(name, pred_img)
    sub_imgs.append(name)
submission = zipfile.ZipFile("../submission.zip", 'w')
for path in sub_imgs:
    submission.write(path)
submission.close()

In [None]:
print(len(glob('/content/submission/*')))

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

result_name_list = []
result_list = []
os.makedirs("./submission", exist_ok=True)
with torch.no_grad():
    for img, name in tqdm(iter(test_dataloader)):
        img = img.to(device).float()
        img = img.unsqueeze(1)
        prediction_image = loaded_netG_A2B(img)
        for pred_img, name in zip(prediction_image, name):
            pred_img.unsqueeze_(0)
            pred_imge = F.interpolate(pred_img, (72,48), mode='bilinear')
            plt.imshow(pred_img.squeeze().cpu(), 'gray')
            # pred_img = pred_img.squeeze().cpu().detach().numpy()
            pred_img = pred_img.squeeze().cpu().detach().numpy()
            save_img_path = f'/content/submission/{name}'
            cv2.imwrite(save_img_path, pred_img)

In [None]:
!zip -r /content/submission.zip /content/submission/*

In [None]:
!cp /content/submission.zip /content/drive/MyDrive/Data/submission.zip

In [None]:
!cp /content/sample_submission.zip /content/drive/MyDrive/Data/sample_submission.zip