In [7]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from torchvision import transforms

from dataset import SatelliteImageDataset
from utils import save_model

from adunet.adunet import ADUNet


In [8]:
# Training Config
BATCH_SIZE = 4
EPOCHS = 100
LEARNING_RATE = 1e-4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


Using device: cpu


In [9]:
from dataset import SatelliteImageDataset
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

dataset = SatelliteImageDataset(
    image_dir='data/images',
    mask_dir='data/masks',
    transform=transform
)


In [10]:
# Initialize ADUNet model
model = ADUNet(in_channels=3, out_channels=6).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [11]:
from torch.utils.data import DataLoader, random_split
from dataset import SatelliteImageDataset

# Dataset banayein
dataset = SatelliteImageDataset(
    image_dir='data/images',
    mask_dir='data/masks'
)

# Train/test split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# DataLoader banayein
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)


In [12]:
# Training Loop
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    for images, masks in train_loader:
        images = images.to(DEVICE)
        masks = masks.long().to(DEVICE)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.4f}")


KeyboardInterrupt: 

In [1]:
# Save trained model
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': EPOCHS
}

torch.save(checkpoint, 'adunet/adunet_100ep_checkpoint.pth')
print("Model saved successfully!")


NameError: name 'model' is not defined