# Import

In [1]:
import torch 
import torch.nn as nn
import os
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from torchvision import transforms, utils, datasets 
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.utils.data.sampler import SubsetRandomSampler
import gc


# Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Dataset

In [3]:
class SourceTargetDataset(Dataset):
  def __init__(self, root_source, root_target, transform = None):
    self.root_source = root_source
    self.root_target = root_target
    self.transform = transform
    
    self.images_source = os.listdir(self.root_source)
    self.images_target = os.listdir(self.root_target)

    self.len_source = len(self.images_source)
    self.len_target = len(self.images_target)

    self.length_dataset = max(self.len_source,self.len_target)

  def __len__(self):
    return self.length_dataset 
  
  def __getitem__(self, index):
    img_source = self.images_source[ index % self.len_source]
    img_target = self.images_target[ index % self.len_target]

    path_source = os.path.join(self.root_source,img_source)
    path_target = os.path.join(self.root_target,img_target)

    img_source = np.array(Image.open(path_source).convert("RGB"))
    img_target = np.array(Image.open(path_target).convert("RGB"))

    if self.transform:
      img_source = self.transform(img_source)
      img_target = self.transform(img_target)

    return img_source, img_target

class TestDataset(Dataset):
  def __init__(self, root_source, transform = None):
    self.root_source = root_source
    self.transform = transform
    self.images_source = os.listdir(self.root_source)
    self.length_dataset = len(self.images_source)

  def __len__(self):
    return self.length_dataset
  
  def __getitem__(self, index):
    img_source = self.images_source[index]
    path_source = os.path.join(self.root_source,img_source)
    img_source = np.array(Image.open(path_source).convert("RGB"))

    if self.transform:
      img_source = self.transform(img_source)

    return img_source

In [4]:
def show_images(images):
    images = np.reshape(images, [images.shape[0], -1]) 
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return

# Discriminator

In [None]:
class Block(nn.Module):
    def __init__(self,in_channels,out_channels,stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, bias=True, padding = 1, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self,x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features = [64, 128, 256, 512]):
    super().__init__()
    layers = []
    layers.append(nn.Conv2d(in_channels, features[0], kernel_size = 4, stride= 2, padding=1, padding_mode= "reflect"))

    in_channels = features[0]
    for feature in features[1:]:
      layers.append(Block(in_channels, feature, stride = 1 if feature==features[-1] else 2))
      in_channels = feature

    layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect" ))
    self.model = nn.Sequential(*layers)

  def forward(self,x):
    return torch.sigmoid(self.model(x))



# Generator

## c7s1-64, d128, d256, R256, R256, R256, R256, R256, R256, u128, u64, c7s1-3

In [None]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, down = True, use_act = True, **kwargs):
    super().__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs) 
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
        
        nn.InstanceNorm2d(out_channels),
       
        nn.ReLU(inplace=True) if use_act  else nn.Identity(),
    )

  def forward(self, x):
    return self.conv(x)

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()

    self.conv = nn.Sequential(
        ConvBlock(channels,channels, kernel_size = 3, padding=1),
        ConvBlock(channels,channels, use_act=False, kernel_size = 3, padding=1)
    )

  def forward(self, x):
    return x + self.conv(x)


In [None]:
class Generator(nn.Module):
  def __init__(self, img_size=3, num_features = 64, num_residuals=6):
    super().__init__()

    layers = []
    # c7s1-64
    self.initial = nn.Sequential(
        nn.Conv2d(img_size, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
        nn.InstanceNorm2d(num_features=num_features),
        nn.ReLU(inplace=True)
    )
    layers.append(self.initial)
    
    # d128, d256
    self.down_blocks = nn.Sequential(
        ConvBlock(num_features,num_features*2, kernel_size=3, stride=2,padding=1),
        ConvBlock(num_features*2,num_features*4, kernel_size=3, stride=2, padding=1)
    )
    layers.append(self.down_blocks)

    # R256, R256, R256, R256, R256, R256,
    self.res_blocks = nn.Sequential(
       *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
    )
    layers.append(self.res_blocks)  
    
    # u128, u64
    self.up_blocks = nn.Sequential(
        ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ConvBlock(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
    )
    layers.append(self.up_blocks)

    # c7s1-3
    self.out_block = nn.Sequential(
        nn.Conv2d(num_features, 3, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
        nn.InstanceNorm2d(num_features=3),
    )
    layers.append(self.out_block)

    self.model = nn.Sequential(*layers)

  def forward(self, x):
    return torch.tanh(self.model(x))

# Hyperparameters

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#DEVICE = "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 1
NUM_WORKERS = 2
NUM_EPOCHS = 100
LAMBDA_CYCLE = 10
LAMBDA_IDENTITY = 0.5 * LAMBDA_CYCLE

 # Paths

In [None]:
DRIVE_MAIN_FOLDER = "/content/drive/MyDrive/COMP411/Project/"

# Data
DRIVE_DATASET_FOLDER = DRIVE_MAIN_FOLDER + "dataset/"

FIRE_TRAIN_PATH = DRIVE_DATASET_FOLDER + "Fire/"
GRASS_TRAIN_PATH = DRIVE_DATASET_FOLDER + "Grass/"
FIRE_TEST_PATH = DRIVE_DATASET_FOLDER + "testFire/"
GRASS_TEST_PATH = DRIVE_DATASET_FOLDER + "testGrass/"

# MODELS
MODEL_CHECK_DIR = DRIVE_MAIN_FOLDER + "/models/"

# Test Outputs
TEST_PATH = DRIVE_MAIN_FOLDER + "/testOutputs/"


# Train 

In [None]:
def train(disc_X, disc_Y, gen_X, gen_Y, loader, optim_disc, optim_gen,loss_histroy):
  loop = tqdm(loader, leave = True, position=0)

  l1 = nn.L1Loss()
  mse = nn.MSELoss()
  
  if DEVICE == "cuda":
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    
  for idx, (x, y) in enumerate(loop):
    x = x.to(DEVICE)
    y = y.to(DEVICE)
   
   # Discriminator loss
    with torch.cuda.amp.autocast():
      # Discriminator - Y
      optim_disc.zero_grad()

      D_y_real = disc_Y(y)
      D_y_real_loss = mse(D_y_real, torch.ones_like(D_y_real))

      fake_y = gen_Y(x)
      D_y_fake = disc_Y(fake_y.detach())
      D_y_fake_loss = mse(D_y_fake, torch.zeros_like(D_y_fake))

      D_y_loss = D_y_real_loss + D_y_fake_loss


      # Discriminator - X
      D_x_real = disc_X(x)
      D_x_real_loss = mse(D_x_real, torch.ones_like(D_x_real))

      fake_x = gen_X(y)
      D_x_fake = disc_X(fake_x.detach())
      D_x_fake_loss = mse(D_x_fake, torch.zeros_like(D_x_fake))

      D_x_loss = D_x_real_loss + D_x_fake_loss

      # Discriminator - Total
      D_loss = D_y_loss + D_x_loss 
      print(f"\nDiscriminator Losses -> Total Loss: {D_loss: .3f}, D_x loss: {D_x_loss: .3f}, D_y loss {D_y_loss: .3f}")
      
      if DEVICE == "cuda":
        d_scaler.scale(D_loss).backward()
        d_scaler.step(optim_disc)
        d_scaler.update()
      else:
        D_loss.backward()
        optim_disc.step()

      loss_history["D_total"].append(D_loss.data )

    # Generator loss
    with torch.cuda.amp.autocast():
      # Adversarial Loss
      optim_gen.zero_grad()
      
      # L_GAN(G_Y, D_Y , X, Y) 
      # = Ey∼pdata(y) [log D_Y(y)]
      # + Ex∼pdata(x) [log(1 − D_Y(G_Y(x))]
      
      # L_GAN(G_X, D_X , Y, X) 
      # = Ex∼pdata(x) [log D_X(x)]
      # + Ey∼pdata(y) [log(1 − D_X(G_X(y))]

      D_x_fake = disc_X(fake_x)
      D_y_fake = disc_Y(fake_y)

      G_x_loss = mse(D_x_fake, torch.ones_like(D_x_fake))
      G_y_loss = mse(D_y_fake, torch.ones_like(D_y_fake))


      # Cycle Consistency Loss
      # L_cyc (G_Y, G_X)
      # = Ex∼pdata(x) [|G_X(G_Y(x)) − x|_1]
      # + Ey∼pdata(y) [|G_Y(G_X(y)) − y|_1]

      cycle_x = gen_X(fake_y)
      cycle_y = gen_Y(fake_x)

      cycle_x_loss = l1(cycle_x,x)
      cycle_y_loss = l1(cycle_y,y)

      cycle_loss = cycle_x_loss + cycle_y_loss

      # Identity Loss
      # L_identity(G,F) 
      # = Ey∼pdata(y)[|G_Y(y) − y|_1] 
      # + Ex∼pdata(x) [|G_X(x) − x|_1 ]

      identity_x_loss = l1(gen_X(x),x)
      identity_y_loss = l1(gen_Y(y),y)
      
      identity_loss = identity_x_loss + identity_y_loss

      # Total loss
      G_loss = G_x_loss + G_y_loss + cycle_loss * LAMBDA_CYCLE + identity_loss * LAMBDA_IDENTITY 
      print(f"Generator Losses -> Total Loss: {G_loss: .3f}, G_x loss: {G_x_loss: .3f}, G_y loss {G_y_loss: .3f}, Cycle loss: {cycle_loss: .3f}, Identity loss: {identity_loss: .3f}")

      if DEVICE == "cuda":
        g_scaler.scale(G_loss).backward()
        g_scaler.step(optim_gen)
        g_scaler.update() 
      else:
        G_loss.backward()
        optim_gen.step()

    
      loss_history["G_total"].append(G_loss.data)
      loss_history["G_x"].append(G_x_loss.data)
      loss_history["G_y"].append(G_y_loss)
      loss_history["G_cycle"].append(cycle_loss.data)
      loss_history["G_identity"].append(identity_loss.data)

# Main

In [None]:
gc.collect()
torch.cuda.empty_cache()

disc_X = Discriminator().to(DEVICE)
disc_Y = Discriminator().to(DEVICE)
gen_Y = Generator().to(DEVICE)
gen_X = Generator().to(DEVICE)

#disc_X.load_state_dict(torch.load(MODEL_CHECK_DIR+"batchsize1_fire2grass_discriminator_X_param.pkl"))
#disc_Y.load_state_dict(torch.load(MODEL_CHECK_DIR+"batchsize1_fire2grass_discriminator_Y_param.pkl"))
#gen_X.load_state_dict(torch.load(MODEL_CHECK_DIR+"batchsize1_fire2grass_generator_X_param.pkl"))
#gen_Y.load_state_dict(torch.load(MODEL_CHECK_DIR+"batchsize1_fire2grass_generator_Y_param.pkl"))

loss_history = dict.fromkeys(['D_total','G_total','G_x','G_y','G_cycle','G_identity'],[])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(256),
 ])

torch.manual_seed(130)

opt_disc = optim.Adam(
    list(disc_X.parameters()) + list(disc_Y.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

opt_gen = optim.Adam(
    list(gen_Y.parameters()) + list(gen_X.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

dataset = SourceTargetDataset(
    root_source=FIRE_TRAIN_PATH, 
    root_target=GRASS_TRAIN_PATH, 
    transform=transform
)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)
print(f"lr: {LEARNING_RATE},batch size: {BATCH_SIZE},lambda cycle: {LAMBDA_CYCLE}, lambda identity: {LAMBDA_IDENTITY}")

for epoch in range(NUM_EPOCHS):
  print(f"\nEpoch #{epoch}\n")
  train(disc_X, disc_Y, gen_X, gen_Y, loader, opt_disc, opt_gen, loss_history)
  torch.save(gen_X.state_dict(), MODEL_CHECK_DIR + 'batchsize1_fire2grass_generator_X_param.pkl')
  torch.save(gen_Y.state_dict(), MODEL_CHECK_DIR + 'batchsize1_fire2grass_generator_Y_param.pkl')
  torch.save(disc_X.state_dict(), MODEL_CHECK_DIR + 'batchsize1_fire2grass_discriminator_X_param.pkl')
  torch.save(disc_Y.state_dict(), MODEL_CHECK_DIR + 'batchsize1_fire2grass_discriminator_Y_param.pkl')


Output hidden; open in https://colab.research.google.com to view.

# Test

In [None]:
from torch.serialization import MAP_LOCATION

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(256),
 ])

test_dataset_f = TestDataset(
    root_source =FIRE_TEST_PATH,
    transform = transform,
)

test_loader = DataLoader(
        test_dataset_f,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )

gen_X = Generator()
gen_X.load_state_dict(torch.load(MODEL_CHECK_DIR+"batchsize1_fire2grass_generator_X_param.pkl",map_location=torch.device('cpu')))
gen_Y = Generator()
gen_Y.load_state_dict(torch.load(MODEL_CHECK_DIR+"batchsize1_fire2grass_generator_Y_param.pkl",map_location=torch.device('cpu')))

with torch.no_grad():
  loop = tqdm(test_loader, leave=True, position = 0)
  for idx, x in enumerate(loop):
    y = gen_Y(x)
    save_image(y, f"{TEST_PATH}/test{idx}.png")
  

100%|██████████| 49/49 [01:33<00:00,  1.91s/it]
