In [1]:
%cd /home/yokoyama/research
from types import SimpleNamespace
import sys
import os
from glob import glob

import cv2
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

sys.path.append(".")
from modules.utils.video import Capture, Writer
from modules.pose import PoseDataHandler


/raid6/home/yokoyama/research


In [2]:
from submodules.i3d.pytorch_i3d import InceptionI3d
from torchvision.ops import RoIAlign
from torchvision.transforms import ToTensor


In [3]:
video_num = 1
cap = Capture(f"/raid6/home/yokoyama/datasets/dataset01/train/{video_num:02d}.mp4")
pose_data = PoseDataHandler.load(f"data/dataset01/train/{video_num:02d}", ["bbox"])
flows_raw = np.load("data/dataset01/train/01/bin/flow.npy")


In [4]:
from typing import List, Dict, Any
import time


class Dataset(Dataset):
    def __init__(self, caps: List[Capture], flows_lst: List[NDArray], pose_data: List[Dict[str, Any]] , seq_len: int, resize_ratio: float):
        self._default_float_dtype = torch.get_default_dtype()
        self._seq_len = seq_len
        self._resize_ratio = resize_ratio
        self._frames = []
        self._flows = []
        self._bboxs = []
        self._max_bboxs_num = 0
        self._create_dataset(caps, flows_lst, pose_data)

        self._frames = self._transform_imgs(self._frames)
        self._flows = self._transform_imgs(self._flows)

    def _create_dataset(self, caps: List[Capture], flows_lst: List[NDArray], pose_datas: List[List[Dict[str, Any]]]):
        for cap, flows, pose_data in zip(tqdm(caps, ncols=100), flows_lst, pose_datas):
            self._load_frames(cap)
            self._resize_flows(flows)
            self._load_bbox(pose_data, cap.frame_count)

        # calc max number of bboxs in each frame
        for bboxs in self._bboxs:
            if len(bboxs) > self._max_bboxs_num:
                self._max_bboxs_num = len(bboxs)

    def _load_frames(self, cap):
        frames = []
        for _ in tqdm(range(cap.frame_count), ncols=100, leave=False):
        # for _ in tqdm(range(100), ncols=100):
            frame = cap.read()[1]
            frame = cv2.resize(frame, None, fx=self._resize_ratio, fy=self._resize_ratio)
            frames.append(frame)
        self._frames += frames

    def _resize_flows(self, flows):
        flows_resized = []
        for flow in tqdm(flows, ncols=100, leave=False):
            flows_resized.append(cv2.resize(flow, None, fx=self._resize_ratio, fy=self._resize_ratio))
        self._flows += flows_resized

    def _load_bbox(self, pose_data: List[Dict[str, Any]], frame_count: int):
        for frame_num in tqdm(range(1, frame_count + 1), ncols=100, leave=False):
        # for frame_num in tqdm(range(1, 100 + 1), ncols=100, leave=False):
            bboxs = [
                np.array(data["bbox"]) * self._resize_ratio for data in pose_data
                if data["frame"] == frame_num
            ]
            bboxs = np.array(bboxs)
            self._bboxs.append(bboxs)

    def _transform_imgs(self, imgs):
        # imgs = np.array(imgs)
        imgs = torch.tensor(np.array(imgs).transpose((0, 3, 1, 2)), dtype=self._default_float_dtype).contiguous()
        # imgs = imgs.to(dtype=self._default_float_dtype)
        if isinstance(imgs, torch.ByteTensor):
            return (imgs /255.) * 2 - 1
        else:
            return imgs

    @property
    def n_samples(self):
        return len(self) * self._max_bboxs_num

    @property
    def n_samples_batch(self):
        return self._max_bboxs_num

    def __len__(self):
        return len(self._frames) - self._seq_len + 1

    def __getitem__(self, idx):
        frames_seq = self._frames[idx:idx + self._seq_len].transpose(1, 0)
        flows_seq = self._flows[idx:idx + self._seq_len].transpose(1, 0)
        bboxs = self._bboxs[idx + (self._seq_len) // 2 + 1]
        # append dmy bboxs
        if len(bboxs) < self._max_bboxs_num:
            diff_num = self._max_bboxs_num - len(bboxs)
            dmy_bboxs = [np.full((4,), np.nan) for _ in range(diff_num)]
            bboxs = np.append(bboxs, dmy_bboxs, axis=0)
        bboxs = torch.Tensor(bboxs)
        return frames_seq, flows_seq, bboxs, idx


In [5]:
seq_len = 30
resize_ratio = 0.5
device = "cuda:9"
batch_size = 64
dataset = Dataset([cap], [flows_raw], [pose_data], seq_len, resize_ratio)
del flows_raw, pose_data, cap


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:27<00:00, 27.91s/it]


In [6]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [7]:
batch = next(iter(dataloader))


In [8]:
for data in batch:
    print(data.shape)


torch.Size([64, 3, 10, 470, 640])
torch.Size([64, 2, 10, 470, 640])
torch.Size([64, 7, 4])
torch.Size([64])


In [9]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self._i3d_frame = InceptionI3d(in_channels=3, final_endpoint="Mixed_3c")
        self._i3d_frame.build()
        self._i3d_flow = InceptionI3d(in_channels=2, final_endpoint="Mixed_3c")
        self._i3d_flow.build()
        self._roi_align = RoIAlign(5, 0.125, 1, aligned=True)

    def forward(self, frames, flows, bboxs):
        # forward i3d
        for end_point in self._i3d_frame.VALID_ENDPOINTS:
            if end_point in self._i3d_frame.end_points:
                frames = self._i3d_frame._modules[end_point](frames)
                flows = self._i3d_flow._modules[end_point](flows)
        f = frames + flows

        # format bbox
        h, w = frames.shape[3:5]
        fy, fx = f.shape[3:5]
        b = bboxs.shape[0]
        bboxs = bboxs.view(-1, 2, 2)
        bboxs *= torch.Tensor((fx / w, fy / h))
        bboxs = bboxs.view(b, -1, 4)
        bboxs = self._convert_bboxes_to_roi_format(bboxs)
        bboxs = bboxs.to(torch.float32)

        # roi align
        return self._roi_align(f, bboxs)

    def _convert_bboxes_to_roi_format(self, boxes: torch.Tensor) -> torch.Tensor:
        concat_boxes = torch.cat([b for b in boxes], dim=0)
        temp = []
        for i, b in enumerate(boxes):
            temp.append(torch.full_like(b[:, :1], i))
        ids = torch.cat(temp, dim=0)
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

class Decoder(nn.Module):
    def __init__(self, ngf=64, nc=5):
        super().__init__()
        self.net = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(480, ngf * 8, 4, 3, (2, 0), bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.LeakyReLU(0.1, True),
            # state size. ``(ngf*8) x 12 x 16``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.LeakyReLU(0.1, True),
            # state size. ``(ngf*4) x 24 x 32``
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.LeakyReLU(0.1, True),
            # state size. ``(ngf*2) x 48 x 64``
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.LeakyReLU(0.1, True),
            # state size. ``(ngf) x 96 x 128``
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
            # state size. ``(nc) x 192 x 256``
        )

    def forward(self, z):
        n = z.shape[0]
        out = self.net(z)
        out = out.view(n, 5, 192, 256)
        return out[:, :3,], out[:, 3:]  # frame, flow


class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self._encoder = Encoder()
        self._decoder = Decoder()

    @property
    def E(self):
        return self._encoder

    @property
    def D(self):
        return self._decoder

    def forward(self, frames, flows, bboxs):
        z = self._encoder(frames, flows, bboxs)
        frames_d, flows_d = self._decoder(z)

        # adjust shapes
        b, n = bboxs.shape[:2]
        c, sy, sx = z.shape[1:]
        z = z.view(b, n, c, sy, sx)
        frames_d = frames_d.view(b, n, 3, 192, 256)
        flows_d = flows_d.view(b, n, 2, 192, 256)

        return z, frames_d, flows_d


class ClusteringModule(nn.Module):
    def __init__(self, n_clusters, n_samples):
        super().__init__()
        self._n_clusters = n_clusters
        self._n_samples = n_samples
        self._t_alpha = 1
        self._dz = 20
        self._centroids = nn.ParameterList(
            [nn.Parameter(torch.randn((self._dz), dtype=torch.float32)) for _ in range(n_clusters)]
        )
        self._target_distribution = None
        self.clear_target_disribution()

        self._emb = nn.Sequential(
            nn.Flatten(),
            nn.Linear(480 * 5 * 5, self._dz),
        )

    @property
    def centroids(self):
        return self._centroids

    @property
    def target_distribution(self):
        return self._target_distribution

    def forward(self, z):
        b, sn = z.shape[:2]
        z = z.view(b * sn, -1)
        z = self._emb(z)
        s = self._student_t(z)
        s = s.view(b, -1, self._n_clusters)
        c = s.argmax(dim=2)
        return s, c

    def _student_t(self, z):
        sn = z.shape[0]
        norm = torch.full((sn, self._n_clusters), torch.nan, dtype=torch.float32)
        for j in range(self._n_clusters):
            norm[:, j] = torch.linalg.vector_norm(z - self._centroids[j], dim=1)

        s = torch.full((sn, self._n_clusters), torch.nan, dtype=torch.float32)
        for i in range(sn):
            s[i] = ((1 + norm[i]) / self._t_alpha)**-((self._t_alpha + 1) / 2)
        s = (s.T / s.sum(dim=1)).T

        return s

    def clear_target_disribution(self):
        self._target_distribution = torch.full((self._n_samples, self._n_clusters), torch.nan)

    def update_target_distribution(self, s, batch_idxs):
        sample_nums = s.shape[1]
        s = s.view(-1, self._n_clusters)
        s_sums = s.nan_to_num(0).sum(dim=0)  # Sigma_i s_ij (n_clusters,)

        for i, batch_idx in enumerate(batch_idxs):
            for sn in range(sample_nums):
                ti = batch_idx * self._n_clusters + sn  # target idx
                si = i * self._n_clusters + sn  # soft idx
                for j in range(self._n_clusters):
                    sij = s[si, j]
                    self._target_distribution[ti, j] = sij**2 / s_sums[j]
                self._target_distribution[ti, j] /= self._target_distribution[ti, j].sum(dim=0)


In [10]:
n_clusters = 5
ae = Autoencoder().to(device)
cm = ClusteringModule(n_clusters, dataset.n_samples).to(device)

loss_r = nn.MSELoss().to(device)
loss_c = nn.KLDivLoss(reduction="sum").to(device)

optim_e = torch.optim.Adam(ae.E.parameters(), 0.0001)
optim_d = torch.optim.Adam(ae.D.parameters(), 0.0001)
optim_c = torch.optim.Adam(cm.parameters(), 0.0001)


In [11]:
epoch_num = 100
update_interval = 10

history = {
    "lr": [],
    "lc": [],
    "le": [],
    "c": [],
}
for epoch in tqdm(range(epoch_num), ncols=100):
    ae.train()
    cm.train()
    c_epoch = torch.full((dataset.n_samples,), torch.nan).to(device)
    for batch_idx, (frames_batch, flows_batch, bboxs_batch, batch_idxs) in enumerate(tqdm(dataloader, ncols=100, leave=False)):
        frames_batch, flows_batch = frames_batch.to(device), flows_batch.to(device)
        bboxs_batch, batch_idxs = bboxs_batch.to(device), batch_idxs.to(device)

        optim_e.zero_grad()
        optim_d.zero_grad()
        optim_c.zero_grad()

        z, frames_out, flows_out = ae(frames_batch, flows_batch, bboxs_batch)
        s, c = cm(z)

        for i, batch_idx in enumerate(batch_idxs):
            for j in range(dataset.n_samples_batch):
                idx = batch_idx * dataset.n_samples_batch + j
                c_epoch[idx] = c[i, j]

        lr_total = 0
        for i in range(batch_size):
            for j in range(dataset.n_samples_batch):
                bx = bboxs_batch[i, j]
                x1, y1, x2, y2 = bx
                frame_bbox = frames_batch[i, seq_len // 2 + 1, y1:y2, x1:x2]
                lr_total += loss_r(frames_out[i, j], frame_bbox)
        lr_total.backward()
        optim_d.step()

        lc_total = 0
        for i, batch_idx in enumerate(batch_idxs):
            idx = batch_idx * dataset.n_samples_batch
            tmp_target = cm.target_distribution[idx:idx + dataset.n_samples_batch]
            tmp_target = torch.nan_to_num(tmp_target, 0)
            s_tmp = torch.nan_to_num(s[i], 0)
            lc_total += loss_c(s_tmp.log(), tmp_target)
        lc_total.backward()

        optim_c.step()

        le = lr_total + lc_total
        le.backward()
        optim_e.step()

        history["lr"].append(lr_total.cpu())
        history["lc"].append(lc_total.cpu())
        history["le"].append(le.cpu())

    if epoch % update_interval == 0:
        cm.update_target_distribution()
        history["c"].append(c)
    tqdm.write(f"epoch:{epoch}, lr:{lr_total:04f}, lc:{lc_total:04f}, le:{le:04f}")


  0%|                                                                       | 0/100 [00:01<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 5.74 GiB (GPU 9; 10.75 GiB total capacity; 6.91 GiB already allocated; 2.93 GiB free; 6.93 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF