In [None]:

!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# After installation, restart your kernel, then run:
import torch
print(torch.cuda.is_available())   
print(torch.cuda.get_device_name(0))  

In [None]:
!pip install imgaug

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
from dataset import dataset
from UNetModela import UNet
import matplotlib.pyplot as plt

# Set paths
image_dir = r"C:\Users\edmun\Downloads\forintern\dataset\images"
mask_dir = r"C:\Users\edmun\Downloads\forintern\dataset\masks"
model_save_dir = r"C:\Users\edmun\Downloads\forintern\dataset\model"
os.makedirs(model_save_dir, exist_ok=True)
model_name = "unet_model.pth"

# Load data
test_dataset = dataset()
test_dataset.train_images = test_dataset.load_image(image_dir)
test_dataset.train_masks = test_dataset.load_image(mask_dir)


In [None]:
test_dataset.augment_images()

image_np = np.stack(test_dataset.aug_images, axis=0).astype(np.float32)
mask_np = np.stack(test_dataset.aug_masks, axis=0).astype(np.float32)

image_np = np.squeeze(image_np, axis=-1)  # (N, 512, 512)
mask_np = np.squeeze(mask_np, axis=-1)

image_tensor = torch.tensor(image_np).unsqueeze(1)  # (N, 1, 512, 512)
mask_tensor = torch.tensor(mask_np).unsqueeze(1)    # (N, 1, 512, 512)


In [None]:
dataset = TensorDataset(image_tensor, mask_tensor)
val_split = int(len(dataset) * 0.1)
train_set, val_set = random_split(dataset, [len(dataset)-val_split, val_split])

train_loader = DataLoader(train_set, batch_size=2, shuffle=True)
val_loader = DataLoader(val_set, batch_size=2)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_ch=1, out_ch=1).to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-6, nesterov=True)
criterion = nn.BCELoss()

def binary_accuracy(preds, targets, threshold=0.5):
    preds = (preds > threshold).float()
    correct = (preds == targets).float()
    return correct.sum() / correct.numel()


In [None]:
num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_acc = 0.0

    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        acc = binary_accuracy(outputs, masks)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_acc += acc.item()

    print(f"Epoch [{epoch+1}/{num_epochs}]  Loss: {running_loss/len(train_loader):.4f}  Acc: {running_acc/len(train_loader):.4f}")

# Save model
torch.save(model.state_dict(), os.path.join(model_save_dir, model_name))
print(f"Model saved to: {os.path.join(model_save_dir, model_name)}")
