In [76]:
%pip install torch torchvision opencv-python tqdm scikit-learn

Note: you may need to restart the kernel to use updated packages.


In [77]:
import torch
from torch import nn

In [78]:
class UNetContractingBlock(nn.Module):
    def __init__(self, input_channels: int, output_channels: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # two back to back convolutional layers
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=output_channels,
                  kernel_size=(3, 3))
        self.conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels,
                               kernel_size=(3, 3))
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)

        return x


In [79]:
class UNetExpandingBlock(nn.Module):

    def __init__(self, input_channels: int, output_channels: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.upsample = nn.ConvTranspose2d(in_channels=input_channels, out_channels=output_channels,
                                        kernel_size=(2, 2), stride=2)
        self.conv1 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels,
                               kernel_size=(3, 3))
        self.conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels,
                               kernel_size=(3, 3))
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor):
        x = self.upsample(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)

        return x

In [80]:
class UNet(nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__()

        depth = 4

        # contracting path
        self.contracting_blocks = []
        start_out_channels = 64
        for i in range(depth):
            if i != 0:
                next_out_channels = start_out_channels * 2
                self.contracting_blocks.append(
                    UNetContractingBlock(input_channels=start_out_channels, output_channels=next_out_channels)
                )
                start_out_channels = next_out_channels
            else:
                self.contracting_blocks.append(
                    UNetContractingBlock(input_channels=3, output_channels=start_out_channels)
                )

        # intermediate conv block, no maxpool
        self.intermediate_conv = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(3, 3)),
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3, 3)),
        )

        # expanding path
        self.expanding_blocks = []
        start_in_channels = 1024
        for _ in range(depth):
            next_in_channels = start_in_channels // 2
            self.expanding_blocks.append(
                UNetExpandingBlock(input_channels=start_in_channels, output_channels=next_in_channels)
            )
            start_in_channels = next_in_channels

        # last convolution
        self.conv_last = nn.Conv2d(in_channels=64, out_channels=2,
                                   kernel_size=(1, 1,))

    def forward(self, x) -> torch.Tensor:
        for block in self.contracting_blocks:
            x = block(x)

        x = self.intermediate_conv(x)

        for block in self.expanding_blocks:
            x = block(x)

        x = self.conv_last(x)

        return x

In [81]:
model = UNet()

In [82]:
t = torch.rand(1, 3, 572, 572)  # batch size, channels, height, width
out = model(t)

type(out), out.shape, out

[UNetContractingBlock(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padd

(torch.Tensor,
 torch.Size([1, 2, 388, 388]),
 tensor([[[[ 0.0646,  0.0650,  0.0647,  ...,  0.0650,  0.0647,  0.0650],
           [ 0.0644,  0.0656,  0.0644,  ...,  0.0656,  0.0644,  0.0656],
           [ 0.0646,  0.0650,  0.0646,  ...,  0.0650,  0.0646,  0.0649],
           ...,
           [ 0.0644,  0.0656,  0.0644,  ...,  0.0656,  0.0644,  0.0656],
           [ 0.0646,  0.0650,  0.0646,  ...,  0.0650,  0.0646,  0.0649],
           [ 0.0644,  0.0655,  0.0644,  ...,  0.0655,  0.0644,  0.0656]],
 
          [[-0.1277, -0.1275, -0.1278,  ..., -0.1275, -0.1278, -0.1275],
           [-0.1276, -0.1277, -0.1277,  ..., -0.1277, -0.1277, -0.1277],
           [-0.1278, -0.1276, -0.1277,  ..., -0.1276, -0.1277, -0.1275],
           ...,
           [-0.1276, -0.1277, -0.1277,  ..., -0.1277, -0.1277, -0.1277],
           [-0.1278, -0.1276, -0.1277,  ..., -0.1276, -0.1277, -0.1275],
           [-0.1276, -0.1277, -0.1277,  ..., -0.1277, -0.1277, -0.1277]]]],
        grad_fn=<ConvolutionBackward0>))

## Training

In [83]:
import os

In [84]:
DATASET_BASE = "../../dataset/cellpose/"
train_path = os.path.join(DATASET_BASE, "train")
test_path = os.path.join(DATASET_BASE, "test")

In [85]:
from torchvision import transforms
from torchvision.io import read_image
import cv2

In [86]:
class CellPoseDataset(torch.utils.data.Dataset):

    def __init__(self, image_path, transforms) -> None:
        self.image_path = image_path
        self.transforms = transforms

        self.images = []
        self.masks = []
        self.count = 0
        for img in sorted(os.listdir(self.image_path)):
            self.count += 1
            if img.endswith("img.png"):
                self.images.append(img)
            else:
                self.masks.append(img)

    def __len__(self) -> int:
        return self.count

    def __getitem__(self, index):

        # return tuple of image and mask: (image, mask)
        image = cv2.imread(os.path.join(train_path, self.images[index]))
        mask = cv2.imread(os.path.join(train_path, self.masks[index]))

        transformed_image = self.transforms(image)
        transformed_mask = self.transforms(mask)

        return transformed_image, transformed_mask



In [87]:
IMG_HEIGHT = 388
IMG_WIDTH = 388

In [88]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor()
])

In [89]:
from torch.utils.data import random_split
from torch.utils.data import DataLoader

In [90]:
batch_size = 8

In [91]:
dataset = CellPoseDataset(train_path, transform)

test_ratio = 0.15
dataset_size = len(dataset)
test_size = int(dataset_size*0.15)

train_size = dataset_size - test_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size,)
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size,)

In [92]:
dataset[0][0].shape, dataset[0]

(torch.Size([3, 388, 388]),
 (tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
  
          [[0.0000, 0.0000, 0.0000,  ..., 0.0863, 0.1608, 0.1725],
           [0.0118, 0.0118, 0.0118,  ..., 0.0510, 0.1373, 0.1725],
           [0.0118, 0.0118, 0.0118,  ..., 0.0353, 0.1098, 0.1686],
           ...,
           [0.0157, 0.0157, 0.0157,  ..., 0.0196, 0.0235, 0.0275],
           [0.0157, 0.0157, 0.0157,  ..., 0.0196, 0.0235, 0.0275],
           [0.0157, 0.0157, 0.0157,  ..., 0.0196, 0.0235, 0.0235]],
  
          [[0.0078, 0.0039, 0.0039,  ..., 0.0078, 0.0157, 0.0118],
           [0.0235, 0.0275, 0.0235,  ..., 0.0196, 0.0235, 0.0

In [93]:
len(dataset)

1080

In [94]:
model

UNet(
  (intermediate_conv): Sequential(
    (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
  )
  (conv_last): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)

In [95]:
model(dataset[0][0])

[UNetContractingBlock(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padd

tensor([[[ 0.0646,  0.0650,  0.0647,  ...,  0.0650,  0.0647,  0.0650],
         [ 0.0644,  0.0656,  0.0644,  ...,  0.0656,  0.0644,  0.0656],
         [ 0.0646,  0.0650,  0.0646,  ...,  0.0650,  0.0646,  0.0649],
         ...,
         [ 0.0644,  0.0656,  0.0644,  ...,  0.0656,  0.0644,  0.0656],
         [ 0.0646,  0.0650,  0.0646,  ...,  0.0650,  0.0646,  0.0649],
         [ 0.0644,  0.0655,  0.0644,  ...,  0.0655,  0.0644,  0.0656]],

        [[-0.1277, -0.1275, -0.1278,  ..., -0.1275, -0.1278, -0.1275],
         [-0.1276, -0.1277, -0.1277,  ..., -0.1277, -0.1277, -0.1277],
         [-0.1278, -0.1276, -0.1277,  ..., -0.1276, -0.1277, -0.1275],
         ...,
         [-0.1276, -0.1277, -0.1277,  ..., -0.1277, -0.1277, -0.1277],
         [-0.1278, -0.1276, -0.1277,  ..., -0.1276, -0.1277, -0.1275],
         [-0.1276, -0.1277, -0.1277,  ..., -0.1277, -0.1277, -0.1277]]],
       grad_fn=<SqueezeBackward1>)

In [96]:
from torch import optim
from tqdm import tqdm
import time

In [97]:
lr = 0.03
epochs = 10

In [98]:
loss_func = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [99]:
train_steps = len(train_dataset) // batch_size
test_steps = len(test_dataset) // batch_size

In [100]:
start_time = time.time()

for e in tqdm(range(epochs)):
    model.train()

    total_train_loss = 0
    total_test_loss = 0

    # for i, (x, y) in enumerate(train_loader):
    for i, (x, y) in enumerate(train_dataset):
        pred = model(x)
        loss = loss_func(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss

    with torch.no_grad():
        model.eval()

        # for (x, y) in test_loader:
        for (x, y) in test_dataset:
            pred = model(x)
            total_test_loss += loss_func(pred, y)

    avg_train_loss = total_train_loss / train_steps
    avg_test_loss = total_test_loss / test_steps

    print(f"[INFO] EPOCH: {e + 1}/{epochs}")
    print("Train loss: {:.6f}, Test loss: {:.4f}".format(avg_train_loss, avg_test_loss))

end_time = time.time()
print(f"[INFO] total time taken to train the model: {end_time-start_time:.2f}s")


  0%|          | 0/10 [00:00<?, ?it/s]


[UNetContractingBlock(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
), UNetContractingBlock(
  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padd

ValueError: Target size (torch.Size([3, 388, 388])) must be the same as input size (torch.Size([2, 196, 196]))