# Model
The model used is a hybrid Mask RCNN and U-Net as shown in the figure below,

![image](../data/seg_mask_rcnn.png)

Zoomed in version of the architecture backbone

![image](../data/maskrcnn.png)

# Train
To train on single node and multipl gpus, use the following,

In [None]:
#%torchrun --standalone --nnodes=1 --nproc_per_node=2 /dbfs/imseg/train.py

# Visualize
Plot a few validation data images and also the model results

In [None]:
import os
import sys
sys.path.insert(0, "../src")

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToTensor

from model.data_loader import FashionpediaDataset, get_transform
from utils.vis_utils import vis_data
from utils.utils import filter_data

torch.manual_seed(0)
np.random.seed(0)

plt.style.use('dark_background')

In [None]:
data_path = "../../../datasets/Fashionpedia"
val_json = "annotations/val.json"
val_img_path = "images/test"
model_path = "../experiments/base"

In [None]:
thr = [0.15] * 46
thr[1] = 0.75
thr[4] = 0.2
thr[8] = 0.5
thr[10] = 0.25
thr[15] = 0.2
thr[18] = 0.4
thr[19] = 0.3
thr[24] = 0.5
thr[28] = 0.5
thr[29] = 0.2
thr[31] = 0.85
thr[33] = 0.5
thr[34] = 0.2
thr[43] = 0.5

In [None]:
model = torch.load(os.path.join(model_path, "segm.pth"), map_location=torch.device('cpu'))
model.eval();

In [None]:
val_dataset = FashionpediaDataset(data_path, val_json, val_img_path, transforms=get_transform(False))
categories = val_dataset.categories

In [None]:
i = np.random.randint(0, len(val_dataset))
val_im, val_target = val_dataset[i]

In [None]:
vis_data(val_im.numpy().transpose(1, 2, 0),
         val_target,
         categories,
         num_cols=3,
         figsize=(15, 10))

In [None]:
pred = model([val_im])

In [None]:
out = filter_data(pred[0], thr)
vis_data(val_im.numpy().transpose(1, 2, 0),
         out,
         categories,
         num_cols=2,
         figsize=(15, 50))

# H&M data
Test model on H&M data

In [None]:
trans = ToTensor()
hm_data_path = "../data"

In [None]:
file_list = []
for file in os.listdir(hm_data_path):
    if file.endswith(".jpg"):
        file_list.append(file)

file_list

In [None]:
im_list = []
for file in file_list:
    with Image.open(os.path.join(hm_data_path, file), "r") as im:
        im_list.append(trans(im))

In [None]:
pred = model(im_list)

In [None]:
i = 1
im_t = im_list[i]
out = filter_data(pred[i], thr)
vis_data(im_t.numpy().transpose(1, 2, 0),
         out,
         categories,
         num_cols=2,
         figsize=(15, 50))