In [None]:
import torch
import torchvision
from torchvision.transforms.transforms import Lambda
from torchvision.utils import make_grid

import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tfrms

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# WSOL

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)

        return F.relu(x, inplace=True)

In [None]:
class Stem(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            *(list(torchvision.models.resnet50(pretrained=True).children())[:1])
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class A(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.a1 = nn.Sequential(
            ConvBlock(64, 64, kernel_size=3, stride=2, padding=1),
            ConvBlock(64, 64, kernel_size=3),
            ConvBlock(64, 128, kernel_size=3, padding=1),
            ConvBlock(128, 128, kernel_size=1),
        )

        self.a2 = nn.Sequential(
            ConvBlock(128, 128, kernel_size=3, stride=2),
            ConvBlock(128, 256, kernel_size=3),
            ConvBlock(256, 256, kernel_size=3),
            ConvBlock(256, 1024, kernel_size=3),
            ConvBlock(1024, 1024, kernel_size=3),
        )

        self.a3 = nn.Sequential(
            ConvBlock(1024, 1024, kernel_size=3, padding=1),
            ConvBlock(1024, 2048, kernel_size=3),
        )

        self.a4 = ConvBlock(2048, num_classes, kernel_size=3)

        self.gap = nn.AdaptiveAvgPool2d(16)

    def forward(self, x):
        out_a1 = self.a1(x)
        out_a2 = self.a2(out_a1)
        out_a3 = self.a3(out_a2)
        out_a4 = self.a4(out_a3)

        out = self.gap(out_a4)

        return (out_a1, out_a2, out_a3, out_a4, out)

In [None]:
class B(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.conv1x = ConvBlock(128, 64, kernel_size=7, stride=3)
        self.conv1y = ConvBlock(1024, 64, kernel_size=3)

        self.conv2 = ConvBlock(64, 64, kernel_size=3)
        self.conv3 = ConvBlock(64, num_classes, kernel_size=3, padding=1)

        self.sig = nn.Sigmoid()

    def forward(self, x, y):
        x = self.conv1x(x)
        y = self.conv1y(y)

        x = self.conv3(self.conv2(x))
        y = self.conv3(self.conv2(y))

        return (self.sig(x), self.sig(y))

In [None]:
class C(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = ConvBlock(2048, 1024, kernel_size=3)
        self.conv2 = ConvBlock(1024, num_classes, kernel_size=3, padding=1)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return self.sig(x)

In [None]:
class WSOL(nn.Module):
    def __init__(self, num_classes, threshold=0.3):
        super().__init__()

        self.thres = threshold

        self.stem = Stem()
        self.a = A(num_classes=num_classes)
        self.b = B(num_classes=num_classes)
        self.c = C(num_classes=num_classes)

        self.upsample = nn.Upsample((240, 240))

        self.fc = nn.Sequential(
            nn.Flatten(), nn.Linear(in_features=19 * 16 * 16, out_features=num_classes)
        )

    def forward(self, x):
        feature_map = self.stem(x)

        # A
        [out_a1, out_a2, out_a3, out_a4, gap] = self.a(feature_map)

        # B
        [out_b1, out_b2] = self.b(out_a1, out_a2)

        # C
        out_c = self.c(out_a3)

        # logits
        logits = self.fc(gap)

        return {
            "a": self.upsample(out_a4),
            "b1": self.upsample(out_b1),
            "b2": self.upsample(out_b2),
            "c": self.upsample(out_c),
            "logits": logits,
        }

    def train_step(self, xb):
        images, labels = xb
        out = self(images)

        a = self.apply_threshold(out["a"])
        b1 = out["b1"]
        b2 = self.apply_threshold(out["b2"])
        c = out["c"]

        loss_ab = self.loss_saliency(F.binary_cross_entropy_with_logits, a, out["b2"])
        loss_bb = self.loss_saliency(F.binary_cross_entropy_with_logits, b2, b1)
        loss_bc = self.loss_saliency(F.binary_cross_entropy_with_logits, b2, c)

        loss_logits = F.cross_entropy(out["logits"], labels)

        loss = loss_ab + loss_bb + loss_bc + loss_logits

        return {"loss": loss, "map": [a, b2]}

    def apply_threshold(self, x):
        bg = x < self.thres
        mask = x > self.thres
        x[bg.data] = 0.0
        x[mask.data] = 1.0

        return x

    def loss_saliency(self, loss_func, logits, labels):
        positions = labels.view(-1, 1) < 255.0

        return loss_func(
            logits.view(logits.shape[0], -1), labels.view(labels.shape[0], -1)
        )

# Execution

In [None]:
model = WSOL(num_classes=19)

In [None]:
def denormalize(images, means, stds, channels):
    means = torch.tensor(means).reshape(1, channels, 1, 1)
    stds = torch.tensor(stds).reshape(1, channels, 1, 1)
    return images * stds + means

In [None]:
def show_batch(images):
    fig, ax = plt.subplots(figsize=(24, 24))
    for img in images:
        ax.set_xticks([])
        ax.set_yticks([])
        img_ = img.detach().clamp(0, 1).numpy()
        print(img_.shape)
        ax.imshow(img_)
        plt.show()
        input()
        plt.close()

In [None]:
tfrm = tfrms.Compose(
    [
        tfrms.Resize((240, 240)),
        tfrms.ToTensor(),
        tfrms.Lambda(lambda x: x.to(torch.float32)),
        tfrms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
        tfrms.Normalize(0.445, 0.225),
    ]
)

In [None]:
img = Image.open("./test_images/img1.png")
sample_inp = torch.unsqueeze(tfrm(img), 0)

In [None]:
# sample_inp = torch.rand((2, 3, 240, 240))
sample_label = torch.rand((1,)).to(torch.long)
sample_out = model.train_step((sample_inp, sample_label))

print("loss", sample_out["loss"].shape)
print("a", sample_out["map"][0].shape)
print("b", sample_out["map"][1].shape)


maps = denormalize(sample_out["map"][0], (0.445,) * 19, (0.225,) * 19, 19)


# masking the input image
masks = maps[0].detach().numpy()
img_arr = np.array(img.resize((240, 240)))

final_mask = masks[0]
for mask in masks[1:]:
    final_mask += mask

masked_img = img_arr * mask
masked_img = Image.fromarray(masked_img).convert("RGB")
masked_img.save(f"./outputs/masked_combined.png")


for idx, mask in enumerate(masks):
    masked_img = img_arr * mask
    masked_img = Image.fromarray(masked_img).convert("RGB")
    # masked_img.show()
    masked_img.save(f"./outputs/masked_{idx}.png")