In [3]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## Подготовка датасета

In [4]:
!pip install segmentation-models-pytorch

In [23]:
import matplotlib.pyplot as plt
import cv2
import warnings

from PIL import Image
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

import torch
from torch.nn import functional as F

warnings.simplefilter("ignore")

In [24]:
DATA_ROOT = '/kaggle/input/makeup-lips-segmentation-28k-samples/set-lipstick-original/'

In [25]:
fig, ax = plt.subplots(1,2, figsize = (15,7))

ax[0].imshow(Image.open(DATA_ROOT + 'mask/mask00000777.png'))
ax[1].imshow(Image.open(DATA_ROOT + '720p/image00000777.jpg'))

plt.show()

In [26]:
df = pd.read_csv(DATA_ROOT + 'list.csv')
df.tail()

In [27]:
lips_pics = []
lips_mask = []
masks_names = []

for root, dirs, files in os.walk(DATA_ROOT + '720p/'):
    for name in files:
        f = os.path.join(root, name)
        lips_pics.append(f)

for root, dirs, files in os.walk(DATA_ROOT + 'mask/'):
     for name in files:
        f = os.path.join(root, name)
        lips_mask.append(f)
        masks_names.append(name)
        
len(lips_pics), len(lips_mask), len(masks_names)

In [28]:
IMG_PATH = '/kaggle/input/makeup-lips-segmentation-28k-samples/set-lipstick-original/720p/'
MASK_PATH = '/kaggle/input/makeup-lips-segmentation-28k-samples/set-lipstick-original/mask/'

In [29]:
lips_pics[:5] , lips_mask[:5], masks_names[:5]

In [30]:
df_2 = df.loc[df['mask'].isin([mask for mask in masks_names])]
print(f"Было : {df.shape}, \nСтало: {df_2.shape}")

In [31]:
df_2['filename'] = IMG_PATH + df_2['filename']
df_2['mask'] = MASK_PATH + df_2['mask']

df_2.head(2)

In [32]:
class LipsDataset(Dataset):
    
    def __init__(self, data, preprocessing=None):
        self.data = data
        
        self.image_arr = self.data['filename']
        self.label_arr = self.data['mask']

        self.data_len = len(self.data.index)
        
        self.preprocessing = preprocessing
        
    def __getitem__(self, index):

        img = cv2.cvtColor(cv2.imread(self.image_arr[index]), cv2.COLOR_RGB2BGR)
        img = cv2.resize(img, (256, 256))
        img = np.array(img).astype('float')
        
        if self.preprocessing:
            img = self.preprocessing(img)
            img = torch.as_tensor(img)
        else:
            # Нормализуем изображение в значениях [0, 1]
            img = torch.as_tensor(img) / 255.0
        
        img = img.permute(2,0,1)
        
#         print(self.label_arr[index])
        mask = cv2.cvtColor(cv2.imread(self.label_arr[index]), cv2.COLOR_BGR2RGB)
#         print(mask)
        mask = cv2.resize(mask, (256, 256))
        mask = np.array(mask).astype('float')
            
        mask = torch.as_tensor(mask)/255.0  
        mask = mask.permute(2,0,1)
                   
        return (img.float(), mask)

    def __len__(self):
        return self.data_len

In [33]:
dataset = LipsDataset(df_2)
img, masks = dataset[42]
print(img.shape, masks.shape)
fig, ax = plt.subplots(1, 2, figsize=(15, 7))
ax[0].imshow(img.permute(1, 2, 0))
ax[1].imshow(masks.permute(1, 2, 0))
plt.show()

In [34]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [35]:
# создание модели
BACKBONE = 'resnet34'
segmodel = smp.Unet(BACKBONE, classes=1, activation='sigmoid').to(device)
preprocess_input = smp.encoders.get_preprocessing_fn(BACKBONE, pretrained='imagenet')

In [36]:
train, test =  train_test_split(df_2, test_size=0.3 ,random_state=42)

train.reset_index(drop=True, inplace=True)
test.reset_index(drop=True, inplace=True)

df_train = LipsDataset(train, preprocessing=preprocess_input)
df_test = LipsDataset(test, preprocessing=preprocess_input)

train_data_loader = DataLoader(df_train, batch_size=30, shuffle=True)
test_data_loader = DataLoader(df_test, batch_size=10, shuffle=False)

In [37]:
for img, target in train_data_loader:
    print(img.shape, target.shape)
    print(img[0].min(), img[0].max())
    print(target[0].min(), target[0].max())
    fig, ax = plt.subplots(1, 2, figsize=(15, 6))
    ax[0].imshow(img[0].permute(1, 2, 0))
    ax[1].imshow(target[0].permute(1, 2, 0)[..., 0])
    break

In [38]:
criterion = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(),]

optimizer = torch.optim.Adam(params=segmodel.parameters(), lr=0.001)

In [39]:
train_epoch = smp.utils.train.TrainEpoch(
    segmodel, 
    loss=criterion, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    segmodel, 
    loss=criterion, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [40]:
# train model
epoch = 5
max_score = 0

for i in range(epoch):
    print(f'Epoch: {i + 1}')
    train_logs = train_epoch.run(train_data_loader)
    valid_logs = valid_epoch.run(test_data_loader)
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(segmodel, './best_model.pth')
        print('Model saved!')

In [44]:
class_idx = 1

for i, data in enumerate(test_data_loader):
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)
    outputs = segmodel(images)
    fig , ax = plt.subplots(1, 3, figsize=(16,5))

    for j in range(3):

        image = images[i].permute(1, 2, 0)
        label = labels[i]

        ax[0].imshow(image.cpu())
        ax[0].set_title('Image')
        
        ax[1].imshow(outputs.detach().cpu()[i].permute(1, 2, 0))
        ax[1].set_title('Pred_mask')

        ax[2].imshow(label.cpu().permute(1, 2, 0))
        ax[2].set_title('True_mask')
        
    if i > 1:
        break