### Import

In [None]:
import os.path as osp
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.backends import cudnn

from configs import *
from tester import *


### Confifurations

In [None]:
# declare the class obj
configs = Configs()
SEED = configs.seed
cudnn.enabled = True
cudnn.benchmark = True
cudnn.deterministic = False
torch.cuda.manual_seed(SEED)

### Preprocessing
run the command to create ground truth mask
```shell
python preprocess.py
```

### Dataset

In [None]:
from face_dataset import *

### Data augmentation

In [None]:
from augmentation import *

In [None]:
train_tranform = Compose({
    RandomCrop(448),
    RandomHorizontallyFlip(p=0.5),
    AdjustBrightness(bf=0.1),
    AdjustContrast(cf=0.1),
    AdjustHue(hue=0.1),
    AdjustSaturation(saturation=0.1)
})

### Train/Val/Test Split

In [None]:
ROOT_DIR = configs.root_dir
image_dir = os.path.join(ROOT_DIR, 'CelebA-HQ-img')

train_indices = set()
indices_file_pth = os.path.join(ROOT_DIR, 'train.txt')
with open(indices_file_pth, 'r') as file:
    train_indices = set(map(int, file.read().splitlines()))
    
sample_indices = list(range(len(os.listdir(image_dir))))
test_indices = [idx for idx in sample_indices if idx not in train_indices]

# Split indices into training and validation sets
train_indices = list(train_indices)
if configs.debug:
    train_indices = train_indices[:100]         ###############################   small training data for debugging   ###########################################
VAL_SIZE = configs.val_size
train_indices, valid_indices = train_test_split(train_indices, test_size=VAL_SIZE, random_state=SEED)
print(len(train_indices))
print(len(valid_indices))
print(len(test_indices))

In [None]:
### dataset ###
trainset = CelebAMask_HQ_Dataset(root_dir=ROOT_DIR, 
                            sample_indices=train_indices,
                            mode='train', 
                            tr_transform=train_tranform)
validset = CelebAMask_HQ_Dataset(root_dir=ROOT_DIR, 
                                sample_indices=valid_indices, 
                                mode = 'val')

print(len(trainset))
print(len(validset))



### Visualize data

In [None]:
from torch.utils.data import DataLoader

# Define a DataLoader to get batches of samples
dataloader = DataLoader(trainset, batch_size=4, shuffle=True)

# Get a batch of samples
for images, masks in dataloader:
    # Visualize each sample in the batch
    print(f'b_img_size: {images.shape}')
    print(f'b_mask_size: {masks.shape}')
    for i in range(images.shape[0]):
        image = images[i].permute(1, 2, 0).numpy()  # Convert PyTorch tensor to NumPy array and rearrange dimensions
        mask = masks[i].numpy()
        # Plot the image and mask side by side
        plt.subplot(2, 4, i + 1)
        plt.imshow(image)
        plt.title('Image')

        plt.subplot(2, 4, i + 5)
        plt.imshow(mask, cmap='gray')  # Assuming masks are grayscale
        # plt.imshow(mask)  # Assuming masks are grayscale
        plt.title('Mask')
    plt.show()
    break  # Only visualize the first batch for simplicity

### DataLoader

In [None]:
### dataloader ###
BATCH_SIZE = configs.batch_size
N_WORKERS = configs.n_workers

# sampler = torch.utils.data.distributed.DistributedSampler(trainset)

train_loader = DataLoader(trainset,
                    batch_size = BATCH_SIZE,
                    shuffle = True,
                    num_workers = N_WORKERS,
                    pin_memory = True,
                    drop_last = True)

valid_loader = DataLoader(validset,
                    batch_size = BATCH_SIZE,
                    shuffle = False,
                    num_workers = N_WORKERS, 
                    pin_memory = True,
                    drop_last = True)
print(f"training data: {len(train_indices)} and validation data: {len(valid_indices)} loaded succesfully ...")

### Unet

In [None]:
from unet import *

#### check input/output of the model

In [None]:
# model = Unet(n_channels=3, n_classes=19)
# batch = torch.randn(1,3,448,448)    
# result = model(batch) #It is your img input
# print(result.shape)

### Loss and metrics

In [None]:
from criterion import *
from metrics import *

## Training

#### Trainer class

In [None]:
from trainer import *

#### clear cache

In [None]:
gc.collect()
torch.cuda.empty_cache()    

#### Model Initialization

In [None]:
### Init model ###
DEVICE = configs.device
model = Unet(n_channels=3, n_classes=19).to(DEVICE)
print("Model Initialized !")

#### Training Hyperparameters

In [None]:
### hyper params ###
EPOCHS = configs.epochs
LR = configs.lr
optimizer = torch.optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0001, amsgrad=False)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5, min_lr=1e-6, verbose=True)  # goal: maximize miou
criterion = DiceLoss()
SAVEPATH = configs.model_path
SAVENAME = configs.model_weight

#### Training loop

In [None]:
### training ###
Trainer( model=model, 
    trainloader=train_loader,
    validloader=valid_loader,
    epochs=EPOCHS,
    criterion=criterion, 
    optimizer=optimizer,
    scheduler=scheduler, 
    device=DEVICE,
    savepath=SAVEPATH, 
    savename=SAVENAME).run()

### Testing

#### Tester Class

In [None]:
from tester import *

#### Test dataset and testloader

In [None]:
### dataloader ###
BATCH_SIZE = configs.batch_size
N_WORKERS = configs.n_workers

testset = CelebAMask_HQ_Dataset(root_dir=ROOT_DIR,
                            sample_indices=test_indices,
                            mode='test')

test_loader = DataLoader(testset,
                    batch_size = BATCH_SIZE,
                    shuffle = False,
                    num_workers = N_WORKERS, 
                    pin_memory = True,
                    drop_last = True)

#### Load model weight

In [None]:
DEVICE = configs.device
SAVEPATH = configs.model_path
OUTPUT_DIR = configs.cmp_result_dir
MODEL_WEIGHT = configs.model_weight

if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

In [None]:
model = Unet(n_channels=3, n_classes=19).to(DEVICE)
model.load_state_dict(torch.load(os.path.join(SAVEPATH , MODEL_WEIGHT)))

#### Testing loop

In [None]:
### testing
Tester(model=model, 
       testloader=test_loader, 
       criterion=criterion, 
       device=DEVICE).run()

### Generate 60 samples of comparision result

In [None]:
### visualize
    cmap = np.array([(0,  0,  0), (204, 0,  0), (76, 153, 0),
                         (204, 204, 0), (51, 51, 255), (204, 0, 204), (0, 255, 255),
                         (51, 255, 255), (102, 51, 0), (255, 0, 0), (102, 204, 0),
                         (255, 255, 0), (0, 0, 153), (0, 0, 204), (255, 51, 153),
                         (0, 204, 204), (0, 51, 0), (255, 153, 51), (0, 204, 0)],
                        dtype=np.uint8)
    
    to_tensor = transforms.Compose([
                transforms.ToTensor(),
                # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
    image_dir = os.path.join(ROOT_DIR, 'CelebA-HQ-img') 
    mask_dir = os.path.join(ROOT_DIR, 'mask')    

    test_dataset =[]
    for i in range(len([name for name in os.listdir(image_dir) if osp.isfile(osp.join(image_dir, name))])):
        img_path = osp.join(image_dir, str(i)+'.jpg')
        label_path = osp.join(mask_dir, str(i)+'.png')
        test_dataset.append([img_path, label_path])

    # inference again in file order
    # for i in tqdm(range(0, len(train_indices))):
    for i in tqdm(range(0, len(test_indices), 100)):
        idx = test_indices[i]
        if configs.debug:
            idx = valid_indices[i]
            idx = train_indices[i]
        img_pth, mask_pth = test_dataset[idx]

        image = Image.open(img_pth).convert('RGB')
        image = image.resize((512, 512), Image.BILINEAR)
        mask = Image.open(mask_pth).convert('L')

        image = to_tensor(image).unsqueeze(0)
        gt_mask = torch.from_numpy(np.array(mask)).long()

        pred_mask = model(image.to(DEVICE))     # predict
        pred_mask = pred_mask.data.max(1)[1].cpu().numpy()  # Matrix index  (1,19,h,w) => (1,1,h,w)
        
        image = image.squeeze(0).permute(1,2,0)     # (1,3,h,w) -> (h,w,3)
        pred_mask = pred_mask.squeeze(0)            # (1,h,w) -> (h,w)

        # generate color mask image
        color_gt_mask = cmap[gt_mask]
        color_pr_mask = cmap[pred_mask]
        
        plt.figure(figsize=(13, 6))
        image = Image.open(img_pth).convert('RGB')      # we want the image without normalization for plotting
        image = image.resize((512, 512), Image.BILINEAR)
        img_list = [image, color_pr_mask, color_gt_mask]
        for i in range(3):
            plt.subplot(1, 3, i+1)
            plt.imshow(img_list[i])
        plt.show()
        # plt.savefig(f"{OUTPUT_DIR}/result_{idx}.jpg")