In [1]:
import os, glob, json, pickle, time
from PIL import Image
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from albumentations.augmentations.crops.transforms import CropNonEmptyMaskIfExists
import re
import torch.nn.functional as F

from mine_loss import MineNetwork, compute_mine_loss

In [2]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!ls

drive  mine_loss.py  __pycache__  sample_data


In [5]:
# Paths and annotation loading
DATA_ROOT = '/content/drive/MyDrive/pt_data'
WEIGHTS_PATH = os.path.join(DATA_ROOT, 'model_weights')

In [6]:
AIHUB_ANN = os.path.join(DATA_ROOT, 'aihub_annotations.json')
RDD_ANN   = os.path.join(DATA_ROOT, 'rdd2022_train_annotations.json')

In [7]:
def load_and_consolidate_annotations(path):
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    consolidated = {}
    if isinstance(data, list):
        for item in data:
            fn = item.get('file_name') or item.get('filename')
            if not fn:
                continue
            base = os.path.splitext(fn)[0]
            consolidated[base] = item
    else:
        for key, item in data.items():
            base = key.split('_1.jpg')[0]
            consolidated[base] = {}
            consolidated[base]['filename'] = base + '.jpg'
            if item['label'] == 'Pothole':
                item['label'] = 1
            else:
                item['label'] = 2
            consolidated[base]['annotations'] = item

    return consolidated

In [8]:
ann_dict = {**load_and_consolidate_annotations(AIHUB_ANN)}
            #**load_and_consolidate_annotations(RDD_ANN)}

In [9]:
print(ann_dict.keys())

dict_keys(['G_A_F_01_0822090547728', 'G_A_F_01_0822090548028', 'G_A_F_01_0822090824244', 'G_A_F_01_0822092030144', 'G_A_F_01_0822092030278', 'G_A_F_01_0822092045378', 'G_A_F_01_0822092701811', 'G_A_F_01_0822094459711', 'G_A_F_01_0822094459845', 'G_A_F_01_0822094710278', 'G_A_F_01_0822094710411', 'G_A_F_01_0822094710544', 'G_A_F_01_0822094900378', 'G_A_F_01_0822094901143', 'G_A_F_01_0822100636844', 'G_A_F_01_0822100636978', 'G_A_F_01_0822100741161', 'G_A_F_01_0822100756028', 'G_A_F_01_0822101245228', 'G_A_F_01_0822101608978', 'G_A_F_01_0822101807011', 'G_A_F_01_0822101807278', 'G_A_F_01_0822101807528', 'G_A_F_01_0822101807745', 'G_A_F_01_0822101808011', 'G_A_F_01_0822101846911', 'G_A_F_01_0822101847044', 'G_A_F_01_0822103223661', 'G_A_F_01_0822111313078', 'G_A_F_01_0822111505778', 'G_A_F_01_0822111506611', 'G_A_F_01_0822111506811', 'G_A_F_01_0822111509161', 'G_A_F_01_0822112222484', 'G_A_F_01_0822112359317', 'G_A_F_01_0822113121544', 'G_A_F_01_0822114224567', 'G_A_F_01_0822121621544', '

In [10]:
print(ann_dict['G_A_F_01_0822090547728'])

{'file_name': 'G_A_F_01_0822090547728.jpg', 'annotations': [{'label': 1, 'bbox': [125.94421875, 7.472499999999999, 7.2209375, 11.1321875]}]}


In [11]:
AIHUB_ANN = os.path.join(DATA_ROOT, 'aihub_PC_annotations.json')
ann_dict = {**load_and_consolidate_annotations(AIHUB_ANN)}

In [12]:
print(ann_dict.keys())

dict_keys(['G_A_F_01_0822090547728', 'G_A_F_01_0822090548028', 'G_A_F_01_0822090824244', 'G_A_F_01_0822092030144', 'G_A_F_01_0822092030278', 'G_A_F_01_0822092045378', 'G_A_F_01_0822092701811', 'G_A_F_01_0822094459711', 'G_A_F_01_0822094459845', 'G_A_F_01_0822094710278', 'G_A_F_01_0822094710411', 'G_A_F_01_0822094710544', 'G_A_F_01_0822094900378', 'G_A_F_01_0822094901143', 'G_A_F_01_0822100636844', 'G_A_F_01_0822100636978', 'G_A_F_01_0822100741161', 'G_A_F_01_0822100756028', 'G_A_F_01_0822101245228', 'G_A_F_01_0822101608978', 'G_A_F_01_0822101807011', 'G_A_F_01_0822101807011_2.jpg', 'G_A_F_01_0822101807011_3.jpg', 'G_A_F_01_0822101807011_4.jpg', 'G_A_F_01_0822101807278', 'G_A_F_01_0822101807278_2.jpg', 'G_A_F_01_0822101807278_3.jpg', 'G_A_F_01_0822101807278_4.jpg', 'G_A_F_01_0822101807278_5.jpg', 'G_A_F_01_0822101807528', 'G_A_F_01_0822101807528_2.jpg', 'G_A_F_01_0822101807528_3.jpg', 'G_A_F_01_0822101807745', 'G_A_F_01_0822101807745_2.jpg', 'G_A_F_01_0822101808011', 'G_A_F_01_082210184

In [13]:
print(ann_dict['G_A_F_01_0822090547728'])

{'filename': 'G_A_F_01_0822090547728.jpg', 'annotations': {'label': 1, 'bbox': [125, 7, 7, 11]}}


In [14]:
# PKL file lists
train_pkls = glob.glob(os.path.join(DATA_ROOT, 'AIhub_Road/training_image_batch_*.pkl'))
             #glob.glob(os.path.join(DATA_ROOT, 'RDD2022/rdd2022_train_image_batch_*.pkl'))
val_pkls   = glob.glob(os.path.join(DATA_ROOT, 'AIhub_Road/validation_image_batch_*.pkl'))

In [15]:
rdd_train_pkls = glob.glob(os.path.join(DATA_ROOT, 'RDD2022/rdd2022_train_image_batch_*.pkl'))

In [16]:
# Transforms and augmentation
seg_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
crop_fn = CropNonEmptyMaskIfExists(height=224, width=224)

In [17]:
# Mask creation
def create_mask_binary(annotations, shape):
    H, W = shape
    mask = np.zeros((H, W), dtype=np.int64)
    for ann in annotations:
        x, y, w, h = ann['bbox']
        x0, y0 = int(round(x)), int(round(y))
        x1 = x0 + int(round(w)) - 1
        y1 = y0 + int(round(h)) - 1
        if x1 > x0 and y1 > y0:
            mask[y0:y1+1, x0:x1+1] = 1
    return mask

In [18]:
# Model setup
def get_deeplab_model(num_classes):
    model = deeplabv3_resnet101(pretrained=True)
    model.classifier = DeepLabHead(2048, num_classes)
    return model

In [19]:
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = get_deeplab_model(3).to(device)
mine_net = MineNetwork(2048, 2048).to(device)

Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth
100%|██████████| 233M/233M [00:01<00:00, 193MB/s]


In [20]:
# Shared batch processing
def process_slice(slice_batch):
    imgs, masks = [], []
    for entry in slice_batch:
        fn, arr = entry.get('filename'), entry.get('image')
        anns = ann_dict.get(os.path.splitext(fn)[0], {}).get('annotations', [])
        if not anns: continue
        if isinstance(anns, dict):
          anns = [anns]
        img_pil = Image.fromarray(arr[..., ::-1].astype(np.uint8))
        mask_np = create_mask_binary(anns, arr.shape[:2])
        aug = crop_fn(image=np.array(img_pil), mask=mask_np)
        img_crop = Image.fromarray(aug['image'])
        mask_crop = aug['mask']
        imgs.append(seg_tf(img_crop))
        masks.append(torch.from_numpy(mask_crop).long())
    return imgs, masks

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
mine_optimizer = torch.optim.Adam(mine_net.parameters(), lr=1e-4)
scaler    = GradScaler()
criterion = nn.CrossEntropyLoss()

In [22]:
seg_bs    = 16 # batch_size
epochs    = 10

In [23]:
def run_epoch(pkls, mode='train'):
    is_train = (mode == 'train')
    model.train() if is_train else model.eval()
    total_loss, total_correct, total_pixels = 0.0, 0, 0

    n_batches = sum((len(pickle.load(open(f, 'rb'))) + seg_bs - 1) // seg_bs for f in pkls)

    with tqdm(total=n_batches, desc=f"{mode.capitalize()} Epoch", unit='batch', ncols=100, leave=False) as pbar:
        with torch.set_grad_enabled(is_train):
            for fpath in pkls:
                batch = pickle.load(open(fpath, 'rb'))

                for i in range(0, len(batch), seg_bs):
                    slice_batch = batch[i:i + seg_bs]
                    imgs, masks = process_slice(slice_batch)
                    if len(imgs) < 2:
                        pbar.update()
                        continue

                    loader = DataLoader(TensorDataset(torch.stack(imgs), torch.stack(masks)),
                                        batch_size=len(imgs), shuffle=is_train)

                    for x, y in loader:
                        x, y = x.to(device), y.to(device)

                        if is_train:
                            optimizer.zero_grad()
                            mine_optimizer.zero_grad()

                        with autocast():
                            out = model(x)['out']
                            seg_loss = criterion(out, y)

                            if is_train:
                                with torch.no_grad():
                                    z = model.backbone(x)['out']
                                    y_ds = F.interpolate(y.unsqueeze(1).float(), size=z.shape[2:], mode='nearest').long().squeeze(1)

                                mine_l = compute_mine_loss(mine_net, z, y_ds, num_classes=3)
                                loss = seg_loss + 0.1 * mine_l
                            else:
                                loss = seg_loss

                        if is_train:
                            scaler.scale(loss).backward()
                            scaler.step(optimizer)
                            mine_optimizer.step()
                            scaler.update()

                        total_loss += loss.item()
                        preds = out.argmax(dim=1)
                        total_correct += (preds == y).sum().item()
                        total_pixels += y.numel()

                        pbar.update()

    avg_loss = total_loss / (total_pixels / (224 * 224)) if total_pixels else 0
    acc = total_correct / total_pixels if total_pixels else 0
    return avg_loss, acc

In [None]:
for epoch in range(1, epochs+1):
    tr_loss, tr_acc = run_epoch(train_pkls, 'train')
    val_loss, val_acc = run_epoch(val_pkls, 'val')
    torch.save(model.state_dict(), f"deeplabv3_ep{epoch}.pth")
    tqdm.write(f"Epoch {epoch:02d} | Train L:{tr_loss:.6f} A:{tr_acc:.4f} | Val L:{val_loss:.6f} A:{val_acc:.4f}")