In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torchvision import transforms
from tqdm import tqdm

# Example of a simple U-Net architecture (you can replace it with a more advanced model if needed)
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Custom Dataset to handle images and masks
class SegmentationDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        image = self.X[idx]
        mask = self.y[idx]
        
        # If you need to apply transformations (like resizing, normalization, etc.)
        if self.transform:
            image = self.transform(image)
        
        # Ensure the mask is the right shape (e.g., single channel)
        mask = torch.tensor(mask, dtype=torch.long)
        
        return image, mask

# Function to train the model
def train_segmentation_model(X, y, epochs=10, batch_size=8, learning_rate=1e-4, device='cuda'):
    # Set the device
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    
    # Transformations
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((256, 256)),  # Resize all images to 256x256 for simplicity
        transforms.ToTensor(),
    ])
    
    # Create dataset and dataloaders
    dataset = SegmentationDataset(X, y, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Initialize model, loss function, and optimizer
    model = UNet(in_channels=3, out_channels=2).to(device)  # 3 channels for RGB and 2 classes
    criterion = nn.CrossEntropyLoss()  # Cross entropy loss for segmentation
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    for epoch in range(epochs):
        model.train()  # Set model to training mode
        running_loss = 0.0
        
        # Iterate over the data
        for images, masks in tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}'):
            images, masks = images.to(device), masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Compute the loss
            loss = criterion(outputs, masks)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Print loss for the epoch
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}")
    
    print("Training finished!")
    return model

# Example usage:
# Assuming X and y are your images and segmentation masks
# X = [image1, image2, ...]  # Each image is a numpy array of shape (C, H, W)
# y = [mask1, mask2, ...]    # Each mask is a numpy array of shape (H, W)

# Train the model
trained_model = train_segmentation_model(X, y, epochs=5, batch_size=4, learning_rate=1e-4, device='cuda')