In [1]:
# For our puffer surver we need to browse via a proxy!!
import os
# Set HTTP and HTTPS proxy
os.environ['http_proxy'] = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

In [2]:
if not os.path.exists('dataset1'):
    !wget -q https://www.dropbox.com/s/0pigmmmynbf9xwq/dataset1.zip
    !unzip -q dataset1.zip
    !rm dataset1.zip
    !pip install -q pytorch_model_summary
import torch
!pip install torch_snippets
from torch_snippets import *
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torchvision.models import vgg16_bn
import cv2 as cv
from tqdm import tqdm



In [3]:
# Create Params dictionary
class Params(object):
    def __init__(self, batch_size, test_batch_size, epochs, lr, seed, cuda, log_interval):
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.epochs = epochs
        self.lr = lr
        self.seed = seed
        self.cuda = 'cuda' if cuda and torch.cuda.is_available() else 'cpu'
        self.log_interval = log_interval

# Configure args
args = Params(8, 2, 5, 1e-3, 1, True, 10)

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
def get_transforms():
  return transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize(
                                 [0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]
                                 )
                             ])

In [5]:
from torch.utils.data import Dataset, DataLoader

class SegmentationData(Dataset):
    def __init__(self, split):
        self.items = stems(f'dataset1/images_prepped_{split}')
        self.split = split                                     # Store the split information

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        # Retrieve an image and its corresponding mask based on the index 'idx'
        image = read(f'dataset1/images_prepped_{self.split}/{self.items[idx]}.png', 1)
        image = cv.resize(image, (224,224))

        mask = read(f'dataset1/annotations_prepped_{self.split}/{self.items[idx]}.png', 0)
        mask = cv.resize(mask, (224,224))

        return image, mask

    def choose(self): return self[randint(len(self))]  # Randomly select and return one image and mask pair from the dataset

    def collate_fn(self, batch):
        # Custom collate function to combine a batch of images and masks
        # Unzip the batch into images and masks
        ims, masks = list(zip(*batch))
        
        # Transform the images: Normalize and convert to tensor, then stack them into a single tensor

        ims = torch.cat([get_transforms()(im.copy()/255.)[None] for im in ims]).float().to(args.cuda)

        # Convert masks to tensors, stack them into a single tensor and cast to long type

        ce_masks = torch.cat([torch.Tensor(mask[None]) for mask in masks]).long().to(args.cuda)

        return ims, ce_masks

In [6]:
def get_dataloaders():
    trn_ds = SegmentationData('train')
    val_ds = SegmentationData('test')

    trn_dl = DataLoader(trn_ds, batch_size=args.batch_size, shuffle=True, collate_fn=trn_ds.collate_fn)
    val_dl = DataLoader(val_ds, batch_size=args.test_batch_size, shuffle=True, collate_fn=val_ds.collate_fn)

    return trn_dl, val_dl

In [7]:
trn_dl, val_dl = get_dataloaders()

# U Net Architecture

- U-Net is a convolutional neural network architecture primarily used for image segmentation tasks. It consists of a contracting path (encoder) that captures context and a symmetric expanding path (decoder) that enables precise localization.

In [8]:
def conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

In [9]:
def up_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
        nn.ReLU(inplace=True)
    )

In [10]:
import torch.nn as nn
class UNet(nn.Module):
    def __init__(self, pretrained=True, out_channels=12):
        super().__init__()

        self.encoder = vgg16_bn(pretrained=pretrained).features
        self.block1 = nn.Sequential(*self.encoder[:6])
        self.block2 = nn.Sequential(*self.encoder[6:13])
        self.block3 = nn.Sequential(*self.encoder[13:20])
        self.block4 = nn.Sequential(*self.encoder[20:27])
        self.block5 = nn.Sequential(*self.encoder[27:34])

        self.bottleneck = nn.Sequential(*self.encoder[34:])
        self.conv_bottleneck = conv(512, 1024)

        self.up_conv6 = up_conv(1024, 512)
        self.conv6 = conv(512 + 512, 512)
        self.up_conv7 = up_conv(512, 256)
        self.conv7 = conv(256 + 512, 256)
        self.up_conv8 = up_conv(256, 128)
        self.conv8 = conv(128 + 256, 128)
        self.up_conv9 = up_conv(128, 64)
        self.conv9 = conv(64 + 128, 64)
        self.up_conv10 = up_conv(64, 32)
        self.conv10 = conv(32 + 64, 32)
        self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        # Contractive Path
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)

        bottleneck = self.bottleneck(block5)
        x = self.conv_bottleneck(bottleneck)
        # Expansive Path
        x = self.up_conv6(x)
        x = torch.cat([x, block5], dim=1)
        x = self.conv6(x)

        x = self.up_conv7(x)
        x = torch.cat([x, block4], dim=1)
        x = self.conv7(x)

        x = self.up_conv8(x)
        x = torch.cat([x, block3], dim=1)
        x = self.conv8(x)

        x = self.up_conv9(x)
        x = torch.cat([x, block2], dim=1)
        x = self.conv9(x)

        x = self.up_conv10(x)
        x = torch.cat([x, block1], dim=1)
        x = self.conv10(x)

        x = self.conv11(x)

        return x

## Steps

### Initializes the U-Net model.
   
Takes two parameters:

    - pretrained: A boolean indicating whether to use a pretrained VGG16 model.
    - out_channels: The number of output channels for the final segmentation map (e.g., 12 classes for segmentation).
    
### Encoder (Contractive Path)

    - The encoder part uses the features from a VGG16 model with batch normalization (vgg16_bn).
    - The encoder is divided into five blocks, each consisting of several convolutional layers that progressively reduce the spatial dimensions while increasing the number of feature channels.
    
## Bottleneck

    - The bottleneck section takes the deepest layers of the encoder.
    - A convolution layer `conv_bottleneck` is applied to increase the number of feature channels from 512 to 1024, allowing the network to learn more complex features.
    

## Decoder (Expansive Path)

The decoder consists of up-convolution (or transposed convolution) layers followed by concatenation with corresponding encoder features to retain spatial information:

    - Each up_conv layer increases the spatial dimensions (upsampling).
    - The output of each up-convolution is concatenated with the corresponding feature map from the encoder (skip connections).
    - This helps the model learn both high-level features from deeper layers and low-level features from shallower layers.
    - Finally, conv11 reduces the number of channels to out_channels (e.g., for multi-class segmentation).

In [11]:
vgg16_bn

<function torchvision.models.vgg.vgg16_bn(*, weights: Optional[torchvision.models.vgg.VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> torchvision.models.vgg.VGG>

In [12]:
ce = nn.CrossEntropyLoss()   # Applies softmax to output logits --> converts into class probabilities --> calculates neg. log likelihood loss between pred and true class label!!

def UnetLoss(preds, targets):
    ce_loss = ce(preds, targets)
    acc = (torch.max(preds, 1)[1] == targets).float().mean()
    #  (torch.max(preds, 1)[1] returns the indices of the maximum values along the class dimension (i.e., the predicted class for each pixel). 
    #  The 1 indicates that we're looking along the columns (the class dimension).
    #  if preds class == targets return 1 --> change to float --> take mean to keep score between 0 and 1
    return ce_loss, acc

In [13]:
class TrainEngine():
    def train_batch(model, data, optimizer, criterion):
        model.train()

        ims, ce_masks = data
        _masks = model(ims)
        optimizer.zero_grad()

        loss, acc = criterion(_masks, ce_masks)
        loss.backward()
        optimizer.step()

        return loss.item(), acc.item()

    @torch.no_grad()
    def validate_batch(model, data, criterion):
        model.eval()

        ims, masks = data
        _masks = model(ims)

        loss, acc = criterion(_masks, masks)

        return loss.item(), acc.item()

In [14]:
from torch import optim
def make_model():
    model = UNet().to(args.cuda)
    criterion = UnetLoss
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    return model, criterion, optimizer

In [15]:
model, criterion, optimizer = make_model()
# Total num. of parametes
num_params = sum(p.numel() for p in model.parameters())
# Total num. of "trainable" parameters
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total num. of parametes: {num_params}')
print(f'Total num. of Trainable parametes: {num_trainable_params}')



Total num. of parametes: 29311308
Total num. of Trainable parametes: 29311308


In [18]:
def run_model():
    for epoch in range(args.epochs):
        print("####################")
        print(f"       Epoch: {epoch}   ")
        print("####################")

        for batch_idx, data in tqdm(enumerate(trn_dl), total=len(trn_dl), leave=False):
            train_loss, train_acc = TrainEngine.train_batch(model, data, optimizer, criterion)
            if batch_idx % args.log_interval == 0:
                # Print training information inline instead of calling a function
                step = epoch * len(trn_dl) + batch_idx
                print(f'Epoch [{epoch+1}/{args.epochs}], Step [{batch_idx}/{len(trn_dl)}], '
                      f'Train Loss: {train_loss:.6f}, Accuracy: {train_acc:.6f}')

        avg_val_acc = avg_val_loss = 0.0
        for batch_idx, data in tqdm(enumerate(val_dl), total=len(val_dl)):
            val_loss, val_acc = TrainEngine.validate_batch(model, data, criterion)

            avg_val_loss += val_loss
            avg_val_acc += val_acc

        step = (epoch + 1) * len(trn_dl)
        avg_val_loss /= len(val_dl)
        avg_val_acc /= len(val_dl)
        print(f'Val: Average loss: {avg_val_loss:.4f}, Accuracy: {avg_val_acc:.4f}')
        print()

    # Save the model and optimizer states after training is complete
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, 'unet.pt')


In [19]:
# Train the model
run_model()

####################
       Epoch: 0   
####################


  2%|▏         | 1/46 [00:09<06:47,  9.07s/it]

Epoch [1/5], Step [0/46], Train Loss: 2.339851, Accuracy: 0.232003


 24%|██▍       | 11/46 [01:32<04:51,  8.34s/it]

Epoch [1/5], Step [10/46], Train Loss: 1.637340, Accuracy: 0.680151


 46%|████▌     | 21/46 [02:53<03:22,  8.12s/it]

Epoch [1/5], Step [20/46], Train Loss: 1.454079, Accuracy: 0.719433


 67%|██████▋   | 31/46 [04:15<02:01,  8.13s/it]

Epoch [1/5], Step [30/46], Train Loss: 1.238068, Accuracy: 0.769078


 89%|████████▉ | 41/46 [05:36<00:41,  8.20s/it]

Epoch [1/5], Step [40/46], Train Loss: 1.158401, Accuracy: 0.747935


100%|██████████| 51/51 [00:53<00:00,  1.04s/it]


Val: Average loss: 1.1736, Accuracy: 0.7735

####################
       Epoch: 1   
####################


  2%|▏         | 1/46 [00:08<06:00,  8.01s/it]

Epoch [2/5], Step [0/46], Train Loss: 0.993044, Accuracy: 0.824734


 24%|██▍       | 11/46 [01:29<04:50,  8.30s/it]

Epoch [2/5], Step [10/46], Train Loss: 0.983688, Accuracy: 0.782227


 46%|████▌     | 21/46 [02:49<03:23,  8.13s/it]

Epoch [2/5], Step [20/46], Train Loss: 0.811047, Accuracy: 0.844283


 67%|██████▋   | 31/46 [04:11<02:04,  8.28s/it]

Epoch [2/5], Step [30/46], Train Loss: 0.749451, Accuracy: 0.848154


 89%|████████▉ | 41/46 [05:31<00:40,  8.14s/it]

Epoch [2/5], Step [40/46], Train Loss: 0.802150, Accuracy: 0.810504


100%|██████████| 51/51 [00:53<00:00,  1.06s/it]


Val: Average loss: 0.7339, Accuracy: 0.8509

####################
       Epoch: 2   
####################


  2%|▏         | 1/46 [00:08<06:06,  8.14s/it]

Epoch [3/5], Step [0/46], Train Loss: 0.664351, Accuracy: 0.866664


 24%|██▍       | 11/46 [01:27<04:40,  8.03s/it]

Epoch [3/5], Step [10/46], Train Loss: 0.643855, Accuracy: 0.852806


 46%|████▌     | 21/46 [02:49<03:31,  8.44s/it]

Epoch [3/5], Step [20/46], Train Loss: 0.643532, Accuracy: 0.859283


 67%|██████▋   | 31/46 [04:17<02:01,  8.12s/it]

Epoch [3/5], Step [30/46], Train Loss: 0.535793, Accuracy: 0.882314


 89%|████████▉ | 41/46 [05:41<00:40,  8.03s/it]

Epoch [3/5], Step [40/46], Train Loss: 0.489861, Accuracy: 0.896796


100%|██████████| 51/51 [00:47<00:00,  1.08it/s]


Val: Average loss: 0.7112, Accuracy: 0.8244

####################
       Epoch: 3   
####################


  2%|▏         | 1/46 [00:08<06:06,  8.15s/it]

Epoch [4/5], Step [0/46], Train Loss: 0.587825, Accuracy: 0.854114


 24%|██▍       | 11/46 [01:34<04:49,  8.28s/it]

Epoch [4/5], Step [10/46], Train Loss: 0.627285, Accuracy: 0.842118


 46%|████▌     | 21/46 [03:08<04:11, 10.08s/it]

Epoch [4/5], Step [20/46], Train Loss: 0.637420, Accuracy: 0.846715


 67%|██████▋   | 31/46 [04:35<02:05,  8.34s/it]

Epoch [4/5], Step [30/46], Train Loss: 0.532278, Accuracy: 0.867467


 89%|████████▉ | 41/46 [05:54<00:39,  8.00s/it]

Epoch [4/5], Step [40/46], Train Loss: 0.618013, Accuracy: 0.836618


100%|██████████| 51/51 [00:52<00:00,  1.03s/it]


Val: Average loss: 0.6775, Accuracy: 0.8267

####################
       Epoch: 4   
####################


  2%|▏         | 1/46 [00:08<06:10,  8.24s/it]

Epoch [5/5], Step [0/46], Train Loss: 0.463302, Accuracy: 0.889277


 24%|██▍       | 11/46 [01:30<04:47,  8.21s/it]

Epoch [5/5], Step [10/46], Train Loss: 0.618043, Accuracy: 0.836665


 46%|████▌     | 21/46 [02:56<03:36,  8.67s/it]

Epoch [5/5], Step [20/46], Train Loss: 0.488171, Accuracy: 0.868954


 67%|██████▋   | 31/46 [04:15<01:59,  7.98s/it]

Epoch [5/5], Step [30/46], Train Loss: 0.461211, Accuracy: 0.882486


 89%|████████▉ | 41/46 [05:49<00:49,  9.89s/it]

Epoch [5/5], Step [40/46], Train Loss: 0.449941, Accuracy: 0.883234


100%|██████████| 51/51 [01:10<00:00,  1.38s/it]


Val: Average loss: 0.5263, Accuracy: 0.8690



In [None]:
# Visualize and save results
for bx, data in tqdm(enumerate(val_dl), total=len(val_dl)):
    im, mask = data
    _mask = model(im)
    _, _mask = torch.max(_mask, dim=1)

    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(im[0].permute(1, 2, 0).detach().cpu()[:, :, 0])
    plt.savefig(f'original_image_{bx}.jpg')
    plt.close()

    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(mask.permute(1, 2, 0).detach().cpu()[:, :, 0])
    plt.savefig(f'original_mask_{bx}.jpg')
    plt.close()

    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(_mask.permute(1, 2, 0).detach().cpu()[:, :, 0])
    plt.savefig(f'predicted_mask_{bx}.jpg')
    plt.close()