In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/MLDL/BiseNetv1/

/content/drive/.shortcut-targets-by-id/1H1LF-uIDd32OaHXtZjb_qT8ZexZ_Th3G/MLDL/BiseNetv1


In [3]:
%%time
%cd /content
%pwd
!gsutil cp gs://recsys-2021-bucket/CamVid.zip /content/CamVid.zip
!gsutil cp gs://recsys-2021-bucket/IDDA.zip /content/IDDA.zip
!unzip -q CamVid.zip
!unzip -q IDDA.zip
!cp /content/drive/MyDrive/MLDL/classes_info.json /content/IDDA/classes_info.json
%cd /content/drive/MyDrive/MLDL/BiseNetv1/

/content
Copying gs://recsys-2021-bucket/CamVid.zip...
- [1 files][579.3 MiB/579.3 MiB]                                                
Operation completed over 1 objects/579.3 MiB.                                    
Copying gs://recsys-2021-bucket/IDDA.zip...
/ [1 files][  4.9 GiB/  4.9 GiB]   19.8 MiB/s                                   
Operation completed over 1 objects/4.9 GiB.                                      
/content/drive/.shortcut-targets-by-id/1H1LF-uIDd32OaHXtZjb_qT8ZexZ_Th3G/MLDL/BiseNetv1
CPU times: user 1.63 s, sys: 309 ms, total: 1.93 s
Wall time: 4min 43s


In [4]:
import os
import gc

from IPython.core.display import display, HTML

import torch.backends.cudnn as cudnn
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from model.build_BiSeNet import BiSeNet
from dataset.IDDA import IDDA
from dataset.CamVid import CamVid
import matplotlib.pyplot as plt

import numpy as np
import seaborn as sns
sns.set_theme()
%config InlineBackend.figure_format = 'retina'
from utils import reverse_one_hot, colour_code_segmentation
from torchvision import transforms as T

mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
unnormalize = T.Normalize((-mean/std).tolist(), (1.0 / std).tolist())
normalize = T.Normalize(mean, std)


In [5]:
NUM_CLASSES = 12
CROP_HEIGHT = 720
CROP_WIDTH = 960

CONTEXT_PATH = 'resnet101'

LEARNING_RATE_SEGMENTATION = 2.5e-4
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9

POWER = 0.9

NUM_STEPS = 250000
ITER_SIZE = 1

BETA=0.09

CHECKPOINT_STEP = 15
CHECKPOINT_PATH = './checkpointBisenetFDA/'

BATCH_SIZE_CAMVID = 2
BATCH_SIZE_IDDA = 2

CAMVID_PATH = ['/content/CamVid/train/', '/content/CamVid/val/']
CAMVID_TEST_PATH = ['/content/CamVid/test/']
CAMVID_LABEL_PATH = ['/content/CamVid/train_labels/', '/content/CamVid/val_labels/']
CAMVID_TEST_LABEL_PATH = ['/content/CamVid/test_labels/']
CSV_CAMVID_PATH = '/content/CamVid/class_dict.csv'

IDDA_PATH = '/content/IDDA/rgb/'
IDDA_LABEL_PATH = '/content/IDDA/labels'
JSON_IDDA_PATH = '/content/IDDA/classes_info.json'

LOSS = 'dice'

NUM_WORKERS = 0

In [6]:
import torch

def low_freq_mutate( amp_src, amp_trg, beta):
    n, c, h, w = amp_src.size()
    b = (np.floor(np.amin((h,w))*beta)).astype(int)         # get b (square with smallested among h,w)
    amp_src[:,:,0:b,0:b]     = amp_trg[:,:,0:b,0:b]      # top left
    amp_src[:,:,0:b,w-b:w]   = amp_trg[:,:,0:b,w-b:w]    # top right
    amp_src[:,:,h-b:h,0:b]   = amp_trg[:,:,h-b:h,0:b]    # bottom left
    amp_src[:,:,h-b:h,w-b:w] = amp_trg[:,:,h-b:h,w-b:w]  # bottom right
    return amp_src
def FDA_source_to_target(src_img, trg_img, beta=1e-2):
    # exchange magnitude
    # input: src_img, trg_img

    # get fft of both source and target
    fft_src = torch.fft.fftn(src_img.clone(), dim=(2, 3)) # check if fft2 is enough
    fft_trg = torch.fft.fftn(trg_img.clone(), dim=(2, 3))

    assert fft_src.dtype == torch.complex64, fft_src.dtype
    assert fft_trg.dtype == torch.complex64, fft_src.dtype
    assert fft_src.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), fft_src.shape
    assert fft_trg.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), fft_trg.shape

    # extract amplitude and phase of both ffts
    amp_src, pha_src = fft_src.abs(), fft_src.angle()
    amp_trg, pha_trg = fft_trg.abs(), fft_trg.angle()

    assert amp_src.dtype == torch.float32, f"assertion failure {amp_src.dtype}"
    assert amp_trg.dtype == torch.float32, f"assertion failure {amp_src.dtype}"
    assert amp_src.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {amp_src.shape}"
    assert amp_trg.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {amp_trg.shape}"

    # replace the low frequency amplitude part of source with that from target
    amp_src_ = low_freq_mutate(amp_src.clone(), amp_trg.clone(), beta=beta)

    assert amp_src_.dtype == torch.float32, f"assertion failure {amp_src_.dtype}"
    assert amp_src_.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {amp_src_.shape}"

    # recompose fft of source
    fft_src_real = torch.cos(pha_src.clone()) * amp_src_.clone()
    fft_src_imag = torch.sin(pha_src.clone()) * amp_src_.clone()
    fft_src_ = torch.complex(fft_src_real, fft_src_imag)
    assert fft_src_.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {fft_src_.shape}"
  
    # get the recomposed image: source content, target style
    _, _, imgH, imgW = src_img.size()
    src_in_trg = torch.fft.ifftn(fft_src_, dim=(2, 3))
    assert src_in_trg.shape == (BATCH_SIZE_CAMVID, 3, 720, 960), f"assertion failure {src_in_trg.shape}"

    return src_in_trg
def adjust_learning_rate(optimizer, initial_learning_rate, step, max_num_step, power):
  # polynomial decay of learning rate
  lr = initial_learning_rate*((1 - float(step)/max_num_step)**(power))
  optimizer.param_groups[0]['lr'] = lr

class CrossEntropy2d(nn.Module):
    def __init__(self, size_average=True, ignore_label=11):
        super(CrossEntropy2d, self).__init__()
        self.size_average = size_average
        self.ignore_label = ignore_label

    def forward(self, predict, target, weight=None):
        """
            Args:
                predict:(n, c, h, w)
                target:(n, h, w)
                weight (Tensor, optional): a manual rescaling weight given to each class.
                                           If given, has to be a Tensor of size "nclasses"
        """
        assert not target.requires_grad
        assert predict.dim() == 4
        assert target.dim() == 3
        assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
        assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
        assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
        n, c, h, w = predict.size()
        target_mask = (target >= 0) * (target != self.ignore_label)
        target = target[target_mask]
        if not target.data.dim():
            return Variable(torch.zeros(1))
        predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
        predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
        loss = F.cross_entropy(predict, target, weight=weight, size_average=self.size_average)
        return loss


def main():
  # Call Python's garbage collector, and empty torch's CUDA cache. Just in case
  gc.collect()
  torch.cuda.empty_cache()
  
  # Enable cuDNN in benchmark mode. For more info see:
  # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
  torch.backends.cudnn.enabled = True
  torch.backends.cudnn.benchmark = True

  # Load Bisenet generator
  model = BiSeNet(NUM_CLASSES, CONTEXT_PATH).cuda()
  model.load_state_dict(torch.load('./checkpointBisenetFDA/30_0.09_fda.pth'))

  model.train()

  # Load source dataset
  source_dataset = IDDA(
      image_path=IDDA_PATH,
      label_path=IDDA_LABEL_PATH,
      classes_info_path=JSON_IDDA_PATH,
      scale=(CROP_HEIGHT, CROP_WIDTH),
      loss=LOSS,
      mode='train'
  )
  source_dataloader = DataLoader(
    source_dataset,
    batch_size=BATCH_SIZE_IDDA,
    shuffle=True,
    num_workers=NUM_WORKERS,
    drop_last=True,
    pin_memory=True
  )

  # Load target dataset
  target_dataset = CamVid(
    image_path=CAMVID_PATH,
    label_path= CAMVID_LABEL_PATH,csv_path= CSV_CAMVID_PATH,
    scale=(CROP_HEIGHT,
    CROP_WIDTH),
    loss=LOSS,
    mode='adversarial_train'
  )
  target_dataloader = DataLoader(
    target_dataset,
    batch_size=BATCH_SIZE_CAMVID,
    shuffle=True,
    num_workers=NUM_WORKERS,
    drop_last=True,
    pin_memory=True
  )

  cross_entropy_2d = CrossEntropy2d()
  optimizer_BiSeNet = torch.optim.SGD(model.parameters(), lr = LEARNING_RATE_SEGMENTATION, momentum = MOMENTUM, weight_decay = WEIGHT_DECAY)

  for epoch in range(31, 51):
    source_dataloader_iter = iter(source_dataloader)
    target_dataloader_iter = iter(target_dataloader)

    print(f'begin epoch {epoch}')

    # Initialize gradient=0 for model
    optimizer_BiSeNet.zero_grad()

    # Compute learning rate for this epoch
    adjust_learning_rate(optimizer_BiSeNet, LEARNING_RATE_SEGMENTATION, epoch, NUM_STEPS, POWER)

    for i in tqdm(range(len(target_dataloader))):
      gc.collect()
      torch.cuda.empty_cache()

      optimizer_BiSeNet.zero_grad()

      # load images from source
      try:
        batch = next(source_dataloader_iter)
      except StopIteration:
        source_dataloader_iter = iter(source_dataloader)
        batch = next(source_dataloader_iter)
      x_s, y_s = batch
      x_s, y_s = x_s.cuda(), y_s.cuda()

      # load images from target
      try:
        batch = next(target_dataloader_iter)
      except StopIteration:
        target_dataloader_iter = iter(target_dataloader)
        batch = next(target_dataloader_iter)
      x_t, _ = batch
      x_t = x_t.cuda() 
      
      x_s_unnormalized = unnormalize(x_s)
      x_t_unnormalized = unnormalize(x_t)

      x_s2t_unnormalized = FDA_source_to_target(
        x_s_unnormalized,
        x_t_unnormalized,
        beta=BETA
      ).real.cuda()

      x_s2t = normalize(x_s2t_unnormalized)
      # we pass the "source in target" images through the network
      p, _, _ = model(x_s2t) # not a real probability distribution, the softmax
      # will be called later inside CrossEntropy2D by F.cross_entropy

      # Loss
      loss_ce = cross_entropy_2d(p, torch.argmax(y_s, dim=1))
      loss = loss_ce 
      loss.backward()

      # Run optimizers using the gradient obtained via backpropagations
      optimizer_BiSeNet.step()

      #print(loss.item())

    
    # Save intermediate model (checkpoint)
    if epoch % CHECKPOINT_STEP == 0 and epoch != 0 or epoch==50:
      # If the directory does not exists create it
      if not os.path.isdir(CHECKPOINT_PATH):
        os.mkdir(CHECKPOINT_PATH)
      # Save the parameters of the model (segmentation network) and discriminator 
      torch.save(model.state_dict(), os.path.join(CHECKPOINT_PATH + str(epoch)+"_"+str(BETA)+'_fda.pth'))

# from pyinstrument import Profiler
# profiler = Profiler(interval=1e-4)
# profiler.start()
# main()
# profiler.stop()
# display(HTML(profiler.output_html()))
main()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


HBox(children=(FloatProgress(value=0.0, max=46830571.0), HTML(value='')))




Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


HBox(children=(FloatProgress(value=0.0, max=178793939.0), HTML(value='')))


begin epoch 31


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 32


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 33


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 34


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 35


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 36


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 37


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 38


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 39


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 40


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 41


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 42


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 43


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 44


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 45


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 46


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 47


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 48


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 49


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))


begin epoch 50


HBox(children=(FloatProgress(value=0.0, max=234.0), HTML(value='')))




In [7]:
!pip install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/07/84/46421bd3e0e89a92682b1a38b40efc22dafb6d8e3d947e4ceefd4a5fabc7/tensorboardX-2.2-py2.py3-none-any.whl (120kB)
[K     |██▊                             | 10kB 18.5MB/s eta 0:00:01[K     |█████▍                          | 20kB 23.4MB/s eta 0:00:01[K     |████████▏                       | 30kB 27.7MB/s eta 0:00:01[K     |██████████▉                     | 40kB 30.5MB/s eta 0:00:01[K     |█████████████▋                  | 51kB 27.7MB/s eta 0:00:01[K     |████████████████▎               | 61kB 28.2MB/s eta 0:00:01[K     |███████████████████             | 71kB 28.8MB/s eta 0:00:01[K     |█████████████████████▊          | 81kB 28.6MB/s eta 0:00:01[K     |████████████████████████▌       | 92kB 29.0MB/s eta 0:00:01[K     |███████████████████████████▏    | 102kB 29.6MB/s eta 0:00:01[K     |██████████████████████████████  | 112kB 29.6MB/s eta 0:00:01[K     |████████████████████████████████

In [8]:
from dataclasses import dataclass
@dataclass
class MyArgs:
  num_classes: int
  use_gpu: bool
  loss: str
from train import val

import torch
from dataset.CamVid import CamVid
from torch.utils.data import DataLoader
from model.build_BiSeNet import BiSeNet


CROP_HEIGHT = 720
CROP_WIDTH = 960

NUM_WORKERS = 0

CAMVID_PATH = ['/content/CamVid/train/', '/content/CamVid/val/']
CAMVID_TEST_PATH = ['/content/CamVid/test/']
CAMVID_LABEL_PATH = ['/content/CamVid/train_labels/', '/content/CamVid/val_labels/']
CAMVID_TEST_LABEL_PATH = ['/content/CamVid/test_labels/']
CSV_CAMVID_PATH = '/content/CamVid/class_dict.csv'


BATCH_SIZE_CAMVID = 2
LOSS = 'dice'
NUM_CLASSES = 12
CONTEXT_PATH = 'resnet101'

generator = BiSeNet(NUM_CLASSES, CONTEXT_PATH).cuda()
generator.load_state_dict(torch.load('/content/drive/MyDrive/MLDL/BiseNetv1/checkpointBisenetFDA/50_0.05_fda.pth'))
generator.eval()
target_dataset_test = CamVid(
  image_path=CAMVID_TEST_PATH,
  label_path= CAMVID_TEST_LABEL_PATH,csv_path= CSV_CAMVID_PATH,
  scale=(CROP_HEIGHT, CROP_WIDTH),
  loss=LOSS,
  mode='val'
)
target_dataloader_test = DataLoader(
  target_dataset_test,
  batch_size=1,
  shuffle=True,
  num_workers=NUM_WORKERS,
  drop_last=True,
  pin_memory=True
)
target_dataloader_test_iter = iter(target_dataloader_test)
val(MyArgs(NUM_CLASSES, True, 'dice'), generator, target_dataloader_test_iter, CSV_CAMVID_PATH)

start val!!
precision per pixel for test: 0.578
mIoU for validation: 0.190
mIoU for each class:
Bicyclist:0.0,
Building:0.5393193435798596,
Car:0.1705626120388322,
Column_Pole:0.002262160745076658,
Fence:0.0,
Pedestrian:0.0,
Road:0.5521019428828983,
Sidewalk:2.0262252989614202e-07,
SignSymbol:0.0,
Sky:0.4563222125510053,
Tree:0.3697320793957166,



(0.5780032673164045, 0.19002732307417441)

ciao

