In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision







class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    
    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))

In [10]:
enc_block = Block(1, 64)
x         = torch.randn(1, 1, 572, 572)
enc_block(x).shape

torch.Size([1, 64, 568, 568])

In [11]:
class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

In [12]:
encoder = Encoder()
# input image
x    = torch.randn(1, 3, 572, 572)
ftrs = encoder(x)
for ftr in ftrs: print(ftr.shape)

torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])


In [13]:
class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs

In [14]:
decoder = Decoder()
x = torch.randn(1, 1024, 28, 28)
decoder(x, ftrs[::-1][1:]).shape

torch.Size([1, 64, 388, 388])

In [17]:
import torch
import torch.nn as nn
from load_data import create_data_loaders
import time as time 
import torchvision
import torch.nn.functional as F
import numpy as np 
from collections import OrderedDict

#https://amaarora.github.io/2020/09/13/unet.html
#https://www.pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/

train_path = "datasets-oxpet/train"
validation_path = "datasets-oxpet/val"
test_path = "datasets-oxpet/test"

train_loader, validation_loader, test_loader = create_data_loaders(train_path, validation_path, test_path, batch_size=4)


class Block(nn.Module):
    def __init__(self, ch_input, ch_output):
        super().__init__()
        self.block = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(ch_input, ch_output, 3)),
            ('relu1', nn.ReLU(inplace=True)),
            ('conv2',nn.Conv2d(ch_output, ch_output, 3)),
            ('relu2',nn.ReLU(inplace=True))
        ]))
    
    def forward(self,input):
        output = self.block(input)
        return output


class Encoder(nn.Module):
    def __init__(self, channels=(3,64,128,256,512,1024)):
    #  def __init__(self, chs=(3,64,17,34,68,136)):
        super().__init__()

        self.encoder_blocks = nn.ModuleList()

        block_list = []
        for ch_idx in range(len(channels)-1):
             block = Block(channels[ch_idx],channels[ch_idx+1])
             block_list.append(block)

        self.encoder_blocks.extend(block_list)

        
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

    def forward(self,input):

        filters = []

        for blk in self.encoder_blocks:
            input = blk(input)
            filters.append(input)
            input = self.maxpool1(input)

        return filters



class Decoder(nn.Module):
    def __init__(self, channels=(1024, 512, 256, 128, 64)):
        super().__init__()

        
        self.upsampling = nn.ModuleList()
        self.channels = channels

        up_list = []
        for idx in range(len(channels)-1):
            block = nn.ConvTranspose2d(channels[idx],channels[idx+1],kernel_size=2,stride=2)
            up_list.append(block)
        
        self.upsampling.extend(up_list)

        self.decoder_blocks = nn.ModuleList()

        dec_list = []
        for dec_idx in range(len(channels)-1):
            dec_block = Block(channels[dec_idx], channels[dec_idx+1])
            dec_list.append(dec_block)
        
        self.decoder_blocks.extend(dec_list)


        
    def forward(self, input, enc_channels):
        for idx in range(len(self.channels)-1):
            input = self.upsampling[idx](input)
            enc_chan = self.crop(enc_channels[idx], input)
            input = torch.cat([input, enc_chan], dim=1)
            input = self.decoder_blocks[idx](input)
        return input
    
    def crop(self, enc_chans, input):
        height = input.shape[2]
        width = input.shape[3]
        enc_chans   = torchvision.transforms.CenterCrop([height, width])(enc_chans)
        return enc_chans



class UNet(nn.Module):
    # def __init__(self, enc_chans=(3,64,128,256,512,1024), dec_chans=(1024, 512, 256, 128, 64), num_class=2, retain_dim=True, output_size=(256,256)): 
    # def __init__(self, enc_chans=(3,64,128), dec_chans=(128, 64), num_class=2, retain_dim=True, output_size=(256,256)):
    def __init__(self, enc_chans=(3,56,112), dec_chans=(112, 56), num_class=2, retain_dim=True, output_size=(256,256)): 
    #  

    
    # def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(68,68)):

        super().__init__()
        # self.unet = nn.Sequential(OrderedDict([
        #     ('Encoder', Encoder(enc_chans)),
        #     ('Decoder', Decoder(dec_chans)),
        #     ('Convolution1', nn.Conv2d(dec_chans[-1], num_class, 1))
        # ]))

        # self.unet.add_module('retain_dimension',retain_dim)
        # self.unet.add_module('output_size', output_size)

        self.encoder     = Encoder(enc_chans)
        self.decoder     = Decoder(dec_chans)
        self.head        = nn.Conv2d(dec_chans[-1], num_class, 1)
        self.retain_dim  = retain_dim
        self.out_sz = output_size

    def forward(self, x):
        enc_chans = self.encoder(x)
        out      = self.decoder(enc_chans[::-1][0], enc_chans[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out
    
    # def forward(self, input):
    #     input = self.unet(input)
    #     return unet 

        # def forward(self, x):
        #     enc_ftrs = self.encoder(x)
        #     out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        #     out      = self.head(out)
        #     if self.retain_dim:
        #         out = F.interpolate(out, self.out_sz)
        #     return out

# initialize our UNet model
unet = UNet()
# initialize loss function and optimizer
# loss = nn.BCEWithLogitsLoss()
criteria  = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(unet.parameters(), lr=0.001)


time_train1 = time.time()
for epoch in range(1):

    time_epoch = time.time()
    train_loss = []
    train_accuracy = []
    for i, batch_data in enumerate(train_loader, 1):
        inputs, labels = batch_data

        mask = torch.squeeze(labels['mask'])
        mask = mask.to(torch.long)
        binary = torch.squeeze(labels['classification'])
        binary = binary.to(torch.long)
        bbox = labels['bbox']

        print(binary.size(), "binary shape")

        optimizer.zero_grad()
        outputs= unet(inputs)
        classes = outputs[0]
        # classes = outputs.index_select(0)

        # print(outputs.size(), "outputs shape")
        # print(classes.size(), "class shape") 
        

        # todo: update the loss_criterion in the loss computation.
        print(mask.size(),'mask size')
        loss_seg = criteria(outputs, mask)
        # loss_class = cri_class(classes, classes)
        # loss = loss_seg + loss_class
        loss = loss_seg 
        print(loss)

        # todo: make the weight of losses a hyper-parameter

        # pred_ax=np.argmax(classes.detach().numpy(),axis=1)
        # train_accuracy.append(np.sum((classes.detach().numpy()==pred_ax).astype(int))/len(binary))
        train_loss.append(loss.item())

        loss.backward()
        optimizer.step()

        # todo: include validation loader in training loop (metrics)

    time_epoch_vl = time.time()
    time_train2 = time.time()

    print('----------------------------------------------------------------------------------')
    print(f"Epoch: {epoch + 1} Time taken : {round(time_epoch_vl - time_epoch, 3)} seconds")
    print("-----------------------Training Metrics-------------------------------------------")
    print("Loss: ", round(np.mean(train_loss), 3))
    print("train time: {}".format(time_train2-time_train1))
    

torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.6892, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.6767, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.6905, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.6849, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.7017, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.6893, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.6631, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.6435, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) binary shape
torch.Size([4, 256, 256]) mask size
tensor(0.6643, grad_fn=<NllLoss2DBackward>)
torch.Size([4]) bin

KeyboardInterrupt: 

In [None]:
# unet with classification

class UNet(nn.Module):
    def _init_(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=2, retain_dim=True, out_sz=(256,256)): 

        super()._init_()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim
        self.out_sz = out_sz

        self.linear_c_0=nn.Linear(512*8*8,64) ##
        self.linear_c_1=nn.Linear(64,2) ##

        self.linear_b_0=nn.Linear(512*8*8,64) ##
        self.linear_b_1=nn.Linear(64,4) ##
        self.flat=nn.Flatten() ##

    def forward(self, x):
        enc_ftrs = self.encoder(x)

        #WE NEED TO FLATTEN THE OUPUT OF THE ENCODER
        #print(len(enc_ftrs[1][0])) #5, 4, 64, 252, 252
        flat_list = [item for sublist in enc_ftrs for item in sublist]
        #print(len(flat_list)) 
        #flat_list = [item for sublist in flat_list for item in sublist]
        #print(len(flat_list)) 
        #flat_list = [item for sublist in flat_list for item in sublist]
        #print(len(flat_list)) 
        #flat_list = [item for sublist in flat_list for item in sublist]
        #print(len(flat_list)) 
        #flat_list = [item for sublist in flat_list for item]
        
        enc_ftrs = torch.FloatTensor(flat_list)
        print("tensor")
        #flat = flat=self.flat(enc_ftrs) ##
        c_0=F.relu(self.linear_c_0(enc_ftrs)) ##
        c_= (self.linear_c_1(c_0)) ##
        b_0=F.relu(self.linear_b_0(enc_ftrs)) ##
        b_=F.relu(self.linear_b_1(b_0)) ##
        print("done")

        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out, c_, b_

In [11]:
from torch import nn
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = self.contract_block(in_channels, 32, 7, 3)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)

        self.upconv3 = self.expand_block(128, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        upconv3 = self.upconv3(conv3)

        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand




In [None]:
train_loader, validation_loader, test_loader = create_data_loaders(train_path, validation_path, test_path, batch_size=4)


In [12]:

def acc_metric(predb, yb):
    return (predb.argmax(dim=1) == yb.cuda()).float().mean()


optimizer = torch.optim.Adam(unet.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
acc_fn = acc_metric()
def train(model, train_loader, validation_loader, loss_fn, optimizer, acc_fn, epochs=1):
    start = time.time()
    model.cuda()

    train_loss, valid_loss = [], []

    best_acc = 0.0

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set trainind mode = true
                dataloader = train_loader
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = validation_loader

            running_loss = 0.0
            running_acc = 0.0

            step = 0

            # iterate over data
            for x, y in dataloader:
                x = x.cuda()
                y = y.cuda()
                step += 1

                # forward pass
                if phase == 'train':
                    # zero the gradients
                    optimizer.zero_grad()
                    outputs = model(x)
                    loss = loss_fn(outputs, y)

                    # the backward pass frees the graph memory, so there is no 
                    # need for torch.no_grad in this training pass
                    loss.backward()
                    optimizer.step()
                    # scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss = loss_fn(outputs, y.long())

                # stats - whatever is the phase
                acc = acc_fn(outputs, y)

                running_acc  += acc*dataloader.batch_size
                running_loss += loss*dataloader.batch_size 

                if step % 10 == 0:
                    # clear_output(wait=True)
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc, torch.cuda.memory_allocated()/1024/1024))
                    # print(torch.cuda.memory_summary())

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_acc / len(dataloader.dataset)

            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))

            train_loss.append(epoch_loss) if phase=='train' else valid_loss.append(epoch_loss)

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    return train_loss, valid_loss    

def acc_metric(predb, yb):
    return (predb.argmax(dim=1) == yb.cuda()).float().mean()

In [None]:


train(UNET, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs=1)