In [76]:
import cv2
import glob
import os
import torch
import math
import re

import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torch_data
import torchvision as tv
import numpy as np
import matplotlib.pyplot as plt

from lpips import LPIPS
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

In [77]:
class Dataset(torch_data.Dataset):
    def __init__(self, img_size, seq_len, train=True, transform=None):
        self.transform = transform
        self.img_size = img_size
        self.seq_len = seq_len

        self.elements = []
        self.elements.extend(glob.glob("/mnt/data/youtube/latent/*"))
        self.elements = sorted(self.elements)

        n_train = int(len(self.elements) * 0.8)
        if train:
            self.elements = self.elements[:n_train]
        else:
            self.elements = self.elements[n_train:]
    
    def __len__(self):
        return len(self.elements)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()

        BASE_PATH_IMG = "/mnt/data/youtube/img_latent"
        BASE_PATH_LATENT = "/mnt/data/youtube/latent"

        filename = self.elements[idx]
        filename = filename.split("/")[-1]
        filename_base = filename.split(".")[0]
        filename_parts = filename_base.split("_")
        filename = "_".join(filename_parts[:-1])
        frame_idx = filename_parts[-1]

        frame_indices = []
        filenames = glob.glob(os.path.join(BASE_PATH_LATENT, f"{filename}*"))
        for filename in filenames:
            filename = filename.split("/")[-1]
            filename_base = filename.split(".")[0]
            filename_parts = filename_base.split("_")
            filename = "_".join(filename_parts[:-1])
            frame_idx = filename_parts[-1]
            frame_idx = frame_idx.split(".")[0]
            frame_indices.append(frame_idx)
        frame_indices = np.sort(np.asarray(frame_indices, dtype=int))

        idx = np.random.randint(0, len(frame_indices) - self.seq_len - 1)
        latent_imgs = np.zeros([self.seq_len, 64, 64, 4], dtype=np.float32)
        for i in range(self.seq_len):
            frame_idx = frame_indices[idx + i]
            idx_filename = os.path.join(BASE_PATH_LATENT, f"{filename}_{frame_idx}.png")
            with open(idx_filename, "rb") as f:
                latent = np.load(f)
            latent_imgs[i] = latent

        last_idx = idx + self.seq_len
        frame_idx = frame_indices[last_idx]

        idx_filename = os.path.join(BASE_PATH_LATENT, f"{filename}_{frame_idx}.png")
        with open(idx_filename, "rb") as f:
            latent = np.load(f)
        
        filename = os.path.join(BASE_PATH_IMG, f"{filename}_{frame_idx}.png")
        img = cv2.imread(filename)

        img[:, :, [0, 1, 2]] = img[:, :, [2, 1, 0]]
        img = img.transpose(2, 0, 1)
        img = (np.asarray(img) / 128) - 1
        img = torch.from_numpy(img)

        latent_imgs = latent_imgs.transpose(0, 3, 1, 2)
        latent_imgs = np.asarray(latent_imgs)
        latent_imgs = torch.from_numpy(latent_imgs)

        latent = latent.transpose(2, 0, 1)
        latent = np.asarray(latent)
        latent = torch.from_numpy(latent)

        element = dict()
        element["img"] = img
        element["latent_seq"] = latent_imgs
        element["latent"] = latent
        
        if self.transform:
            element = self.transform(element)

        return element

In [78]:
BATCH_SIZE = 8
BATCH_SUM = 16
NUM_WORKERS = 4

N_HEADS = 12
N_LAYER = 12
N_EMB = 512
BLOCK_SIZE = 64
DROPOUT = 0.1
IMG_SIZE = 512
F_IN = 4 * 64 * 64
SEQ_LEN = 4

DEVICE = "cuda"
LR = 1e-3
WEIGHT_DECAY = 1e-5
N_EPOCHS = 100
FILE_PATH = f"weights/video.pth"
FILE_NAME_DECODER = f"weights/decoder_{IMG_SIZE}.pth"
FILE_NAME_REFINER = f"weights/refiner_{IMG_SIZE}.pth"
GRAD_CLIP = 1.0

In [97]:
train_dataset = Dataset(IMG_SIZE, SEQ_LEN, train=True)
test_dataset = Dataset(IMG_SIZE, SEQ_LEN, train=False)

In [98]:
train_loader = torch_data.DataLoader(train_dataset, shuffle=True, pin_memory=True, num_workers=4, batch_size=BATCH_SIZE)
test_loader = torch_data.DataLoader(test_dataset, shuffle=True, pin_memory=True, num_workers=4, batch_size=BATCH_SIZE)

In [99]:
print(len(train_dataset))
print(len(test_dataset))

8892
2224


In [82]:
class SelfAttention(nn.Module):
    def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x, causal_mask=False):
        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)

        q, k, v = self.in_proj(x).chunk(3, dim=-1)

        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)

        weight = q @ k.transpose(-1, -2)
        if causal_mask:
            mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
            weight.masked_fill_(mask, -torch.inf)
        weight /= math.sqrt(self.d_head)
        weight = F.softmax(weight, dim=-1)

        output = weight @ v
        output = output.transpose(1, 2)
        output = output.reshape(input_shape)
        output = self.out_proj(output)
        return output
class AttentionBlock(nn.Module):
    
    def __init__(self, channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)
    
    def forward(self, x):
        residue = x
        x = self.groupnorm(x)

        n, c, h, w = x.shape
        x = x.view((n, c, h * w))
        x = x.transpose(-1, -2)
        x = self.attention(x)
        x = x.transpose(-1, -2)
        x = x.view((n, c, h, w))

        x += residue
        return x
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm_1 = nn.GroupNorm(32, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.groupnorm_2 = nn.GroupNorm(32, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, x):
        residue = x

        x = self.groupnorm_1(x)
        x = F.silu(x)
        x = self.conv_1(x)

        x = self.groupnorm_2(x)
        x = F.silu(x)
        x = self.conv_2(x)

        return x + self.residual_layer(residue)
class Decoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(4, 4, kernel_size=1, padding=0),
            nn.Conv2d(4, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 512),
            AttentionBlock(512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 256),
            ResidualBlock(256, 256),
            ResidualBlock(256, 256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ResidualBlock(256, 128),
            ResidualBlock(128, 128),
            ResidualBlock(128, 128),
            nn.GroupNorm(32, 128),
            nn.SiLU(),
            nn.Conv2d(128, 3, kernel_size=3, padding=1),
        )

    def forward(self, x):
        x /= 0.18215
        for module in self:
            x = module(x)
        return x
class Refiner(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(3, 128, kernel_size=2, padding=1),
            ResidualBlock(128, 128),
            ResidualBlock(128, 128),
            nn.GroupNorm(32, 128),
            nn.SiLU(),
            nn.Conv2d(128, 3, kernel_size=2, padding=0),
        )

    def forward(self, x):
        for module in self:
            x = module(x)
        return x

In [83]:
class Head(nn.Module):
    def __init__(self, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)

        if mask is not None:
            wei = wei.masked_fill(mask == 0, float('-inf')) # (B, T, T)
        
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        out = torch.cat([h(x, mask) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
class FeedFoward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
        self.ffwd = FeedFoward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, mask):
        x = x + self.sa(self.ln1(x), mask)
        x = x + self.ffwd(self.ln2(x))
        return x
class Model(nn.Module):
    def __init__(self, f_in, n_embd, block_size, n_head, n_layer, dropout):
        super().__init__()
        self.block_size = block_size
        self.embedding = nn.Linear(f_in, 512)
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, f_in)
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):

        B, SEQ, C, H, W = idx.shape
        idx = idx.view(B, SEQ, -1)
        
        device = idx.device
        mask = None

        emb = self.embedding(idx)
        x = self.dropout(emb)
        for i in range(len(self.blocks)):
            x = self.blocks[i](x, mask)
        x = x[:, -1]
        x = self.ln_f(x)
        logits = self.lm_head(x)
        logits = logits.view(B, C, H, W)
        return logits

    def generate(self, idx, max_new_tokens=1):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits = self(idx_cond)
            next_logits = logits[:, -1, :]
            probs = F.softmax(next_logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [84]:
def criterion(pred, img):
    loss_mse = nn.MSELoss()(pred, img)
    
    loss = 0
    loss += loss_mse

    return loss
def train():
    model.train()
    
    iterator = iter(train_loader)
    N = int(np.floor(len(train_loader) / BATCH_SUM))
    sum_loss = 0
    count = 0
    for i in range(N):
        for j in range(BATCH_SUM):
            batch = next(iterator)
            
            latent_seq = batch["latent_seq"].to(DEVICE)
            latent = batch["latent"].to(DEVICE)
            img = batch["img"].to(DEVICE)
            
            pred = model(latent_seq)
            
            loss = criterion(pred, latent)
            loss = loss / BATCH_SUM
            scaler.scale(loss).backward()
            
            sum_loss += loss.item()
            count += 1
            
            print(f"\r{i + 1:06}/{N:06} | {j + 1:03}/{BATCH_SUM:03} loss: {(sum_loss / count) * BATCH_SUM}", end="")

        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        if ((i + 1) % 10) == 0:
            torch.save(model.state_dict(), FILE_PATH)
            sum_loss = 0
            count = 0
            print()

        if ((i + 1) % 10) == 0:
            show(N=1, M=1)
    
    print()
@torch.no_grad()
def show(N=1, M=1):
    model.eval()

    batch = next(iter(test_loader))
    
    latent_seq = batch["latent_seq"].to(DEVICE)
    latent = batch["latent"].to(DEVICE)
    img = batch["img"].to(DEVICE)
    
    pred = model(latent_seq)

    BS, SEQ, D, H, W = latent_seq.shape
    latent = F.interpolate(latent, (IMG_SIZE, IMG_SIZE), mode="bilinear")
    pred = F.interpolate(pred, (IMG_SIZE, IMG_SIZE), mode="bilinear")
    latent_seq = F.interpolate(latent_seq.view(BS * SEQ, D, H, W), (IMG_SIZE, IMG_SIZE), mode="bilinear").view(BS, SEQ, D, IMG_SIZE, IMG_SIZE)
    
    latent_seq = latent_seq.detach().cpu().numpy()
    latent = latent.detach().cpu().numpy()
    img = img.detach().cpu().numpy()
    pred = pred.detach().cpu().numpy()

    H = 20
    W = int((H / 2.5) * (N / M))
    fig, axes = plt.subplots(N, M, figsize=(H, W))

    if N == 1 and M == 1:
        axes = [axes]
    elif N == 1:
        axes = [axes]
    if M == 1:
        axes = [axes]               
                             
    for n in range(N):
        for m in range(M):
            idx = n * M + m

            latent_seq_ = []
            for element in latent_seq[idx]:
                element = element - element.min()
                element = element / element.max()
                element = element[:3]
                latent_seq_.append(element)

            latent_ = latent[idx]
            latent_ = latent_ - latent_.min()
            latent_ = latent_ / latent_.max()
            latent_ = latent_[:3]
            
            pred_ = pred[idx]
            pred_ = pred_ - pred_.min()
            pred_ = pred_ / pred_.max()
            pred_ = pred_[:3]

            img_ = img[idx]
            img_ = (img_ + 1) / 2

            filler = latent_.copy()
            filler[:] = 0
            
            img_1 = np.concatenate([img_, pred_, latent_, filler], axis=-1).transpose(1, 2, 0)
            img_2 = np.concatenate(latent_seq_, axis=-1).transpose(1, 2, 0)
            img_ = np.concatenate([img_1, img_2], axis=0)
            
            axes[n][m].imshow(img_)
    plt.show()

In [87]:
model = Model(F_IN, N_EMB, BLOCK_SIZE, N_HEADS, N_LAYER, DROPOUT).to(DEVICE)
# model.load_state_dict(torch.load(FILE_PATH))
print(f"model: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}".replace(",", "."))
print()

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=True)

batch = next(iter(train_loader))
latent_seq = batch["latent_seq"].to(DEVICE)
latent = batch["latent"].to(DEVICE)
img = batch["img"].to(DEVICE)

with torch.no_grad():
    pred = model(latent_seq)

loss = criterion(pred, latent)
print(loss)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x78c7290651c0>
Traceback (most recent call last):
  File "/home/henning/tmp/experiments/ai/stable_diffusion/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/henning/tmp/experiments/ai/stable_diffusion/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/connection.py", line 947, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/s

KeyboardInterrupt: 

In [None]:
show(N=1, M=1)

In [None]:
while True:
    try:
        n_epoch = 0
        for n_epoch in range(n_epoch, N_EPOCHS):
            print(f"{n_epoch + 1}|{N_EPOCHS}")
            show(N=1, M=1)
            train()
            torch.save(model.state_dict(), FILE_PATH)
    except Exception as e:
        print(e)
        n_epoch = 0
        torch.save(model.state_dict(), FILE_PATH)

In [None]:
torch.save(model.state_dict(), FILE_PATH)