In [1]:
from contextlib import nullcontext
import os
import torch
import numpy as np
from datetime import datetime
import time
from prettytable import PrettyTable
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from img_utils import *
import wandb
import torch
import torch.nn as nn


class Timer:
    def __init__(self, start_msg = "", end_msg = ""):
    
        self.start_msg = start_msg
        self.end_msg = end_msg
        
    def __enter__(self):
        if self.start_msg != "":
            print(self.start_msg)
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        elapsed_time = time.time() - self.start_time
        print(self.end_msg, f"{elapsed_time:.3f} sec")


def count_parameters(model, print_table = False):
    
    total_params = 0
    
    if(print_table):
        table = PrettyTable(["Modules", "Parameters", "dtype", "Required Grad", "Device"]) 
    
    for name, parameter in model.named_parameters():
        params = parameter.numel()
        
        if(print_table):
            table.add_row([name, parameter.shape, parameter.dtype, parameter.requires_grad, parameter.device ])
            
        total_params += params
        
    if(print_table):
        print(table)
        
    if total_params/1e9 > 1:
        print(f"Total Trainable Params: {total_params/1e9} B")
    else:
        print(f"Total Trainable Params: {total_params/1e6} M")
        
    return total_params



class ImageMaskDataset(Dataset):
    def __init__(self, image_dir, mask_dir, src_dir,transform=None, maskTransform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.src_dir = src_dir
        self.transform = transform
        self.maskTransform = maskTransform
        self.image_names = sorted(os.listdir(image_dir))
        self.src_names = sorted(os.listdir(src_dir))
        self.mask_names = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        mask_name = self.mask_names[idx]
        src_name = self.src_names[idx]
        
        image_path = os.path.join(self.image_dir, image_name)
        mask_path = os.path.join(self.mask_dir, mask_name)
        src_path = os.path.join(self.src_dir, src_name)
        
        image = Image.open(image_path).convert('RGB')
        src = Image.open(src_path).convert('RGB')
        mask = np.load(mask_path)
        
        
        if self.transform:
            image = self.transform(image)
            src = self.transform(src)
            mask = self.maskTransform(mask)
        
        return image, src,mask

# Define the transforms
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the image to [-1,1]
])
maskTransform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor
])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ResnetBlock(nn.Module):
    def __init__(self, out_c,padding=1):
        super().__init__()

        self.conv1 = nn.Conv2d(out_c,out_c,3,1,padding)
        self.conv2 = nn.Conv2d(out_c,out_c,3,1,padding)
        self.conv3 = nn.Conv2d(out_c,out_c,3,1,padding)
        self.norm = nn.BatchNorm2d(out_c)
        self.silu = nn.SiLU()

    def forward(self, x):

        x = self.conv3(self.conv2(self.conv1(x)))
        x = self.norm(x) + x
        x = self.silu(x)

        return x



class Encoder(nn.Module):
    def __init__(self, in_c, out_c,  padding=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_c,out_c,3,1,padding)
        self.conv2 = nn.Conv2d(out_c,out_c,3,1,padding)
        self.conv3 = nn.Conv2d(out_c,out_c,3,1,padding)

        self.resBlocks = nn.ModuleList([ResnetBlock(out_c, padding) for i in range(3)])

        self.silu = nn.SiLU()
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):

        x = self.conv3(self.conv2(self.conv1(x)))
        x = self.silu(x)

        for block in self.resBlocks:
            x = block(x)

        return self.pool(x), x


class Decoder(nn.Module):
    def __init__(self, in_c, out_c, img_sizes, padding=1):
        super().__init__()

        self.upsample = nn.Upsample(size=img_sizes, mode="bilinear")


        self.conv1 = nn.Conv2d(in_c,in_c,3,1, padding)
        self.conv2 = nn.Conv2d(in_c,in_c,3,1, padding)
        self.conv3 = nn.Conv2d(in_c,in_c//2,3,1, padding)


        self.conv4 = nn.Conv2d(in_c,out_c,3,1, padding)
        self.conv5 = nn.Conv2d(out_c,out_c,3,1, padding)
        self.conv6 = nn.Conv2d(out_c,out_c,3,1, padding)


        self.resBlocks = nn.ModuleList([ResnetBlock(out_c, padding) for i in range(3)])

        self.silu = nn.SiLU()

    def forward(self, x, skip):

        x = self.upsample(x)
        x = self.conv3(self.conv2(self.conv1(x)))
        x = torch.cat((x, skip), dim = 1)
        x = self.conv6(self.conv5(self.conv4(x)))
        x = self.silu(x)

        for block in self.resBlocks:

            x = block(x)

        return x


class Unet(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoders = nn.ModuleList([Encoder(*i) for i in [(3,32), (32,64), (64, 128), (128, 256), (256, 512)]])

        self.imageDecoders = nn.ModuleList([Decoder(*i) for i in [(512,256, (21,37)), (256,128,(42,74)), (128, 64,(84,149)), (64, 32,(168,298))]])

        self.maskDecoders = nn.ModuleList([Decoder(*i) for i in [(512,256, (21,37)), (256,128,(42,74)), (128, 64,(84,149)), (64, 32,(168,298))]])

        self.imgDecoder = nn.Conv2d(32,3,1,1)
        self.tanh = nn.Tanh()

        self.maskDecoder = nn.Conv2d(32,1,1,1)
        self.sigmoid = nn.Sigmoid()


    def forward(self, x):

        skips = []
        for enc in self.encoders:
            x, skip = enc(x)
            skips.append(skip)

        x = skips[-1]

        for idx, dec in enumerate(self.imageDecoders):
            x = dec(x, skips[3-idx])

        image = self.tanh(self.imgDecoder(x))

        x = skips[-1]

        for idx, dec in enumerate(self.maskDecoders):
            x = dec(x, skips[3-idx])
            
        mask = self.sigmoid(self.maskDecoder(x))
        return image, mask

In [3]:
lr = 1e-4
# bce = nn.BCEWithLogitsLoss()
bce = nn.BCELoss()
bs = 48


device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
unet = Unet().to(device)

optimizer = torch.optim.AdamW(unet.parameters(), lr)

image_dir = '/root/data/linum/train/corrupted_imgs/'
mask_dir = '/root/data/linum/train/binary_masks/'
src_dir = '/root/data/linum/train/src_imgs/'
nparams = count_parameters(unet, print_table=False)
dataset = ImageMaskDataset(image_dir=image_dir, mask_dir=mask_dir, src_dir = src_dir,transform=transform, maskTransform = maskTransform)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=24)


Total Trainable Params: 72.211428 M


In [4]:
log = False
log = True
log_iter = 20
img_losses = 0
mask_losses = 0
iter = 0
epochs = 25
num_imgs = 3

if log:
    config={"epochs": epochs, "batch_size": bs,"lr": lr}
    wandb.init(project='linum', entity='basujindal123', config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbasujindal[0m ([33mbasujindal123[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
# convert_img_tensor_to_pil_img((masks.bool()*images)[0])
# convert_img_tensor_to_pil_img((images)[0])
# # convert_mask_tensor_to_pil_img(masks[0])
# unet = torch.compile(unet)

In [6]:
for epoch in range(epochs):

    for images, src, masks in tqdm(dataloader):

        unet.train()
        
       
        images = images.to(device)
        masks = masks.to(device)
        src = src.to(device)
        m = torch.sum(masks)
        unet.zero_grad()

        iter+=1
        img_pred, mask_pred = unet(images)

        img_loss = 2*torch.sum(torch.abs(masks.bool()*(img_pred-src)))/m
        mask_loss = bce(mask_pred, masks)
        loss = mask_loss + img_loss

        loss.backward()
        optimizer.step()

        img_losses+=img_loss.item()
        mask_losses+=mask_loss.item()

        if (iter+1)%log_iter == 0:

            if log:
                wandb.log({
                    'loss': (mask_losses+img_losses)/log_iter,
                    'mask_loss': mask_losses/log_iter,
                    'img_loss': img_losses/log_iter,
                    'Corrupted Images': [wandb.Image(i) for i in images[:num_imgs].detach()],
                    'Reconstructed Images' : [wandb.Image(i) for i in img_pred[:num_imgs].detach()],
                    'Reconstructed Images' : [wandb.Image(i) for i in img_pred[:num_imgs].detach()],
                    'Src Images' : [wandb.Image(i) for i in src[:num_imgs].detach()],
                    'Masks' : [wandb.Image(i) for i in masks[:num_imgs].detach()],
                    'Predicted Masks' : [wandb.Image(i) for i in (mask_pred[:num_imgs]).detach()],
                    })

            print(epoch, iter, (mask_losses+img_losses)/log_iter, mask_losses/log_iter,img_losses/log_iter)
            mask_losses = 0
            img_losses = 0

  1%|          | 4/648 [00:14<32:04,  2.99s/it]  

0 4 613.0632619261742 0.4833302855491638 612.579931640625


  1%|▏         | 9/648 [00:22<19:36,  1.84s/it]

0 9 645.6492048740387 0.5533552646636963 645.095849609375


  2%|▏         | 14/648 [00:31<18:04,  1.71s/it]

0 14 594.9862143814564 0.5001792252063751 594.48603515625


  3%|▎         | 19/648 [00:39<18:12,  1.74s/it]

0 19 509.4382229685783 0.4368740916252136 509.0013488769531


  4%|▎         | 24/648 [00:47<17:29,  1.68s/it]

0 24 502.4263049423695 0.3585681259632111 502.06773681640624


  4%|▍         | 29/648 [00:56<17:16,  1.67s/it]

0 29 506.44429592490195 0.2614345967769623 506.182861328125


  5%|▌         | 34/648 [01:04<17:05,  1.67s/it]

0 34 478.1281729608774 0.1664297968149185 477.9617431640625


  6%|▌         | 39/648 [01:12<17:00,  1.68s/it]

0 39 491.39682470932604 0.08378759995102883 491.313037109375


  7%|▋         | 44/648 [01:20<16:52,  1.68s/it]

0 44 504.4131069464609 0.02347681950777769 504.38963012695314


  8%|▊         | 49/648 [01:29<16:45,  1.68s/it]

0 49 499.7596044704318 0.004843728244304657 499.7547607421875


  8%|▊         | 54/648 [01:37<16:43,  1.69s/it]

0 54 491.6873776607681 0.00305759240873158 491.6843200683594


  9%|▉         | 59/648 [01:45<16:27,  1.68s/it]

0 59 457.94624592959883 0.0029720038175582884 457.9432739257812


 10%|▉         | 64/648 [01:53<16:19,  1.68s/it]

0 64 484.3196352182422 0.0024294076953083276 484.31720581054685


 11%|█         | 69/648 [02:02<16:13,  1.68s/it]

0 69 468.49187536505053 0.002605345519259572 468.48927001953126


 11%|█▏        | 74/648 [02:10<16:05,  1.68s/it]

0 74 457.7945683060214 0.0027470169588923454 457.7918212890625


 12%|█▏        | 79/648 [02:18<15:57,  1.68s/it]

0 79 466.00155042619446 0.004608287522569299 465.9969421386719


 13%|█▎        | 84/648 [02:26<15:43,  1.67s/it]

0 84 463.33385441736317 0.00430729822255671 463.3295471191406


 14%|█▎        | 89/648 [02:35<15:37,  1.68s/it]

0 89 464.0975162688177 0.004370516864582896 464.0931457519531


 15%|█▍        | 94/648 [02:43<15:28,  1.68s/it]

0 94 446.4652971137315 0.0038225043565034865 446.461474609375


 15%|█▌        | 99/648 [02:51<15:23,  1.68s/it]

0 99 436.689015124226 0.003321764850988984 436.685693359375


 16%|█▌        | 104/648 [02:59<15:13,  1.68s/it]

0 104 437.1658110151999 0.0030180464498698713 437.16279296875


 17%|█▋        | 109/648 [03:08<15:05,  1.68s/it]

0 109 458.99180139070376 0.003343138750642538 458.98845825195315


 18%|█▊        | 114/648 [03:16<14:55,  1.68s/it]

0 114 451.1339710639324 0.003581659635528922 451.13038940429686


 18%|█▊        | 119/648 [03:24<14:52,  1.69s/it]

0 119 447.7798511540983 0.0038257634732872248 447.776025390625


 19%|█▊        | 120/648 [03:32<15:34,  1.77s/it]


KeyboardInterrupt: 