In [34]:
from contextlib import nullcontext
import os
import torch
import numpy as np
from datetime import datetime
import time
from prettytable import PrettyTable

class Timer:
    def __init__(self, start_msg = "", end_msg = ""):
    
        self.start_msg = start_msg
        self.end_msg = end_msg
        
    def __enter__(self):
        if self.start_msg != "":
            print(self.start_msg)
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        elapsed_time = time.time() - self.start_time
        print(self.end_msg, f"{elapsed_time:.3f} sec")


def count_parameters(model, print_table = False):
    
    total_params = 0
    
    if(print_table):
        table = PrettyTable(["Modules", "Parameters", "dtype", "Required Grad", "Device"]) 
    
    for name, parameter in model.named_parameters():
        params = parameter.numel()
        
        if(print_table):
            table.add_row([name, parameter.shape, parameter.dtype, parameter.requires_grad, parameter.device ])
            
        total_params += params
        
    if(print_table):
        print(table)
        
    if total_params/1e9 > 1:
        print(f"Total Trainable Params: {total_params/1e9} B")
    else:
        print(f"Total Trainable Params: {total_params/1e6} M")
        
    return total_params

In [40]:
import torch
import torch.nn as nn


class ResnetBlock(nn.Module):
    def __init__(self, out_c,padding=1):
        super().__init__()

        self.conv1 = nn.Conv2d(out_c,out_c,3,1,padding)
        self.conv2 = nn.Conv2d(out_c,out_c,3,1,padding)
        self.conv3 = nn.Conv2d(out_c,out_c,3,1,padding)
        self.norm = nn.BatchNorm2d(out_c)
        self.silu = nn.SiLU()

    def forward(self, x):

        x = self.conv3(self.conv2(self.conv1(x)))
        x = self.norm(x) + x
        x = self.silu(x)

        return x



class Encoder(nn.Module):
    def __init__(self, in_c, out_c,  padding=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_c,out_c,3,1,padding)
        self.conv2 = nn.Conv2d(out_c,out_c,3,1,padding)
        self.conv3 = nn.Conv2d(out_c,out_c,3,1,padding)

        self.resBlocks = nn.ModuleList([ResnetBlock(out_c, padding) for i in range(3)])

        self.silu = nn.SiLU()
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):

        x = self.conv3(self.conv2(self.conv1(x)))
        x = self.silu(x)

        for block in self.resBlocks:
            x = block(x)

        return self.pool(x), x


class Decoder(nn.Module):
    def __init__(self, in_c, out_c, img_sizes, padding=1):
        super().__init__()

        self.upsample = nn.Upsample(size=img_sizes, mode="bilinear")


        self.conv1 = nn.Conv2d(in_c,in_c,3,1, padding)
        self.conv2 = nn.Conv2d(in_c,in_c,3,1, padding)
        self.conv3 = nn.Conv2d(in_c,in_c//2,3,1, padding)


        self.conv4 = nn.Conv2d(in_c,out_c,3,1, padding)
        self.conv5 = nn.Conv2d(out_c,out_c,3,1, padding)
        self.conv6 = nn.Conv2d(out_c,out_c,3,1, padding)


        self.resBlocks = nn.ModuleList([ResnetBlock(out_c, padding) for i in range(3)])

        self.silu = nn.SiLU()

    def forward(self, x, skip):

        x = self.upsample(x)
        x = self.conv3(self.conv2(self.conv1(x)))
        x = torch.cat((x, skip), dim = 1)
        x = self.conv6(self.conv5(self.conv4(x)))
        x = self.silu(x)

        for block in self.resBlocks:

            x = block(x)

        return x


class Unet(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoders = nn.ModuleList([Encoder(*i) for i in [(3,32), (32,64), (64, 128), (128, 256), (256, 512)]])

        self.decoders = nn.ModuleList([Decoder(*i) for i in [(512,256, (21,37)), (256,128,(42,74)), (128, 64,(84,149)), (64, 32,(168,298))]])

        self.imgDecoder = nn.Conv2d(32,3,1,1)
        self.tanh = nn.Tanh()

        self.maskDecoder = nn.Conv2d(32,1,1,1)
        # self.sigmoid = nn.Sigmoid()


    def forward(self, x):

        skips = []
        for enc in self.encoders:
            x, skip = enc(x)
            print(x.shape, skip.shape)
            skips.append(skip)

        x = skips[-1]

        for idx, dec in enumerate(self.decoders):
            x = dec(x, skips[3-idx])

        image = self.tanh(self.imgDecoder(x))
        mask = self.maskDecoder(x)
        return image, mask


device = "cuda" if torch.cuda.is_available() else "cpu"
unet = Unet().to(device)
img = torch.ones(2,3,168,298).to(device)
out = unet(img)
print(out[0].shape, out[1].shape)


torch.Size([2, 32, 84, 149]) torch.Size([2, 32, 168, 298])
torch.Size([2, 64, 42, 74]) torch.Size([2, 64, 84, 149])
torch.Size([2, 128, 21, 37]) torch.Size([2, 128, 42, 74])
torch.Size([2, 256, 10, 18]) torch.Size([2, 256, 21, 37])
torch.Size([2, 512, 5, 9]) torch.Size([2, 512, 10, 18])
torch.Size([2, 3, 168, 298]) torch.Size([2, 1, 168, 298])


In [41]:
nparams = count_parameters(unet, print_table=False)

Total Trainable Params: 54.183108 M


In [46]:
lr = 16-4
optimizer = torch.optim.AdamW(unet.parameters(), lr)

bce = nn.BCEWithLogitsLoss

In [None]:
config={"epochs": epochs, "batch_size": bs,"lr": lr}

wandb.init(project='linum', entity='basujindal123', config=config)

In [None]:
log_iter = 200
log = True

img_losses = 0
mask_losses = 0
iter = 0


for i in (range(epochs)):
    for data in tqdm(train_loader):
        unet.train()

        optimizer.zero_grad()
        imgs = data.to(device)

        iter+=1
        img_pred, mask_pred = unet(imgs)

        img_loss = torch.mean(torch.abs(mask*(img_pred-imgs)))
        mask_loss = bce(mask, mask_pred)
        loss = 2*mask_loss + img_loss
        loss.backward()
        optimizer.step()

        img_losses+=img_loss.item()
        mask_losses+=mask_loss.item()



        if (iter+1)%log_iter == 0:

            # unet.eval()
            # with torch.no_grad():
            #     fixed_fake_imgs = G(fixed_noise[:16]).detach()

            if wandb_log:
                wandb.log({
                    'loss': (mask_losses+img_losses)/log_iter,
                    'mask_loss': mask_loss/log_iter,
                    'img_loss': img_loss/log_iter,
                    # 'Corrupted Images': [wandb.Image(i) for i in fixed_fake_imgs],
                    # 'Reconstructed Images' : [wandb.Image(i) for i in real_imgs[:16].detach()]
                    })

            print((mask_losses+img_losses)/log_iter, mask_loss/log_iter,img_loss/log_iter)
            mask_losses = 0
            img_losses = 0