In [1]:
import os

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

# from shared.collator import zero_pad_collator
# from shared.tokenizers import HamNoSysTokenizer
# from data import get_dataset
# from model import IterativeTextGuidedPoseGenerationModel

  from .autonotebook import tqdm as notebook_tqdm


## Collator

In [2]:
from typing import Dict, List, Tuple, Union, TypedDict

import numpy as np
import torch
from pose_format.torch.masked import MaskedTensor, MaskedTorch


In [3]:
def collate_tensors(batch: List, pad_value=0):
    datum = batch[0]

    if isinstance(datum, dict):  # Recurse over dictionaries
        return zero_pad_collator(batch)

    if isinstance(datum, (int, np.int32)):
        return torch.tensor(batch, dtype=torch.long)

    if isinstance(datum, (MaskedTensor, torch.Tensor)):
        max_len = max(len(t) for t in batch)
        if max_len == 1:
            return torch.stack(batch)

        torch_cls = MaskedTorch if isinstance(datum, MaskedTensor) else torch

        new_batch = []
        for tensor in batch:
            missing = list(tensor.shape)
            missing[0] = max_len - tensor.shape[0]

            if missing[0] > 0:
                padding_tensor = torch.full(missing, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device)
                tensor = torch_cls.cat([tensor, padding_tensor], dim=0)

            new_batch.append(tensor)

        return torch_cls.stack(new_batch, dim=0)

    return batch

def zero_pad_collator(batch) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor]]:
    datum = batch[0]

    # For strings
    if isinstance(datum, str):
        return batch

    # For tuples
    if isinstance(datum, tuple):
        return tuple(collate_tensors([b[i] for b in batch]) for i in range(len(datum)))

    # For dictionaries
    keys = datum.keys()
    return {k: collate_tensors([b[k] for b in batch]) for k in keys}

In [4]:
from pathlib import Path

from fontTools.ttLib import TTFont

class BaseTokenizer:

    def __init__(self,
                 tokens: List[str],
                 starting_index=None,
                 init_token="[CLS]",
                 eos_token="[SEP]",
                 pad_token="[PAD]",
                 unk_token="[UNK]"):
        if starting_index is None:
            starting_index = 4

        self.pad_token = pad_token
        self.bos_token = init_token
        self.eos_token = eos_token
        self.unk_token = unk_token

        self.i2s = {(i + starting_index): c for i, c in enumerate(tokens)}
        # Following the same ID scheme as JoeyNMT
        self.i2s[0] = self.unk_token
        self.i2s[1] = self.pad_token
        self.i2s[2] = self.bos_token
        self.i2s[3] = self.eos_token
        self.s2i = {c: i for i, c in self.i2s.items()}

        self.pad_token_id = self.s2i[self.pad_token]
        self.bos_token_id = self.s2i[self.bos_token]
        self.eos_token_id = self.s2i[self.eos_token]
        self.unk_token_id = self.s2i[self.unk_token]

    def __len__(self):
        return len(self.i2s)

    def vocab(self):
        return list(self.i2s.values())

    def text_to_tokens(self, text: str) -> List[str]:
        raise NotImplementedError()

    def tokens_to_text(self, tokens: List[str]) -> str:
        raise NotImplementedError()

    def tokenize(self, text: str, bos=False, eos=False):
        tokens = [self.s2i[c] for c in self.text_to_tokens(text)]
        if bos:
            tokens.insert(0, self.bos_token_id)
        if eos:
            tokens.append(self.eos_token_id)

        return tokens

    def detokenize(self, tokens: List[int]):
        if len(tokens) == 0:
            return ""
        if tokens[0] == self.bos_token_id:
            tokens = tokens[1:]
        if tokens[-1] == self.eos_token_id:
            tokens = tokens[:-1]

        try:
            padding_index = tokens.index(self.pad_token_id)
            tokens = tokens[:padding_index]
        except ValueError:
            pass

        return self.tokens_to_text([self.i2s[t] for t in tokens])

    def __call__(self, texts, is_tokenized=False, device=None):
        if not is_tokenized:
            all_tokens = [self.tokenize(text) for text in texts]
        else:
            all_tokens = texts.tolist()

        tokens_batch = zero_pad_collator([{
            "tokens_ids": torch.tensor(tokens, dtype=torch.long, device=device),
            "attention_mask": torch.ones(len(tokens), dtype=torch.bool, device=device),
            "positions": torch.arange(0, len(tokens), dtype=torch.int, device=device)
        } for tokens in all_tokens])
        # In transformers, 1 is mask, not 0
        tokens_batch["attention_mask"] = torch.logical_not(tokens_batch["attention_mask"])

        return tokens_batch

In [5]:
class HamNoSysTokenizer(BaseTokenizer):

    def __init__(self, starting_index=None, **kwargs):
        self.font_path = "./shared/tokenizers/hamnosys/HamNoSysUnicode.ttf"

        with TTFont(self.font_path) as font:
            tokens = [chr(key) for key in font["cmap"].getBestCmap().keys()]

        super().__init__(tokens=tokens, starting_index=starting_index, **kwargs)

    def text_to_tokens(self, text: str) -> List[str]:
        return list(text)

    def tokens_to_text(self, tokens: List[str]) -> str:
        return "".join(tokens)

In [6]:
# from text_to_pose.data import get_dataset
# from text_to_pose.model import IterativeTextGuidedPoseGenerationModel

## Data Preparation

In [7]:
from pose_format import Pose
from torch.utils.data import Dataset
from shared.tfds_dataset import ProcessedPoseDatum, get_tfds_dataset

In [8]:
class TextPoseDatum(TypedDict):
    id: str
    text: str
    pose: Pose
    length: int


class TextPoseDataset(Dataset):

    def __init__(self, data: List[TextPoseDatum]):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        datum = self.data[index]
        pose = datum["pose"]

        torch_body = pose.body.torch()
        pose_length = len(torch_body.data)

        return {
            "id": datum["id"],
            "text": datum["text"],
            "pose": {
                "obj": pose,
                "data": torch_body.data.tensor[:, 0, :, :],
                "confidence": torch_body.confidence[:, 0, :],
                "length": torch.tensor([pose_length], dtype=torch.float),
                "inverse_mask": torch.ones(pose_length, dtype=torch.int8)
            }
        }


def process_datum(datum: ProcessedPoseDatum) -> TextPoseDatum:
    text = datum["tf_datum"]["hamnosys"].numpy().decode('utf-8').strip()
    pose: Pose = datum["pose"]

    # Prune all leading frames containing only zeros
    for i in range(len(pose.body.data)):
        if pose.body.confidence[i].sum() != 0:
            if i != 0:
                pose.body.data = pose.body.data[i:]
                pose.body.confidence = pose.body.confidence[i:]
            break

    return {"id": datum["id"], "text": text, "pose": pose, "length": max(len(pose.body.data), len(text))}


In [9]:
def get_dataset(name="dicta_sign",
                poses="holistic",
                fps=25,
                split="train",
                components: List[str] = None,
                data_dir=None,
                max_seq_size=1000):
    data = get_tfds_dataset(name=name, poses=poses, fps=fps, split=split, components=components, data_dir=data_dir)

    data = [process_datum(d) for d in data]
    data = [d for d in data if d["length"] < max_seq_size]

    return TextPoseDataset(data)

In [10]:
from typing import List

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn

In [11]:
from shared.models.pose_encoder import PoseEncoderModel

In [12]:
def masked_loss(loss_type, pose: torch.Tensor, pose_hat: torch.Tensor, confidence: torch.Tensor):
    # Loss by confidence. If missing joint, no loss. If less likely joint, less gradients.
    if loss_type == 'l1':
        error = torch.abs(pose - pose_hat).sum(-1)
    elif loss_type == 'l2':
        error = torch.pow(pose - pose_hat, 2).sum(-1)
    else:
        raise NotImplementedError()
    return (error * confidence).mean()


class DistributionPredictionModel(nn.Module):

    def __init__(self, input_size: int):
        super().__init__()

        self.fc_mu = nn.Linear(input_size, 1)
        self.fc_var = nn.Linear(input_size, 1)

    def forward(self, x: torch.Tensor):
        mu = self.fc_mu(x)
        if not self.training:  # In test time, just predict the mean
            return mu

        log_var = self.fc_var(x)
        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        return q.rsample()


class IterativeTextGuidedPoseGenerationModel(pl.LightningModule):

    def __init__(self,
                 tokenizer,
                 pose_dims: (int, int) = (137, 2),
                 hidden_dim: int = 128,
                 text_encoder_depth=2,
                 pose_encoder_depth=4,
                 encoder_heads=2,
                 encoder_dim_feedforward=2048,
                 max_seq_size: int = 1000,
                 loss_type='l1'):
        super().__init__()

        self.tokenizer = tokenizer
        self.max_seq_size = max_seq_size

        # Embedding layers
        self.positional_embeddings = nn.Embedding(num_embeddings=max_seq_size, embedding_dim=hidden_dim)

        self.embedding = nn.Embedding(
            num_embeddings=len(tokenizer),
            embedding_dim=hidden_dim,
            padding_idx=tokenizer.pad_token_id,
        )

        self.pose_encoder = PoseEncoderModel(pose_dims=pose_dims,
                                             hidden_dim=hidden_dim,
                                             encoder_depth=pose_encoder_depth,
                                             encoder_heads=encoder_heads,
                                             encoder_dim_feedforward=encoder_dim_feedforward,
                                             max_seq_size=max_seq_size,
                                             dropout=0)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim,
                                                   nhead=encoder_heads,
                                                   dim_feedforward=encoder_dim_feedforward,
                                                   batch_first=True)
        self.text_encoder = nn.TransformerEncoder(encoder_layer, num_layers=text_encoder_depth)

        # Predict sequence length
        self.seq_length = DistributionPredictionModel(hidden_dim)

        # Predict pose difference
        self.pose_diff_projection = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.pose_encoder.pose_dim),
        )

        # Loss
        self.loss_type = loss_type

    def encode_text(self, texts: List[str]):
        tokenized = self.tokenizer(texts, device=self.device)
        positional_embedding = self.positional_embeddings(tokenized["positions"])
        embedding = self.embedding(tokenized["tokens_ids"]) + positional_embedding
        encoded = self.text_encoder(embedding, src_key_padding_mask=tokenized["attention_mask"])
        seq_length = self.seq_length(torch.mean(encoded, dim=1))
        return {"data": encoded, "mask": tokenized["attention_mask"]}, seq_length

    def refine_pose_sequence(self, pose_sequence, text_encoding):
        batch_size, seq_length, _, _ = pose_sequence["data"].shape
        pose_encoding = self.pose_encoder(pose=pose_sequence, additional_sequence=text_encoding)
        pose_encoding = pose_encoding[:, :seq_length, :]

        # Predict desired change
        flat_pose_projection = self.pose_diff_projection(pose_encoding)
        return flat_pose_projection.reshape(batch_size, seq_length, *self.pose_encoder.pose_dims)

    def forward(self, text: str, first_pose: torch.Tensor, step_size: float = 0.5):
        text_encoding, sequence_length = self.encode_text([text])
        sequence_length = round(float(sequence_length))

        pose_sequence = {
            "data": first_pose.expand(1, sequence_length, *self.pose_encoder.pose_dims),
            "mask": torch.zeros([1, sequence_length], dtype=torch.bool),
        }
        while True:
            yield pose_sequence["data"][0]

            step = self.refine_pose_sequence(pose_sequence, text_encoding)
            pose_sequence["data"] = pose_sequence["data"] + step_size * step

    def training_step(self, batch, *unused_args, steps=100):
        return self.step(batch, *unused_args, steps=steps, name="train")

    def validation_step(self, batch, *unused_args, steps=100):
        return self.step(batch, *unused_args, steps=steps, name="validation")

    def step(self, batch, *unused_args, steps: int, name: str):
        text_encoding, sequence_length = self.encode_text(batch["text"])
        pose = batch["pose"]

        # Calculate sequence length loss
        sequence_length_loss = F.mse_loss(sequence_length, pose["length"]) / 10000

        # Repeat the first frame for initial prediction
        batch_size, pose_seq_length, _, _ = pose["data"].shape
        pose_sequence = {
            "data": torch.stack([pose["data"][:, 0]] * pose_seq_length, dim=1),
            "mask": torch.logical_not(pose["inverse_mask"])
        }

        refinement_loss = 0
        for _ in range(steps):
            pose_sequence["data"] = pose_sequence["data"].detach()  # Detach from graph
            l1_gold = pose["data"] - pose_sequence["data"]
            l1_predicted = self.refine_pose_sequence(pose_sequence, text_encoding)
            refinement_loss += masked_loss(self.loss_type, l1_gold, l1_predicted, confidence=pose["confidence"])

            step_size = 1 / steps
            l1_step = l1_gold if name == "validation" else l1_predicted
            pose_sequence["data"] = pose_sequence["data"] + step_size * l1_step

            if name == "train":  # add just a little noise while training
                pose_sequence["data"] = pose_sequence["data"] + torch.randn_like(pose_sequence["data"]) * 1e-4

        self.log(name + "_seq_length_loss", sequence_length_loss, batch_size=batch_size)
        self.log(name + "_refinement_loss", refinement_loss, batch_size=batch_size)
        loss = refinement_loss + sequence_length_loss
        self.log(name + "_loss", loss, batch_size=batch_size)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [13]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

In [14]:
LOGGER = WandbLogger(project="text-to-pose", log_model=False, offline=False)

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Nguyen Thanh Loc/.netrc


In [None]:
train_dataset = get_dataset(poses=args.pose,
                                fps=args.fps,
                                components=args.pose_components,
                                max_seq_size=args.max_seq_size,
                                split="train[10:]")
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=zero_pad_collator)