In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import random
from sklearn.metrics import jaccard_score
import matplotlib.pyplot as plt
from PIL import Image
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Set device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:


# Define paths
train_images_folder_path = "/kaggle/input/cityscapes/train/img"
train_mask_folder_path = "/kaggle/input/cityscapes/train/label"
test_images_folder_path = "/kaggle/input/cityscapes/val/img"
test_mask_folder_path = "/kaggle/input/cityscapes/val/label"

# Get image and mask names
train_images_names_original = sorted([img for img in os.listdir(train_images_folder_path) if img.endswith('.png')])
train_mask_names_original = sorted([img for img in os.listdir(train_mask_folder_path) if img.endswith('.png')])
test_images_names = sorted([img for img in os.listdir(test_images_folder_path) if img.endswith('.png')])
test_mask_names = sorted([img for img in os.listdir(test_mask_folder_path) if img.endswith('.png')])

# Define colors for Cityscapes dataset
names = ['unlabeled', 'dynamic', 'ground', 'road', 'sidewalk', 'parking', 'rail track', 'building', 'wall',
         'fence', 'guard rail', 'bridge', 'tunnel', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck',
         'bus', 'caravan', 'trailer', 'train', 'motorcycle', 'bicycle', 'license plate']

colors = np.array([(0, 0, 0), (111, 74, 0), (81, 0, 81), (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70),
        (102, 102, 156), (190, 153, 153), (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35),
        (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), ( 0, 0, 142), ( 0, 0, 70), (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230),
        (119, 11, 32), (0, 0, 142)], dtype = np.int32)

def one_hot_mask(y, colors_tensor):
    """Convert mask to one-hot encoding"""
    batch_size, height, width, _ = y.shape
    num_classes = len(colors)
    one_hot = torch.zeros((batch_size, height, width, num_classes), device=y.device)
    
    for i, color in enumerate(colors_tensor):
        class_map = torch.all(y == color.unsqueeze(0).unsqueeze(0).unsqueeze(0), dim=-1)
        one_hot[..., i] = class_map.float()
    
    return one_hot

class CityscapesDataset(Dataset):
    def __init__(self, image_folder, mask_folder, image_names, mask_names, img_height=96, img_width=256, 
                 transform=None, one_hot=True):
        self.image_folder = image_folder
        self.mask_folder = mask_folder
        self.image_names = image_names
        self.mask_names = mask_names
        self.img_height = img_height
        self.img_width = img_width
        self.transform = transform
        self.one_hot = one_hot
        
        # Convert colors to tensor
        self.colors_tensor = torch.tensor(colors, dtype=torch.float32)
        
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_folder, self.image_names[idx])
        image = Image.open(img_path).convert('RGB')
        image = image.resize((self.img_width, self.img_height))
        image = np.array(image).astype(np.float32) / 255.0
        image = torch.tensor(image).permute(2, 0, 1)  # C, H, W
        
        # Load mask
        mask_path = os.path.join(self.mask_folder, self.mask_names[idx])
        mask = Image.open(mask_path).convert('RGB')
        mask = mask.resize((self.img_width, self.img_height), Image.NEAREST)
        mask = np.array(mask).astype(np.int32)
        mask = torch.tensor(mask)  # H, W, C
        
        if self.one_hot:
            # One-hot encoding
            mask_one_hot = torch.zeros((self.img_height, self.img_width, len(colors)))
            for i, color in enumerate(colors):
                class_map = torch.all(mask == torch.tensor(color), dim=-1)
                mask_one_hot[:, :, i] = class_map.float()
            mask = mask_one_hot.permute(2, 0, 1)  # C, H, W
        else:
            mask = mask.permute(2, 0, 1)  # C, H, W
        
        if self.transform:
            image = self.transform(image)
            
        return image, mask

# GDN Layer
class NonNegConstraint:
    def __call__(self, module):
        if hasattr(module, 'weight'):
            module.weight.data = torch.clamp(module.weight.data, min=1e-15)

class GDN(nn.Module):
    def __init__(self, in_channels, filter_size=3):
        super(GDN, self).__init__()
        self.filter_size = filter_size
        self.padding = (filter_size - 1) // 2
        
        # Learnable parameters
        self.beta = nn.Parameter(torch.ones(in_channels))
        self.alpha = nn.Parameter(torch.ones(in_channels), requires_grad=False)
        self.epsilon = nn.Parameter(torch.ones(in_channels), requires_grad=False)
        
        # Gamma weights
        self.gamma = nn.Parameter(torch.zeros(filter_size, filter_size, in_channels, in_channels))
        
        # Apply non-negative constraint
        self.apply_constraint = NonNegConstraint()
        
    def forward(self, x):
        # Apply constraints
        self.apply_constraint(self)
        
        # Input shape: (B, C, H, W)
        abs_x = torch.abs(x)
        abs_x_alpha = torch.pow(abs_x, self.alpha.view(1, -1, 1, 1))
        
        # Convolution for normalization term
        norm_conv = F.conv2d(
            abs_x_alpha,
            self.gamma.permute(3, 2, 0, 1),  # (out_c, in_c, H, W)
            padding=self.padding,
            groups=x.shape[1]  # Depthwise convolution
        )
        
        norm_term = self.beta.view(1, -1, 1, 1) + norm_conv
        norm_term = torch.pow(norm_term, self.epsilon.view(1, -1, 1, 1))
        
        return x / norm_term

# UNet Model
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.2):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(dropout_rate)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        return x

class UNetWithGDN(nn.Module):
    def __init__(self, in_channels=3, num_classes=30):
        super(UNetWithGDN, self).__init__()
        
        # Encoder
        self.gdn = GDN(in_channels)
        self.encoder1 = ConvBlock(in_channels, 16)
        self.pool1 = nn.MaxPool2d(2)
        
        self.encoder2 = ConvBlock(16, 32)
        self.pool2 = nn.MaxPool2d(2)
        
        self.encoder3 = ConvBlock(32, 64)
        self.pool3 = nn.MaxPool2d(2)
        
        self.encoder4 = ConvBlock(64, 128)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bridge
        self.bridge = ConvBlock(128, 256)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder4 = ConvBlock(256, 128)  # 128 from upconv + 128 from encoder4
        
        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder3 = ConvBlock(128, 64)  # 64 from upconv + 64 from encoder3
        
        self.upconv2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder2 = ConvBlock(64, 32)  # 32 from upconv + 32 from encoder2
        
        self.upconv1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.decoder1 = ConvBlock(32, 16)  # 16 from upconv + 16 from encoder1
        
        # Output
        self.output = nn.Conv2d(16, num_classes, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        g1 = self.gdn(x)
        x1 = self.encoder1(g1)
        p1 = self.pool1(x1)
        
        x2 = self.encoder2(p1)
        p2 = self.pool2(x2)
        
        x3 = self.encoder3(p2)
        p3 = self.pool3(x3)
        
        x4 = self.encoder4(p3)
        p4 = self.pool4(x4)
        
        # Bridge
        b1 = self.bridge(p4)
        
        # Decoder
        u4 = self.upconv4(b1)
        u4 = torch.cat([u4, x4], dim=1)
        d4 = self.decoder4(u4)
        
        u3 = self.upconv3(d4)
        u3 = torch.cat([u3, x3], dim=1)
        d3 = self.decoder3(u3)
        
        u2 = self.upconv2(d3)
        u2 = torch.cat([u2, x2], dim=1)
        d2 = self.decoder2(u2)
        
        u1 = self.upconv1(d2)
        u1 = torch.cat([u1, x1], dim=1)
        d1 = self.decoder1(u1)
        
        # Output
        output = self.output(d1)
        
        return output

# Metrics and visualization functions
def iou_metrics(y_true, y_pred, num_classes=30):
    """Calculate mean IoU"""
    batch_size = y_true.shape[0]
    ious = []
    
    for i in range(batch_size):
        # Convert to numpy for sklearn
        true_flat = y_true[i].cpu().numpy().reshape(-1, num_classes)
        pred_flat = y_pred[i].cpu().numpy().reshape(-1, num_classes)
        
        iou = jaccard_score(true_flat, pred_flat, average='samples', zero_division=0)
        ious.append(iou)
    
    return np.mean(ious)

def color_to_one_hot_mask(mask, colors_tensor, img_height=96, img_width=256):
    """Convert class indices to color mask"""
    color_mask = np.zeros((img_height, img_width, 3), dtype=np.float32)
    
    for c in range(len(colors)):
        color_true = (mask == c)
        for i in range(3):
            color_mask[:, :, i] += color_true * colors[c][i]
    
    return color_mask.astype(np.uint8)

# Training function
def train_model(seed=0):
    print(f'STARTS TRAINING NUMBER {seed}')
    
    # Set seeds
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    # Split data
    all_indices = list(range(len(train_images_names_original)))
    val_indices = random.sample(all_indices, 300)
    train_indices = list(set(all_indices) - set(val_indices))
    
    train_image_names = [train_images_names_original[i] for i in train_indices]
    train_mask_names = [train_mask_names_original[i] for i in train_indices]
    val_image_names = [train_images_names_original[i] for i in val_indices]
    val_mask_names = [train_mask_names_original[i] for i in val_indices]
    
    print(f'Train images: {len(train_image_names)}')
    print(f'Val images: {len(val_image_names)}')
    print(f'Test images: {len(test_images_names)}')
    
    # Create datasets
    img_height, img_width = 96, 256
    batch_size = 32
    
    train_dataset = CityscapesDataset(
        train_images_folder_path, train_mask_folder_path,
        train_image_names, train_mask_names,
        img_height, img_width
    )
    
    val_dataset = CityscapesDataset(
        train_images_folder_path, train_mask_folder_path,
        val_image_names, val_mask_names,
        img_height, img_width
    )
    
    test_dataset = CityscapesDataset(
        test_images_folder_path, test_mask_folder_path,
        test_images_names, test_mask_names,
        img_height, img_width
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # Initialize model
    model = UNetWithGDN(in_channels=3, num_classes=30).to(device)
    print(model)
    
    # Loss and optimizer
    criterion = nn.L1Loss()  # MAE loss
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=15, min_lr=1e-12)
    
    # Training metrics storage
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_iou': []}
    best_iou = 0.0
    ious = []
    
    # Create output directory
    output_dir = f'./Good_train/1_gdn/Train_{seed}'
    os.makedirs(output_dir, exist_ok=True)
    
    # Training loop
    epochs = 300
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
            # Calculate accuracy
            with torch.no_grad():
                preds = torch.argmax(outputs, dim=1)
                true_classes = torch.argmax(masks, dim=1)
                correct = (preds == true_classes).sum().item()
                train_correct += correct
                train_total += true_classes.numel()
            
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        all_preds = []
        all_true = []
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                preds = torch.argmax(outputs, dim=1)
                true_classes = torch.argmax(masks, dim=1)
                
                correct = (preds == true_classes).sum().item()
                val_correct += correct
                val_total += true_classes.numel()
                
                all_preds.append(preds.cpu())
                all_true.append(true_classes.cpu())
        
        # Calculate metrics
        train_loss = train_loss / len(train_loader)
        train_acc = train_correct / train_total
        val_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        
        # Calculate IoU
        val_preds = torch.cat(all_preds, dim=0)
        val_true = torch.cat(all_true, dim=0)
        
        # Convert to one-hot for IoU calculation
        val_true_one_hot = F.one_hot(val_true, num_classes=30).float()
        val_preds_one_hot = F.one_hot(val_preds, num_classes=30).float()
        
        iou = iou_metrics(val_true_one_hot, val_preds_one_hot)
        ious.append(iou)
        
        # Update learning rate
        scheduler.step(iou)
        
        # Save best model
        if iou > best_iou:
            best_iou = iou
            torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pth'))
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_iou'].append(iou)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val IoU: {iou:.4f}')
        print(f'  Best IoU: {best_iou:.4f}')
    
    # Save history
    np.save(os.path.join(output_dir, 'history.npy'), history)
    np.save(os.path.join(output_dir, 'ious.npy'), np.array(ious))
    
    # Plot training curves
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid()
    
    plt.subplot(1, 3, 2)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    
    plt.subplot(1, 3, 3)
    plt.plot(history['val_iou'])
    plt.title('Validation IoU')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.grid()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves.png'))
    plt.close()
    
    # Load best model for evaluation
    model.load_state_dict(torch.load(os.path.join(output_dir, 'best_model.pth')))
    model.eval()
    
    # Test evaluation
    test_correct = 0
    test_total = 0
    all_test_preds = []
    all_test_true = []
    
    with torch.no_grad():
        for images, masks in test_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            
            preds = torch.argmax(outputs, dim=1)
            true_classes = torch.argmax(masks, dim=1)
            
            correct = (preds == true_classes).sum().item()
            test_correct += correct
            test_total += true_classes.numel()
            
            all_test_preds.append(preds.cpu())
            all_test_true.append(true_classes.cpu())
    
    test_acc = test_correct / test_total
    test_preds = torch.cat(all_test_preds, dim=0)
    test_true = torch.cat(all_test_true, dim=0)
    
    # Calculate test IoU
    test_true_one_hot = F.one_hot(test_true, num_classes=30).float()
    test_preds_one_hot = F.one_hot(test_preds, num_classes=30).float()
    test_iou = iou_metrics(test_true_one_hot, test_preds_one_hot)
    
    print(f'\nTest Results:')
    print(f'  Test Accuracy: {test_acc:.4f}')
    print(f'  Test IoU: {test_iou:.4f}')
    print(f'  Best Validation IoU: {best_iou:.4f}')
    
    # Visualization
    visualize_predictions(model, test_dataset, colors, output_dir, img_height, img_width)
    
    return model, history, best_iou

def visualize_predictions(model, dataset, colors, output_dir, img_height=96, img_width=256, num_samples=5):
    """Visualize predictions on sample images"""
    model.eval()
    
    # Get random indices
    indices = random.sample(range(len(dataset)), num_samples)
    
    plt.figure(figsize=(15, 5))
    
    for i, idx in enumerate(indices):
        image, mask = dataset[idx]
        
        # Add batch dimension
        image_batch = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(image_batch)
            pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()
        
        # Get ground truth
        true_mask = torch.argmax(mask, dim=0).cpu().numpy()
        
        # Convert to color
        pred_color = color_to_one_hot_mask(pred, colors, img_height, img_width)
        true_color = color_to_one_hot_mask(true_mask, colors, img_height, img_width)
        
        # Plot
        plt.subplot(3, num_samples, i + 1)
        plt.imshow(image.permute(1, 2, 0).cpu().numpy())
        plt.title('Image')
        plt.axis('off')
        
        plt.subplot(3, num_samples, i + 1 + num_samples)
        plt.imshow(true_color)
        plt.title('Ground Truth')
        plt.axis('off')
        
        plt.subplot(3, num_samples, i + 1 + 2 * num_samples)
        plt.imshow(pred_color)
        plt.title('Prediction')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'predictions.png'))
    plt.close()

# Main training loop
seeds = [0]

for seed in seeds:
    model, history, best_iou = train_model(seed)