# New STEP4: MOVING TOWARDS FFreeDA - PRE-TRAINING PHASE

## TODO remove the following cell and add ToTensor to the transforms function

In [None]:
import torchvision.transforms.functional as F
import collections
import numbers

_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
    Image.HAMMING: 'PIL.Image.HAMMING',
    Image.BOX: 'PIL.Image.BOX',
}

class Resize(object):

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img, lbl=None):
        if lbl is not None:
            return F.resize(img, self.size, self.interpolation), F.resize(lbl, self.size, Image.NEAREST)
        else:
            return F.resize(img, self.size, self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)


class ToTensor(object):

    def __call__(self, pic, lbl=None):
        if lbl is not None:
            return F.to_tensor(pic), torch.from_numpy(np.array(lbl, dtype=np.uint8))
        else:
            return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'


## Settings

In [None]:
transform  = Resize(size=(512, 1024))

UNIFORM = True
MAX_SAMPLE_PER_CLIENT = 20
N_STYLE = 5
TOT_CLIENT = 20
FDA = True

# Checkpoint and best model
use_checkpoint = True 
CKPT_DIR = 'checkpoints'

best_ckpt_path = f"best_model_step4.tar"
ckpt_path = f"model_step4.tar"

## GTA Dataset


In [None]:
from torchvision.datasets import VisionDataset

class GTA5(VisionDataset):

    labels2train = {
        'cityscapes': {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
                       26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18},
    }

    def __init__(self, root, transform=None, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), cv2=False, target_dataset='cityscapes'):
        assert target_dataset in GTA5.labels2train, f'Class mapping missing for {target_dataset}, choose from: {GTA5.labels2train.keys()}'
        self.labels2train = GTA5.labels2train[target_dataset]

        #super().__init__(root, transform=transform, target_transform=None)

        self.root = root
        self.transform = transform
        self.mean = mean
        self.std = std
        self.cv2 = cv2

        self.target_transform = self.__map_labels()

        self.return_unprocessed_image = False
        self.style_tf_fn = None

        with open(os.path.join(self.root,'train.txt'), "r") as f:
          lines = f.readlines()

        # manipulate each file row in order to obtain the correct path 
        self.paths_images = [l.strip() for l in lines]
        # self.paths_tagets = [l for l in lines]

        self.len = len(self.paths_images)

    def set_style_tf_fn(self, style_tf_fn):
        self.style_tf_fn = style_tf_fn

    def reset_style_tf_fn(self):
        self.style_tf_fn = None

    def __getitem__(self, index):
        x_path = os.path.join(self.root,'images',self.paths_images[index])
        y_path = os.path.join(self.root,'labels',self.paths_images[index])
 
        x = Image.open(x_path)
        y = Image.open(y_path) 

        ## using read_image
        # x = read_image(x_path)
        # y = read_image(y_path)

        if self.return_unprocessed_image:
            return x
        if self.style_tf_fn is not None:
            x = self.style_tf_fn(x)
        if self.transform is not None:
          x, y = self.transform(x, y)
        y = self.target_transform(y)

# TODO: insert directly in the transform Compose ?? 
        transform_Tensor = ToTensor()
        x, y = transform_Tensor(x, y)

        return x, y

    def __len__(self):
        return self.len

    def __map_labels(self):
        mapping = np.zeros((256,), dtype=np.int64) + 255
        for k, v in self.labels2train.items():
            mapping[k] = v
        return lambda x: from_numpy(mapping[x])
        

## Cityscapes Client Dataset 

In [None]:
from tqdm import tqdm
from torch import from_numpy
from torch.utils import data
      
class CityscapesClient(data.Dataset):

    labels2train = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13,
                    27: 14, 28: 15, 31: 16, 32: 17, 33: 18}

    def __init__(self, root, uniform, id_client, transform=None, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):

        self.root = root

        if uniform:
          filename = 'uniformA.json'
        else:
          filename = 'heterogeneuosA.json'

        with open(os.path.join(root,filename)) as f:
          dict_data = json.load(f)

        self.paths_images = [l[0] for l in dict_data[str(id_client)]]
        self.paths_tagets = [l[1] for l in dict_data[str(id_client)]]

        self.len = len(self.paths_images)
        
        self.transform = transform
        self.mean = mean
        self.std = std 
        self.test = False
        self.style_tf_fn = None
        self.return_unprocessed_image = False

        self.target_transform = self.__map_labels()

    def __getitem__(self, index):
        x = Image.open(os.path.join(self.root,"images",self.paths_images[index]))
        y = Image.open(os.path.join(self.root,"labels",self.paths_tagets[index])) 
        
        if self.return_unprocessed_image:
            return x
            
 
        if self.transform is not None:
          x, y = self.transform(x, y)
          
        y = self.target_transform(y)

# TODO: insert directly in the transform Compose ?? 
        transform_Tensor = ToTensor()
        x, y = transform_Tensor(x, y)

        return x, y


    def __len__(self):
        return self.len

    def __map_labels(self):
        mapping = np.zeros((256,), dtype=np.int64) + 255
        for k, v in self.labels2train.items():
            mapping[k] = v
        return lambda x: from_numpy(mapping[x])



## Style Augment

In [None]:
import numpy as np
import random
from PIL import Image
import cv2
from tqdm import tqdm


class StyleAugment:

    def __init__(self, n_images_per_style=10, L=0.1, size=(1024, 512), b=None):
        self.styles = []
        self.styles_names = []
        self.n_images_per_style = n_images_per_style
        self.L = L
        self.size = size
        self.sizes = None
        self.cv2 = False
        self.b = b

    def preprocess(self, x):
        if isinstance(x, np.ndarray):
            x = cv2.resize(x, self.size, interpolation=cv2.INTER_CUBIC)
            self.cv2 = True
        else:
            x = x.resize(self.size, Image.BICUBIC)
        x = np.asarray(x, np.float32)
        x = x[:, :, ::-1]
        x = x.transpose((2, 0, 1))
        return x.copy()

    def deprocess(self, x, size):
        if self.cv2:
            x = cv2.resize(np.uint8(x).transpose((1, 2, 0))[:, :, ::-1], size, interpolation=cv2.INTER_CUBIC)
        else:
            x = Image.fromarray(np.uint8(x).transpose((1, 2, 0))[:, :, ::-1])
            x = x.resize(size, Image.BICUBIC)
        return x

    def add_style(self, loader, multiple_styles=False, name=None):
        if self.n_images_per_style < 0:
            return

        if name is not None:
            self.styles_names.append([name] * self.n_images_per_style if multiple_styles else [name])

        loader.return_unprocessed_image = True
        n = 0
        styles = []

        for sample in tqdm(loader, total=min(len(loader), self.n_images_per_style)):

            image = self.preprocess(sample)

            if n >= self.n_images_per_style:
                break
            styles.append(self._extract_style(image))
            n += 1

        if self.n_images_per_style > 1:
            if multiple_styles:
                self.styles += styles
            else:
                styles = np.stack(styles, axis=0)
                style = np.mean(styles, axis=0)
                self.styles.append(style)
        elif self.n_images_per_style == 1:
            self.styles += styles

        loader.return_unprocessed_image = False

    def _extract_style(self, img_np):
        fft_np = np.fft.fft2(img_np, axes=(-2, -1))
        amp = np.abs(fft_np)
        amp_shift = np.fft.fftshift(amp, axes=(-2, -1))
        if self.sizes is None:
            self.sizes = self.compute_size(amp_shift)
        h1, h2, w1, w2 = self.sizes
        style = amp_shift[:, h1:h2, w1:w2]
        return style

    def compute_size(self, amp_shift):
        _, h, w = amp_shift.shape
        b = (np.floor(np.amin((h, w)) * self.L)).astype(int) if self.b is None else self.b
        c_h = np.floor(h / 2.0).astype(int)
        c_w = np.floor(w / 2.0).astype(int)
        h1 = c_h - b
        h2 = c_h + b + 1
        w1 = c_w - b
        w2 = c_w + b + 1
        return h1, h2, w1, w2

    def apply_style(self, image):
        return self._apply_style(image)

    def _apply_style(self, img):

        if self.n_images_per_style < 0:
            return img

        if len(self.styles) > 0:
            n = random.randint(0, len(self.styles) - 1)
            style = self.styles[n]
        else:
            style = self.styles[0]

        if isinstance(img, np.ndarray):
            H, W = img.shape[0:2]
        else:
            W, H = img.size
        img_np = self.preprocess(img)

        fft_np = np.fft.fft2(img_np, axes=(-2, -1))
        amp, pha = np.abs(fft_np), np.angle(fft_np)
        amp_shift = np.fft.fftshift(amp, axes=(-2, -1))
        h1, h2, w1, w2 = self.sizes
        amp_shift[:, h1:h2, w1:w2] = style
        amp_ = np.fft.ifftshift(amp_shift, axes=(-2, -1))

        fft_ = amp_ * np.exp(1j * pha)
        img_np_ = np.fft.ifft2(fft_, axes=(-2, -1))
        img_np_ = np.real(img_np_)
        img_np__ = np.clip(np.round(img_np_), 0., 255.)

        img_with_style = self.deprocess(img_np__, (W, H))

        return img_with_style

    def test(self, images_np, images_target_np=None, size=None):

        Image.fromarray(np.uint8(images_np.transpose((1, 2, 0)))[:, :, ::-1]).show()
        fft_np = np.fft.fft2(images_np, axes=(-2, -1))
        amp = np.abs(fft_np)
        amp_shift = np.fft.fftshift(amp, axes=(-2, -1))
        h1, h2, w1, w2 = self.sizes
        style = amp_shift[:, h1:h2, w1:w2]

        fft_np_ = np.fft.fft2(images_np if images_target_np is None else images_target_np, axes=(-2, -1))
        amp_, pha_ = np.abs(fft_np_), np.angle(fft_np_)
        amp_shift_ = np.fft.fftshift(amp_, axes=(-2, -1))
        h1, h2, w1, w2 = self.sizes
        amp_shift_[:, h1:h2, w1:w2] = style
        amp__ = np.fft.ifftshift(amp_shift_, axes=(-2, -1))

        fft_ = amp__ * np.exp(1j * pha_)
        img_np_ = np.fft.ifft2(fft_, axes=(-2, -1))
        img_np_ = np.real(img_np_)
        img_np__ = np.clip(np.round(img_np_), 0., 255.)
        Image.fromarray(np.uint8(images_target_np.transpose((1, 2, 0)))[:, :, ::-1]).show()
        Image.fromarray(np.uint8(img_np__).transpose((1, 2, 0))[:, :, ::-1]).show()


## GTA dataset visualization 

In [None]:
transform = Resize(size=(512, 1024))
train_dataset = GTA5(root=ROOT_DIR_GTA5, transform=transform)
print("Dataset dimension: ", len(train_dataset))
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

rows = 1
columns = 2
for i, (imgs, targets) in enumerate(train_dataloader):
    print(imgs[0].shape,targets[0].shape)

    figure = plt.figure(figsize=(10,20))
    figure.add_subplot(rows, columns,1)
    print("img type:",type(imgs[0]), " target:",type(targets[0]))
    print("img:", imgs[0].squeeze().shape, " target:",targets[0].squeeze().shape)
  

    plt.imshow(imgs[0].permute((1, 2, 0)))
    plt.axis('off')
    plt.title("Image")
    figure.add_subplot(rows, columns,2)
    
    plt.imshow(decode_segmap(targets[0]))
    plt.axis('off')
    plt.title("Groundtruth")
    plt.show()
    
    if i == 1: break

## Dataset creation

In [None]:
if FDA:
  #L 0.01, 0.05, 0.09
  # b == 0 --> 1x1, b == 1 --> 3x3, b == 2 --> 5x5, ...'
  SA = StyleAugment(n_images_per_style=MAX_SAMPLE_PER_CLIENT, L=0.01, size=(1024, 512), b=1) 

  clients = random.sample([_ for _ in range(TOT_CLIENT)],N_STYLE)
  for c in clients:
    client_dataset = CityscapesClient(root=ROOT_DIR, uniform=UNIFORM, id_client=c, transform=transform)
    SA.add_style(client_dataset)

In [None]:
train_dataset = GTA5(root=ROOT_DIR_GTA5, transform=transform)

if FDA:
  train_dataset.set_style_tf_fn(SA.apply_style)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
print("Dataset dimension: ", len(train_dataset))

Dataset dimension:  500


## GTA images with Cityscapes style visualization

In [None]:
n_images = 2
rows = 1
columns = 2
for i, (imgs, targets) in enumerate(train_dataloader):
    print(imgs[0].shape,targets[0].shape)

    figure = plt.figure(figsize=(50,50))
    figure.add_subplot(rows, columns,1)
    print("img type:",type(imgs[0]), " target:",type(targets[0]))
    print("img:", imgs[0].squeeze().shape, " target:",targets[0].squeeze().shape)

    plt.imshow(imgs[0].permute((1, 2, 0)))
    plt.axis('off')
    plt.title("Image")
    figure.add_subplot(rows, columns,2)
    
    plt.imshow(decode_segmap(targets[0]))
    plt.axis('off')
    plt.title("Groundtruth")
    #plt.show()
    plt.savefig(f"style_{i}.png", transparent = True)
    if i+1 == n_images: break

## Training

Validation Dataset

In [None]:
transform  = Resize(size=(512, 1024))
client_dataset = CityscapesClient(root=ROOT_DIR, uniform=True, id_client=0, transform=transform)
val_dataloader = DataLoader(client_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

Checkpoint loading

In [None]:
if not os.path.isdir(CKPT_DIR):
  os.mkdir(CKPT_DIR)

if use_checkpoint and ckpt_path in os.listdir(CKPT_DIR):
  checkpoint = torch.load(os.path.join(CKPT_DIR,ckpt_path))
  print('Epoch {}, Loss {}, MIoU {}'.format(checkpoint['epoch'], checkpoint['loss'], checkpoint['miou']))

In [None]:
model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)
if use_checkpoint and ckpt_path in os.listdir(CKPT_DIR):
  model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss(ignore_index= 255) 

parameters_to_optimize = model.parameters()
optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
if use_checkpoint and ckpt_path in os.listdir(CKPT_DIR):
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
cudnn.benchmark # Calling this optimizes runtime

losses = []
losses_val = []
epochs = []

# wandb.watch(model, log='all')

if best_ckpt_path in os.listdir(CKPT_DIR):
    best_checkpoint = torch.load(os.path.join(CKPT_DIR,best_ckpt_path))
    best_loss = best_checkpoint['loss']
    best_miou = best_checkpoint['miou']


current_step = 0
start_epoch = 0
if use_checkpoint and ckpt_path in os.listdir(CKPT_DIR):
  start_epoch = checkpoint['epoch']

# Start iterating over the epochs
for epoch in range(start_epoch, NUM_EPOCHS):
  print('Starting epoch {}/{}'.format(epoch+1, NUM_EPOCHS))
  epochs.append(epoch+1)

  # Iterate over the dataset
  for images, labels in train_dataloader:

    images = images.to(DEVICE, dtype=torch.float32)
    labels = labels.to(DEVICE, dtype=torch.long)
    #print("images:",images.shape,"labels:",labels.shape)

    model.train()
    optimizer.zero_grad()

    predictions = model(images)
    #print("predictions:",predictions.shape,"labels:",labels.shape)
    loss = criterion(predictions, labels.squeeze())

    #wandb.log({"train/loss":loss})

    # Log loss
    if current_step % LOG_FREQUENCY == 0:
      miou = compute_moiu(model, val_dataloader)
      print('Step {}, Loss {}, MIoU {}'.format(current_step, loss.item(),miou))
      # wandb.log({"train/loss": loss})
    # Compute gradients for each layer and update weights
    loss.backward()
    optimizer.step()

    current_step += 1

  # save intermediate checkpoint
  torch.save({'epoch': epoch+1, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'miou': miou}, 
             os.path.join(CKPT_DIR,ckpt_path))
  
  # save/update best checkpoint
  if best_ckpt_path not in os.listdir(CKPT_DIR) or (best_ckpt_path in os.listdir(CKPT_DIR) and best_miou < miou):
    torch.save({'epoch': epoch+1, 'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'miou': miou}, 
               os.path.join(CKPT_DIR,best_ckpt_path))
    best_loss = loss
    best_miou = miou

## Validation

In [None]:
if best_ckpt_path in os.listdir(CKPT_DIR):
  best_checkpoint = torch.load(os.path.join(CKPT_DIR,best_ckpt_path))
  model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)
  model.load_state_dict(best_checkpoint['model_state_dict'])
  # Set dropout and batch normalization layers to evaluation mode before running inference.
  model.eval() 
  print()
else: 
  print('There is no model to load')




Load Validation Dataset

In [None]:
transform = Resize(size=(512, 1024))
val_dataset = CityscapesClient(root=ROOT_DIR, uniform=True, id_client=10, transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
print("Dataset dimension: ", len(val_dataset))

Dataset dimension:  19


In [None]:
validation_plot(net= model, val_dataloader=val_dataloader, n_image=20)
torch.cuda.empty_cache()
miou = compute_moiu(net=model, val_dataloader=val_dataloader)
print('Validation MIoU: {}'.format(miou))