In [None]:
!wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow.zip

--2025-10-31 13:25:31--  https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow.zip
Resolving s3.eu-central-1.amazonaws.com (s3.eu-central-1.amazonaws.com)... 3.5.135.193, 3.5.135.93, 3.5.139.216, ...
Connecting to s3.eu-central-1.amazonaws.com (s3.eu-central-1.amazonaws.com)|3.5.135.193|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1681488619 (1.6G) [application/zip]
Saving to: ‘data_scene_flow.zip’


2025-10-31 13:30:16 (5.65 MB/s) - ‘data_scene_flow.zip’ saved [1681488619/1681488619]



In [None]:
!mkdir data
!unzip data_scene_flow.zip -d data

Archive:  data_scene_flow.zip
   creating: data/training/
   creating: data/training/disp_noc_1/
 extracting: data/training/disp_noc_1/000033_10.png  
 extracting: data/training/disp_noc_1/000076_10.png  
 extracting: data/training/disp_noc_1/000194_10.png  
 extracting: data/training/disp_noc_1/000137_10.png  
 extracting: data/training/disp_noc_1/000110_10.png  
 extracting: data/training/disp_noc_1/000188_10.png  
 extracting: data/training/disp_noc_1/000015_10.png  
 extracting: data/training/disp_noc_1/000155_10.png  
 extracting: data/training/disp_noc_1/000128_10.png  
 extracting: data/training/disp_noc_1/000162_10.png  
 extracting: data/training/disp_noc_1/000144_10.png  
 extracting: data/training/disp_noc_1/000052_10.png  
 extracting: data/training/disp_noc_1/000121_10.png  
 extracting: data/training/disp_noc_1/000193_10.png  
 extracting: data/training/disp_noc_1/000171_10.png  
 extracting: data/training/disp_noc_1/000101_10.png  
 extracting: data/training/disp_noc_1/0

In [None]:
import os, cv2, numpy as np
import time
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

## Load & Process data

In [None]:
class Kitti15Stereo(Dataset):
    def __init__(self, root, split='train', resize=(256, 512), augment=False):
        lefts = sorted([os.path.join(root, 'training', 'image_2', f) for f in
                             os.listdir(os.path.join(root, 'training', 'image_2')) if f.endswith('.png')])
        rights = [p.replace('image_2', 'image_3') for p in lefts]
        disps = [p.replace('image_2', 'disp_occ_1') for p in lefts]

        self.samples = []
        for l, r, d in zip(lefts, rights, disps):
            if os.path.exists(d):
                self.samples.append({'left': l, 'right': r, 'disp': d})
            else:
                print(f"Warning: Skipping sample due to missing disparity image: {d}")


        self.resize = resize
        self.augment = augment

    def __len__(self):
        return len(self.samples)

    def _read_img(self, p):
        img = cv2.imread(p)[:, :, ::-1]  # BGR->RGB

        if self.resize:
            img = cv2.resize(img, (self.resize[1], self.resize[0]),
                             interpolation=cv2.INTER_AREA)
        return img

    def _read_disp(self, p):
        disp = cv2.imread(p, cv2.IMREAD_UNCHANGED)

        # KITTI disparity lưu 256*disparity
        disp = disp.astype(np.float32) / 256.0
        if self.resize:
            disp = cv2.resize(disp, (self.resize[1], self.resize[0]),
                              interpolation=cv2.INTER_NEAREST)
        return disp

    def __getitem__(self, i):
        sample = self.samples[i]
        L = self._read_img(sample['left'])
        R = self._read_img(sample['right'])
        D = self._read_disp(sample['disp'])

        # chuẩn hóa [0,1]
        L = (L / 255.0).astype(np.float32)
        R = (R / 255.0).astype(np.float32)
        # HWC->CHW
        L = np.transpose(L, (0, 1, 2))[..., :]
        R = np.transpose(R, (0, 1, 2))[..., :]
        L = np.moveaxis(L, -1, 0)
        R = np.moveaxis(R, -1, 0)
        return {'left': L, 'right': R, 'disp': D}

## Define Model

In [None]:
class Feature(nn.Module):
    def __init__(self, channels=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),  # /2
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
        )

    def forward(self, x):
        return self.net(x)


class Cost3D(nn.Module):
    def __init__(self, max_disp=96, c=32):
        super().__init__()
        self.max_disp = max_disp
        self.agg = nn.Sequential(
            nn.Conv3d(in_channels=2 * c, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.Conv3d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, fL, fR):
        B, C, H, W = fL.shape
        costs = []
        for d in range(self.max_disp):
            if d > 0:
                cL = fL[:, :, :, d:]
                cR = fR[:, :, :, :W - d]
            else:
                cL = fL;
                cR = fR
            pad = (0, d, 0, 0)  # pad right width by d
            cL = F.pad(cL, pad)
            cR = F.pad(cR, (d, 0, 0, 0))
            costs.append(torch.cat([cL, cR], 1))
        cost = torch.stack(costs, dim=2)  # B x 2C x D x H x W
        out = self.agg(cost).squeeze(1)  # B x D x H x W
        prob = F.softmax(-out, dim=1)
        disp = torch.sum(prob * torch.arange(self.max_disp,
                                             device=prob.device)[None, :, None, None], dim=1)
        return disp


class MiniPSM(nn.Module):
    def __init__(self, max_disp=96, c=32):
        super().__init__()
        self.feat = Feature(c)
        self.cost = Cost3D(max_disp, c)

    def forward(self, L, R):
        fL = self.feat(L)
        fR = self.feat(R)
        return self.cost(fL, fR)

In [None]:
def epe_loss(pred, gt, mask=None):
    if mask is None: mask = (gt > 0)
    return torch.mean(torch.abs(pred[mask] - gt[mask]))


train_ds = Kitti15Stereo('/content/data/')
train_ld = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MiniPSM(max_disp=96, c=32).to(device)
opt = optim.Adam(model.parameters(), lr=2e-4)
print(model)





MiniPSM(
  (feat): Feature(
    (net): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ReLU(inplace=True)
    )
  )
  (cost): Cost3D(
    (agg): Sequential(
      (0): Conv3d(64, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU(inplace=True)
      (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): ReLU(inplace=True)
      (4): Conv3d(32, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    )
  )
)


## Traning

In [None]:
num_epoch = 5
for epoch in range(num_epoch):
    start_time = time.perf_counter()
    model.train()
    losses = []
    for b in train_ld:
        L = torch.tensor(b['left']).to(device)
        R = torch.tensor(b['right']).to(device)
        D = torch.tensor(b['disp']).to(device)
        opt.zero_grad()
        P = model(L, R)
        loss = epe_loss(P, D)
        loss.backward()
        opt.step()
        losses.append(loss.item())
    end_time = time.perf_counter()
    print(f"Epoch {epoch + 1} / {num_epoch} | "
          f"Loss: {sum(losses) / len(losses):.4f} | "
          f"Training time: {end_time - start_time}")

  L = torch.tensor(b['left']).to(device)
  R = torch.tensor(b['right']).to(device)
  D = torch.tensor(b['disp']).to(device)


Epoch 1 / 5 | Loss: 18.0483 | Training time: 664.706567427
Epoch 2 / 5 | Loss: 16.3127 | Training time: 664.0912839550001
Epoch 3 / 5 | Loss: 15.3379 | Training time: 664.50224495
Epoch 4 / 5 | Loss: 15.0644 | Training time: 664.876488418
Epoch 5 / 5 | Loss: 15.0418 | Training time: 663.6434453010002
