In [1]:
!pip install transformers==4.40.2

Collecting transformers==4.40.2
  Downloading transformers-4.40.2-py3-none-any.whl.metadata (137 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/138.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m133.1/138.0 kB[0m [31m8.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.0/138.0 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.40.2-py3-none-any.whl (9.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.0/9.0 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.44.2
    Uninstalling transformers-4.44.2:
      Successfully uninstalled transformers-4.44.2
Successfully installed transformers-4.40.2


In [2]:
!pip install wandb



In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
import wandb
import os

In [5]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [6]:
import random
seed = 7
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(seed)

In [7]:
h2o_root = '/content/drive/My Drive/vit_3d/'



In [8]:

def process_images(rgb_images, hand_heatmaps,obj_heatmaps):
    img_size =224

    images = rgb_images.unsqueeze(0).permute(2, 1, 3, 4, 0).squeeze(-1)  # Reshape to (batch_size * sequence_length, 3, H, W)

    interleaved_sequences = torch.zeros((16, 3, 496, 496), device=rgb_images.device)


    # Resize grayscale images
    crop_size = 360
    start = (images.shape[-1]) // 2
    #start = (grayscale_images.shape[-1]) // 2
    cropped_grayscale = images[:,:, :, int(start - crop_size / 2):int(start + crop_size / 2)]#grayscale_

    resized_grayscale = nn.functional.interpolate(cropped_grayscale, size=(img_size, img_size), mode='nearest').squeeze(0)

    # Resize heatmaps
    crop_size = 720
    start = (hand_heatmaps.shape[-1] - crop_size) // 2
    cropped_hand_heatmaps = hand_heatmaps[:,:, start:start + crop_size]
    cropped_hand_heatmaps =cropped_hand_heatmaps.unsqueeze(1)

    resized_hand_heatmaps = nn.functional.interpolate(cropped_hand_heatmaps, size=(img_size, img_size), mode='nearest').squeeze(0)
    resized_hand_heatmaps = resized_hand_heatmaps.float() # Add channel dimension
    resized_hand_heatmaps = resized_hand_heatmaps.view( 1,8, img_size, img_size).squeeze(0).unsqueeze(1)

    cropped_obj_heatmaps =obj_heatmaps[:,:, start:start + crop_size]
    cropped_obj_heatmaps =cropped_obj_heatmaps.unsqueeze(1)

    resized_obj_heatmaps = nn.functional.interpolate(cropped_obj_heatmaps, size=(img_size, img_size), mode='nearest').squeeze(0)
    resized_obj_heatmaps = resized_obj_heatmaps.float() # Add channel dimension
    resized_obj_heatmaps = resized_obj_heatmaps.view( 1,8, img_size, img_size).squeeze(0).unsqueeze(1)


    #print(resized_grayscale.shape)
    return resized_grayscale, resized_hand_heatmaps, resized_obj_heatmaps


class TrainData(torch.utils.data.DataLoader):
    def __init__(self):
        self.data_path = h2o_root + "seq_8_train/"
        self.img_path = self.data_path + "frames_train(1)/"
        self.hand_path = self.data_path + "poses_hand_train/"
        self.obj_poses = self.data_path + "poses_obj_train/"
        self.num_actions = len(os.listdir(self.hand_path))
        self.labels = np.load(h2o_root + "action_labels_train.npy")

    def __len__(self):
        return self.num_actions

    def __getitem__(self, idx):
        img = np.load(self.img_path + format(idx + 1, '03d') + ".npy")
        hand_heatmap = np.load(self.data_path + "heatmaps_train/" + format(idx + 1, '03d') + ".npy")
        obj_heatmap = np.load(self.data_path + "obj_heatmaps_train/" + format(idx + 1, '03d') + ".npy")
        mano_pose = np.load(self.data_path +'mano_8_train/' + format(idx, '03d') + ".npy")
        label = self.labels[idx]

        img = np.moveaxis(img, -1, 0)
        img = torch.from_numpy(img).float()
        #hand_heatmap[hand_heatmap > 0] = 255
        hand_heatmap =torch.from_numpy(hand_heatmap/255.0).float()
        obj_heatmap =torch.from_numpy(obj_heatmap/255.0).float()
        img,hand,obj = process_images(img,hand_heatmap,obj_heatmap)
        mano_pose = torch.from_numpy(mano_pose[:,4:52]).float()
        return img, hand, obj,mano_pose, label


class ValData(torch.utils.data.DataLoader):
    def __init__(self):
        self.data_path = h2o_root + "seq_8_val/"
        self.img_path = self.data_path+ "frames_val(1)/"
        self.hand_path = self.data_path + "poses_hand_val/"
        self.obj_path  = self.data_path + "poses_obj_val/"
        self.num_actions = len(os.listdir(self.hand_path))
        self.labels = np.load(h2o_root + "action_labels_val.npy")

    def __len__(self):
        return self.num_actions

    def __getitem__(self, idx):
        img = np.load(self.img_path + format(idx + 1, '03d') + ".npy")
        hand_poses = np.load(self.hand_path + format(idx + 1, '03d') + ".npy")
        obj_poses = np.load(self.obj_path + format(idx + 1, '03d') + ".npy")
        hand_heatmap = np.load(self.data_path + "heatmaps_val/" + format(idx + 1, '03d') + ".npy")
        obj_heatmap = np.load(self.data_path + "obj_heatmaps_val/" + format(idx + 1, '03d') + ".npy")
        mano_pose = np.load(self.data_path +'mano_8_val/' + format(idx, '03d') + ".npy")
        mano_pose = torch.from_numpy(mano_pose[:,4:52]).float()
        label = self.labels[idx]


        img = np.moveaxis(img, -1, 0)
        img = torch.from_numpy(img).float()
        #hand_heatmap[hand_heatmap > 0] = 255
        hand_heatmap =torch.from_numpy(hand_heatmap/255.0).float()
        obj_heatmap =torch.from_numpy(obj_heatmap/255.0).float()
        img,hand,obj = process_images(img,hand_heatmap,obj_heatmap)
        return img, hand, obj,mano_pose, label



In [9]:
class Cars_Mano_Action(nn.Module):
    def __init__(self, vit_model, num_classes, mano_pose_dim, sequence_length):
        super(Cars_Mano_Action, self).__init__()
        self.vit_model = vit_model
        self.sequence_length = sequence_length
        self.mano_pose_dim = 48

        self.classifier = nn.Sequential(
            nn.Linear(vit_model.config.hidden_size * sequence_length, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        self.mano_pose_predictor = nn.Sequential(
            nn.Linear(vit_model.config.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, self.mano_pose_dim)
        )

    def forward(self, pixel_values):
        batch_size = pixel_values.shape[0]
        sequence_length = self.sequence_length

        # Forward pass through ViT
        outputs = self.vit_model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
        last_hidden_state = outputs.last_hidden_state[:, 0, :]  # (batch_size * sequence_length, hidden_size)
        attentions = outputs.attentions[-1]  # Get the attention maps from the last layer
        # Process attentions to match the input image resolution
        num_heads = attentions.shape[1]
        num_tokens = attentions.shape[-1] - 1
        attentions = attentions[:, :, 0, 1:].reshape(batch_size, num_heads, num_tokens)

        w_featmap = pixel_values.shape[-2] // self.vit_model.config.patch_size
        h_featmap = pixel_values.shape[-1] // self.vit_model.config.patch_size
        attentions = attentions.reshape(batch_size, num_heads, w_featmap, h_featmap)
        attentions = F.interpolate(attentions, scale_factor=self.vit_model.config.patch_size, mode="nearest")
        attentions = attentions.view(batch_size, num_heads, pixel_values.shape[-2], pixel_values.shape[-1])
        attentions = (attentions - attentions.min()) / (attentions.max() - attentions.min())
        hand_attention = attentions[:,0:10,:,:]
        obj_attention = attentions[:,10:12,:,:]
        #free_attention = attentions[:,8:12,:,:]
        sum_hand = torch.mean(hand_attention, dim=1)
        sum_obj = torch.mean(obj_attention, dim=1)
        #mean_hand = torch.mean(hand_attention, dim=1)
        #mean_obj = torch.mean(obj_attention, dim=1)
        #mean_free =torch.mean(free_attention, dim=1)
        # Reshape and concatenate embeddings

        concatenated_embeddings = last_hidden_state.reshape(batch_size // sequence_length, sequence_length * self.vit_model.config.hidden_size)
        mano_embeddings = last_hidden_state.reshape(batch_size, self.vit_model.config.hidden_size)
        # Pass through the classifier
        logits = self.classifier(concatenated_embeddings)
        mano_poses = self.mano_pose_predictor(mano_embeddings)

        return logits, mano_poses, sum_hand, sum_obj#mean_hand,mean_obj#, mean_free



In [24]:


feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224', size=224)
#feature_extractor = ViTImageProcessor.from_pretrained('facebook/dino-vitb16',size =496)

#vit_model = torch.load(h2o_root + 'model.pth')
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224', add_pooling_layer = False)
#
#vit_model = ViTModel.from_pretrained('facebook/dino-vitb16', add_pooling_layer=False)


# Define the action prediction model
num_classes = 37 # Change this to the number of action classes in your dataset
mano_pose_dim = 48
sequence_length = 8
model = Cars_Mano_Action(vit_model, num_classes,mano_pose_dim, sequence_length)
#if os.path.exists(h2o_root + "best_both_heat_action_prediction_model.pth"):
 #   print('loading')
  #  model.load_state_dict(torch.load(h2o_root + "best_both_heat_action_prediction_model.pth"))




In [19]:
train_dataset = TrainData()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True,num_workers=4,
    worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2**32 - 1)),
    generator=torch.Generator().manual_seed(seed))
val_dataset = ValData()
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False)

# Define the optimizer and loss function





In [None]:
mano_batch = 8
n_comps =48
mano_layer_left = ManoLayer(mano_root='manopth/mano/models/',use_pca=False, flat_hand_mean=True, ncomps=48, side='left')


def compute_mano_loss(pred_mano, target_mano, shape):

    mano_keypoints_3d_pred = mano_layer_left(pred_mano,shape)
    mano_keypoints_3d_target = mano_layer_left(target_mano,shape)
    print('pred_shape',mano_keypoints_3d_pred.shape)
    print('targetshape',mano_keypoints_3d_target.shape)

    pred_joints = mano_keypoints_3d_pred[1]
    target_joints = mano_keypoints_3d_target[1]
    return F.mse_loss(pred_joints, target_joints)

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Cars_Mano_Action(
  (vit_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768,

In [26]:
wandb.init(project="3dvision", entity="debaumann")

VBox(children=(Label(value='0.012 MB of 0.012 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [27]:
from tqdm import tqdm

In [28]:

img_size = 224
# Define the optimizer and loss functionf
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

criterion = nn.CrossEntropyLoss()

image_save_dir =  os.path.join(h2o_root, 'evo_new_new_496/')
os.makedirs(image_save_dir, exist_ok=True)
# Initialize variables for best model saving
best_val_loss = float('inf')
save_model_path = h2o_root + 'models__new496/'
os.makedirs(save_model_path, exist_ok=True)

# Training and validation loop
num_epochs = 15
alpha = 1.0  # Initial weight for classification loss
beta = 30.0   # Initial weight for heatmap loss
gamma = 10.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_class_loss =0.0
    running_hand_mse = 0.0
    running_obj_mse = 0.0
    train_loss_mano = 0.0

    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):

        optimizer.zero_grad()
        imgs,hand_heatmap, obj_heatmap,mano_pose,label = batch
        imgs =imgs.squeeze(0)
        hand_heatmap = hand_heatmap.squeeze(0)
        obj_heatmap = obj_heatmap.squeeze(0)
        mano_pose = mano_pose.squeeze(0)
        mano_pose = mano_pose.squeeze(1).to(device)
        pixel_values = feature_extractor(images=imgs, return_tensors="pt").pixel_values
        pixel_values =  pixel_values.to(device)#pixel_values.to(device)
        label = label.to(device)

        # Forward pass
        logits,mano_pose_pred, hand_attention, obj_attention = model(pixel_values)#,free_attention
        loss_class = criterion(logits, label)
        loss_manopose = F.mse_loss(mano_pose_pred,mano_pose)
        train_loss_mano += loss_manopose.item()
        running_class_loss += loss_class.item()

        hand_heatmap = hand_heatmap.squeeze(1).to(device)
        obj_heatmap = obj_heatmap.squeeze(1).to(device)

        #heatmap stuff
        obj_mse = F.mse_loss(obj_attention, obj_heatmap)
        hand_mse = F.mse_loss(hand_attention,hand_heatmap)
        running_hand_mse += obj_mse.item()
        running_obj_mse += hand_mse.item()

        loss =  alpha * loss_class + beta* hand_mse + gamma * obj_mse + loss_manopose * 100.0

        running_loss += loss.item()
        loss.backward()
        optimizer.step()



    avg_train_loss = running_loss / len(train_loader)
    avg_class_loss = running_class_loss / len(train_loader)
    avg_hand_mse = running_hand_mse / len(train_loader)
    avg_obj_mse = running_obj_mse / len(train_loader)
    avg_train_mano = train_loss_mano / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_train_loss:.4f} ,class: {avg_class_loss:.4f},hand:{avg_hand_mse:.4f},obj:{avg_obj_mse:.4f}, mano:{avg_train_mano:.4f}")

    # Log training loss to wandb
    wandb.log({"train_loss": avg_train_loss, "class_loss": avg_class_loss,"hand_mse":avg_hand_mse, "obj_mse": avg_obj_mse,"avg_train_mano": avg_train_mano, "epoch": epoch + 1})

    # Validation loop
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    val_loss = 0.0
    val_class_loss = 0.0
    val_hand_mse = 0.0
    val_loss_mano = 0.0
    val_obj_mse = 0.0


    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{num_epochs}")):
            imgs,hand_heatmap, obj_heatmap,mano_pose,label = batch
            imgs =imgs.squeeze(0)
            hand_heatmap = hand_heatmap.squeeze(0)
            obj_heatmap = obj_heatmap.squeeze(0)
            pixel_values = feature_extractor(images=imgs, return_tensors="pt").pixel_values
            pixel_values = pixel_values.to(device)
            mano_pose = mano_pose.squeeze(0)
            mano_pose = mano_pose.squeeze(1).to(device)
            label = label.to(device)


            # Forward pass
            logits, mano_pose_val,hand_attention, obj_attention = model(pixel_values)#,free_attention
            loss_class = criterion(logits, label)
            val_class_loss += loss_class.item()
            loss_mano_val = F.mse_loss(mano_pose_pred,mano_pose)
            val_loss_mano += loss_mano_val.item()


            hand_heatmap = hand_heatmap.squeeze(1).to(device)
            obj_heatmap = obj_heatmap.squeeze(1).to(device)

            #heatmap stuff
            obj_mse = F.mse_loss(obj_attention, obj_heatmap)
            hand_mse = F.mse_loss(hand_attention,hand_heatmap)

            val_hand_mse += obj_mse.item()
            val_obj_mse += hand_mse.item()
            loss =  alpha * loss_class + beta* hand_mse + gamma * obj_mse + loss_mano_val * 100.0
            val_loss += loss.item()
            # Calculate accuracy
            _, predicted = torch.max(logits.data, 1)
            total_predictions += label.size(0)
            correct_predictions += (predicted == label).sum().item()

            # Calculate attention MSE
            #attentions = attentions.mean(dim=2).squeeze(1)  # Average over heads and remove singleton dimension


            hand_attention = hand_attention.unsqueeze(1)
            obj_attention = obj_attention.unsqueeze(1)
            #free_attention = free_attention.unsqueeze(1)



            # Log images, heatmaps, and attention masks to wandb
            if batch_idx in [0]:#,11,15,63,73
                save_batch_dir = os.path.join(image_save_dir, f'image_{batch_idx + 1}')
                os.makedirs(save_batch_dir, exist_ok=True)
                # Store images, heatmaps, and attentions for logging
                # Log first eight images
                fig, axes = plt.subplots(5, 8, figsize=(20, 10))
                for j in range(8):

                    img_np = pixel_values[j].permute(1, 2, 0).cpu().numpy()  # Transpose to (height, width, channels)
                    img_save = imgs[j].permute(1, 2, 0).cpu().numpy()
                    axes[0,j].imshow(img_np)
                    axes[0,j].set_title(f'Input Image {j+1}')
                    axes[0,j].axis('off')

                    heat_np = hand_heatmap[j].cpu().numpy()  # Transpose to (height, width, channels)
                    axes[1,j].imshow(heat_np, cmap='hot', interpolation='nearest')
                    axes[1,j].set_title(f'hand_heatmap {j+1}')
                    axes[1,j].axis('off')

                    obj_np = obj_heatmap[j].cpu().numpy()  # Attention map

                    axes[2, j].imshow(obj_np, cmap='hot', interpolation='nearest')
                    axes[2, j].set_title('obj heat Map')
                    axes[2, j].axis('off')

                    hand_attention_np = hand_attention[j].squeeze(0).cpu().numpy()  # Attention map

                    axes[3, j].imshow(hand_attention_np, cmap='hot', interpolation='nearest')
                    axes[3, j].set_title('hand attention')
                    axes[3, j].axis('off')

                    obj_attention_np = obj_attention[j].squeeze(0).cpu().numpy()  # Attention map

                    axes[4, j].imshow(obj_attention_np, cmap='hot', interpolation='nearest')
                    axes[4, j].set_title('obj attention')
                    axes[4, j].axis('off')
                    plt.imsave(os.path.join(save_batch_dir, f'{epoch+1}_original_{j}.png'), img_save / 255)
                    plt.imsave(os.path.join(save_batch_dir, f'{epoch+1}_gt_hand_{j}.png'), heat_np, cmap='hot')
                    plt.imsave(os.path.join(save_batch_dir, f'{epoch+1}_gt_obj_{j}.png'), obj_np, cmap='hot')
                    plt.imsave(os.path.join(save_batch_dir, f'{epoch+1}_hand_{j}.png'), hand_attention_np, cmap='hot')
                    plt.imsave(os.path.join(save_batch_dir, f'{epoch+1}_obj_{j}.png'), obj_attention_np, cmap='hot')
                    #free_attention_np = free_attention[j].squeeze(0).cpu().numpy()  # Attention map

                    #axes[5, j].imshow(free_attention_np, cmap='hot', interpolation='nearest')
                    #axes[5, j].set_title('free attention')
                    #axes[5, j].axis('off')


                action_prediction = predicted[0].item()
                ground_truth = label[0].item()
                fig.suptitle(f'Action Prediction: {action_prediction}, Ground Truth: {ground_truth}')
                wandb.log({"val_image_{}": wandb.Image(fig)})
                plt.close(fig)

    accuracy = correct_predictions / total_predictions
    avg_val_class_loss = val_class_loss / len(val_loader)
    avg_val_loss = val_loss / len(val_loader)
    avg_val_hand_mse = val_hand_mse / len(val_loader)
    avg_val_obj_mse = val_obj_mse / len(val_loader)
    avg_val_mano = val_loss_mano / len(val_loader)

    print(f"Validation - Epoch [{epoch + 1}/{num_epochs}], Accuracy: {accuracy:.4f}, Val Loss: {avg_val_loss:.4f},val_class_loss:{avg_val_class_loss:.4f},val_hand_mse:{avg_val_hand_mse:.4f} ,val_obj_mse:{avg_val_obj_mse:.4f}, val_mano:{avg_val_mano:.4f}")

    # Log validation metrics to wandb
    wandb.log({"val_accuracy": accuracy, "val_loss": avg_val_loss,"val_class_loss": avg_val_class_loss,"val_hand_mse": avg_val_hand_mse,"val_obj_mse": avg_obj_mse, "val_mano": avg_val_mano, "epoch": epoch + 1})

    # Save the best model based on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_path = save_model_path + f"{epoch:02d}_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_{avg_val_loss:.2f}.pth"
        torch.save(model.state_dict(), best_model_path)
        print(f"Saved best model with validation loss: {avg_val_loss:.4f}")

# Save the final trained model
final_model_path = h2o_root + "final_concatenated_action_prediction_model.pth"
torch.save(model.state_dict(), final_model_path)


Training Epoch 1/15: 100%|██████████| 569/569 [03:16<00:00,  2.90it/s]


Epoch [1/15], Loss: 19.5806 ,class: 3.7118,hand:0.0100,obj:0.0111, mano:0.1543


Validation Epoch 1/15: 100%|██████████| 122/122 [00:42<00:00,  2.85it/s]


Validation - Epoch [1/15], Accuracy: 0.1311, Val Loss: 18.7838,val_class_loss:3.4394,val_hand_mse:0.0087 ,val_obj_mse:0.0098, val_mano:0.1496
Saved best model with validation loss: 18.7838


Training Epoch 2/15: 100%|██████████| 569/569 [03:19<00:00,  2.85it/s]


Epoch [2/15], Loss: 14.8846 ,class: 3.4635,hand:0.0090,obj:0.0105, mano:0.1102


Validation Epoch 2/15: 100%|██████████| 122/122 [00:48<00:00,  2.50it/s]


Validation - Epoch [2/15], Accuracy: 0.1230, Val Loss: 27.0498,val_class_loss:3.2160,val_hand_mse:0.0091 ,val_obj_mse:0.0094, val_mano:0.2346


Training Epoch 3/15: 100%|██████████| 569/569 [03:31<00:00,  2.69it/s]


Epoch [3/15], Loss: 10.5758 ,class: 3.1774,hand:0.0082,obj:0.0100, mano:0.0702


Validation Epoch 3/15: 100%|██████████| 122/122 [00:51<00:00,  2.39it/s]


Validation - Epoch [3/15], Accuracy: 0.2623, Val Loss: 21.5172,val_class_loss:2.8791,val_hand_mse:0.0075 ,val_obj_mse:0.0089, val_mano:0.1829


Training Epoch 4/15: 100%|██████████| 569/569 [03:29<00:00,  2.72it/s]


Epoch [4/15], Loss: 8.2729 ,class: 2.7439,hand:0.0075,obj:0.0098, mano:0.0516


Validation Epoch 4/15: 100%|██████████| 122/122 [00:55<00:00,  2.22it/s]


Validation - Epoch [4/15], Accuracy: 0.3197, Val Loss: 32.2487,val_class_loss:2.2590,val_hand_mse:0.0078 ,val_obj_mse:0.0090, val_mano:0.2964


Training Epoch 5/15: 100%|██████████| 569/569 [03:21<00:00,  2.83it/s]


Epoch [5/15], Loss: 6.9909 ,class: 2.1623,hand:0.0071,obj:0.0096, mano:0.0447


Validation Epoch 5/15: 100%|██████████| 122/122 [00:50<00:00,  2.40it/s]


Validation - Epoch [5/15], Accuracy: 0.3852, Val Loss: 24.6330,val_class_loss:1.7186,val_hand_mse:0.0072 ,val_obj_mse:0.0087, val_mano:0.2258


Training Epoch 6/15: 100%|██████████| 569/569 [03:24<00:00,  2.78it/s]


Epoch [6/15], Loss: 5.6183 ,class: 1.6313,hand:0.0069,obj:0.0092, mano:0.0364


Validation Epoch 6/15: 100%|██████████| 122/122 [00:54<00:00,  2.26it/s]


Validation - Epoch [6/15], Accuracy: 0.5328, Val Loss: 40.6379,val_class_loss:1.3373,val_hand_mse:0.0067 ,val_obj_mse:0.0089, val_mano:0.3897


Training Epoch 7/15: 100%|██████████| 569/569 [03:28<00:00,  2.73it/s]


Epoch [7/15], Loss: 4.4207 ,class: 1.0655,hand:0.0066,obj:0.0090, mano:0.0302


Validation Epoch 7/15: 100%|██████████| 122/122 [00:52<00:00,  2.30it/s]


Validation - Epoch [7/15], Accuracy: 0.6885, Val Loss: 19.7377,val_class_loss:0.9456,val_hand_mse:0.0069 ,val_obj_mse:0.0085, val_mano:0.1847


Training Epoch 8/15: 100%|██████████| 569/569 [03:30<00:00,  2.70it/s]


Epoch [8/15], Loss: 3.9627 ,class: 0.6780,hand:0.0064,obj:0.0088, mano:0.0296


Validation Epoch 8/15: 100%|██████████| 122/122 [00:49<00:00,  2.47it/s]


Validation - Epoch [8/15], Accuracy: 0.6639, Val Loss: 25.9970,val_class_loss:0.8803,val_hand_mse:0.0069 ,val_obj_mse:0.0086, val_mano:0.2479


Training Epoch 9/15: 100%|██████████| 569/569 [03:28<00:00,  2.74it/s]


Epoch [9/15], Loss: 4.9079 ,class: 0.7417,hand:0.0072,obj:0.0094, mano:0.0381


Validation Epoch 9/15: 100%|██████████| 122/122 [00:48<00:00,  2.54it/s]


Validation - Epoch [9/15], Accuracy: 0.7705, Val Loss: 20.6520,val_class_loss:0.7141,val_hand_mse:0.0077 ,val_obj_mse:0.0089, val_mano:0.1959


Training Epoch 10/15:   2%|▏         | 11/569 [00:07<06:06,  1.52it/s]


KeyboardInterrupt: 

In [None]:
# Define the path to the model weights
weight_path = h2o_root + "best_both_heat_action_prediction_model.pth"

# Load the model weights
model.load_state_dict(torch.load(weight_path))

# Move the model to the appropriate device
model.to(device)

In [29]:
save_dir = h2o_root + 'mano_val_out_good/'
os.makedirs(save_dir, exist_ok=True)
with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{num_epochs}")):
            imgs,hand_heatmap, obj_heatmap,mano_pose,label = batch
            imgs =imgs.squeeze(0)
            hand_heatmap = hand_heatmap.squeeze(0)
            obj_heatmap = obj_heatmap.squeeze(0)
            pixel_values = feature_extractor(images=imgs, return_tensors="pt").pixel_values
            pixel_values = pixel_values.to(device)
            mano_pose = mano_pose.squeeze(0)
            mano_pose = mano_pose.squeeze(1).to(device)
            label = label.to(device)


            # Forward pass
            logits, mano_pose_val,hand_attention, obj_attention = model(pixel_values)#,free_attention
            mano_pose_val = mano_pose_val.cpu().numpy()  # Convert to numpy
            file_path = os.path.join(save_dir, f"{batch_idx:03d}.npy")
            np.save(file_path, mano_pose_val)

Validation Epoch 10/15: 100%|██████████| 122/122 [00:48<00:00,  2.52it/s]


In [None]:
class TestData(torch.utils.data.DataLoader):
    def __init__(self):
        self.img_path = h2o_root + "framesequences_8_test/"
        self.num_actions = int(len(os.listdir(self.img_path)))

    def __len__(self):
        return self.num_actions

    def __getitem__(self, idx):
        img = np.load(self.img_path + format(idx + 1) + ".npy")
        img = np.moveaxis(img, -1, 0)
        img = torch.from_numpy(img).float()

        return img

In [None]:
test_dataset = TestData()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
save_model_path = h2o_root + 'models__new496/'

In [None]:
model_files = [os.path.join(save_model_path, f) for f in os.listdir(save_model_path) if f.endswith('.pth')]
models = []
print(sorted(model_files))
for model_file in sorted(model_files):
    model = InterleaveHeatmapViTActionPredictionModel(vit_model, num_classes, sequence_length)
    model.load_state_dict(torch.load(model_file))
    models.append(model)

['/content/drive/My Drive/vit_3d/models__new496/00_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_3.81.pth', '/content/drive/My Drive/vit_3d/models__new496/01_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_2.70.pth', '/content/drive/My Drive/vit_3d/models__new496/02_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_1.42.pth', '/content/drive/My Drive/vit_3d/models__new496/03_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_0.90.pth', '/content/drive/My Drive/vit_3d/models__new496/04_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_0.51.pth', '/content/drive/My Drive/vit_3d/models__new496/06_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_0.35.pth', '/content/drive/My Drive/vit_3d/models__new496/07_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_0.28.pth', '/content/drive/My Drive/vit_3d/models__new496/14_mean_dual_10_2_attention_496_1_30_10_action_prediction_model_0.25.pth']


In [None]:
import json
json_base = save_model_path + 'action_results/'

os.makedirs(json_base, exist_ok=True)


In [None]:
for i,m in enumerate(models):
    m.to(device)
    m.eval()
    epoch =0
    num_epochs = 1
    img_size =224
    predictions ={}
    predictions["modality"] = "training: rgb + heatmaps, test: rgb"
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(test_loader, desc=f"testdation Epoch {epoch + 1}/{num_epochs}")):
            img = batch
            images = img.permute(2, 1, 3, 4, 0).squeeze(-1) # Reshape to (batch_size * sequence_length, 3, H, W)
            # Define the crop size and preprocess images
            crop_size = 360
            start = (images.shape[-1]) // 2
            cropped_images = images[:, :, :, int(start-crop_size/2):int(start+crop_size/2)]
            cropped_images = nn.functional.interpolate(cropped_images, size=(img_size, img_size), mode='nearest')
            pixel_values = feature_extractor(images=cropped_images, return_tensors="pt").pixel_values
            pixel_values = pixel_values.to(device)

            # Forward pass
            logits, hand,obj = m(pixel_values)#,obj

            # Calculate accuracy
            _, predicted = torch.max(logits.data, 1)

            # Calculate attention MSE
            #attentions = attentions.mean(dim=2).squeeze(1)  # Average over heads and remove singleton dimension
            # Store images, heatmaps, and attentions for logging
          # Log first eight images

            predictions[f'{batch_idx + 1}'] = predicted[0].item()

            action_prediction = predicted[0].item()


    with open(json_base + f'action_labels{i}.json', 'w') as json_file:
        json.dump(predictions, json_file)
    # Log validation metrics to wandb

testdation Epoch 1/1: 100%|██████████| 242/242 [01:40<00:00,  2.40it/s]
testdation Epoch 1/1: 100%|██████████| 242/242 [01:40<00:00,  2.40it/s]
testdation Epoch 1/1: 100%|██████████| 242/242 [01:40<00:00,  2.40it/s]
testdation Epoch 1/1: 100%|██████████| 242/242 [01:40<00:00,  2.40it/s]
testdation Epoch 1/1: 100%|██████████| 242/242 [01:40<00:00,  2.40it/s]
testdation Epoch 1/1: 100%|██████████| 242/242 [01:40<00:00,  2.40it/s]
testdation Epoch 1/1: 100%|██████████| 242/242 [01:40<00:00,  2.40it/s]
testdation Epoch 1/1: 100%|██████████| 242/242 [01:40<00:00,  2.40it/s]


In [None]:
json_files = [os.path.join(json_base, f) for f in os.listdir(json_base) if f.endswith('.json')]
with open(h2o_root + 'action_labels_gt.json', 'r') as file:
    gt_data = json.load(file)

for i,j in enumerate(json_files):
    with open(j, 'r') as file:
        data = json.load(file)
    data.pop('modality', None)
    score = 0
    for key in data:
        if key in gt_data and data[key] == gt_data[key]:

            score += 1

    prediction_score = score / len(data) * 100
    print(f"Prediction Score: {prediction_score}%")

Prediction Score: 87.60330578512396%
Prediction Score: 87.19008264462809%
Prediction Score: 44.62809917355372%
Prediction Score: 59.09090909090909%
Prediction Score: 69.42148760330579%
Prediction Score: 76.03305785123968%
Prediction Score: 85.53719008264463%
Prediction Score: 87.19008264462809%
Prediction Score: 87.60330578512396%
Prediction Score: 88.01652892561982%
