In [3]:
%pip install torchinfo torchvision tqdm numpy matplotlib pandas scikit-learn torch

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.0 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [7]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128))  # Resize for U-Net compatibility
])

mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Convert digit images to binary masks (digit pixels = 1, background = 0)
def create_mask(image):
    return (image > 0).float()

class MNISTSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        image, _ = self.dataset[idx]
        mask = create_mask(image)
        return image, mask

    def __len__(self):
        return len(self.dataset)

seg_dataset = MNISTSegmentationDataset(mnist)
loader = DataLoader(seg_dataset, batch_size=32, shuffle=True)

In [8]:
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU())
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU())
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, padding=1), nn.ReLU())

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = nn.Sequential(nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU())
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = nn.Sequential(nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU())

        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        b = self.bottleneck(self.pool2(e2))
        d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return torch.sigmoid(self.final(d1))

In [14]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = UNet().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = criterion(preds, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")



Using device: cpu




RuntimeError: DataLoader worker (pid(s) 13040, 17972, 34856, 35324) exited unexpectedly

In [None]:
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    sample_img, sample_mask = seg_dataset[0]
    pred_mask = model(sample_img.unsqueeze(0).to(device)).squeeze().cpu()

plt.subplot(1, 3, 1); plt.imshow(sample_img.squeeze(), cmap='gray'); plt.title("Input")
plt.subplot(1, 3, 2); plt.imshow(sample_mask.squeeze(), cmap='gray'); plt.title("True Mask")
plt.subplot(1, 3, 3); plt.imshow(pred_mask.squeeze(), cmap='gray'); plt.title("Predicted Mask")
plt.show()