In [None]:
from IPython.display import clear_output
try:
    import albumentations
except ImportError:
    !pip install albumentations

try:
    import Cython
except ImportError:
    !pip install Cython
clear_output()

In [None]:
import json
from torch.utils.data import Dataset, DataLoader
from skimage import io
import pandas as pd
import torch
import os
import numpy as np
from pathlib import Path
import cv2, zlib, base64
from PIL import Image
import torch.nn as nn
import torchvision.transforms.functional as TF
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision

In [None]:
# dataset.py

class FootballBannerDataset(Dataset):
    """Football advertising banners images from UEFA Champions League matches."""

    def __init__(self, image_dir: str, mask_dir: str, transform=None):
        """
        Args:
            mask_dir (string): Directory with all the annotations.
            image_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)
    
    def extract_mask(self, labels):
        mask = np.zeros((labels['size']['height'], labels['size']['width']), dtype=np.float32)
        if len(labels["objects"]) == 0:
            return mask
        bitmap = labels["objects"][0]["bitmap"]["data"]
        start_point = labels["objects"][0]["bitmap"]["origin"]

        mask_small = base64_2_mask(bitmap)
        mask[
            start_point[1] : start_point[1] + mask_small.shape[0],
            start_point[0] : start_point[0] + mask_small.shape[1],
        ] = mask_small

        mask[mask == 255.0] = 1.0
        return mask

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Read Image
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx]+".json")
        image = np.array(Image.open(img_path).convert("RGB"))
        
        
        with open(mask_path, "r", encoding="utf-8") as annotReader:
            labels = json.loads(annotReader.read())
        mask = self.extract_mask(labels)

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [None]:
# model.py

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in features[::-1]:
            self.ups.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(feature * 2, feature))

        # Bootleneck layer
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Final conv
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):

        # Down
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        # Upper
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)

        return self.final_conv(x)


In [None]:
# utils.py

def base64_2_mask(s):
    z = zlib.decompress(base64.b64decode(s))
    n = np.frombuffer(z, np.uint8)
    mask = cv2.imdecode(n, cv2.IMREAD_UNCHANGED)[:, :, 3].astype(bool)
    return mask


def mask_2_base64(mask):
    img_pil = Image.fromarray(np.array(mask, dtype=np.uint8))
    img_pil.putpalette([0, 0, 0, 255, 255, 255])
    bytes_io = io.BytesIO()
    img_pil.save(bytes_io, format="PNG", transparency=0, optimize=0)
    bytes = bytes_io.getvalue()
    return base64.b64encode(zlib.compress(bytes)).decode("utf-8")


def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def get_loaders(
    train_dir,
    train_maskdir,
    batch_size,
    train_transform,
    num_workers=4,
    pin_memory=True,
):
    footballBannerDataset = FootballBannerDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )
    train_ds, val_ds = torch.utils.data.random_split(footballBannerDataset, [7000, 1851])

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader


def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()


def save_predictions_as_imgs(loader, model, folder="saved_images", device="cuda"):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        if not os.path.exists(folder):
            !mkdir $folder
        torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/{idx}.png")

    model.train()


In [None]:
# train.py


def train_fn(loader, model, optimizer, loss_fn, scaler, scheduler=None):
    """Does one epoch of training."""
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()  # cross entropy loss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("../input/unet-football-banner-image-segmentation/my_checkpoint.pth.tar"), model)

    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()


    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE
        )

In [None]:
meta_class_data = {
    "mastercard": 0,
    "nissan": 1,
    "playstation": 2,
    "unicredit": 3,
    "pepsi": 4,
    "adidas": 5,
    "gazprom": 6,
    "heineken": 7,
}

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32  # 32
NUM_EPOCHS = 100  # 100
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "../input/football-advertising-banners-detection/football/images"
TRAIN_MASK_DIR = "../input/football-advertising-banners-detection/football/annotations"

In [None]:
# Train Model
main()

In [None]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
load_checkpoint(torch.load("./my_checkpoint.pth.tar"), model)

# Test on Example Case

In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms

def pred(x, model):
    x = x.to(device=DEVICE)
    with torch.no_grad():
        preds = torch.sigmoid(model(x))
        preds = (preds > 0.5).float()
    return preds

ex_img_path1 ="../input/football-advertising-banners-detection/football/images/00bhhxx56ft6rtq.png"
ex_img_path2 = "../input/football-advertising-banners-detection/football/images/00fkgxlxff8hd3z.png"
fig, axes = plt.subplots(2,1,figsize=(8,20))
ax1,ax2 = axes


fig.suptitle("Example Model Predictions", fontsize=24)

# Image 1
img=Image.open(ex_img_path1)
ax1.set_title("Example Image from GAL - BRU", fontsize=20)
ax1.imshow(img)
mask_tensor = pred(transforms.ToTensor()(img).unsqueeze_(0),model)
mask = transforms.ToPILImage()(mask_tensor.squeeze_(0))
ax1.imshow(mask, cmap='jet', alpha=0.5)

# Image 2
img=Image.open(ex_img_path2)
ax2.set_title("Example Image from GAL - BRU", fontsize=20)
ax2.imshow(img)
mask_tensor = pred(transforms.ToTensor()(img).unsqueeze_(0),model)
mask = transforms.ToPILImage()(mask_tensor.squeeze_(0))
ax2.imshow(mask, cmap='jet', alpha=0.5)
plt.tight_layout()