In [1]:
# %%

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision.models as models
# !pip install monai
import monai
from PIL import Image
import torchvision.transforms.functional as F

torch.manual_seed(42)

device = torch.device("cpu")

if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
    device = torch.device("cuda")

print(f"Using device: {device}")

Using device: mps


In [2]:
# %%
model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)

In [3]:
# %%
# Model pretrained on imagenet

In [4]:
# %%
# Larger transformation pipeline for imagenet
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(20),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.Resize((128, 128)),
    ]
)

In [5]:
# %%
try:
    datasets.SBDataset(
        root="./data/sbd",
        image_set="train",
        download=True,
        mode="segmentation",
    )
    datasets.SBDataset(
        root="./data/sbd",
        image_set="val",
        download=True,
        mode="segmentation",
    )
except Exception as e:
    print(e)

train_data = datasets.SBDataset(
    root="./data/sbd",
    image_set="train",
    transforms=lambda x, y: [transform(x), transform(y)],
    # download=True,
    mode="segmentation",
)

test_data = datasets.SBDataset(
    root="./data/sbd",
    image_set="val",
    transforms=lambda x, y: [transform(x), transform(y)],
    # download=True,
    mode="segmentation",
)


train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)


Using downloaded and verified file: ./data/sbd/benchmark.tgz
Extracting ./data/sbd/benchmark.tgz to ./data/sbd
Destination path './data/sbd/cls' already exists


In [6]:
# %%
X_test = F.to_pil_image(train_data[0][0]).save("X_test.png")
y_test = F.to_pil_image(train_data[0][1]).save("y_test.png")

In [7]:
# %%
learning_rate = 1e-4
epochs = 100
# Dice is a log loss function so negative values are expected
loss_function = monai.losses.DiceLoss(softmax=True)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = monai.losses.DiceLoss(sigmoid=True)

In [8]:
# %%
assert model(train_data[0][0].unsqueeze(0)).shape == (
    1,
    1,
    128,
    128,
), "Model output shape is correct"

In [9]:
# %%
model = model.to(device)

In [10]:
# %%
for epoch in tqdm(range(epochs)):
    size = len(train_dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        X = X.to(device)
        y = y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"Epoch: {epoch+1}, Loss: {loss:.6f}, Progress: [{current}/{size}]")

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1, Loss: -3.176144, Progress: [0/8498]
Epoch: 1, Loss: -3.245847, Progress: [640/8498]
Epoch: 1, Loss: -3.271580, Progress: [1280/8498]
Epoch: 1, Loss: -3.309801, Progress: [1920/8498]
Epoch: 1, Loss: -3.318891, Progress: [2560/8498]
Epoch: 1, Loss: -3.431382, Progress: [3200/8498]
Epoch: 1, Loss: -3.422415, Progress: [3840/8498]
Epoch: 1, Loss: -3.478830, Progress: [4480/8498]


In [None]:
# %%