In [None]:
import os
import cv2
import rootutils
from dotenv import load_dotenv
from pathlib import Path
from IPython.display import display

import torch
import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2

# adding root to python path
rootutils.setup_root(
    os.path.abspath(''), indicator=['.git', 'pyproject.toml'], pythonpath=True
)

from src.models.components.base_model import BaseModel
from src.models.components.nn_utils import weight_load
from src.data.components.utils import list_files

load_dotenv()

#### Read seat image

In [None]:
source_path = Path(os.environ.get('lear_wrinkles_data_path'))
image_paths = list_files(source_path, file_extensions=['.bmp', '.jpg', '.png'])
image = cv2.imread(str(image_paths[0]), cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
transform = A.Compose([
    A.Resize(
        height = 768,
        width = 640
        ),
    A.ToFloat(max_value=255),
    ToTensorV2(),
])
transformed = transform(image=image)
image_tensor = transformed['image'].unsqueeze(0).to(device)
image_tensor_pil = T.ToPILImage()(image_tensor[0].detach().cpu())
display(image_tensor_pil)
print('Shape: ', image_tensor.shape, 'Type: ', image_tensor.dtype, 'max: ', image_tensor.max(), 'min: ', image_tensor.min())

In [None]:
model = BaseModel(
    model_name = 'segmentation_models_pytorch/UnetPlusPlus',
    encoder_name = 'mobilenet_v2',
    ).to(device)
weights = weight_load(
    ckpt_path='../trained_models/unet++.ckpt',
    weights_only=True,
)
model.load_state_dict(weights)
model.eval()

In [None]:
%%timeit
with torch.no_grad():
    model(image_tensor)

In [None]:
with torch.no_grad():
    out = model(image_tensor)

mask = torch.sigmoid(out[0])
mask = (mask > 0.5).float()
mask = T.ToPILImage()(mask.detach().cpu())
display(image_tensor_pil)
display(mask)

#### Run inference on all samples

In [None]:
output_path = source_path / 'masks'
output_path.mkdir(exist_ok=True)
for path in image_paths:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    transformed = transform(image=image)
    image_tensor = transformed['image'].unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(image_tensor)
    out = torch.nn.functional.interpolate(out, size=image.shape[:2], mode="bilinear", align_corners=False)
    mask = torch.sigmoid(out[0])
    mask = (mask > 0.5).float()

    mask = mask.detach().cpu().numpy()
    mask = (mask[0] * 255).astype('uint8')
    mask_path = output_path / Path(path).name
    cv2.imwrite(str(mask_path), mask)