In [10]:
import os, glob, pickle, json, time
from PIL import Image, ImageDraw
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

In [11]:
# — 모델 정의 —
def get_deeplab_model(num_classes: int) -> nn.Module:
    model = models.segmentation.deeplabv3_resnet101(pretrained=True)
    model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)
    return model

class CBAMBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels//reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels//reduction, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg = self.fc(self.avg_pool(x))
        maxv = self.fc(self.max_pool(x))
        attn = self.sigmoid(avg + maxv)
        return x * attn

class ResNetCBAM(nn.Module):
    def __init__(self, num_classes: int = 3):
        super().__init__()
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.layer1.add_module('cbam1', CBAMBlock(256))
        self.backbone.layer2.add_module('cbam2', CBAMBlock(512))
        in_feat = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_feat, num_classes)
    def forward(self, x):
        return self.backbone(x)

In [12]:
# — 유틸리티 함수 —
def load_annotations(path, retries=3, delay=1):
    for _ in range(retries):
        try:
            with open(path,'r',encoding='utf-8') as f:
                return json.load(f)
        except OSError:
            time.sleep(delay)
    raise

def create_mask(bboxes, shape):
    img = Image.new('L', (shape[1], shape[0]), 0)
    d = ImageDraw.Draw(img)
    for x,y,w,h in bboxes:
        d.rectangle([x,y,x+w,y+h], fill=1)
    return np.array(img, dtype=np.int64)

def warp_roi(img_t, mask_t, bboxes, size=(224,224)):
    mean,std = img_t.new_tensor([0.485,0.456,0.406]).view(3,1,1), img_t.new_tensor([0.229,0.224,0.225]).view(3,1,1)
    img = (img_t*std+mean)*255
    np_img = img.clamp(0,255).byte().cpu().permute(1,2,0).numpy()[...,::-1]
    m = (mask_t.cpu().numpy()>0).astype(np.uint8)
    cnts,_ = cv2.findContours(m,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    if not cnts: return None,[]
    c = max(cnts, key=cv2.contourArea)
    pts = cv2.approxPolyDP(c,0.02*cv2.arcLength(c,True),True).reshape(-1,2) \
          if len(c)>=4 else cv2.boxPoints(cv2.minAreaRect(c))
    s,d = pts.sum(1), np.diff(pts,axis=1).ravel()
    tl,br,tr,bl = pts[np.argmin(s)], pts[np.argmax(s)], pts[np.argmin(d)], pts[np.argmax(d)]
    src = np.array([tl,tr,br,bl], float)
    dst = np.array([[0,0],[size[0]-1,0],[size[0]-1,size[1]-1],[0,size[1]-1]], float)
    H = cv2.getPerspectiveTransform(src, dst)
    wimg = cv2.warpPerspective(np_img, H, size)
    roi = torch.from_numpy(wimg[...,::-1]).permute(2,0,1).float().to(img_t.device)/255
    return (roi-mean)/std, []

In [13]:
# — 경로 설정 및 어노테이션 로드 —
DATA_ROOT = '/content/drive/MyDrive/pt_data'
AIHUB_ANN = os.path.join(DATA_ROOT, 'aihub_annotations.json')
RDD_ANN   = os.path.join(DATA_ROOT, 'rdd2022_train_annotations.json')
ann_dict  = {**load_annotations(AIHUB_ANN), **load_annotations(RDD_ANN)}

# — 파일 목록 준비 —
train_pkls = glob.glob(os.path.join(DATA_ROOT, 'AIhub_Road',    'training_image_batch_*.pkl')) + \
             glob.glob(os.path.join(DATA_ROOT, 'RDD2022',        'training_image_batch_*.pkl'))
val_pkls   = glob.glob(os.path.join(DATA_ROOT, 'AIhub_Road',    'validation_image_batch_*.pkl')) + \
             glob.glob(os.path.join(DATA_ROOT, 'RDD2022',        'validation_image_batch_*.pkl'))

In [14]:
# — 트랜스폼 정의 —
seg_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
roi_tf = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

In [20]:
# — 훈련 함수 —
def train_seg(seg_epochs=5, seg_bs=16):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    scaler_s = GradScaler()

    seg_model = get_deeplab_model(2).to(device)
    seg_opt   = torch.optim.Adam(seg_model.parameters(), lr=1e-4)
    seg_loss  = nn.CrossEntropyLoss()

    # — Segmentation loop —
    for epoch in range(1, seg_epochs+1):
        # train
        seg_model.train()
        total_loss, steps = 0, 0
        for pklf in tqdm(train_pkls, desc=f"Seg Train Ep{epoch}", unit="file"):
            with open(pklf,'rb') as f: batch = pickle.load(f)
            items = batch.items() if isinstance(batch,dict) else [(i,e['image']) for i,e in enumerate(batch)]
            imgs, msks = [], []
            for key, img in items:
                arr = img.numpy() if hasattr(img,'numpy') else np.array(img)
                imgs.append(seg_tf(Image.fromarray(arr.astype(np.uint8))))
                mask = create_mask(ann_dict.get(key,{}).get('bboxes',[]), arr.shape[:2])
                msks.append(torch.from_numpy(mask).long())
            loader = DataLoader(TensorDataset(torch.stack(imgs), torch.stack(msks)),
                                batch_size=seg_bs, shuffle=True,
                                num_workers=4, pin_memory=True)
            for x,y in loader:
                x,y = x.to(device), y.to(device)
                # Change autocast initialization for older PyTorch versions
                #with autocast(enabled=device.type=='cuda'):
                with autocast():
                    out = seg_model(x)['out']
                    loss = seg_loss(out, y)
                seg_opt.zero_grad()
                scaler_s.scale(loss).backward()
                scaler_s.step(seg_opt)
                scaler_s.update()
                total_loss += loss.item(); steps += 1
        avg_tr = total_loss/steps

        # val
        seg_model.eval()
        v_loss, v_steps = 0, 0
        for pklf in tqdm(val_pkls, desc=f"Seg Val Ep{epoch}", unit="file"):
            with open(pklf,'rb') as f: batch = pickle.load(f)
            items = batch.items() if isinstance(batch,dict) else [(i,e['image']) for i,e in enumerate(batch)]
            imgs, msks = [], []
            for key, img in items:
                arr = img.numpy() if hasattr(img,'numpy') else np.array(img)
                imgs.append(seg_tf(Image.fromarray(arr.astype(np.uint8))))
                mask = create_mask(ann_dict.get(key,{}).get('bboxes',[]), arr.shape[:2])
                msks.append(torch.from_numpy(mask).long())
            loader = DataLoader(TensorDataset(torch.stack(imgs), torch.stack(msks)),
                                batch_size=seg_bs, shuffle=False,
                                num_workers=4, pin_memory=True)
            with torch.no_grad():
                # Change autocast initialization for older PyTorch versions
                #with autocast(enabled=device.type=='cuda'):
                with autocast():
                    for x,y in loader:
                        x,y = x.to(device), y.to(device)
                        out = seg_model(x)['out']
                        loss = seg_loss(out, y)
                        v_loss += loss.item(); v_steps += 1
        avg_val = v_loss/v_steps

        torch.save(seg_model.state_dict(), f"deeplab_ep{epoch}.pth")
        print(f"Seg Ep{epoch} | Train:{avg_tr:.4f} | Val:{avg_val:.4f}")

In [21]:
def train_cls(cls_epochs=5, cls_bs=16):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    scaler_c = GradScaler()
    cls_model = ResNetCBAM(num_classes=3).to(device)
    cls_opt   = torch.optim.Adam(cls_model.parameters(), lr=1e-4)
    cls_loss  = nn.CrossEntropyLoss()

  # — ROI 추출 —
    rois, labels = [], []
    for pklf in tqdm(train_pkls, desc="ROI Extract", unit="file"):
        with open(pklf,'rb') as f: batch = pickle.load(f)
        items = batch.items() if isinstance(batch,dict) else [(i,e['image']) for i,e in enumerate(batch)]
        for key, img in items:
            arr = img.numpy() if hasattr(img,'numpy') else np.array(img)
            mask = torch.from_numpy(create_mask(ann_dict.get(key,{}).get('bboxes',[]), arr.shape[:2]))
            # Ensure warp_roi can handle tensor inputs and potentially move to device
            roi, _ = warp_roi(seg_tf(Image.fromarray(arr.astype(np.uint8))).to(device), mask.to(device), [])
            if roi is not None:
                rois.append(roi.cpu()); labels.append(int(ann_dict.get(key,{}).get('labels',[0])[0]))

    # Handle case where no ROIs are extracted
    if not rois:
        print("No ROIs extracted for training. Skipping classification training.")
        train_ld = None # Or handle this differently based on desired behavior
    else:
        X = torch.stack([roi_tf(r) for r in rois])
        Y = torch.tensor(labels)
        train_ld = DataLoader(TensorDataset(X, Y), batch_size=cls_bs, shuffle=True,
                              num_workers=4, pin_memory=True)

    # Prepare val set ROI in the same way...
    val_rois, val_labels = [], []
    for pklf in tqdm(val_pkls, desc="Val ROI Extract", unit="file"):
         with open(pklf,'rb') as f: batch = pickle.load(f)
         items = batch.items() if isinstance(batch,dict) else [(i,e['image']) for i,e in enumerate(batch)]
         for key, img in items:
             arr = img.numpy() if hasattr(img,'numpy') else np.array(img)
             mask = torch.from_numpy(create_mask(ann_dict.get(key,{}).get('bboxes',[]), arr.shape[:2]))
             roi, _ = warp_roi(seg_tf(Image.fromarray(arr.astype(np.uint8))).to(device), mask.to(device), [])
             if roi is not None:
                 val_rois.append(roi.cpu()); val_labels.append(int(ann_dict.get(key,{}).get('labels',[0])[0]))

    # Handle case where no validation ROIs are extracted
    if not val_rois:
        print("No ROIs extracted for validation. Skipping classification validation.")
        val_ld = None # Or handle this differently based on desired behavior
    else:
        val_X = torch.stack([roi_tf(r) for r in val_rois])
        val_Y = torch.tensor(val_labels)
        val_ld = DataLoader(TensorDataset(val_X, val_Y), batch_size=cls_bs, shuffle=False,
                            num_workers=4, pin_memory=True)


    # — Classification loop —
    if train_ld is not None and val_ld is not None: # Only run if both loaders are prepared
        for epoch in range(1, cls_epochs+1):
            cls_model.train()
            total_loss, steps = 0, 0
            for xb,yb in tqdm(train_ld, desc=f"Cls Train Ep{epoch}", unit="batch"):
                xb,yb = xb.to(device), yb.to(device)
                # Change autocast initialization for older PyTorch versions
                with autocast():
                #with autocast(enabled=device.type=='cuda'):
                    logits = cls_model(xb)
                    loss = cls_loss(logits, yb)
                cls_opt.zero_grad()
                scaler_c.scale(loss).backward()
                scaler_c.step(cls_opt)
                scaler_c.update()
                total_loss += loss.item(); steps += 1
            avg_tr = total_loss/steps

            cls_model.eval()
            v_loss, v_steps = 0, 0
            for xb,yb in tqdm(val_ld, desc=f"Cls Val Ep{epoch}", unit="batch"):
                xb,yb = xb.to(device), yb.to(device)
                with torch.no_grad():
                     # Change autocast initialization for older PyTorch versions
                    with autocast(enabled=device.type=='cuda'):
                        logits = cls_model(xb)
                        loss = cls_loss(logits, yb)
                        v_loss += loss.item(); v_steps += 1
            avg_val = v_loss/v_steps

            torch.save(cls_model.state_dict(), f"resnetcbam_ep{epoch}.pth")
            print(f"Cls Ep{epoch} | Train:{avg_tr:.4f} | Val:{avg_val:.4f}")
    else:
        print("Skipping classification training and validation due to no extracted ROIs.")

In [None]:
train_seg()

  scaler_s = GradScaler()
  with open(pklf,'rb') as f: batch = pickle.load(f)
  with autocast():
Seg Train Ep1:   2%|▏         | 1/42 [00:32<21:52, 32.01s/file]