In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision import transforms
import torch.optim as optim
import torch
import torch.nn as nn

In [17]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)
        self.encoder5 = conv_block(512, 1024)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = conv_block(128, 64)
        
        self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))
        enc5 = self.encoder5(self.pool(enc4))
        
        dec4 = self.upconv4(enc5)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        
        return self.conv_last(dec1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, out_channels=1).to(device)


In [18]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask


train_image_dir = 'train/images'
train_mask_dir = 'train/masks'
valid_image_dir = 'valid/images'
valid_mask_dir = 'valid/masks'

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

train_dataset = CustomDataset(train_image_dir, train_mask_dir, transform)
valid_dataset = CustomDataset(valid_image_dir, valid_mask_dir, transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False)


In [3]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, out_channels=1).to(device)


criterion = nn.BCEWithLogitsLoss()  
optimizer = optim.Adam(model.parameters(), lr=0.001) 


def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    return epoch_loss


def validate(model, valid_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(valid_loader.dataset)
    return epoch_loss


num_epochs = 30
best_loss = float('inf')

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    valid_loss = validate(model, valid_loader, criterion, device)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}')
    
 
    if valid_loss < best_loss:
        best_loss = valid_loss
        torch.save(model.state_dict(), 'best_model.pth')


Epoch 1/30, Train Loss: 0.2862, Valid Loss: 0.2365
Epoch 2/30, Train Loss: 0.1712, Valid Loss: 0.1290
Epoch 3/30, Train Loss: 0.1207, Valid Loss: 0.1252
Epoch 4/30, Train Loss: 0.1006, Valid Loss: 0.0999
Epoch 5/30, Train Loss: 0.0881, Valid Loss: 0.0925
Epoch 6/30, Train Loss: 0.0813, Valid Loss: 0.0841
Epoch 7/30, Train Loss: 0.0745, Valid Loss: 0.0785
Epoch 8/30, Train Loss: 0.0685, Valid Loss: 0.0980
Epoch 9/30, Train Loss: 0.0632, Valid Loss: 0.0729
Epoch 10/30, Train Loss: 0.0587, Valid Loss: 0.0707
Epoch 11/30, Train Loss: 0.0554, Valid Loss: 0.0700
Epoch 12/30, Train Loss: 0.0517, Valid Loss: 0.0634
Epoch 13/30, Train Loss: 0.0481, Valid Loss: 0.0663
Epoch 14/30, Train Loss: 0.0453, Valid Loss: 0.0691
Epoch 15/30, Train Loss: 0.0429, Valid Loss: 0.0728
Epoch 16/30, Train Loss: 0.0399, Valid Loss: 0.0645
Epoch 17/30, Train Loss: 0.0393, Valid Loss: 0.0682
Epoch 18/30, Train Loss: 0.0364, Valid Loss: 0.0792
Epoch 19/30, Train Loss: 0.0349, Valid Loss: 0.0661
Epoch 20/30, Train Lo