In [None]:
import os, glob, json, pickle, time
from PIL import Image
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from albumentations.augmentations.crops.transforms import CropNonEmptyMaskIfExists

In [None]:
# Paths and annotation loading
DATA_ROOT = '/content/drive/MyDrive/pt_data'
WEIGHTS_PATH = os.path.join(DATA_ROOT, 'model_weights')

In [None]:
AIHUB_ANN = os.path.join(DATA_ROOT, 'aihub_annotations.json')
RDD_ANN   = os.path.join(DATA_ROOT, 'rdd2022_train_annotations.json')

In [None]:
def load_and_consolidate_annotations(path):
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    consolidated = {}
    if isinstance(data, list):
        for item in data:
            fn = item.get('file_name') or item.get('filename')
            if not fn: continue
            base = os.path.splitext(fn)[0]
            consolidated[base] = item
    else:
        for key, item in data.items():
            base = os.path.splitext(key)[0]
            consolidated[base] = item
        for item in data.get('annotations', []):
            fn = item.get('file_name')
            if not fn: continue
            base = os.path.splitext(fn)[0]
            consolidated[base] = item
    return consolidated

In [None]:
ann_dict = {**load_and_consolidate_annotations(AIHUB_ANN)}
            #**load_and_consolidate_annotations(RDD_ANN)}

In [None]:
# PKL file lists
train_pkls = glob.glob(os.path.join(DATA_ROOT, 'AIhub_Road/training_image_batch_*.pkl'))
             #glob.glob(os.path.join(DATA_ROOT, 'RDD2022/rdd2022_train_image_batch_*.pkl'))
val_pkls   = glob.glob(os.path.join(DATA_ROOT, 'AIhub_Road/validation_image_batch_*.pkl'))

In [None]:
rdd_train_pkls = glob.glob(os.path.join(DATA_ROOT, 'RDD2022/rdd2022_train_image_batch_*.pkl'))

In [None]:
# Transforms and augmentation
seg_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
crop_fn = CropNonEmptyMaskIfExists(height=224, width=224)

In [None]:
# Mask creation
def create_mask_binary(annotations, shape):
    H, W = shape
    mask = np.zeros((H, W), dtype=np.int64)
    for ann in annotations:
        x, y, w, h = ann['bbox']
        x0, y0 = int(round(x)), int(round(y))
        x1 = x0 + int(round(w)) - 1
        y1 = y0 + int(round(h)) - 1
        if x1 > x0 and y1 > y0:
            mask[y0:y1+1, x0:x1+1] = 1
    return mask

In [None]:
# Model setup
def get_deeplab_model(num_classes):
    model = deeplabv3_resnet101(pretrained=True)
    model.classifier = DeepLabHead(2048, num_classes)
    return model

In [None]:
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = get_deeplab_model(2).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scaler    = GradScaler()
criterion = nn.CrossEntropyLoss()

In [None]:
seg_bs    = 16 # batch_size
epochs    = 10

In [None]:
# Shared batch processing
def process_slice(slice_batch):
    imgs, masks = [], []
    for entry in slice_batch:
        fn, arr = entry.get('filename'), entry.get('image')
        anns = ann_dict.get(os.path.splitext(fn)[0], {}).get('annotations', [])
        if not anns: continue
        img_pil = Image.fromarray(arr[..., ::-1].astype(np.uint8))
        mask_np = create_mask_binary(anns, arr.shape[:2])
        aug = crop_fn(image=np.array(img_pil), mask=mask_np)
        img_crop = Image.fromarray(aug['image'])
        mask_crop = aug['mask']
        imgs.append(seg_tf(img_crop))
        masks.append(torch.from_numpy(mask_crop).long())
    return imgs, masks

In [None]:
def run_epoch(pkls, mode='train'):
    is_train = (mode=='train')
    model.train() if is_train else model.eval()
    total_loss, total_correct, total_pixels = 0.0, 0, 0
    n_batches = sum((len(pickle.load(open(f,'rb'))) + seg_bs - 1)//seg_bs for f in pkls)
    pbar = tqdm(total=n_batches, desc=f"{mode.capitalize()} Epoch", unit='batch', ncols=80)
    with torch.set_grad_enabled(is_train):
        for fpath in pkls:
            batch = pickle.load(open(fpath,'rb'))
            for i in range(0, len(batch), seg_bs):
                slice_batch = batch[i:i+seg_bs]
                imgs, masks = process_slice(slice_batch)
                if len(imgs) < 2:
                    pbar.update(); continue
                loader = DataLoader(TensorDataset(torch.stack(imgs), torch.stack(masks)),
                                     batch_size=len(imgs), shuffle=is_train)
                for x, y in loader:
                    x, y = x.to(device), y.to(device)
                    if is_train: optimizer.zero_grad()
                    with autocast(): out = model(x)['out']; loss = criterion(out, y)
                    if is_train:
                        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
                    total_loss   += loss.item()
                    preds        = out.argmax(dim=1)
                    total_correct+= (preds==y).sum().item()
                    total_pixels += y.numel()
                    pbar.update()
    pbar.close()
    loss = total_loss / (total_pixels/(224*224)) if total_pixels else 0
    acc  = total_correct/total_pixels if total_pixels else 0
    return loss, acc

In [None]:
for epoch in range(1, epochs+1):
    tr_loss, tr_acc = run_epoch(train_pkls, 'train')
    val_loss, val_acc = run_epoch(val_pkls, 'val')
    torch.save(model.state_dict(), f"deeplabv3_ep{epoch}.pth")
    print(f"Epoch {epoch} | Train L:{tr_loss:.6f} A:{tr_acc:.4f} | Val L:{val_loss:.6f} A:{val_acc:.4f}")