# DAAI Project 3 - Train On CityScapes (step 2a)

## Data Preparation

### Clone the professor's repository

In [None]:
!git clone https://github.com/ClaudiaCuttano/AML_Semantic_DA.git

### Mount Google Drive to access files

In [None]:
from google.colab import drive

drive.mount('/content/drive')

### Replace the empty 'cityscapes.py' with the one in this repo

In [None]:
!cp datasets/cityscapes.py AML_Semantic_DA/cityscapes.py

### Extract the CityScapes dataset

In [None]:
import zipfile

with zipfile.ZipFile(f'drive/MyDrive/cityscapes.zip', 'r') as zip_ref:
  zip_ref.extractall()

### Copy the pre-trained model to Colab

In [None]:
!cp drive/MyDrive/STDCNet813M_73.91.tar STDCNet813M_73.91.tar

### Install tensorboardX

In [None]:
!pip install tensorboardX

## Train the model

In [None]:
!python AML_Semantic_DA/train.py \
--save_model_path drive/MyDrive/cityscapes_checkpoints/ \
--backbone STDCNet813 \
--mode train \
--pretrain_path STDCNet813M_73.91.tar \
--num_workers 2 \
--num_epochs 50 \
--batch_size 2 \
--validation_step 1 \
--learning_rate 0.01

## Visualize The Results

### Create dataset and dataloader

In [None]:
from pathlib import Path
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
from AML_Semantic_DA.cityscapes import CityScapes


BATCH_SIZE = 1
NUM_WORKERS = os.cpu_count()
num_classes = 19

val_dataset = CityScapes(
    mode='val'
    )
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True
)

In [None]:
data, label = next(iter(val_loader))

In [None]:
plt.imshow(data[0].permute(1,2,0).cpu())

In [None]:
plt.imshow(label[0].permute(1,2,0).cpu())

### Load the trained model

In [None]:
from AML_Semantic_DA.model.model_stages import BiSeNet
import torch

model = BiSeNet(backbone='STDCNet813', n_classes=num_classes, pretrain_model='STDCNet813M_73.91.tar', use_conv_last=False)
model.load_state_dict(torch.load('best.pth', map_location='cpu'))
model = torch.nn.DataParallel(model).cuda()

### Get an image and its label

In [None]:
data, label = next(iter(val_loader))

### Generate prediction and calculate metrics

In [None]:
import numpy as np
from AML_Semantic_DA.utils import poly_lr_scheduler
from AML_Semantic_DA.utils import reverse_one_hot, compute_global_accuracy, fast_hist, per_class_iu

with torch.no_grad():
  model.eval()
  hist = np.zeros((num_classes, num_classes))
  label = label.type(torch.LongTensor)
  data = data.cuda()
  label = label.long().cuda()

  # get RGB predict image
  predict, _, _ = model(data)
  predict = predict.squeeze(0)
  predict = reverse_one_hot(predict)
  predict = np.array(predict.cpu())

  # get RGB label image
  label = label.squeeze()
  label = np.array(label.cpu())

  # compute per pixel accuracy
  precision = compute_global_accuracy(predict, label)
  hist += fast_hist(label.flatten(), predict.flatten(), num_classes)

  precision = np.mean(precision_record)
  miou_list = per_class_iu(hist)
  miou = np.mean(miou_list)
  print('precision per pixel for test: %.3f' % precision)
  print('mIoU for validation: %.3f' % miou)
  print(f'mIoU per class: {miou_list}')

### Show the pictures

In [None]:
plt.imshow(data[0].permute(1,2,0).cpu())

In [None]:
label[label==255] = 0
plt.imshow(label)

In [None]:
plt.imshow(predict)