In [1]:
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

In [2]:
# Dataset retrieval
class NotesDataset(Dataset):
    def __init__(self, image_folder, mask_folder, transform=None):
        self.image_folder = image_folder
        self.mask_folder = mask_folder
        self.filenames = sorted(os.listdir(image_folder))
        self.transform = transform

    def __getitem__(self, index):
        filename = self.filenames[index]
        image_path = os.path.join(self.image_folder, filename)
        mask_path = os.path.join(self.mask_folder, filename)

        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        mask = mask.point(lambda pixel: 255 if pixel > 128 else 0)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        mask = (mask > 0).float()

        return image, mask
    
    def __len__(self):
        return len(self.filenames)

In [3]:
# U-net model
class DoubleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder1 = DoubleConvBlock(3, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = DoubleConvBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConvBlock(128, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = DoubleConvBlock(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = DoubleConvBlock(128, 64)
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        x1 = self.encoder1(x)
        x2 = self.encoder2(self.pool1(x1))
        x3 = self.bottleneck(self.pool2(x2))

        x = self.upconv2(x3)
        x = self.decoder2(torch.cat([x, x2], dim=1))
        x = self.upconv1(x)
        x = self.decoder1(torch.cat([x, x1], dim=1))

        return self.final_conv(x)

In [4]:
# Model training
def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = T.Compose([
        T.Resize((512, 512)),
        T.ToTensor()
    ])

    dataset = NotesDataset(
        image_folder="data/images",
        mask_folder="data/masks",
        transform=transform
    )

    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    model = UNet().to(device)
    loss_function = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    os.makedirs("checkpoints", exist_ok=True)

    num_epochs = 7
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = loss_function(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # print(f"Epoch {epoch+1} - Total Loss: {total_loss:.4f}")

        # Saves tuned model
        torch.save(model.state_dict(), f"checkpoints/epoch{epoch+1}.pth")

# def training():
#     for images, masks in dataloader:
#         out = nn.Conv2d(3,1,3)(images)
#         loss = torch.abs(out - masks).mean()
#         loss.backward()
#         optimizer.step()

In [5]:
# Run program
if __name__ == "__main__":
    train_model()