In [8]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/wikiart-damage-masks/square.h5
/kaggle/input/wikiart-damage-masks/irregular.h5
/kaggle/input/wikiart-inpainting-full-set-weights/best_weights_epoch_28_val_loss_3.7235-2348.pth
/kaggle/input/wikiart-cluster-annotations-split/combined_clusters.csv
/kaggle/input/wikiart-clean-without-split/annotations.csv
/kaggle/input/wikiart-clean-without-split/dataset.h5


In [9]:
# !pip -q install lightning
# !pip -q install comet-ml

In [10]:
import h5py
import torch
import numpy as np
from PIL import Image
import math
import os

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import timm

import torchvision.utils as vutils
from torchvision.transforms import v2

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

In [11]:
'''
This variable is used to identify the group of models that are being trained based on previous clusterization.
Clusterization output was combined_clusters.csv with column determining the group of the model.
Group was assigned to train and validation sets only.
If GROUP_ID is None, the model will be trained on the whole dataset. 
'''

GROUP_ID = 5 

In [12]:
from lightning.pytorch.loggers import CometLogger
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("COMET_API_KEY")

comet_logger = CometLogger(
    api_key=secret_value_0,
    project_name=f'UNet_Inpainting-{GROUP_ID}',
    workspace='wikiart-inpainting',
)

INFO: CometLogger will be initialized in online mode


In [13]:
class WikiArtDataset(Dataset):
    def __init__(self, h5_path: str, mask_h5_path: str, csv_path: str, set_type: str, group_id=None, label_col='cluster_label', transform=None):
        self.h5_path = h5_path
        self.mask_h5_path = mask_h5_path
        
        self.df = pd.read_csv(csv_path)
        self.df = self.df[self.df['set_type'] == set_type]

        if group_id is not None:
            self.df = self.df[self.df['cluster_label'] == group_id]
        
        self.label_col = label_col
        self.transform = transform
        
        self.length = len(self.df)
  
        with h5py.File(self.mask_h5_path, 'r') as mask_h5f:
            self.num_masks = mask_h5f['mask'].shape[0]

    def __len__(self):
        return self.length

    def _open_hdf5(self):
        if not hasattr(self, '_hf') or self._hf is None:
            self._hf = h5py.File(self.h5_path, 'r')

        if not hasattr(self, '_mask_hf') or self._mask_hf is None:
            self._mask_hf = h5py.File(self.mask_h5_path, 'r')

    def _get_random_mask(self):
        mask_idx = np.random.randint(0, self.num_masks)
        mask = self._mask_hf['mask'][mask_idx]
        return mask
    
    def __getitem__(self, idx):
        self._open_hdf5()

        row = self.df.iloc[idx]
        image_idx = row['h5_index']

        label = row[self.label_col]

        image = self._hf['image'][image_idx]
        image = torch.from_numpy(image).float()

        mask = self._get_random_mask()
        mask = torch.from_numpy(mask).float()
        
        if self.transform:
            image = self.transform(image)

        return image, mask, label

In [14]:
class UNetInpainting(nn.Module):
    def __init__(self, in_channels=4, out_channels=3, use_dropout=False):
        super().__init__()
        self.use_dropout = use_dropout

        self.encoder1 = self.conv_block(in_channels, 16)
        self.encoder2 = self.conv_block(16, 32, pool=True)
        self.encoder3 = self.conv_block(32, 64, pool=True)
        self.encoder4 = self.conv_block(64, 128, pool=True)

        self.bottleneck = self.conv_block(128, 256, pool=True)

        self.decoder4 = self.upconv_block(256, 128)
        self.decoder3 = self.upconv_block(128, 64)
        self.decoder2 = self.upconv_block(64, 32)
        self.decoder1 = self.upconv_block(32, 16)

        self.final_conv = nn.Conv2d(16, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels, pool=False):
        layers = []
        if pool:
            layers.append(nn.MaxPool2d(kernel_size=2))
        layers.extend([
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        ])
        if self.use_dropout:
            layers.append(nn.Dropout(0.5))
        return nn.Sequential(*layers)

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x, mask):
        x_with_mask = x * (1 - mask)

        x_with_mask_and_mask = torch.cat([x_with_mask, mask], dim=1)

        e1 = self.encoder1(x_with_mask_and_mask)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        b = self.bottleneck(e4)

        d4 = self.decoder4(b) + e4
        d3 = self.decoder3(d4) + e3
        d2 = self.decoder2(d3) + e2
        d1 = self.decoder1(d2) + e1

        output = self.final_conv(d1)

        output = output * mask + x * (1 - mask)
    
        return output

In [15]:
def log_images_to_comet(logger, images, mask, output, epoch, step, name='sample'):
    grid = vutils.make_grid(
        [images[0].cpu(), (images[0] * (1 - mask[0])).cpu(), output[0].cpu()],
        nrow=3,
        normalize=True
    )

    grid_np = grid.permute(1, 2, 0).detach().numpy()  # CxHxW -> HxWxC

    grid_image = Image.fromarray((grid_np * 255).clip(0, 255).astype('uint8'))

    logger.experiment.log_image(
        grid_image, name=f'{name}_epoch_{epoch}_step_{step}.png'
    )

In [16]:
class UNetInpaintingLightning(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = nn.MSELoss()
        self.best_val_loss = float('inf')
        self.automatic_optimization = False

    def forward(self, x, mask):
        return self.model(x, mask)

    def training_step(self, batch, batch_idx):
        x, mask, _ = batch

        mask = mask.unsqueeze(1)

        optimizer = self.optimizers()
        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda'):
            output = self(x, mask)
            loss = self.criterion(output, x)

        self.manual_backward(loss)
        optimizer.step()

        current_lr = optimizer.param_groups[0]['lr']
        self.log('learning_rate', current_lr, prog_bar=True, on_epoch=True, on_step=True)

        self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True)

        if batch_idx % 50 == 0:
            log_images_to_comet(
                self.logger,
                x,
                mask,
                output,
                epoch=self.current_epoch,
                step=self.global_step,
                name='train_sample'
            )

        return loss

    def validation_step(self, batch, batch_idx):
        x, mask, _ = batch
        mask = mask.unsqueeze(1)
        output = self(x, mask)
        loss = self.criterion(output, x)

        self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=True)

        if batch_idx == 0:
            log_images_to_comet(
                self.logger,
                x,
                mask,
                output,
                epoch=self.current_epoch,
                step=self.global_step,
                name='val_sample'
            )

        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
            }
        }

    def on_validation_epoch_end(self):
        val_loss = self.trainer.callback_metrics.get('val_loss')
        if val_loss is not None:
            # Step the scheduler manually with val_loss
            lr_scheduler = self.lr_schedulers()
            if lr_scheduler is not None:
                lr_scheduler.step(val_loss)

            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss

                weights_path = f'best_weights_epoch_{self.current_epoch}_val_loss_{val_loss:.4f}.pth'
                torch.save(self.model.state_dict(), weights_path)

                self.logger.experiment.log_model(
                    name='best_weights',
                    file_or_folder=weights_path,
                    overwrite=True
                )
                print(f'Weights saved and logged to CometML: {weights_path}')

In [17]:
transforms = v2.Compose([
    v2.ToDtype(torch.float32, scale=True),
    # v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

h5_path = '/kaggle/input/wikiart-clean-without-split/dataset.h5'
mask_path = '/kaggle/input/wikiart-damage-masks/square.h5'
csv_path = '/kaggle/input/wikiart-cluster-annotations-split/combined_clusters.csv'

train_dataset = WikiArtDataset(
    h5_path=h5_path,
    mask_h5_path=mask_path,
    csv_path=csv_path,
    set_type='train',
    group_id=GROUP_ID,
    transform=transforms
)

valid_dataset = WikiArtDataset(
    h5_path=h5_path,
    mask_h5_path=mask_path,
    csv_path=csv_path,
    set_type='valid',
    group_id=GROUP_ID,
    transform=transforms
)

In [18]:
batch_size = 128

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
model = UNetInpainting(in_channels=4)

model.load_state_dict(torch.load('/kaggle/input/wikiart-inpainting-full-set-weights/best_weights_epoch_28_val_loss_3.7235-2348.pth'))

checkpoint_path = ''

if checkpoint_path == '':
    lightning_model = UNetInpaintingLightning(model)
else:
    lightning_model = UNetInpaintingLightning.load_from_checkpoint(checkpoint_path, model=model)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    dirpath='./checkpoints',
    filename='best_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    verbose=True
)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=8,
    verbose=True,
    mode='min',
    check_on_train_epoch_end=False
)

trainer = pl.Trainer(
    logger=comet_logger,
    max_epochs=30,
    devices=1 if torch.cuda.is_available() else 0,
    callbacks=[checkpoint_callback, early_stopping_callback]
)

In [None]:
trainer.fit(lightning_model, train_dataloader, valid_dataloader)

In [None]:
comet_logger.experiment.end()