In [1]:
import os
import numpy as np
import argparse
from datetime import datetime


import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


from bams.data import KeypointsDataset
from bams.models import BAMS
from bams import HoALoss

In [7]:
# Customized for Alice dataset
def load_data(path, f1, f2):
    segment = 60 # in seconds
    fz = 500
    sample_period = int(f1.split(".")[0].split("samp")[-1])
    step = fz // sample_period * 10 # 10 second as a step
    
    # load raw train data (with annotations for 2 tasks)
    data_train = np.load(
        os.path.join(path, f1), allow_pickle=True
    )
    data_submission = np.load(
        os.path.join(path, f2), allow_pickle=True
    )

    print("Subject ids in training data: ", data_train.keys())
    print("Subject ids in submission data: ", data_submission.keys())

    train_values = list(data_train.values())
    submission_values = list(data_submission.values())
    all_values = train_values + submission_values

    min_len = min(map(lambda x: x.shape[0], all_values))
    print("Minimum sequence length: ", min_len)

    total_sample = segment * sample_period

    keypoints_train = np.array([[data[start * step : start * step + total_sample] 
                                 for start in range((min_len - total_sample) // step)] 
                                for data in train_values])
    keypoints_submission = np.array([[data[start * step : start * step + total_sample] 
                                      for start in range((min_len - total_sample) // step)] 
                                     for data in submission_values])
    num_subject_train, num_sequence, sequence_len, num_channel = keypoints_train.shape
    num_subject_submission, _, _, _ = keypoints_submission.shape
    keypoints_train = keypoints_train.reshape((-1, sequence_len, num_channel))
    keypoints_submission = keypoints_submission.reshape((-1, sequence_len, num_channel))
    keypoints = np.concatenate([keypoints_train, keypoints_submission], axis=0)
    
    split_mask = np.ones(len(keypoints), dtype=bool)
    split_mask[-num_subject_submission*num_sequence:] = False

    print("Shape of keypoints: ", keypoints.shape)
    print("Shape of split mask: ", split_mask.shape)



    return keypoints, split_mask

In [13]:
input_train = "train_24chans_fmin10_fmax25000_rwin40_samp20.pkl"
input_submission = "test_24chans_fmin10_fmax25000_rwin40_samp20.pkl"
data_root = "../data/alice"
cache_path = "../data/alice/custom_dataset"
hoa_bins = 32
batch_size = 32
num_workers = 4
epochs = 500
lr = 1e-3
weight_decay = 4e-5
log_every_step = 50
ckpt_path = "../bams-custom-2024-03-21-13-57-16.pt"
job = "compute_representations" # or "train"

#if job == "compute_representations":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# dataset
if not KeypointsDataset.cache_is_available(cache_path, hoa_bins):
    print("Processing data...")
    keypoints, split_mask = load_data(data_root, input_train, input_submission)
else:
    print("No need to process data")

# only use

dataset = KeypointsDataset(
    keypoints=keypoints,
    hoa_bins=hoa_bins,
    cache_path=cache_path,
    cache=False,
)

print("Number of sequences:", len(dataset))

# build model
model = BAMS(
    input_size=dataset.input_size,
    short_term=dict(num_channels=(64, 64, 64, 64), kernel_size=3),
    long_term=dict(num_channels=(64, 64, 64, 64, 64), kernel_size=3, dilation=4),
    predictor=dict(
        hidden_layers=(-1, 256, 512, 512, dataset.target_size * hoa_bins)
    ),
).to(device)

if ckpt_path is None:
    raise ValueError("Please specify a checkpoint path")

# load checkpoint
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

loader = DataLoader(
    dataset,
    shuffle=False,
    drop_last=False,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True,
)

# compute representations
short_term_emb, long_term_emb = [], []

Processing data...
Subject ids in training data:  dict_keys([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25])
Subject ids in submission data:  dict_keys([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
Minimum sequence length:  18167
Shape of keypoints:  (3283, 1200, 24)
Shape, of split mask:  (3283,)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:05<00:00,  4.34it/s]


Number of sequences: 3283




KeyboardInterrupt: 

In [None]:
for data in loader:
    input = data["input"].float().to(device)  # (B, N, L)

    with torch.inference_mode():
        embs, hoa_pred, byol_pred = model(input)

        print("Output: ")
        print(embs.shape)
        print(hoa_pred.shape)
        print(byol_pred.shape)

        short_term_emb.append(embs["short_term"].detach().cpu())
        long_term_emb.append(embs["long_term"].detach().cpu())

short_term_emb = torch.cat(short_term_emb)
long_term_emb = torch.cat(long_term_emb)

embs = torch.cat([short_term_emb, long_term_emb], dim=2)

# embs: (B, L, N)
batch_size, seq_len, num_feats = embs.size()

embs_mean = embs.mean(1)
embs_max = embs.max(1).values
embs_min = embs.min(1).values

embs = torch.cat([embs_mean, embs_max - embs_min], dim=-1)

# normalize embeddings
mean, std = embs.mean(0, keepdim=True), embs.std(0, unbiased=False, keepdim=True)
embs = (embs - mean) / std