<a href="https://colab.research.google.com/github/byrkbrk/unet-implementation/blob/main/unet_cell_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# clone repository
!git clone https://github.com/byrkbrk/unet-implementation.git
!ls ./unet-implementation
!cat ./unet-implementation/README.md

# add directory to path
import sys
sys.path.append("./unet-implementation/")

In [None]:
# import modules
from tqdm.auto import tqdm
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn

from model import UNet
from utils import crop, show_tensor_images

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# read input images as np array
dir = "./unet-implementation/"
volumes = io.imread(dir + "train-volume.tif")
labels = io.imread(dir + "train-labels.tif")

# print shapes
print("volumes shape:", volumes.shape)
print("labels shape:", labels.shape)

# plot images
plt.imshow(volumes[0], cmap="gray")
plt.show()
plt.imshow(labels[0], cmap="gray")
plt.show()

# check unique pixels
print("volumes unique: ")
print(np.unique(volumes[0]))
print("labels unique: ")
print(np.unique(labels[0]))


In [None]:
# convert to torch and normalize 
volumes = torch.Tensor(volumes)[:, None, :, :]/255
labels = torch.Tensor(labels)[:, None, :, :]/255

# crop labels
target_dim = 373 # i.e. unet output shape: (-1, 1, 373, 373)
labels = crop(labels, (labels.shape[0], labels.shape[1], target_dim, target_dim))

# construct dataset
dataset = torch.utils.data.TensorDataset(volumes, labels)


In [None]:
# set dataloadaer
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True
)

# instantiate model
unet = UNet(1, 1)
unet.to(device)

# define criterion
criterion = nn.BCEWithLogitsLoss()

# define optimizer
optimizer = optim.Adam(unet.parameters(), lr=2e-4)


In [None]:
# train
n_epochs = 200
display_step = 40
cur_step = 0

for epoch in range(n_epochs):
    for mini_images, mini_labels in tqdm(dataloader):
        mini_images = mini_images.to(device)
        mini_labels = mini_labels.to(device)
        preds = unet(mini_images)
        loss = criterion(preds, mini_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if cur_step % display_step == 0:
            print(f"Epoch: {epoch}, Steps: {cur_step}, Loss: {loss:.3f}")
            
            # plot input images
            show_tensor_images(
                crop(mini_images, (mini_images.shape[0], 1, target_dim, target_dim)),
                4,
                size=(1, target_dim, target_dim)
            )

            # plot labels
            show_tensor_images(mini_labels, 4, size=(1, target_dim, target_dim))
            
            # plot predictions
            show_tensor_images(torch.sigmoid(preds), 4, size=(1, target_dim, target_dim))

        cur_step += 1


In [None]:
# check performance on test dataset

# read test volumes
test_volumes = io.imread(dir + "test-volume.tif")
test_volumes = torch.Tensor(test_volumes)[:, None, :, :]/255

# get small test dataset portion
mini_test_volumes = test_volumes[4:8]

# get test predictions
unet.eval().to(device="cpu")
test_preds = unet(mini_test_volumes)

# plot images
show_tensor_images(
                crop(mini_test_volumes, (mini_test_volumes.shape[0], 1, target_dim, target_dim)),
                4,
                size=(1, target_dim, target_dim)
)
show_tensor_images(torch.sigmoid(test_preds), 4, size=(1, target_dim, target_dim))

