In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import shutil
shutil.unpack_archive("/content/drive/MyDrive/COCO2017/cocomy.zip", "/content/cocomy")

In [None]:
import os
os.kill(os.getpid(), 9)

In [None]:
import pandas
print(pandas.__file__)

/usr/local/lib/python3.12/dist-packages/pandas/__init__.py


In [None]:
import math
import os
import torch
import torch.nn as nn

from torchvision.transforms import v2, functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.optim import AdamW

from PIL import Image
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import matplotlib.pyplot as plt
from datetime import datetime
from zoneinfo import ZoneInfo

torch.backends.cudnn.benchmark = True

In [None]:
class CocoDenoisingDataset(Dataset):
    def __init__(self, root_dir, input_img_transform, target_img_transform, noise_factor=0.2):
        """
        Args:
            root_dir (string): Шлях до папки з картинками.
            img_size (tuple): Розмір, до якого треба ресайзити (h, w).
            noise_factor (float): Сила шуму (від 0.0 до 1.0).
        """
        self.root_dir = root_dir
        self.input_transform = input_img_transform
        self.target_transform = target_img_transform

        self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])

        image = Image.open(img_path).convert('RGB')

        input_img = self.input_transform(image)
        target_img = self.target_transform(image)

        return input_img, target_img


In [None]:
class LayerNorm2d(nn.Module):
    """
    LayerNorm, for (N, C, H, W) shape.
    torch`s default LayerNorm expects (N, H, W, C) which needs costful permutes.
    """
    def __init__(self, num_features, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)

        # 4D-equivalent of (x * weight + bias)
        return self.weight[:, None, None] * x + self.bias[:, None, None]

class SimpleGate(nn.Module):
    """
    NAFNet feature. Replacer of ReLU/GELU.
    Divides C dim in half and performs C1*C2.
    """
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(c, dw_channel, 1)
        self.conv2 = nn.Conv2d(dw_channel, dw_channel, 3, padding=1, groups=dw_channel) # Depthwise

        self.conv3 = nn.Conv2d(dw_channel // 2, c, 1)

        # Simplified Channel Attention (SCA)
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dw_channel // 2, dw_channel // 2, 1)
        )

        # SimpleGate
        self.sg = SimpleGate()

        # Feed Forward Network (FFN) part
        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(c, ffn_channel, 1)
        self.conv5 = nn.Conv2d(ffn_channel // 2, c, 1)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        # Layer Scale (параметри, що навчаються, для стабільності)
        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp):
        x = inp

        # Частина 1: Spatial Mixing
        x = self.norm1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x) # Множення на увагу
        x = self.conv3(x)
        x = self.dropout1(x)

        y = inp + x * self.beta # Residual Connection 1

        # Частина 2: Channel Mixing (FFN)
        x = self.norm2(y)
        x = self.conv4(x)
        x = self.sg(x)
        x = self.conv5(x)
        x = self.dropout2(x)

        return y + x * self.gamma # Residual Connection 2

class NAFNet(nn.Module):
    """
    Gather all the NAFblocks into U-Net structure.
    """
    def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1], dec_blk_nums=[1, 1, 1]):
        super().__init__()

        self.intro = nn.Conv2d(img_channel, width, 3, padding=1)
        self.ending = nn.Conv2d(width, img_channel, 3, padding=1)

        # Encoder
        self.encoders = nn.ModuleList()
        self.downs = nn.ModuleList()
        chan = width

        for num in enc_blk_nums:
            self.encoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
            self.downs.append(nn.Conv2d(chan, 2*chan, 2, 2)) # Downsampling (stride 2)
            chan = chan * 2

        # Middle
        self.middle_blks = nn.Sequential(*[NAFBlock(chan) for _ in range(middle_blk_num)])

        # Decoder
        self.decoders = nn.ModuleList()
        self.ups = nn.ModuleList()

        for num in dec_blk_nums:
            self.ups.append(nn.Sequential(
                nn.Conv2d(chan, chan * 2, 1, bias=False),
                nn.PixelShuffle(2) # Upsampling
            ))
            chan = chan // 2
            self.decoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))

    def forward(self, inp):

        # Basic extraction
        x = self.intro(inp)

        # Encoder
        encs = []
        for encoder, down in zip(self.encoders, self.downs):
            x = encoder(x)
            encs.append(x)
            x = down(x)

        # Middle
        x = self.middle_blks(x)

        # Decoder
        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = up(x)
            x = x + enc_skip
            x = decoder(x)

        # Ending
        x = self.ending(x)

        return x + inp

In [None]:
!pip install pytorch-msssim
from pytorch_msssim import ms_ssim

Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-1.0.0


In [None]:
class BatchSaltAndPepper(v2.Transform):
    """
    Applies Salt and Pepper noise directly to a batch of PyTorch tensors (B, C, H, W).
    """
    def __init__(self, salt_prob: float = 0.01, pepper_prob: float = 0.01):
        super().__init__()
        self.salt_prob = salt_prob
        self.pepper_prob = pepper_prob
        self.total_prob = salt_prob + pepper_prob

        if not (0.0 <= self.total_prob <= 1.0):
            raise ValueError("salt_prob + pepper_prob must be between 0.0 and 1.0")

    def _apply_batch(self, batch: torch.Tensor) -> torch.Tensor:
        B, C, H, W = batch.shape

        rand_tensor = torch.rand(
            B, 1, H, W,
            device=batch.device,
            dtype=batch.dtype
        )

        salt_mask = rand_tensor < self.salt_prob

        pepper_mask = (rand_tensor >= self.salt_prob) & (rand_tensor < self.total_prob)


        batch[salt_mask.expand_as(batch)] = 1.0  # Apply Salt
        batch[pepper_mask.expand_as(batch)] = 0.0 # Apply Pepper

        return batch

    def __call__(self, inpt):
        # single image is also "supported" - treated as a batch for _apply_batch()
        if isinstance(inpt, torch.Tensor):
            if inpt.dim() == 4:
                return self._apply_batch(inpt)
            elif inpt.dim() == 3:
                return self._apply_batch(inpt.unsqueeze(0)).squeeze(0)

        # default v2.Transform behavior for other types
        return super().__call__(inpt)

In [None]:


img_size = 256
rand_state = 44
num_workers = 2

def relative_data_path(path: str):
  return "/content/cocomy/data/"+ path

def relative_drive_path(path: str):
  return "/content/drive/MyDrive/COCO2017/" + path

def curr_time():
    return datetime.now(ZoneInfo('Europe/Kiev'))


def printshare(msg, logfile=relative_drive_path("training_log.txt")):
    print(msg)

    with open(logfile, "a") as f:
        print(msg, file=f)


def cosannealing_decay_warmup(warmup_steps, T_0, T_mult, decay_factor, base_lr, eta_min):
    # returns the func that performs all the calculations.
    # useful for keeping all the params in one place = scheduler def.
    def lr_lambda(epoch): #0-based epoch
        if epoch < warmup_steps:
            return base_lr * ((epoch + 1) / warmup_steps)

        annealing_step = epoch - warmup_steps

        # calculating which cycle (zero-based) are we in,
        # current cycle length (T_current) and position inside the cycle (t)
        if T_mult == 1:
            cycle = annealing_step // T_0
            t = annealing_step % T_0
            T_current = T_0

        else:
            # fast log-based computation
            cycle = int(math.log((annealing_step * (T_mult - 1)) / T_0 + 1, T_mult))
            sum_steps_of_previous_cycles = T_0 * (T_mult ** cycle - 1) // (T_mult - 1)
            t = annealing_step - sum_steps_of_previous_cycles
            T_current = T_0 * (T_mult ** cycle)


        # enable decay
        eta_max = base_lr * (decay_factor ** cycle)

        # cosine schedule between (eta_min, max_lr]
        lr = eta_min + 0.5 * (eta_max-eta_min) * (1 + math.cos(math.pi * t / T_current))
        return lr/base_lr

    return lr_lambda






def perform_training(net,
                     training_set,
                     validation_set,
                     epochs, w_decay, batch_size, sub_batch_size,
                     lr, lr_lambda: cosannealing_decay_warmup,
                     pretrained: bool | str = False):

    assert batch_size % sub_batch_size == 0 #screws up gradient accumulation otherwise

    printshare("training preparation...")

    scaler = torch.amp.GradScaler('cuda')

    train_loader = DataLoader(training_set, batch_size=sub_batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(validation_set, batch_size=sub_batch_size, shuffle=True, num_workers=num_workers)

    #========= loading the checkpoint and preparing optimizers =========

    criterion = nn.L1Loss()
    optimizer = AdamW(
        params=filter(lambda p: p.requires_grad, net.parameters()),
        lr=lr, weight_decay=w_decay)
        #[
        #    {"params": net.features[-2].parameters()},  # last residual block
        #    {"params": net.features[-1].parameters()},  # last conv
        #    {"params": net.classifier.parameters()}  # classifier
        #],

    #used LambdaLR to implement CosineAnnealing with warm restarts and decay.
    #yup, we need the base_lr to be passed in, cause it looks like this is the safest way.
    scheduler = LambdaLR(
        optimizer,
        lr_lambda=lr_lambda
    )

    #scheduler = CosineAnnealingLR(
    #    optimizer=optimizer,
    #    T_max=50,
    #    eta_min=1e-8,
    #)

    curr_epoch = 0
    if isinstance(pretrained, str):
        printshare("Loading pretrained model, optimizer & scheduler state dicts...")
        checkpoint = torch.load(pretrained)
        mid_se_keys = ["mid_se.fc.0.weight", "mid_se.fc.0.bias", "mid_se.fc.2.weight", "mid_se.fc.2.bias"]

        if 'model' not in checkpoint:
            missing, unexpected = net.load_state_dict(checkpoint, strict=False)
            printshare("got no optimizer & scheduler state dicts. model state dict set up successfully.")

        else:
            missing, unexpected = net.load_state_dict(checkpoint['model'], strict=False)
            optimizer.load_state_dict(checkpoint['optimizer'])
            for g in optimizer.param_groups:
                g['weight_decay'] = w_decay

            #scheduler.load_state_dict(checkpoint["scheduler"])
            scheduler.last_epoch = checkpoint['epoch']
            curr_epoch = checkpoint['epoch'] + 1

            printshare("all the dicts set up successfully.")


        printshare(f"[DEBUG] model missing statedict vals: {missing};")
        printshare(f"[DEBUG] model unexpected statedict vals: {unexpected}")

    #manual testing cycle
    #while(True):

    #    image, _ = training_set[225]
    #    transform = v2.ToPILImage()
    #    for i in range(16):
    #        img = transform(image[i])
    #        plt.imshow(img)
    #        plt.title(f"Augmented sample #0")
    #        plt.axis('off')
    #        plt.show()

    os.makedirs(relative_drive_path("checkpoints"), exist_ok=True)
    os.makedirs(relative_drive_path("checkpoints/stats"), exist_ok=True)
    printshare("done.")

    #========== training itself ==========
    while curr_epoch < epochs:
        printshare(f"[{curr_time().strftime('%Y-%m-%d %H:%M:%S')}] epoch {curr_epoch + 1}/{epochs} processing...")
        output_train_msssim, input_train_msssim, train_loss = perform_training_epoch(
            net=net,
            full_batch_size=batch_size, sub_batch_size=sub_batch_size,
            train_loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            scaler=scaler
        )

        printshare(f"training done. input ms-ssim: {round(100*input_train_msssim, 3)}%;"+
                   f" output ms-ssim: {round(100*output_train_msssim, 3)}%")


        printshare(f"[{curr_time().strftime('%Y-%m-%d %H:%M:%S')}] processing validation phase...")
        output_val_msssim, input_val_msssim, val_loss = perform_validation_epoch(
            net=net,
            val_loader=val_loader,
            criterion=criterion
        )

        printshare(f"validation done. input ms-ssim: {round(100*input_val_msssim, 3)}%;" +
                   f" output ms-ssim: {round(100*output_val_msssim, 3)}%")

        torch.save({ # model
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': curr_epoch,

        }, relative_drive_path(f'checkpoints/ep_{curr_epoch+1}_ts_{round(100*output_train_msssim, 1)}_vs_{round(100*output_val_msssim, 1)}_model.pth'))

        torch.save({ # stats
            'epoch': curr_epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'input_train_msssim': input_train_msssim,
            'output_train_msssim': output_train_msssim,
            'input_val_msssim': input_val_msssim,
            'output_val_msssim': output_val_msssim
        },
            relative_drive_path(f'checkpoints/stats/ep_{curr_epoch+1}_ts_{round(100*output_train_msssim, 1)}_vs_{round(100*output_val_msssim, 1)}_stats.pth'))

        curr_epoch += 1

    printshare(f"[{curr_time().strftime('%Y-%m-%d %H:%M:%S')}] training successfully finished.")
    return net


def perform_training_epoch(net, full_batch_size, sub_batch_size,
                           train_loader, criterion, optimizer, scheduler,
                           scaler):
    batch_losses = []
    model_output_msssim_vals = []
    input_msssim_vals = []

    net.train()

    accum_steps = math.ceil(full_batch_size / sub_batch_size)
    optimizer.zero_grad()

    for i, (input_imgs, target_imgs) in enumerate(train_loader):
        input_imgs, target_imgs = input_imgs.cuda(), target_imgs.cuda()

        with torch.amp.autocast('cuda'):
            outputs = net(input_imgs)
            outputs = torch.clamp(outputs, 0.0, 1.0)

            loss = criterion(outputs, target_imgs)
            loss = loss / accum_steps

        # MS-SSIM usually expects Float32
        with torch.no_grad():
            outputs_f32 = outputs.detach().float()

            model_output_batch_msssim = ms_ssim(outputs_f32,
                                                target_imgs,
                                                data_range=1.0, size_average=True)
            model_output_msssim_vals.append(model_output_batch_msssim.item())

            input_batch_msssim = ms_ssim(input_imgs,
                                         target_imgs,
                                         data_range=1.0, size_average=True)
            input_msssim_vals.append(input_batch_msssim.item())


        scaler.scale(loss).backward()

        batch_losses.append(loss.item() * accum_steps)

        if (i + 1) % accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    scheduler.step()

    epoch_loss = sum(batch_losses) / len(batch_losses)

    avg_model_msssim = sum(model_output_msssim_vals) / len(model_output_msssim_vals) if len(
        model_output_msssim_vals) > 0 else 0
    avg_input_msssim = sum(input_msssim_vals) / len(input_msssim_vals) if len(input_msssim_vals) > 0 else 0

    return avg_model_msssim, avg_input_msssim, epoch_loss


def perform_validation_epoch(net, val_loader, criterion):
    net.eval()
    with torch.no_grad():
        batch_losses = []
        model_output_msssim_vals = []
        input_msssim_vals = []

        for input_imgs, target_imgs in val_loader:
            input_imgs, target_imgs = input_imgs.cuda(), target_imgs.cuda()

            with torch.amp.autocast('cuda'):
                outputs = net(input_imgs)
                outputs = torch.clamp(outputs, 0.0, 1.0)

                loss_val = criterion(outputs, target_imgs)

            outputs_f32 = outputs.float()

            model_batch_msssim = ms_ssim(outputs_f32,
                                         target_imgs,
                                         data_range=1.0, size_average=True)

            input_batch_msssim = ms_ssim(input_imgs,
                                         target_imgs,
                                         data_range=1.0, size_average=True)

            model_output_msssim_vals.append(model_batch_msssim.item())
            input_msssim_vals.append(input_batch_msssim.item())
            batch_losses.append(loss_val.item())

        epoch_loss = sum(batch_losses) / len(batch_losses)

        avg_model_msssim = sum(model_output_msssim_vals) / len(model_output_msssim_vals) if len(
            model_output_msssim_vals) > 0 else 0
        avg_input_msssim = sum(input_msssim_vals) / len(input_msssim_vals) if len(input_msssim_vals) > 0 else 0

        return avg_model_msssim, avg_input_msssim, epoch_loss





def custom_loader(path):
    return Image.open(path, formats=["JPEG"])




if __name__ == '__main__':

    net = NAFNet()
    net.cuda(0)

    noise_transform = v2.Compose([
        v2.ToImage(),
        v2.Resize(size=(img_size, img_size)),
        v2.ToDtype(torch.float32, scale=True),
        v2.GaussianNoise(mean=0, sigma=0.08),
        BatchSaltAndPepper(salt_prob=0.05, pepper_prob=0.05),
        v2.GaussianBlur(kernel_size=5, sigma=(0.5, 1.5))
    ])

    base_transform = v2.Compose([
        v2.ToImage(),
        v2.Resize(size=(img_size, img_size)),
        v2.ToDtype(torch.float32, scale=True),
    ])

    train_set = CocoDenoisingDataset(relative_data_path("train"), input_img_transform=noise_transform,
                                     target_img_transform=base_transform)
    val_set = CocoDenoisingDataset(relative_data_path("val"), input_img_transform=noise_transform,
                                     target_img_transform=base_transform)
    test_set = CocoDenoisingDataset(relative_data_path("test"), input_img_transform=noise_transform,
                                     target_img_transform=base_transform)

    printshare(f"len_train: {len(train_set)}; len_val: {len(val_set)}; len_test: {len(test_set)}.")
    #perform_training(net, train_set, val_set,
    #                 epochs=600, w_decay=1e-3, batch_size=128, sub_batch_size=32,
    #                 lr=1e-3, lr_lambda=cosannealing_decay_warmup(
    #                   warmup_steps=0, T_0=10, T_mult=1.1, decay_factor=0.9, base_lr=1e-3, eta_min=1e-8),
    #                 pretrained=relative_drive_path('checkpoints/ep_5_ts_0.9_vs_0.9_model.pth'))

len_train: 25000; len_val: 5000; len_test: 40670.


In [None]:
def perform_testing(net, test_set, bs=128, weights_file=""):
    test_loader = DataLoader(test_set, batch_size=bs, shuffle=True, num_workers=num_workers)
    if isinstance(weights_file, str):
        printshare("Loading pretrained model, optimizer & scheduler state dicts...")
        checkpoint = torch.load(weights_file)

        if 'model' not in checkpoint:
            _, _ = net.load_state_dict(checkpoint, strict=False)
            printshare("got no optimizer & scheduler state dicts. model state dict set up successfully.")

        else:
            _, _ = net.load_state_dict(checkpoint['model'], strict=False)

            printshare("all the dicts set up successfully.")

    net.eval()
    with torch.no_grad():
        model_output_msssim_vals = []
        input_msssim_vals = []

        for input_imgs, target_imgs in test_loader:
            input_imgs, target_imgs = input_imgs.cuda(), target_imgs.cuda()

            with torch.amp.autocast('cuda'):
                outputs = net(input_imgs)
                outputs = torch.clamp(outputs, 0.0, 1.0)

            outputs_f32 = outputs.float()

            model_batch_msssim = ms_ssim(outputs_f32,
                                         target_imgs,
                                         data_range=1.0, size_average=True)

            input_batch_msssim = ms_ssim(input_imgs,
                                         target_imgs,
                                         data_range=1.0, size_average=True)

            model_output_msssim_vals.append(model_batch_msssim.item())
            input_msssim_vals.append(input_batch_msssim.item())

        avg_output_msssim = sum(model_output_msssim_vals) / len(model_output_msssim_vals) if len(
            model_output_msssim_vals) > 0 else 0
        avg_input_msssim = sum(input_msssim_vals) / len(input_msssim_vals) if len(input_msssim_vals) > 0 else 0

        printshare(f"testing done. input ms-ssim: {round(100*avg_input_msssim, 3)}%;" +
                   f" output ms-ssim: {round(100*avg_output_msssim, 3)}%")
perform_testing(net, test_set, weights_file=relative_drive_path('checkpoints/ep_11_ts_95.1_vs_95.1_model.pth'))


Loading pretrained model, optimizer & scheduler state dicts...
all the dicts set up successfully.
testing done. input ms-ssim: 78.256%; output ms-ssim: 95.072%
