# U-Net+ Change Detection Training Notebook
This notebook trains a U-Net+ model for change detection using A/B/label folders for train, val, and test. U-Net+ is a simplified variant with enhanced skip connections.

## Assignment Compliance (Segmentation)
- Problem: Change detection (binary segmentation of change mask)
- Model: U-Net+ (variant with enhanced skip connections)
- Epochs: Min 50 with early stopping (patience 10)
- Data: Using existing train / val / test folders exactly as provided (no re-splitting enforced).
- Metrics tracked: IoU, Dice, Precision, Recall, F1, Accuracy, Loss + confusion matrix (pixel-wise)
- Outputs: Metric plots, sample predictions, parameter count, saved best weights.
- Saved artifacts: best_model.pth, training_history.csv, test_metrics.csv, confusion_matrix.txt, prediction PNGs.

In [None]:
# Install all required packages
!pip install torch torchvision scikit-learn pandas tqdm matplotlib seaborn pillow --quiet

In [None]:
# Imports & Setup
import os, random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Device & Reproducibility
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
print(f"Using device: {DEVICE}")

# Loss components (Dice + BCE)
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, preds, targets):
        preds = preds.contiguous()
        targets = targets.contiguous()
        intersection = (preds * targets).sum(dim=(2,3))
        denom = preds.sum(dim=(2,3)) + targets.sum(dim=(2,3))
        dice = (2 * intersection + self.smooth) / (denom + self.smooth)
        return 1 - dice.mean()

def combined_loss(logits, targets, bce_w=0.6, dice_w=0.4):
    bce = nn.BCEWithLogitsLoss()(logits, targets)
    probs = torch.sigmoid(logits)
    dloss = DiceLoss()(probs, targets)
    return bce_w * bce + dice_w * dloss

@torch.no_grad()
def batch_metrics(logits, targets, thresh=0.5):
    probs = torch.sigmoid(logits)
    preds = (probs >= thresh).float()
    p = preds.view(-1).cpu().numpy()
    t = targets.view(-1).cpu().numpy()
    cm = confusion_matrix(t, p, labels=[0,1]) if (t.sum()>0 or p.sum()>0) else np.array([[len(t),0],[0,0]])
    if cm.shape == (2,2):
        tn, fp, fn, tp = cm.ravel()
    else:
        tn = fp = fn = tp = 0
    eps = 1e-8
    iou = tp / (tp + fp + fn + eps)
    dice = (2*tp) / (2*tp + fp + fn + eps)
    precision = tp / (tp + fp + eps) if (tp+fp)>0 else 0.0
    recall = tp / (tp + fn + eps) if (tp+fn)>0 else 0.0
    f1 = 2*precision*recall/(precision+recall+eps) if (precision+recall)>0 else 0.0
    acc = (tp + tn) / (tp + tn + fp + fn + eps)
    return dict(tp=int(tp), fp=int(fp), fn=int(fn), tn=int(tn), iou=float(iou), dice=float(dice), precision=float(precision), recall=float(recall), f1=float(f1), acc=float(acc))

class EarlyStopping:
    def __init__(self, patience=10, min_delta=1e-4, restore_best=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best = restore_best
        self.best_loss = None
        self.counter = 0
        self.best_state = None
    def __call__(self, epoch, current_loss, model):
        if self.best_loss is None or (self.best_loss - current_loss) > self.min_delta:
            self.best_loss = current_loss
            self.counter = 0
            if self.restore_best:
                self.best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
        else:
            self.counter += 1
        if self.counter >= self.patience:
            if self.restore_best and self.best_state is not None:
                model.load_state_dict(self.best_state)
            return True
        return False

In [None]:
# Dataset
DATA_ROOT = '/kaggle/input/finaldatasetnew/earthquakedatasetnew'  # Kaggle dataset path
IMG_SIZE = (256, 256)
TRAIN_BATCH = 6
VAL_BATCH = 2
TEST_BATCH = 1

transform_img = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

transform_mask = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor()
])

class ChangeDataset(Dataset):
    def __init__(self, root, split='train'):
        if split=='train':
            a_dir = os.path.join(root,'train','A_train_aug')
            b_dir = os.path.join(root,'train','B_train_aug')
            m_dir = os.path.join(root,'train','label_train_aug')
        elif split=='val':
            a_dir = os.path.join(root,'val','A_val')
            b_dir = os.path.join(root,'val','B_val')
            m_dir = os.path.join(root,'val','label_val')
        else:
            a_dir = os.path.join(root,'test','A_test')
            b_dir = os.path.join(root,'test','B_test')
            m_dir = os.path.join(root,'test','label_test')
        self.a_files = sorted([f for f in os.listdir(a_dir) if f.endswith('.png')])
        self.a_dir, self.b_dir, self.m_dir = a_dir, b_dir, m_dir
    def __len__(self): return len(self.a_files)
    def __getitem__(self, idx):
        name = self.a_files[idx]
        a = Image.open(os.path.join(self.a_dir,name)).convert('RGB')
        b = Image.open(os.path.join(self.b_dir,name)).convert('RGB')
        m = Image.open(os.path.join(self.m_dir,name)).convert('L')
        a = transform_img(a)
        b = transform_img(b)
        m = transform_mask(m)
        m = (m>0.5).float()
        x = torch.cat([a,b], dim=0)
        return x, m

train_ds = ChangeDataset(DATA_ROOT,'train')
val_ds = ChangeDataset(DATA_ROOT,'val')
test_ds = ChangeDataset(DATA_ROOT,'test')

train_loader = DataLoader(train_ds, batch_size=TRAIN_BATCH, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=VAL_BATCH, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=TEST_BATCH, shuffle=False, num_workers=0)

print(f"Train {len(train_ds)} | Val {len(val_ds)} | Test {len(test_ds)}")

In [None]:
# U-Net+ Implementation (Enhanced Skip Connections)
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_rate=0.15):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu1 = nn.ReLU(inplace=True)
        self.dropout1 = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.relu2 = nn.ReLU(inplace=True)
        self.dropout2 = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.dropout1(x)
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.dropout2(x)
        return x

class AttentionGate(nn.Module):
    """Attention mechanism for skip connections"""
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, 1, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, 1, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, 1, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class UNetPlus(nn.Module):
    """U-Net+ with Attention Gates"""
    def __init__(self, in_ch=6, out_ch=1, filters=(32,64,128,256,512)):
        super().__init__()
        f = filters
        
        # Encoder
        self.enc1 = ConvBlock(in_ch, f[0], dropout_rate=0.0)
        self.enc2 = ConvBlock(f[0], f[1], dropout_rate=0.0)
        self.enc3 = ConvBlock(f[1], f[2], dropout_rate=0.1)
        self.enc4 = ConvBlock(f[2], f[3], dropout_rate=0.1)
        self.bottleneck = ConvBlock(f[3], f[4], dropout_rate=0.15)
        
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        # Attention gates
        self.att4 = AttentionGate(F_g=f[4], F_l=f[3], F_int=f[3]//2)
        self.att3 = AttentionGate(F_g=f[3], F_l=f[2], F_int=f[2]//2)
        self.att2 = AttentionGate(F_g=f[2], F_l=f[1], F_int=f[1]//2)
        self.att1 = AttentionGate(F_g=f[1], F_l=f[0], F_int=f[0]//2)
        
        # Decoder
        self.dec4 = ConvBlock(f[4] + f[3], f[3], dropout_rate=0.15)
        self.dec3 = ConvBlock(f[3] + f[2], f[2], dropout_rate=0.15)
        self.dec2 = ConvBlock(f[2] + f[1], f[1], dropout_rate=0.15)
        self.dec1 = ConvBlock(f[1] + f[0], f[0], dropout_rate=0.15)
        
        self.final = nn.Conv2d(f[0], out_ch, 1)
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))
        
        # Decoder with attention
        d4 = self.up(b)
        e4_att = self.att4(d4, e4)
        d4 = self.dec4(torch.cat([d4, e4_att], dim=1))
        
        d3 = self.up(d4)
        e3_att = self.att3(d3, e3)
        d3 = self.dec3(torch.cat([d3, e3_att], dim=1))
        
        d2 = self.up(d3)
        e2_att = self.att2(d2, e2)
        d2 = self.dec2(torch.cat([d2, e2_att], dim=1))
        
        d1 = self.up(d2)
        e1_att = self.att1(d1, e1)
        d1 = self.dec1(torch.cat([d1, e1_att], dim=1))
        
        return self.final(d1)

model = UNetPlus(in_ch=6, out_ch=1, filters=(32,64,128,256,512)).to(DEVICE)
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=7, verbose=True, min_lr=1e-6)
early_stop = EarlyStopping(patience=10, min_delta=1e-4)

EPOCHS = 200
history = []

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} Train", leave=False):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        logits = model(xb)
        loss = combined_loss(logits, yb)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    agg = dict(tp=0,fp=0,fn=0,tn=0)
    with torch.no_grad():
        for xb, yb in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} Val", leave=False):
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb)
            loss = combined_loss(logits, yb)
            val_loss += loss.item() * xb.size(0)
            mets = batch_metrics(logits, yb, thresh=0.5)
            for k in agg: agg[k] += mets[k]
    
    val_loss /= len(val_loader.dataset)
    eps=1e-8
    tp,fp,fn,tn = agg['tp'],agg['fp'],agg['fn'],agg['tn']
    iou = tp / (tp+fp+fn+eps)
    dice = (2*tp)/(2*tp+fp+fn+eps)
    precision = tp/(tp+fp+eps) if (tp+fp)>0 else 0
    recall = tp/(tp+fn+eps) if (tp+fn)>0 else 0
    f1 = 2*precision*recall/(precision+recall+eps) if (precision+recall)>0 else 0
    acc = (tp+tn)/(tp+tn+fp+fn+eps)
    history.append(dict(epoch=epoch+1, train_loss=train_loss, val_loss=val_loss, IoU=iou, Dice=dice, Precision=precision, Recall=recall, F1=f1, Accuracy=acc))

    scheduler.step(val_loss)
    print(f"Epoch {epoch+1}: TL {train_loss:.4f} VL {val_loss:.4f} IoU {iou:.4f} Dice {dice:.4f} F1 {f1:.4f}")

    if epoch==0 or val_loss == min(h['val_loss'] for h in history):
        torch.save(model.state_dict(), 'best_unetplus.pth')

    if early_stop(epoch, val_loss, model):
        print(f"Early stopping at epoch {epoch+1}")
        break

pd.DataFrame(history).to_csv('training_history_unetplus.csv', index=False)
print('Training complete.')

In [None]:
# Test evaluation
model = UNetPlus(in_ch=6, out_ch=1, filters=(32,64,128,256,512)).to(DEVICE)
model.load_state_dict(torch.load('best_unetplus.pth', map_location=DEVICE))
model.eval()

agg = dict(tp=0,fp=0,fn=0,tn=0)
all_preds = []

with torch.no_grad():
    for xb, yb in tqdm(test_loader, desc="Test", leave=False):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        mets = batch_metrics(logits, yb, thresh=0.5)
        for k in agg: agg[k] += mets[k]
        probs = torch.sigmoid(logits)
        preds = (probs>=0.5).float().cpu()
        all_preds.append(preds)

all_preds = torch.cat(all_preds, dim=0)
eps=1e-8
tp,fp,fn,tn = agg['tp'],agg['fp'],agg['fn'],agg['tn']
iou = tp/(tp+fp+fn+eps)
dice = (2*tp)/(2*tp+fp+fn+eps)
precision = tp/(tp+fp+eps) if (tp+fp)>0 else 0
recall = tp/(tp+fn+eps) if (tp+fn)>0 else 0
f1 = 2*precision*recall/(precision+recall+eps) if (precision+recall)>0 else 0
acc = (tp+tn)/(tp+tn+fp+fn+eps)

cm = np.array([[tn, fp],[fn, tp]])
metrics = dict(IoU=iou, Dice=dice, Precision=precision, Recall=recall, F1=f1, Accuracy=acc, TP=tp, FP=fp, FN=fn, TN=tn)

print('\nTest Metrics:')
print(f'IoU: {iou:.4f}')
print(f'Dice: {dice:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')

pd.DataFrame([metrics]).to_csv('test_metrics_unetplus.csv', index=False)
np.savetxt('confusion_matrix_unetplus.txt', cm, fmt='%d')

os.makedirs('test_predictions_unetplus', exist_ok=True)
for i in range(min(10, all_preds.shape[0])):
    img = (all_preds[i,0].numpy()*255).astype('uint8')
    Image.fromarray(img).save(f'test_predictions_unetplus/pred_{i}.png')
print('Saved prediction samples.')

In [None]:
# Visualization
hist_df = pd.read_csv('training_history_unetplus.csv')

fig, ((ax1, ax2, ax3),(ax4, ax5, ax6)) = plt.subplots(2,3, figsize=(16,8))

ax1.plot(hist_df['epoch'], hist_df['train_loss'], label='Train Loss', color='blue')
ax1.plot(hist_df['epoch'], hist_df['val_loss'], label='Val Loss', color='red')
ax1.set_title('Loss Curves')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(hist_df['epoch'], hist_df['IoU'], label='IoU', color='green')
ax2.set_title('Validation IoU')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('IoU')
ax2.grid(True, alpha=0.3)

ax3.plot(hist_df['epoch'], hist_df['Dice'], label='Dice', color='orange')
ax3.set_title('Validation Dice')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Dice')
ax3.grid(True, alpha=0.3)

ax4.plot(hist_df['epoch'], hist_df['Precision'], label='Precision', color='purple')
ax4.plot(hist_df['epoch'], hist_df['Recall'], label='Recall', color='brown')
ax4.set_title('Precision & Recall')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Score')
ax4.legend()
ax4.grid(True, alpha=0.3)

ax5.plot(hist_df['epoch'], hist_df['F1'], label='F1', color='red')
ax5.plot(hist_df['epoch'], hist_df['Accuracy'], label='Accuracy', color='blue')
ax5.set_title('F1 & Accuracy')
ax5.set_xlabel('Epoch')
ax5.set_ylabel('Score')
ax5.legend()
ax5.grid(True, alpha=0.3)

ax6.axis('off')
if len(hist_df) > 0:
    best_epoch = hist_df.loc[hist_df['val_loss'].idxmin(), 'epoch']
    best_val_loss = hist_df['val_loss'].min()
    best_iou = hist_df['IoU'].max()
    best_dice = hist_df['Dice'].max()
    best_f1 = hist_df['F1'].max()
    
    summary_text = f"""
    TRAINING SUMMARY
    ================
    Model: U-Net+
    Total Epochs: {len(hist_df)}
    Best Epoch: {best_epoch}
    
    Best Metrics:
    Val Loss: {best_val_loss:.4f}
    IoU: {best_iou:.4f}
    Dice: {best_dice:.4f}
    F1: {best_f1:.4f}
    """
    ax6.text(0.1, 0.5, summary_text, fontsize=10, fontfamily='monospace',
             verticalalignment='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))

plt.tight_layout()
plt.savefig('training_curves_unetplus.png', dpi=150, bbox_inches='tight')
plt.show()
print("Visualization complete!")