In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from PIL import Image
import os
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large as deeplab

In [3]:
# Define a simple custom dataset
class CustomSegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx])
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask

In [None]:
# Data preparation
image_dir = 'path_to_images'  # Replace with the path to your images
mask_dir = 'path_to_masks'  # Replace with the path to your masks

image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
mask_paths = [os.path.join(mask_dir, f) for f in os.listdir(mask_dir)]

train_images, val_images, train_masks, val_masks = train_test_split(image_paths, mask_paths, test_size=0.2, random_state=42)

# Transforms
transform = transforms.Compose([
    transforms.Resize((520, 520)),
    transforms.ToTensor(),
])

train_dataset = CustomSegmentationDataset(train_images, train_masks, transform=transform)
val_dataset = CustomSegmentationDataset(val_images, val_masks, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# Load the model
model = models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True)
model.classifier[4] = nn.Conv2d(256, 21, kernel_size=(1, 1), stride=(1, 1))  # Assuming 21 classes for segmentation
model = model.cuda()  # Move the model to GPU if available

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images = images.cuda()
        masks = masks.cuda()
        
        optimizer.zero_grad()
        
        outputs = model(images)['out']
        loss = criterion(outputs, masks.long())
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}')
    
    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.cuda()
            masks = masks.cuda()
            
            outputs = model(images)['out']
            loss = criterion(outputs, masks.long())
            
            val_loss += loss.item()
    
    print(f'Validation Loss: {val_loss/len(val_loader)}')

print('Training completed')

In [3]:
model = deeplab(weights_backbone="DEFAULT", num_classes=4)

In [4]:
a = model(torch.ones(8,3,520,520))

In [5]:
a['out'][0]

tensor([[[ 0.2251,  0.2251,  0.2251,  ...,  0.3271,  0.3271,  0.3271],
         [ 0.2251,  0.2251,  0.2251,  ...,  0.3271,  0.3271,  0.3271],
         [ 0.2251,  0.2251,  0.2251,  ...,  0.3271,  0.3271,  0.3271],
         ...,
         [-0.3126, -0.3126, -0.3126,  ...,  0.3252,  0.3252,  0.3252],
         [-0.3126, -0.3126, -0.3126,  ...,  0.3252,  0.3252,  0.3252],
         [-0.3126, -0.3126, -0.3126,  ...,  0.3252,  0.3252,  0.3252]],

        [[ 2.9444,  2.9444,  2.9444,  ...,  0.3030,  0.3030,  0.3030],
         [ 2.9444,  2.9444,  2.9444,  ...,  0.3030,  0.3030,  0.3030],
         [ 2.9444,  2.9444,  2.9444,  ...,  0.3030,  0.3030,  0.3030],
         ...,
         [-0.0179, -0.0179, -0.0179,  ...,  0.3096,  0.3096,  0.3096],
         [-0.0179, -0.0179, -0.0179,  ...,  0.3096,  0.3096,  0.3096],
         [-0.0179, -0.0179, -0.0179,  ...,  0.3096,  0.3096,  0.3096]],

        [[-0.9134, -0.9134, -0.9134,  ...,  0.1181,  0.1181,  0.1181],
         [-0.9134, -0.9134, -0.9134,  ...,  0

In [10]:
a['out'].shape

torch.Size([8, 4, 520, 520])