In [None]:
# ! pip install segmentation-models-pytorch albumentations scikit-learn

In [None]:
# !pip3.11 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1

In [None]:
import glob
import os
import random
import pandas as pd
import albumentations as A
import cv2
import torch
import matplotlib.pyplot as plt
import numpy as np
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

from losses import ComboLoss, FocalLoss2d

## 4 класса на изображении - плита, духовка, микроволновка и чайник.

In [None]:
SEED = 60

def set_seed():
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
#     torch.cuda.manual_seed(SEED)
#     torch.cuda.manual_seed_all(SEED)
#     torch.backends.cudnn.deterministic = True

    os.environ['PYTHONHASHSEED'] = str(SEED)
#     os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
#     os.environ['CUDA_VISIBLE_DEVICES']= '1'

device = 'mps'

BATCH_SIZE = 4
DECODERS = smp._MODEL_ARCHITECTURES

In [None]:
DECODERS

In [None]:
def plot_images_side_by_side(image1, image2, title1='', title2=''):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    
    axs[0].imshow(image1)
    axs[0].set_title(title1)
    axs[0].axis('off')
    
    axs[1].imshow(image2)
    axs[1].set_title(title2)
    axs[1].axis('off')
    
    plt.show()


def dice_channels(prob, truth, threshold=0.5, eps = 1E-9):
    num_imgs = prob.size(0)
    num_channels = prob.size(1)
    prob = (prob > threshold).float()
    truth = (truth > 0.5).float()
    prob = prob.view(num_imgs, num_channels, -1)
    truth = truth.view(num_imgs, num_channels, -1)
    intersection = (prob * truth)
    score = (2. * intersection.sum(2) + eps) / (prob.sum(2) + truth.sum(2) + eps)
    score[score >= 1] = 1
    return score.mean()


def train_epoch(loader, model, loss_function_seg, optimizer, device):
    model = model.to(device)
    model.train()
    avg_loss = 0.
    optimizer.zero_grad()
    for image, mask in loader:
        x = image.to(device)
        y = mask.to(device)
        prediction_seg = model(x)
        loss = loss_function_seg(prediction_seg, y)
        loss.backward()
        optimizer.step()
#         scheduler.step()
        optimizer.zero_grad() 
        avg_loss += loss.item()
    avg_loss /= len(loader)
    return avg_loss


def valid_epoch(loader, model, device):
    model = model.to(device)
    model.eval()
    scores = []
    with torch.no_grad():
        for image, mask in loader:
            x = image.to(device)
            y = mask.to(device)
            probs = torch.sigmoid(model(x))
            scores.append(dice_channels(probs, y))
    return torch.stack(scores).mean().item()


def mask2rle(mask):
    '''
    mask: numpy array, 1 - pixel classified as target, 0 - background
    Returns run length as string formated
    '''
    rles = []
    for channel in range(mask.shape[2]):
        channel_mask = mask[:, :, channel]
        pixels = channel_mask.T.flatten()
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        rle_channel = ' '.join(str(x) for x in runs)
        rles.append(rle_channel)
    return ';'.join(rles)
 

def get_submission_df(img_paths, transforms, model, device='mps'):
    '''
    img_paths: list of paths to test images
    transforms: albumentation test transforms
    model: the trained model
    Returns submission dataframe
    '''
    model = model.to(device)
    model.eval()
    submission_df = pd.DataFrame()
    c = 0
    with torch.no_grad():
        for n, img_path in enumerate(img_paths):
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            transformed_data = transforms(image=img)
            transformed_img = transformed_data['image']
            model_input = torch.from_numpy(transformed_img).permute(2, 0, 1).unsqueeze(0)
            pred = model(model_input.to(device))
            mask = torch.sigmoid(pred).squeeze().cpu() > 0.5
            resized_mask = cv2.resize(
                mask.numpy().transpose((1, 2, 0)).astype(int),
                (img.shape[1], img.shape[0]),
                interpolation=cv2.INTER_NEAREST
            )
            submission_df.loc[c, 'image_id'] = os.path.basename(img_path)
            submission_df.loc[c, 'rle'] = mask2rle(resized_mask)
            submission_df.loc[c, 'batch_number'] = c
            c += 1
    return submission_df

In [None]:
train_transforms = A.Compose([
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.05, hue=0.05, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Resize(500, 500),
    A.RandomCrop(384, 384, p=0.9),
    A.Normalize(),
])



# train_transforms = A.Compose([
#     A.GaussNoise(mean_range=(0.0, 0.0), 
#                  std_range=(0.01, 0.05)),
#     A.RandomToneCurve()
#     A.Rotate(limit=(-20, 20)),
#     A.HorizontalFlip(p=0.5),
#     A.Resize(384, 384),
#     A.Normalize(),
#     A.RandomCrop()
# ])

test_transforms = A.Compose([
    A.Resize(384, 384),
    A.Normalize()
])

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, img_paths, masks_paths, transforms):
        self.img_paths = img_paths
        self.masks_paths = masks_paths
        self.transforms = transforms
        
    def __getitem__(self, item):
        img_path = self.img_paths[item]
        mask_path = self.masks_paths[item]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = np.load(mask_path)
        
        transformed_data = self.transforms(image=img, mask=mask)
        img = transformed_data['image']
        mask = transformed_data['mask']

        return torch.from_numpy(img).permute(2, 0, 1), torch.from_numpy(mask).permute(2, 0, 1).float()
    
    def __len__(self):
        return len(self.img_paths)

In [None]:
img_path = sorted(glob.glob('data/train/images/*.jpg'))[0]
mask_path = sorted(glob.glob('data/train/masks/*.npy'))[0]
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = np.load(mask_path)

transformed_data = train_transforms(image=img, mask=mask)
transformed_img = transformed_data['image']
transformed_mask = transformed_data['mask']

In [None]:
plot_images_side_by_side(img, np.sum(mask, axis=2))

In [None]:
img_path = sorted(glob.glob('data/train/images/*.jpg'))
mask_path = sorted(glob.glob('data/train/masks/*.npy'))
None

In [None]:
for i in range(len(img_path)):
    img = cv2.imread(img_path[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask = np.load(mask_path[i])

    transformed_data = train_transforms(image=img, mask=mask)
    transformed_img = transformed_data['image']
    transformed_mask = transformed_data['mask']
    plot_images_side_by_side(transformed_img, np.sum(transformed_mask, axis=-1))

In [None]:
transformed_mask.astype(np.float32)

In [None]:
set_seed()

In [None]:
all_train_images = sorted(glob.glob('data/train/images/*.jpg'))
all_train_masks = sorted(glob.glob('data/train/masks/*.npy'))

train_images, valid_images, train_masks, valid_masks = train_test_split(
    all_train_images,
    all_train_masks,
    test_size=0.2,
    random_state=42
)

all_test_images = sorted(glob.glob('data/test/images/*.jpg'))

len(train_images), len(train_masks), len(valid_images), len(valid_masks), len(all_test_images)

In [None]:
train_dataset = SegmentationDataset(
    train_images, 
    train_masks,
    train_transforms
)

valid_dataset = SegmentationDataset(
    valid_images, 
    valid_masks,
    test_transforms
)

train_loader = DataLoader(train_dataset, BATCH_SIZE, num_workers=0, shuffle=True)
valid_loader = DataLoader(valid_dataset, BATCH_SIZE, num_workers=0, shuffle=False)

print(len(train_dataset), len(valid_dataset))

# Baseline

In [None]:
model = smp.UnetPlusPlus(
    encoder_name='efficientnet-b2',
    encoder_weights='imagenet',
    classes=4,
).to(device)

loss_fn = ComboLoss(weights={'bce': 0.65, 'dice': 0.35}, channel_weights=[1]*4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
optimizer.param_groups[0]['initial_lr'] = optimizer.param_groups[0]['lr']
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       eta_min=5e-5,
                                                       T_max=40,
                                                       last_epoch=39)

In [None]:
# best_score = 0
optimizer.param_groups[0]['lr'] = 5e-6
for epoch in range(40):
    train_loss = train_epoch(train_loader, model, loss_fn, optimizer, device)
    valid_score = valid_epoch(valid_loader, model, device)
#     scheduler.step(train_loss)
    if valid_score > best_score:
        torch.save(model.state_dict(), 'best_model_unetPlusPlus_efnB2_2.pth')
        best_score = valid_score
    print(f'Epoch: {epoch}, train_loss: {train_loss:.4f}, valid_score: {valid_score:.3f}\n')

In [None]:
model.load_state_dict(torch.load('best_model_unetPlusPlus_efnB2_2.pth'))

best_score = valid_epoch(valid_loader, model, device)
print(f'Best validation score: {best_score}')

In [None]:
valid_image_path = sorted(valid_images)[7]

with torch.no_grad():
    img = cv2.imread(valid_image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    transformed_data = test_transforms(image=img)
    transformed_img = transformed_data['image']
    model_input = torch.from_numpy(transformed_img).permute(2, 0, 1).unsqueeze(0)
    pred = model(model_input.to(device))
    mask = torch.sigmoid(pred).squeeze().cpu() > 0.5
    resized_mask = cv2.resize(
        mask.numpy().transpose((1, 2, 0)).astype(int),
        (img.shape[1], img.shape[0]),
        interpolation=cv2.INTER_NEAREST
    )

In [None]:
plot_images_side_by_side(img, resized_mask.sum(2, keepdims=True))

In [None]:
for i in range(37):
    test_image_path = sorted(all_test_images)[i]

    with torch.no_grad():
        img = cv2.imread(test_image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        transformed_data = test_transforms(image=img)
        transformed_img = transformed_data['image']
        model_input = torch.from_numpy(transformed_img).permute(2, 0, 1).unsqueeze(0)
        pred = model(model_input.to(device))
        mask = torch.sigmoid(pred).squeeze().cpu() > 0.5
        resized_mask = cv2.resize(
            mask.numpy().transpose((1, 2, 0)).astype(int),
            (img.shape[1], img.shape[0]),
            interpolation=cv2.INTER_NEAREST
        )
    plot_images_side_by_side(img, resized_mask.sum(2, keepdims=True))

In [None]:
submission_df = get_submission_df(all_test_images, test_transforms, model)
submission_df.to_csv(
    'solution.csv', 
    header=['image_id', 'rle', 'batch_number'], 
    index=False
)

In [None]:
submission_df