In [None]:
import torch
from torch import nn
from scipy.linalg import sqrtm
import numpy as np
from model import  RotationType
from pathlib import Path
from rotation_conversions import axis_angle_to_matrix, matrix_to_rotation_6d
from model import Encoder, FeedFowardBlock

In [None]:
num_joints = 22
rotation_type = RotationType.ZHOU_6D
block_size = 75
batch_size = 64
feature_length = 135
timesteps = 300

In [None]:
def load(path):
    data = torch.load(path)
    data['poses'] = matrix_to_rotation_6d(axis_angle_to_matrix(data['poses']))
    data = torch.cat([data['trans'], data['poses'].reshape(*data['trans'].shape[:2], 22 * 6)], dim=-1)
    return data


In [None]:
cmu = load('data_prepared/CMU.pt')
bml = load('data_prepared/BLMrub.pt')
ddb= load('data_prepared/DanceDB.pt')
mpi = load('data_prepared/MPI_Limits.pt')
sfu = load('data_prepared/SFU.pt')

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

batch_size = 32
input_features = 135  
sequence_length = 75
extracted_features = 256  

class FeatureExtractor(nn.Module):
    def __init__(self, input_features, sequence_length, extracted_features):
        super(FeatureExtractor, self).__init__()

        self.proj_in = FeedFowardBlock(135, 256, extracted_features, 0.1)
        self.encoder2 = Encoder(4, extracted_features, 256, extracted_features, 4, 0.1)
        self.positional_embedding = nn.Embedding(
            block_size, extracted_features) 
        self.proj_out = FeedFowardBlock(extracted_features, 256,extracted_features, 0.1)
        
        self.decoder = nn.Sequential(
            nn.Linear(extracted_features, 64 * (sequence_length // 4)),
            nn.Unflatten(1, (64, sequence_length // 4)),
            nn.ConvTranspose1d(64, 32, 1),
            nn.ReLU(),
            nn.ConvTranspose1d(32, input_features, 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.LazyLinear(input_features * sequence_length)
        )

        self.feature_length = input_features
        self.sequence_length = sequence_length

    @torch.no_grad()
    def encode(self, x):
        x = self.proj_in(x) + self.positional_embedding(torch.arange(block_size, device=x.device))
        x = self.encoder2(x)
        x = self.proj_out(x)
        return x[:, -1, :]

    def forward(self, x):

        x = self.proj_in(x) + self.positional_embedding(torch.arange(block_size, device=x.device))
        x = self.encoder2(x)
        x = self.proj_out(x)
        x = x[:, -1, :]

        reconstructed = self.decoder(x)

        reconstructed = reconstructed.reshape(reconstructed.shape[0], self.feature_length, self.sequence_length)
        return reconstructed.permute(0, 2, 1)

model = FeatureExtractor(input_features, sequence_length, extracted_features).to('cuda')
model.load_state_dict(torch.load('feature_extractor.pt', map_location='cpu'))
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

dataset = TensorDataset(torch.cat([cmu, bml, ddb, mpi, sfu]).to('cuda'))

dataloader = DataLoader(dataset, batch_size=64)


In [None]:
num_epochs = 200
model.train()
for epoch in range(num_epochs):
    for batch in dataloader:
        data = batch[0]
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, data)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

In [None]:
torch.save(model.state_dict(), 'feature_extractor.pt')

In [None]:
feature_extractor = FeatureExtractor(input_features, sequence_length, extracted_features).to('cuda')
feature_extractor.load_state_dict(torch.load('feature_extractor.pt', map_location='cpu'))

In [None]:

@torch.no_grad()
def extract_features(x):
    feature_extractor.eval()
    return feature_extractor.encode(x.to('cuda'))

def calculate_activation_statistics(motions: torch.Tensor):
    m = extract_features(motions)
    mean = m.mean(0)
    cov = m.T.cov()
    return mean, cov

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    diff = mu1 - mu2

    covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

    return fid
def fid(real: torch.Tensor, generated: torch.Tensor) -> float:

    mu1, sigma1 = calculate_activation_statistics(real)
    mu2, sigma2 = calculate_activation_statistics(generated)

    return calculate_frechet_distance(mu1.cpu().numpy(), sigma1.cpu().numpy(), mu2.cpu().numpy(), sigma2.cpu().numpy())


In [None]:
prediction = load('prediction_mpi_20.pt')

In [None]:
#fid
scores = []

data = sfu # real source

for i in range(20):
    size = 200
    real = data[torch.randperm(len(data))[:size]]
    real2 = data[torch.randperm(len(data))[:size]]
    generated = prediction[torch.randperm(len(prediction))[:size]]

    scores.append(fid(real, generated))

scores = torch.tensor(scores)

scores.mean(), scores.std()

In [None]:
#diversity
scores = []

data = cmu

for i in range(20):
    size = 1000
    real = data[torch.randperm(len(data))[:size]]
    real2 = data[torch.randperm(len(data))[:size]]
    generated = prediction[torch.randperm(len(prediction))[:size]]
    generated2 = prediction[torch.randperm(len(prediction))[:size]]
    
    out = torch.norm(extract_features(generated) - extract_features(generated2),p=2, dim=1)

    scores.append(out.mean())

scores = torch.tensor(scores)

scores.mean(), scores.std()

In [None]:
#multimodality
scores = []

data = cmu

for i in range(20):
    shape = prediction.shape
    generated = torch.stack(prediction.chunk(10, dim=0))

    first = generated[torch.randperm(10)].reshape(shape)
    second = generated[ torch.randperm(10)].reshape(shape)
    out = torch.norm(extract_features(first) - extract_features(second), p=2, dim=-1)

    scores.append(out.mean())

scores = torch.tensor(scores)

scores.mean(), scores.std()