In [None]:
import os
import torch
import numpy as np
import tqdm.auto as tqdm
import imageio.v3 as iio
import matplotlib.pyplot as plt
import glob
import os
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.transforms import v2

In [None]:
# Define the U-net model
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        # Encoder path
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        # Decoder path
        self.up6 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv7 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        self.conv8 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.up8 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv9 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(64)
        self.conv10 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv11 = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        #x = x.permute(0, 3, 1, 2)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.conv2(x))
        encoder1 = x
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.conv4(x))
        encoder2 = x
        x = self.pool(x)
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.conv6(x))
        x = self.up6(x)
        x = torch.cat([x, encoder2], dim=1)
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.conv8(x))
        x = self.up8(x)
        x = torch.cat([x, encoder1], dim=1)
        x = F.relu(self.bn9(self.conv9(x)))
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        return x

In [None]:
model = UNet(3,41)
model.load_state_dict(torch.load("models/submission_model.pth")) # load the best model

In [None]:
# path to the futuregan output folder
P = "path_to_futuregan_outputs"

dir_list = os.listdir(f"{P}")
lst2 = [x for x in dir_list if len(x) == 9 and x != '.DS_Store']
dirs = lst2.sort()

In [None]:
len(dirs)

In [None]:
from torch.utils.data import Dataset

class PredDataset(Dataset):
    def __init__(self, dir_list=None):
        self.data_files = dir_list
        
    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, index):
        img = torch.Tensor(np.copy(iio.imread(f"{P}/{self.data_files[index]}/image_22.png"))).to(torch.uint8)
        img = img.permute(2, 0, 1).to(torch.float) / 255
        img = (img - 0.5) / 2
        return img

hid_dataset = PredDataset(dir_list=dirs)
hid_loader = torch.utils.data.DataLoader(hid_dataset, num_workers=8) # Lazy Loader

In [None]:
hid_masks = torch.zeros([5000,1,160, 240]) 
hid_masks.shape

In [None]:
count = 0
model.eval()
for batch in tqdm.tqdm(hid_loader):
    inp = batch
    masks = model(inp)
    hid_masks[count] = masks.argmax(1).unsqueeze(0) 
    count += 1

torch.save(hid_masks, 'Leaderboard_Team14.pt') 

In [None]:
hm = torch.load("Leaderboard_Team14.pt")

In [None]:
pred_mask = hm[4999].cpu().squeeze(0)

example_image = (((hid_dataset[4999].permute(1, 2, 0)*2) + 0.5) * 255).to(int)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(example_image, vmin=0, vmax=48)
axes[1].imshow(pred_mask, vmin=0, vmax=48)

plt.show()