## Train discriminator and generator at each epoch test

## cAvatar architecture with cAvatar dataset starts from here

## Dataset read and prepare

In [93]:
import os
import cv2
import pickle
import numpy as np
import random
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
import math
from torchvision import models
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import json
# from smpl_torch import SMPLModel  # Import the SMPL model from your saved file
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# # Define parameters for a smaller dataset
window_size = 20  # Reduced window size for testing

In [94]:
folder_path = '/home/mkeya/cAvatar_dataset/touch_normalized.p'

with open(folder_path, 'rb') as f:
        data = pickle.load(f)  # Load the entire .p file
        

In [95]:
print(f"data type: {type(data)}")

print(f"data shape: {data.shape}")

# frames = data

data type: <class 'numpy.ndarray'>
data shape: (11075, 96, 96)


In [97]:
missing_frames_file = '/home/mkeya/VIBE/missing_frames.txt'

In [98]:
# Read missing frames
with open(missing_frames_file, 'r') as f:
    missing_frames = set(int(line.strip()) for line in f)

# Filter out missing frames from the data
filtered_data = [frame for i, frame in enumerate(data) if i not in missing_frames]

In [99]:
print(f"filtered_data type: {type(filtered_data)}")

filtered_data_array = np.array(filtered_data)

print(f"filtered_data_array shape: {filtered_data_array.shape}")

filtered_data type: <class 'list'>
filtered_data_array shape: (9143, 96, 96)


In [100]:
frames = filtered_data_array

len(frames)

9143

## Teacher Dataset

In [101]:
# Path to the JSON file saved by demo.py
json_file_path = '/home/mkeya/VIBE/merged_vibe_results.json'  # Replace with the actual path to the JSON file

# Load the JSON file with ground truth values
with open(json_file_path, 'r') as f:
    vibe_results = json.load(f)
print('---------------')



ground_truth_pose = torch.tensor(vibe_results['pose'], dtype=torch.float64)
ground_truth_shape = torch.tensor(vibe_results['shape'], dtype=torch.float64)
ground_truth_camera = torch.tensor(vibe_results['camera'], dtype=torch.float64)
ground_truth_joints = torch.tensor(vibe_results['joints'], dtype=torch.float64)
ground_truth_vertices = torch.tensor(vibe_results['vertices'], dtype=torch.float64)

print(f"ground_truth_pose shape: {ground_truth_pose.shape}") # theta
print(f"ground_truth_shape shape: {ground_truth_shape.shape}") # beta
print(f"ground_truth_camera shape: {ground_truth_camera.shape}") # translation
print(f"ground_truth_vertices shape: {ground_truth_vertices.shape}") # vertices
print(f"ground_truth_joints shape: {ground_truth_joints.shape}") # joints taking all joints this time

---------------
ground_truth_pose shape: torch.Size([9145, 72])
ground_truth_shape shape: torch.Size([9145, 10])
ground_truth_camera shape: torch.Size([9145, 3])
ground_truth_vertices shape: torch.Size([9145, 6890, 3])
ground_truth_joints shape: torch.Size([9145, 49, 3])


## Create Training input and target pair

## Sliding window + Splitting the dataset into training, validation and testing set

In [102]:
def extract_from_list(key, sequences): 
    sequences_list = [item[key] for item in sequences]
    return sequences_list



# Step 3: Create overlapping sequences using a reduced sliding window
sequences = []
for i in range(len(frames) - window_size + 1):
    sequence = frames[i:i + window_size]  # Get 20 frames
    

    samples_X_y = {} # it's very important to make it empty here.
    keys = ['sequence', 'pose', 'shape', 'camera', 'vertices', 'coco']
    samples_X_y = dict.fromkeys(keys)
    
    # the following data is the learning target for 20th frame from VIBE output
    # in python a[i] is the (i+1)th element in list a

    samples_X_y['sequence']=np.array(sequence)
    samples_X_y['pose']=ground_truth_pose[i+window_size-1:i+window_size]
    samples_X_y['shape']=ground_truth_shape[i+window_size-1:i+window_size]
    samples_X_y['camera']=ground_truth_camera[i+window_size-1:i+window_size]
    samples_X_y['vertices']=ground_truth_vertices[i+window_size-1:i+window_size]
    samples_X_y['coco']=ground_truth_joints[i+window_size-1:i+window_size]

    sequences.append(samples_X_y) 

# # Step: Shuffle the dataset 
# random.seed(42)
# random.shuffle(sequences)

# Step 4: Split sequences into training, validation, and testing sets (80-10-10 split)
num_sequences = len(sequences)
print(f"num_sequences: {num_sequences}")
train_end = int(num_sequences * 0.8)
val_end = train_end + int(num_sequences * 0.1)

# you are not randomizing the samples, just following the original orders of the continues frames
# each sample in the list, is a dictionary
train_sequences1 = sequences[:train_end]

print(f"train_sequences1 shape: {len(train_sequences1)}")
# print(f"train_sequences1 [0]: {train_sequences1[0]}")
train_sequences = extract_from_list('sequence', train_sequences1)

print(f"train_sequences len: {len(train_sequences)}")

val_sequences1 = sequences[train_end:val_end]
val_sequences = extract_from_list('sequence', val_sequences1)

test_sequences1 = sequences[val_end:]
test_sequences = extract_from_list('sequence', test_sequences1)


# Dynamic extraction using globals()
VIBE_keys = ['pose', 'shape', 'camera', 'vertices', 'coco']
name_list = ['ground_truth_pose1','ground_truth_shape1','ground_truth_camera1','ground_truth_vertices1','ground_truth_coco1']

for index, key in enumerate(VIBE_keys):
    # print(index)
    # print(key)
    name1 = name_list[index]
    name2 = name1 

    # print(name2)
    globals()[name2] = extract_from_list(key, train_sequences1)

new_name_list = [item+'_train' for item in name_list]
# Step 5: Print shapes to verify
print(f"Training set shape: {np.array(train_sequences).shape}") 
print(f"Validation set shape: {np.array(val_sequences).shape}")
print(f"Testing set shape: {np.array(test_sequences).shape}")

# 4. Extract ground truth for training, validation, and testing
ground_truth_pose_train = extract_from_list('pose', train_sequences1)
ground_truth_pose_val = extract_from_list('pose', val_sequences1)
ground_truth_pose_test = extract_from_list('pose', test_sequences1)

ground_truth_shape_train = extract_from_list('shape', train_sequences1)
ground_truth_shape_val = extract_from_list('shape', val_sequences1)
ground_truth_shape_test = extract_from_list('shape', test_sequences1)

ground_truth_camera_train = extract_from_list('camera', train_sequences1)
ground_truth_camera_val = extract_from_list('camera', val_sequences1)
ground_truth_camera_test = extract_from_list('camera', test_sequences1)

ground_truth_vertices_train = extract_from_list('vertices', train_sequences1)
ground_truth_vertices_val = extract_from_list('vertices', val_sequences1)
ground_truth_vertices_test = extract_from_list('vertices', test_sequences1)

ground_truth_coco_train = extract_from_list('coco', train_sequences1)
ground_truth_coco_val = extract_from_list('coco', val_sequences1)
ground_truth_coco_test = extract_from_list('coco', test_sequences1)

# Convert to tensors if needed
ground_truth_pose_train = torch.tensor(np.vstack(ground_truth_pose_train))
ground_truth_pose_val = torch.tensor(np.vstack(ground_truth_pose_val))
ground_truth_pose_test = torch.tensor(np.vstack(ground_truth_pose_test))

ground_truth_shape_train = torch.tensor(np.vstack(ground_truth_shape_train))
ground_truth_shape_val = torch.tensor(np.vstack(ground_truth_shape_val))
ground_truth_shape_test = torch.tensor(np.vstack(ground_truth_shape_test))

ground_truth_camera_train = torch.tensor(np.vstack(ground_truth_camera_train))
ground_truth_camera_val = torch.tensor(np.vstack(ground_truth_camera_val))
ground_truth_camera_test = torch.tensor(np.vstack(ground_truth_camera_test))

ground_truth_vertices_train = torch.tensor(np.vstack(ground_truth_vertices_train))
ground_truth_vertices_val = torch.tensor(np.vstack(ground_truth_vertices_val))
ground_truth_vertices_test = torch.tensor(np.vstack(ground_truth_vertices_test))

ground_truth_coco_train = torch.tensor(np.vstack(ground_truth_coco_train))
ground_truth_coco_val = torch.tensor(np.vstack(ground_truth_coco_val))
ground_truth_coco_test = torch.tensor(np.vstack(ground_truth_coco_test))


num_sequences: 9124
train_sequences1 shape: 7299
train_sequences len: 7299
Training set shape: (7299, 20, 96, 96)
Validation set shape: (912, 20, 96, 96)
Testing set shape: (913, 20, 96, 96)
ground_truth_pose_test shape: torch.Size([913, 72])


In [103]:
print(f"ground_truth_pose_train shape: {ground_truth_pose_train.shape}")

print(f"ground_truth_coco_test shape: {ground_truth_coco_test.shape}")

ground_truth_pose_train shape: torch.Size([7299, 72])
ground_truth_coco_test shape: torch.Size([913, 49, 3])


## Model initialization

## Student Model

In [106]:


# ResNet-18 Modification
class CustomResNet18(nn.Module):
    def __init__(self, output_dim):
        super(CustomResNet18, self).__init__()
        resnet = models.resnet18(pretrained=True)
        # Adjusting the first convolution layer for 1 channel input (grayscale or single channel images)
        resnet.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet = nn.Sequential(*list(resnet.children())[:-2])  # Removing fully connected layers
        self.fc = nn.Linear(512, output_dim)  # The output from ResNet is reduced to the required output dimension

    def forward(self, x):
        x = self.resnet(x)
        x = torch.mean(x, dim=(2, 3))  # Global average pooling to reduce spatial dimensions
        x = self.fc(x)
        return x

def positional_encoding(seq_len, d_model):
    """
    Generates sinusoidal positional encoding matrix of shape (seq_len, d_model).
    """
    position = torch.arange(0, seq_len).unsqueeze(1).float()  # Shape: (seq_len, 1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
    pe = torch.zeros(seq_len, d_model)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe  # Shape: (seq_len, d_model)

# Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)

    def forward(self, src):
        src2 = self.norm1(src)
        attn_output, _ = self.self_attn(src2, src2, src2)
        src = src + self.dropout(attn_output)
        src2 = self.norm2(src)
        src = src + self.dropout(self.linear2(F.relu(self.linear1(src2))))
        return src

class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, d_model, nhead, dim_feedforward, dropout):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, src):
        for layer in self.layers:
            src = layer(src)
        return src

# The main Student Model that includes all components
class StudentModel(nn.Module):
    def __init__(self, output_dim=86, num_layers=6, d_model=256, nhead=4, dim_feedforward=2048, dropout=0.1, max_seq_len=20):
        super(StudentModel, self).__init__()
        self.resnet_model = CustomResNet18(d_model)
        self.positional_encoding = positional_encoding(seq_len=max_seq_len, d_model=d_model)
        self.transformer_encoder = TransformerEncoder(num_layers, d_model, nhead, dim_feedforward, dropout)
        self.output_layer = nn.Linear(d_model, output_dim)  # Output layer to generate joint angles, shapes, etc.
        self.fc_layer = nn.Linear(in_features=d_model, out_features=d_model)  # Connect ResNet18 and Transformer

    def forward(self, x):
        # Step 1: ResNet for each frame in the sequence
        # x is of shape (batch_size, seq_len, 96, 96), reshape to (batch_size * seq_len, 1, 96, 96)
        batch_size, seq_len, h, w = x.shape
        x_reshaped = x.view(batch_size * seq_len, 1, h, w)
        resnet_output = self.resnet_model(x_reshaped)  # Output shape: (batch_size * seq_len, 256)
        
        # Reshape the output back to (batch_size, seq_len, 256)
        resnet_output = resnet_output.view(batch_size, seq_len, -1)

        # Step 2: Linear layer after ResNet to match the transformer input size
        final_output = self.fc_layer(resnet_output)  # Shape: (batch_size, seq_len, 256)

        # Step 3: Add positional encoding
        batch_size, seq_len, d_model = final_output.shape
        positional_encoding = self.positional_encoding[:seq_len, :].unsqueeze(0).repeat(batch_size, 1, 1)  # Shape: (batch_size, seq_len, d_model)

        final_output = final_output.to('cuda:0')
        positional_encoding = positional_encoding.to('cuda:0')
        
        final_output_with_pe = final_output + positional_encoding  # Add positional encoding to the final output

        # Step 4: Transformer Encoder
        transformer_output = self.transformer_encoder(final_output_with_pe)  # Shape: (batch_size, seq_len, d_model)

        # Step 5: Use the last time step's embedding to predict the 21st frame
        last_embedding = transformer_output[:, -1, :]  # Shape: (batch_size, d_model)
        output_21st_frame = self.output_layer(last_embedding)  # Shape: (batch_size, output_dim)

        # Step 6: Separate Beta, Theta, and Translation
        beta_student = output_21st_frame[..., :10]
        theta_student = output_21st_frame[..., 10:82]
        translation_student = output_21st_frame[..., 82:85]

        return beta_student, theta_student, translation_student


## Initiate training

In [108]:

# Mean Absolute Error loss function
mae_loss = nn.L1Loss()

# Instantiate the model
student_model = StudentModel()

# Move model to device (e.g., GPU if available)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = torch.device("cuda")

student_model = student_model.to(device)

In [109]:
# Assuming train_sequences is already a tensor
train_sequences = torch.tensor(train_sequences, dtype=torch.float32) # (num_windows, window_size, height, width) = 7299, 20, 96, 96
test_sequences = torch.tensor(test_sequences, dtype=torch.float32)

batch_size = 32
num_samples = train_sequences.shape[0]

print(f"train_sequences shape: {train_sequences.shape}")
print(f"test_sequences shape: {test_sequences.shape}")

print(f"num_samples: {num_samples}")

train_sequences shape: torch.Size([7299, 20, 96, 96])
test_sequences shape: torch.Size([913, 20, 96, 96])
num_samples: 7299


## Discriminator

In [110]:
class Discriminator(nn.Module):
    
    def __init__(self, num_joints=49, hidden_dim=256, dropout_prob=0.4): # num of joints is 49 not 24   
        super(Discriminator, self).__init__()
        self.hidden_layer = nn.Sequential(
            nn.Linear(num_joints * 3, hidden_dim),  # Flattened input size
            nn.Tanh(),
            nn.Dropout(dropout_prob)  # Apply dropout here
        )
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim, 1),  # Outputs a single value per input
            nn.Sigmoid()
        )

    def forward(self, joints):
        # Flatten the input tensor
        joints_flat = joints.view(joints.size(0), -1)  # Shape: [batch_size, num_joints * 3]
        hidden_output = self.hidden_layer(joints_flat)
        output = self.output_layer(hidden_output)
        return output



# Instantiate the discriminator
discriminator = Discriminator(num_joints=49, dropout_prob=0.4).to(device) # working with 49 joints

d_lr = 8e-5
# Discriminator training optimizer
d_optimizer = optim.Adam(discriminator.parameters(), lr=d_lr, weight_decay=1e-4)

# Define the Binary Cross-Entropy Loss
criterion = nn.BCELoss()

discriminator hidden layer weights dtype: torch.float32


## Test set

In [111]:
def test_model(model, student_model, discriminator, test_sequences, ground_truth_shape_test, 
               ground_truth_pose_test, ground_truth_camera_test, ground_truth_vertices_test, 
               ground_truth_coco_test, batch_start, batch_end, batch_size, lambda_adv=0.05, lambda_ca=1.0):
    """
    Function to test the model and return total loss and discriminator loss for a single batch.
    """
    model.eval()
    student_model.eval()
    discriminator.eval()

    with torch.no_grad():
        # Extract batch of sliding windows for testing
        batch_windows = test_sequences[batch_start:batch_end].to(device)
        
        # Get the ground truth for the current batch
        ground_truth_shape_batch = ground_truth_shape_test[batch_start:batch_end].to(device)
        ground_truth_pose_batch = ground_truth_pose_test[batch_start:batch_end].to(device)
        ground_truth_camera_batch = ground_truth_camera_test[batch_start:batch_end].to(device)
        ground_truth_vertices_batch = ground_truth_vertices_test[batch_start:batch_end].to(device)
        ground_truth_joints_batch = ground_truth_coco_test[batch_start:batch_end].to(device)

        if batch_windows.shape[0] != batch_size:
            return None, None  # Skip if batch size mismatch

        # Get the beta, theta, and translation from the student model
        beta_student, theta_student, translation_student = student_model(batch_windows)


        # Forward pass through the SMPL model
        all_vertices = []
        all_joints = []

        for b in range(batch_size):
            beta = beta_student[b].to(device).to(torch.float32)          # Shape: (10,)
            theta = theta_student[b].to(device).to(torch.float32)        # Shape: (72,)
            translation = translation_student[b].to(device).to(torch.float32)  # Shape: (3,) 

            # Ensure shapes of inputs are correct
            betas = beta.unsqueeze(0) if beta.dim() == 1 else beta  # Shape: (1, 10)
            body_pose = theta[3:].unsqueeze(0) if theta[3:].dim() == 1 else theta[3:]  # Shape: (1, 69)
            global_orient = theta[:3].unsqueeze(0) if theta[:3].dim() == 1 else theta[:3]  # Shape: (1, 3)
            transl = translation.unsqueeze(0) if translation.dim() == 1 else translation  # Shape: (1, 3)

            
            # Forward pass through SMPL model
            output = smpl_model(
                betas=betas,
                body_pose=body_pose,
                global_orient=global_orient,
                transl=transl
            )
        
            # Extract vertices and joints
            vertices = output.vertices  # Shape: (1, 6890, 3)
            joints = output.joints      # Shape: (1, 49, 3)
            
            # Collect vertices and joints for this frame
            all_vertices.append(vertices.cpu().detach().numpy())
            all_joints.append(joints.cpu().detach().numpy())


        # Convert lists to numpy arrays
        all_vertices = np.array(all_vertices)  # Shape: (batch_size, 1, num_vertices, 3)
        all_joints = np.array(all_joints)      # Shape: (batch_size, 1, 49, 3)
        


        all_vertices = all_vertices.squeeze(1)  # Shape: (batch_size, num_vertices, 3)
        all_joints = all_joints.squeeze(1)      # Shape: (batch_size, 49, 3)


        
        # Convert SMPL output to torch tensors (if needed)
        all_vertices_torch = torch.tensor(all_vertices, device=device)  # Shape: (32, 6890, 3)
        all_joints_torch = torch.tensor(all_joints, device=device)      # Shape: (32, 24, 3)
  

        # Adversarial loss labels
        real_outputs = discriminator(ground_truth_joints_batch.to(device).to(torch.float32))
        fake_outputs = discriminator(all_joints_torch.detach().to(device).to(torch.float32))
        d_real_loss = -torch.mean(torch.log(real_outputs + 1e-8))
        d_fake_loss = -torch.mean(torch.log(1 - fake_outputs + 1e-8))
        d_loss = d_real_loss + d_fake_loss

        discriminator_loss_test.append(d_loss.item())

        # Calculate MAE losses
        beta_loss = mae_loss(beta_student, ground_truth_shape_batch)
        camera_loss = mae_loss(translation_student, ground_truth_camera_batch)
        pose_loss = mae_loss(theta_student, ground_truth_pose_batch)
        vertices_loss = mae_loss(all_vertices_torch, ground_truth_vertices_batch)
        joints_loss = mae_loss(all_joints_torch, ground_truth_joints_batch)

        # Compute total loss for the batch
        loss_ca = (torch.mean(beta_loss) + torch.mean(camera_loss) + 
                   torch.mean(pose_loss) + torch.mean(vertices_loss) + torch.mean(joints_loss))
        

        adv_loss = -torch.mean(torch.log(discriminator(all_joints_torch.to(torch.float32)) + 1e-8))  # L_adv
        
        # total_loss = lambda_ca * loss_ca + lambda_adv * adv_loss
        total_loss = loss_ca

        # return d_loss, total_loss
        return total_loss


## Training

In [112]:
import sys
import os

# Add the VIBE directory to the Python path
vibe_root = os.path.abspath("/home/mkeya/VIBE")
if vibe_root not in sys.path:
    sys.path.append(vibe_root)

In [113]:
from lib.models.smpl import SMPL  # Import SMPL from the VIBE library

# Training parameters
num_windows = train_sequences.shape[0]  # Total sliding windows
print(f"num_windows: {num_windows}")

batch_size = 32  # Batch size for training
num_epochs = 40  # Number of epochs for training
learning_rate = 1.5e-4


lambda_adv = 0.03
lambda_ca = 1.0    # Weight for reconstruction loss

# Gradient Penalty parameters
lambda_gp = 20  # Weight for the gradient penalty

# # SMPL model
# model = SMPLModel(device=device)

# Initialize optimizer 
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)

# SMPL model directory
SMPL_MODEL_DIR = "/home/mkeya/VIBE/data/vibe_data"

# Initialize the SMPL model
smpl_model = SMPL(SMPL_MODEL_DIR, batch_size=batch_size, create_transl=True).to(device)

# collecting losses per epoch
epoch_losses = []
epoch_losses_test = []

# collecting discriminator loss per epoch
discriminator_losses = []
discriminator_losses_test = []

ca_losses = []
adv_losses = []


# Training loop
for epoch in range(num_epochs):

    # total loss
    epoch_loss = []
    epoch_loss_test = []

    # discriminator loss
    discriminator_loss = []
    discriminator_loss_test = []

    ca_loss = []
    adver_loss = []

    
    for i in range(0, num_windows, batch_size):
        
        batch_start = i
        batch_end = min(i + batch_size, num_windows)
  
        # Extract batch of sliding windows
        i = epoch*batch_size +i
        
        batch_windows = train_sequences[i:i + batch_size].to(device)  # Shape: (32, 20, 96, 96)

        # Get the ground truth for the 20th frame of each sliding window in the batch
        ground_truth_shape_batch = ground_truth_shape_train[i:i + batch_size].to(device)  # Shape: (32, 10) beta
        ground_truth_pose_batch = ground_truth_pose_train[i:i + batch_size].to(device)    # Shape: (32, 72) theta
        ground_truth_camera_batch = ground_truth_camera_train[i:i + batch_size].to(device)  # Shape: (32, 3) translation
        ground_truth_vertices_batch = ground_truth_vertices_train[i:i + batch_size].to(device)  # Shape: (32, 6890, 3)
        ground_truth_joints_batch = ground_truth_coco_train[i:i + batch_size].to(device)  # Shape: (32, 49, 3) # joints

        # Ensure batch consistency
        if batch_windows.shape[0] != batch_size:
            continue

        beta_student, theta_student, translation_student = student_model(batch_windows)



        # Forward pass: SMPL vertices and joints
        all_vertices = []
        all_joints = []
        
        
        # Loop over the batch (one frame per sequence)
        for b in range(batch_size):
            # Extract beta, theta, and translation for the current frame
            beta = beta_student[b].to(device).to(torch.float32)          # Shape: (10,)
            theta = theta_student[b].to(device).to(torch.float32)        # Shape: (72,)
            translation = translation_student[b].to(device).to(torch.float32)  # Shape: (3,)           
            

            # Forward pass through the SMPL model
            output = smpl_model(
                betas=beta.unsqueeze(0),  # Add batch dimension
                body_pose=theta[3:].unsqueeze(0),  # Exclude global orientation
                global_orient=translation[:3].unsqueeze(0),  # Use only global orientation
                transl=translation.unsqueeze(0)  # Add batch dimension
            )
        
            # Extract vertices and joints
            vertices = output.vertices  # Shape: (1, 6890, 3)
            joints = output.joints      # Shape: (1, 49, 3)
            
            # Collect vertices and joints for this frame
            all_vertices.append(vertices.cpu().detach().numpy())
            all_joints.append(joints.cpu().detach().numpy())
        
        
        # Convert lists to numpy arrays
        all_vertices = np.array(all_vertices)  # Shape: (batch_size, 1, num_vertices, 3)
        all_joints = np.array(all_joints)      # Shape: (batch_size, 1, 49, 3)
        
        # print(f"all_vertices shape before: {all_vertices.shape}")
        # print(f"all_joints shape before: {all_joints.shape}")

        all_vertices = all_vertices.squeeze(1)  # Shape: (batch_size, num_vertices, 3)
        all_joints = all_joints.squeeze(1)      # Shape: (batch_size, 49, 3)

        
        # Convert SMPL output to torch tensors (if needed)
        all_vertices_torch = torch.tensor(all_vertices, device=device)  # Shape: (32, 6890, 3)
        all_joints_torch = torch.tensor(all_joints, device=device)      # Shape: (32, 24, 3)

            
        # unfreeze the discriminator         
        discriminator.train()  # Set discriminator to training mode
        student_model.eval() # Set the student model to eval mode during discriminator training
        
        for param in discriminator.parameters():
            param.requires_grad = True  # Enable gradients for discriminator

            
        # Forward pass through discriminator
        real_outputs = discriminator(ground_truth_joints_batch.to(device).to(torch.float32))  # Real joints
        fake_outputs = discriminator(all_joints_torch.detach().to(device).to(torch.float32))  # Generated joints (detach to prevent grad flow)

        # Theoretical discriminator loss (BCE with logits to match formula)
        d_real_loss = -torch.mean(torch.log(real_outputs + 1e-8))  # Avoid log(0)
        d_fake_loss = -torch.mean(torch.log(1 - fake_outputs + 1e-8))  # Avoid log(0)
        d_loss = d_real_loss + d_fake_loss  # Total discriminator loss


        # Gradient penalty
        batch_size = ground_truth_joints_batch.size(0)
        num_joints = ground_truth_joints_batch.size(1)
        alpha = torch.rand(batch_size, 1, 1, device=device)  # Add extra dimensions for broadcasting
        
        interpolates = alpha * ground_truth_joints_batch + (1 - alpha) * all_joints_torch.detach()
        interpolates.requires_grad_(True)
        
        # Forward pass through discriminator for interpolated samples
        interpolated_outputs = discriminator(interpolates.to(torch.float32).to(device))
        
        # Compute gradients of the outputs w.r.t. inputs
        grad_outputs = torch.ones_like(interpolated_outputs, device=device)
        gradients = torch.autograd.grad(
            outputs=interpolated_outputs,
            inputs=interpolates,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        # Calculate the gradient penalty
        gradients = gradients.view(batch_size, -1)
        gradient_norm = gradients.norm(2, dim=1)  # L2 norm
        gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
        
        # Add the gradient penalty to the discriminator loss
        d_loss += gradient_penalty


        discriminator_loss.append(d_loss.item())
        
        # Backpropagation for discriminator
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        # print(f"Discriminator Loss: {d_loss.item():.4f}")
    
        # freeze the discriminator for generator training
        discriminator.eval()
        for param in discriminator.parameters():
            param.requires_grad = False # Freeze the discriminator's weight

        # Student model or generator training
        student_model.train()  # Set to training mode

        
        # Calculate the individual MAE losses for each component
        beta_loss = mae_loss(beta_student, ground_truth_shape_batch)  # Shape: (32, 10)
        camera_loss = mae_loss(translation_student, ground_truth_camera_batch)  # Shape: (32, 3)
        pose_loss = mae_loss(theta_student, ground_truth_pose_batch)  # Shape: (32, 72)
        vertices_loss = mae_loss(all_vertices_torch, ground_truth_vertices_batch)  # Shape: (32, 6890, 3)
        joints_loss = mae_loss(all_joints_torch, ground_truth_joints_batch)  # Shape: (32, 24, 3)
        
        # Now compute the total loss by summing all individual losses
        loss_ca = (torch.mean(beta_loss) + torch.mean(camera_loss) + 
                      torch.mean(pose_loss) + torch.mean(vertices_loss) + torch.mean(joints_loss))

        ca_loss.append(loss_ca.item())

        # Adversarial loss for generator with no grad mode
        with torch.no_grad(): # Discriminator forward pass should not compute gradients
            # adv_loss = -torch.mean(torch.log(discriminator(all_joints_torch.to(torch.float32)) + 1e-8))  # L_adv
            adv_loss = torch.mean((discriminator(all_joints_torch.to(torch.float32)) -1) **2)  # L_adv

        adver_loss.append(adv_loss.item())

        # Log the values of loss_ca and adv_loss
        print(f"Epoch {epoch+1}, Batch {i//batch_size+1}, loss_ca: {loss_ca.item():.4f}, adv_loss: {adv_loss.item():.4f}")
       
        # Compute total loss
        total_loss = lambda_ca * loss_ca + lambda_adv * adv_loss

        # Testing logic for this batch
        total_loss_test, d_loss_test = test_model(
            smpl_model, student_model, discriminator, test_sequences, ground_truth_shape_test,
            ground_truth_pose_test, ground_truth_camera_test, ground_truth_vertices_test,
            ground_truth_coco_test, batch_start, batch_end, batch_size, lambda_adv, lambda_ca
        )
        

        if d_loss_test is not None:
            discriminator_loss_test.append(d_loss_test.item())
        else:
            discriminator_loss_test.append(0.0) 
        

        # BackpropagationL
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        epoch_loss.append(total_loss.item())
        
        if total_loss_test is not None:
            epoch_loss_test.append(total_loss_test)
        else:
            epoch_loss_test.append(0.0)


    #avg total loss per epoch
    avg_loss_per_epoch = sum(epoch_loss)/len(epoch_loss)
    epoch_losses.append(avg_loss_per_epoch)
    

    # Filter out 0.0 values
    filtered_epoch_loss_test = [loss for loss in epoch_loss_test if loss != 0.0]
    
    # Calculate average if the filtered list is not empty
    if filtered_epoch_loss_test:
        avg_loss_per_epoch_test = sum(filtered_epoch_loss_test) / len(filtered_epoch_loss_test)
    else:
        avg_loss_per_epoch_test = 0.0  # Handle case when all values are 0.0 or the list is empty
    
    epoch_losses_test.append(avg_loss_per_epoch_test)

    
    #avg discriminator loss per epoch
    if len(discriminator_loss) > 0:
        avg_discriminator_loss_per_epoch = sum(discriminator_loss)/len(discriminator_loss)
        discriminator_losses.append(avg_discriminator_loss_per_epoch)
        
    else:
        discriminator_losses.append(None)


    # Filter out 0.0 values
    filtered_discriminator_loss_test = [loss for loss in discriminator_loss_test if loss != 0.0]
    
    # Calculate average if the filtered list is not empty
    if filtered_discriminator_loss_test:
        avg_discriminator_loss_per_epoch_test = sum(filtered_discriminator_loss_test) / len(filtered_discriminator_loss_test)
    else:
        avg_discriminator_loss_per_epoch_test = 0.0  # Handle case when all values are 0.0 or the list is empty
    
    discriminator_losses_test.append(avg_discriminator_loss_per_epoch_test)

    ca_loss_per_epoch = sum(ca_loss)/len(ca_loss)
    ca_losses.append(ca_loss_per_epoch)

    adver_loss_per_epoch = sum(adver_loss)/len(adver_loss)
    adv_losses.append(adver_loss_per_epoch)



# Save loss trend to file
with open("loss_log.txt", "w") as f:
    for epoch, (loss, d_loss, loss_test, d_loss_test, adv_l) in enumerate(zip(epoch_losses, discriminator_losses, epoch_losses_test, discriminator_losses_test, adv_losses)):
        d_loss_str = f"{d_loss:.4f}" if d_loss is not None else "N/A"
        f.write(f"Epoch {epoch+1}, Total Loss: {loss:.4f}, Discriminator Loss: {d_loss_str}, Test Loss: {loss_test:.4f}, Test Discrimination Loss: {d_loss_test:.4f}, Adv Loss: {adv_l:.4f}\n")

print("Training complete. Loss log saved to 'loss_log.txt'.")


num_windows: 7299
Epoch 1, Batch 1, loss_ca: 3.2931, adv_loss: 0.2464
Epoch 1, Batch 2, loss_ca: 9.8247, adv_loss: 0.2368
Epoch 1, Batch 3, loss_ca: 9.2984, adv_loss: 0.1594
Epoch 1, Batch 4, loss_ca: 5.4381, adv_loss: 0.2189
Epoch 1, Batch 5, loss_ca: 4.5873, adv_loss: 0.2260
Epoch 1, Batch 6, loss_ca: 4.6332, adv_loss: 0.2717
Epoch 1, Batch 7, loss_ca: 3.8030, adv_loss: 0.2275
Epoch 1, Batch 8, loss_ca: 3.9623, adv_loss: 0.2216
Epoch 1, Batch 9, loss_ca: 3.3026, adv_loss: 0.2185
Epoch 1, Batch 10, loss_ca: 3.5805, adv_loss: 0.2532
Epoch 1, Batch 11, loss_ca: 3.4051, adv_loss: 0.2518
Epoch 1, Batch 12, loss_ca: 3.0822, adv_loss: 0.2451
Epoch 1, Batch 13, loss_ca: 2.5051, adv_loss: 0.2614
Epoch 1, Batch 14, loss_ca: 2.5838, adv_loss: 0.2705
Epoch 1, Batch 15, loss_ca: 2.8796, adv_loss: 0.2717
Epoch 1, Batch 16, loss_ca: 3.0299, adv_loss: 0.2490
Epoch 1, Batch 17, loss_ca: 3.2745, adv_loss: 0.2570
Epoch 1, Batch 18, loss_ca: 2.7380, adv_loss: 0.2874
Epoch 1, Batch 19, loss_ca: 2.5439, a

In [116]:
import matplotlib.pyplot as plt

# Initialize lists to store epochs and losses
epochs = []
adv_loss_plot = []
total_losses = []
discriminator_losses = []
test_losses = []
test_discriminator_losses = []

# Read the log file and parse the values
with open("loss_log.txt", "r") as file:
    for line in file:
        parts = line.strip().split(", ")
        if len(parts) == 6:  # Ensure proper formatting
            
            epoch = int(parts[0].split(" ")[1])  # Extract epoch number
            total_loss = float(parts[1].split(": ")[1])  # Extract Total Loss
            
            # Extract and handle Discriminator Loss
            d_loss_str = parts[2].split(": ")[1]
            d_loss = float(d_loss_str) if d_loss_str != "N/A" else None
            
            # Extract Test Loss
            test_loss = float(parts[3].split(": ")[1])  # Extract Test Loss
            
            # Extract and handle Test Discriminator Loss
            d_loss_test_str = parts[4].split(": ")[1]
            test_d_loss = float(d_loss_test_str) if d_loss_test_str != "N/A" else None
            
            # Extract Adv Loss
            adv_loss = float(parts[5].split(": ")[1])

            # Append values to lists
            epochs.append(epoch)
            total_losses.append(total_loss)
            discriminator_losses.append(d_loss)
            test_losses.append(test_loss)
            test_discriminator_losses.append(test_d_loss)
            adv_loss_plot.append(adv_loss)

# Plot the losses
fig, ax1 = plt.subplots(figsize=(12, 8))

# Plot total loss
# ax1.plot(epochs, total_losses, label="Generator Loss", color="blue", marker="o")

ax1.plot(epochs, total_losses, label=f"Generator Loss, LR = {learning_rate}, L_ca = {lambda_ca}, l_adv = {lambda_adv}", color="blue", marker="o")

# Plot discriminator loss, skipping epochs with no discriminator loss
valid_d_losses = [(epoch, d_loss) for epoch, d_loss in zip(epochs, discriminator_losses) if d_loss is not None]
if valid_d_losses:
    d_epochs, d_losses = zip(*valid_d_losses)
    ax1.plot(d_epochs, d_losses, label="Discriminator Loss", color="red", marker="o")

# Plot test loss
ax1.plot(epochs, test_losses, label="Test Loss", color="green", linestyle="--", marker="x")

# Plot test discriminator loss, skipping epochs with no discriminator test loss
valid_test_d_losses = [(epoch, d_loss) for epoch, d_loss in zip(epochs, test_discriminator_losses) if d_loss is not None]
if valid_test_d_losses:
    d_test_epochs, d_test_losses = zip(*valid_test_d_losses)
    ax1.plot(d_test_epochs, d_test_losses, label="Test Discrimination Loss", color="purple", linestyle="--", marker="x")

# Plot adv loss
ax1.plot(epochs, adv_loss_plot, label="Adv Loss", color="black", linestyle="--", marker="x")

# Add labels, title, and legend for the loss curves
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title(f"Losses Over {len(epochs)} Epochs")
ax1.legend(loc="upper right")
ax1.grid(True)

# Save the plot as an image (optional)
plt.savefig("loss_dis_train_first_5epoch_test_dataset.png")

# Display the plot
plt.show()
