In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from movinets import MoViNet
from movinets.config import _C

import cv2
import numpy as np
from tqdm.notebook import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")



Using device: cuda


In [34]:
import os

def prepare_dataset_paths_and_labels_recursive(root_dir, class_map):
    video_folders = []
    labels = []
    for class_name, label_id in class_map.items():
        class_path = os.path.join(root_dir, class_name)
        if not os.path.exists(class_path):
            continue
        # Walk recursively starting from class_path
        for subdir, dirs, files in os.walk(class_path):
            # If this folder contains image frames, treat it as a video folder
            if any(f.lower().endswith(('.jpg', '.png')) for f in files):
                video_folders.append(subdir)
                labels.append(label_id)
    return video_folders, labels



In [35]:
binary_class_map = {
    'Tiger_Normal': 0,
    'Tiger_Abnormal': 1,
}

multi_class_map = {
    'Dehydration or Heat Stroke': 0,
    'Digestive Issues': 1,
    'Eye Injury': 2,
    'Injured_Tiger': 3,
    'Lethargy, Apathy, Unresponsive, and Listless Tiger': 4,
    'Neurological Issues': 5,
    'Nutritional_Deficiencies': 6,
    'Oral or Dental Issues or Respiratory distress': 7,
    'Skin Desease or irritation_Tiger': 8,
    'Sress_Frustation': 9,
    'Tremors or Seizures': 10,
    'underweightness or emaciation': 11,
    'Weakness': 12,
    'Zoochosis_stereotypic behavior': 13,
    'Zoonotic Disease Behavior': 14

}

multi_class_id_to_name = {v: k for k, v in multi_class_map.items()}



In [36]:
train_root = 'frames_dataset/train'
test_root = 'frames_dataset/test'

train_video_folders_bin, train_labels_bin = prepare_dataset_paths_and_labels_recursive(train_root, binary_class_map)
test_video_folders_bin, test_labels_bin = prepare_dataset_paths_and_labels_recursive(test_root, binary_class_map)

train_abnormal_root = os.path.join(train_root, 'Tiger_Abnormal')
test_abnormal_root = os.path.join(test_root, 'Tiger_Abnormal')
train_video_folders_multi, train_labels_multi = prepare_dataset_paths_and_labels_recursive(train_abnormal_root, multi_class_map)
test_video_folders_multi, test_labels_multi = prepare_dataset_paths_and_labels_recursive(test_abnormal_root, multi_class_map)



In [37]:
def estimate_3d_pose(frames):
    """
    Dummy example function: in reality, load and run your 3D pose estimator model here.
    Input: frames as list of RGB numpy arrays or tensor of shape [clip_len, C, H, W]
    Output: 3D keypoints tensor of shape [clip_len, num_keypoints * 3] (x, y, z coordinates)
    """
    clip_len = len(frames)                                                                              
    
    num_keypoints = 34
    # Dummy random 3D keypoints to illustrate; replace with your model inference
    pose_3d = torch.rand(clip_len, num_keypoints * 3)
    return pose_3d

In [38]:
import torch
from torch.utils.data import Dataset
import cv2
import numpy as np

class TigerBehaviorDataset(Dataset):
    def __init__(self, video_folders, labels, pose_folder_root=None,
                 clip_len=8, frame_size=(224, 224), transform=None,
                 use_3d_pose=False):
        self.video_folders = video_folders
        self.labels = labels
        self.pose_folder_root = pose_folder_root
        self.clip_len = clip_len
        self.frame_size = frame_size
        self.transform = transform
        self.use_3d_pose = use_3d_pose

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        video_folder = self.video_folders[idx]
        label = self.labels[idx]

        frame_files = sorted([
            os.path.join(video_folder, f)
            for f in os.listdir(video_folder)
            if f.lower().endswith(('.jpg','.png'))
        ])
        if len(frame_files) == 0:
            raise RuntimeError(f"No frames found in folder {video_folder}")

        if len(frame_files) < self.clip_len:
            frame_files += [frame_files[-1]] * (self.clip_len - len(frame_files))
        else:
            frame_files = frame_files[:self.clip_len]

        frames = []
        for fpath in frame_files:
            img = cv2.imread(fpath)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, self.frame_size)
            if self.transform:
                img = self.transform(img)
            else:
                img = torch.FloatTensor(img / 255.0).permute(2, 0, 1)
            frames.append(img)
        frames_tensor = torch.stack(frames, dim=1)  # Shape: (C, T, H, W)

        if self.use_3d_pose:
            # Estimate 3D pose dynamically
            # Convert frames to list or tensor shape suitable for your estimator
            frames_for_pose = [frame.permute(1, 2, 0).numpy() for frame in frames]  # list of HWC images
            pose_3d_tensor = estimate_3d_pose(frames_for_pose)  # shape: [clip_len, num_keypoints*3]
        elif self.pose_folder_root:
            # Load precomputed 2D/3D poses from file
            pose_filename = os.path.basename(video_folder) + '.npy'
            pose_path = os.path.join(self.pose_folder_root, pose_filename)
            pose_seq = np.load(pose_path)
            pose_3d_tensor = torch.from_numpy(pose_seq).float()
        else:
            # If no poses available, provide zeros (adjust input dimension accordingly)
            pose_3d_tensor = torch.zeros(self.clip_len, 34 * 3)  # assuming 3D with 34 keypoints

        return frames_tensor, pose_3d_tensor, label

    


In [39]:
from torch.utils.data import DataLoader

# Binary datasets and loaders
train_dataset_bin = TigerBehaviorDataset(train_video_folders_bin, train_labels_bin, pose_folder_root=None, clip_len=8)
test_dataset_bin = TigerBehaviorDataset(test_video_folders_bin, test_labels_bin, pose_folder_root=None, clip_len=8)

import multiprocessing as mp
train_loader_bin = DataLoader(
    train_dataset_bin,
    batch_size=32,
    pin_memory=True,
)

test_loader_bin = DataLoader(test_dataset_bin, batch_size=6, shuffle=False, num_workers=0)

# Multi-class datasets and loaders (only abnormal)
train_dataset_multi = TigerBehaviorDataset(train_video_folders_multi, train_labels_multi, pose_folder_root=None, clip_len=8)
test_dataset_multi = TigerBehaviorDataset(test_video_folders_multi, test_labels_multi, pose_folder_root=None, clip_len=8)
train_loader_multi = DataLoader(train_dataset_multi, batch_size=8, shuffle=True, num_workers=0)
test_loader_multi = DataLoader(test_dataset_multi, batch_size=8, shuffle=False, num_workers=0)



In [40]:
import torch.nn as nn
import torch

class DummyMoViNet(nn.Module):
    def __init__(self, feat_dim=600):
        super().__init__()
        self.feat_dim = feat_dim
    def forward(self, x):
        # Dummy output: random tensor for demo
        return torch.randn(x.size(0), self.feat_dim).to(x.device)

class PoseTransformer(nn.Module):
    def __init__(self, input_dim=34, num_classes=2):
        super().__init__()
        self.transformer_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=2)
        self.transformer = nn.TransformerEncoder(self.transformer_layer, num_layers=2)
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # x shape: (seq_len, batch, input_dim)
        x = self.transformer(x)
        x = x.mean(dim=0)
        return self.fc(x)

class FusionClassifier(nn.Module):
    def __init__(self, movinet_feat_dim=600, pose_feat_dim=34, num_classes=2):
        super().__init__()
        self.fc1 = nn.Linear(movinet_feat_dim + pose_feat_dim, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, movinet_feat, pose_feat):
        x = torch.cat([movinet_feat, pose_feat], dim=1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)



In [41]:

movinet_bin = DummyMoViNet()
pose_model_bin = PoseTransformer(input_dim=102, num_classes=2).to(device)
fusion_model_bin = FusionClassifier(movinet_feat_dim=600, pose_feat_dim=2, num_classes=2).to(device)

movinet_bin.to(device)
pose_model_bin.to(device)
fusion_model_bin.to(device)

# Multi-class models
movinet_multi = DummyMoViNet()
pose_model_multi = PoseTransformer(input_dim=102, num_classes=len(multi_class_map)).to(device)
fusion_model_multi = FusionClassifier(movinet_feat_dim=600, pose_feat_dim=2, num_classes=len(multi_class_map)).to(device)

movinet_multi.to(device)
pose_model_multi.to(device)
fusion_model_multi.to(device)



FusionClassifier(
  (fc1): Linear(in_features=602, out_features=512, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=15, bias=True)
)

In [42]:
def infer_video(frames_tensor, pose_tensor,
                movinet_bin, pose_model_bin, fusion_model_bin,
                movinet_multi, pose_model_multi, fusion_model_multi,
                multi_class_id_to_name, device):

    frames_tensor = frames_tensor.unsqueeze(0).to(device)  # Add batch dim
    pose_tensor = pose_tensor.unsqueeze(0).to(device)      # Add batch dim
    pose_input = pose_tensor.permute(1,0,2)  # (seq_len, batch, pose_dim)

    movinet_bin.eval()
    pose_model_bin.eval()
    fusion_model_bin.eval()

    movinet_multi.eval()
    pose_model_multi.eval()
    fusion_model_multi.eval()

    with torch.no_grad():
        # Binary prediction
        movinet_feats_bin = movinet_bin(frames_tensor)
        pose_feats_bin = pose_model_bin(pose_input)
        logits_bin = fusion_model_bin(movinet_feats_bin, pose_feats_bin)
        pred_bin = torch.argmax(logits_bin, dim=1).item()

        if pred_bin == 0:
            # Normal class
            return "Normal"
        else:
            # Abnormal class → run multi-class inference
            movinet_feats_multi = movinet_multi(frames_tensor)
            pose_feats_multi = pose_model_multi(pose_input)
            logits_multi = fusion_model_multi(movinet_feats_multi, pose_feats_multi)
            pred_multi = torch.argmax(logits_multi, dim=1).item()
            subclass_name = multi_class_id_to_name.get(pred_multi, "Unknown Abnormal Class")
            return f"Abnormal - Subclass: {subclass_name}"
        


In [43]:
# Example: load one sample from your dataset (binary or multi-class does not matter)
frames_tensor, pose_tensor, label = test_dataset_bin[0]

# Run inference
result = infer_video(frames_tensor, pose_tensor,
                     movinet_bin, pose_model_bin, fusion_model_bin,
                     movinet_multi, pose_model_multi, fusion_model_multi,
                     multi_class_id_to_name,
                     device)

print("Inference result:", result)



Inference result: Normal


In [44]:
# Assume frames_tensor and pose_tensor are already loaded and prepared inputs
movinet_bin.eval()
pose_model_bin.eval()
fusion_model_bin.eval()

with torch.no_grad():
    movinet_feats_bin = movinet_bin(frames_tensor.unsqueeze(0).to(device))  # batch dimension added
    pose_feats_bin = pose_model_bin(pose_tensor.unsqueeze(0).permute(1, 0, 2).to(device))
    


In [None]:
print("MoViNet feature shape:", movinet_feats_bin.shape)
print("Pose feature shape:", pose_feats_bin.shape)


MoViNet feature shape: torch.Size([1, 600])
Pose feature shape: torch.Size([1, 2])


In [46]:
for frames, poses, labels in train_loader_bin:
    print("frames shape:", frames.shape)       # [batch_size, 3, seq_len, 224, 224]
    print("poses shape before permute:", poses.shape)  # [batch_size, seq_len, pose_dim]
    poses = poses.permute(1, 0, 2)  # Permute if needed
    print("poses shape after permute:", poses.shape)  
    print("labels shape:", labels.shape)

    # Break early if you want to just check one batch
    break



frames shape: torch.Size([32, 3, 8, 224, 224])
poses shape before permute: torch.Size([32, 8, 102])
poses shape after permute: torch.Size([8, 32, 102])
labels shape: torch.Size([32])


In [47]:
import torch
import torch.optim as optim
from tqdm import tqdm
from torchmetrics.classification import Accuracy

pose_model_bin = pose_model_bin.to(device)
fusion_model_bin = fusion_model_bin.to(device)
movinet_bin = movinet_bin.to(device)

num_epochs = 30

optimizer = torch.optim.SGD(
    list(pose_model_bin.parameters()) +
    list(fusion_model_bin.parameters()) +
    list(movinet_bin.parameters()),
    lr=0.02
)

criterion = torch.nn.CrossEntropyLoss()

train_acc_metric = Accuracy(task="binary").to(device)
val_acc_metric = Accuracy(task="binary").to(device)

for epoch in range(num_epochs):
    pose_model_bin.train()
    fusion_model_bin.train()
    train_acc_metric.reset()
    running_train_loss = 0.0
    train_samples = 0

    with tqdm(train_loader_bin, unit="batch") as train_iter:
        train_iter.set_description(f"Epoch {epoch+1}/{num_epochs} [Train]")

        for frames, poses, labels in train_iter:
            frames = frames.to(device)
            labels = labels.to(device)
            poses = poses.permute(1, 0, 2).to(device)

            optimizer.zero_grad()

            with torch.no_grad():
                movinet_feats = movinet_bin(frames)

            pose_feats = pose_model_bin(poses)
            outputs = fusion_model_bin(movinet_feats, pose_feats)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            preds = torch.argmax(outputs, dim=1)
            train_acc_metric.update(preds, labels)

            batch_size = labels.size(0)
            running_train_loss += loss.item() * batch_size
            train_samples += batch_size

            train_acc = train_acc_metric.compute().item()
            train_iter.set_postfix(loss=f"{running_train_loss/train_samples:.4f}", accuracy=f"{train_acc:.4f}")

    epoch_train_loss = running_train_loss / train_samples
    epoch_train_acc = train_acc_metric.compute().item()

    # Validation loop needs val_loader_bin similarly defined and active, else skip

    print(
        f"Epoch {epoch+1}/{num_epochs} - "
        f"Train loss: {epoch_train_loss:.4f}, Train accuracy: {epoch_train_acc:.4f}"
    )
    


Epoch 1/30 [Train]: 100%|██████████| 3/3 [00:19<00:00,  6.52s/batch, accuracy=0.3830, loss=0.7162]


Epoch 1/30 - Train loss: 0.7162, Train accuracy: 0.3830


Epoch 2/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.37s/batch, accuracy=0.3191, loss=0.7262]


Epoch 2/30 - Train loss: 0.7262, Train accuracy: 0.3191


Epoch 3/30 [Train]: 100%|██████████| 3/3 [00:09<00:00,  3.32s/batch, accuracy=0.3830, loss=0.7156]


Epoch 3/30 - Train loss: 0.7156, Train accuracy: 0.3830


Epoch 4/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.58s/batch, accuracy=0.2447, loss=0.7378]


Epoch 4/30 - Train loss: 0.7378, Train accuracy: 0.2447


Epoch 5/30 [Train]: 100%|██████████| 3/3 [00:09<00:00,  3.31s/batch, accuracy=0.3830, loss=0.7208]


Epoch 5/30 - Train loss: 0.7208, Train accuracy: 0.3830


Epoch 6/30 [Train]: 100%|██████████| 3/3 [00:09<00:00,  3.31s/batch, accuracy=0.3085, loss=0.7237]


Epoch 6/30 - Train loss: 0.7237, Train accuracy: 0.3085


Epoch 7/30 [Train]: 100%|██████████| 3/3 [00:09<00:00,  3.30s/batch, accuracy=0.2447, loss=0.7323]


Epoch 7/30 - Train loss: 0.7323, Train accuracy: 0.2447


Epoch 8/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.37s/batch, accuracy=0.3191, loss=0.7254]


Epoch 8/30 - Train loss: 0.7254, Train accuracy: 0.3191


Epoch 9/30 [Train]: 100%|██████████| 3/3 [00:09<00:00,  3.16s/batch, accuracy=0.2766, loss=0.7243]


Epoch 9/30 - Train loss: 0.7243, Train accuracy: 0.2766


Epoch 10/30 [Train]: 100%|██████████| 3/3 [00:09<00:00,  3.18s/batch, accuracy=0.2766, loss=0.7258]


Epoch 10/30 - Train loss: 0.7258, Train accuracy: 0.2766


Epoch 11/30 [Train]: 100%|██████████| 3/3 [00:09<00:00,  3.24s/batch, accuracy=0.2766, loss=0.7295]


Epoch 11/30 - Train loss: 0.7295, Train accuracy: 0.2766


Epoch 12/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.39s/batch, accuracy=0.3617, loss=0.7194]


Epoch 12/30 - Train loss: 0.7194, Train accuracy: 0.3617


Epoch 13/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.40s/batch, accuracy=0.2979, loss=0.7267]


Epoch 13/30 - Train loss: 0.7267, Train accuracy: 0.2979


Epoch 14/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.39s/batch, accuracy=0.3617, loss=0.7205]


Epoch 14/30 - Train loss: 0.7205, Train accuracy: 0.3617


Epoch 15/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.43s/batch, accuracy=0.3298, loss=0.7170]


Epoch 15/30 - Train loss: 0.7170, Train accuracy: 0.3298


Epoch 16/30 [Train]: 100%|██████████| 3/3 [00:09<00:00,  3.30s/batch, accuracy=0.2872, loss=0.7237]


Epoch 16/30 - Train loss: 0.7237, Train accuracy: 0.2872


Epoch 17/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.42s/batch, accuracy=0.3298, loss=0.7198]


Epoch 17/30 - Train loss: 0.7198, Train accuracy: 0.3298


Epoch 18/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.41s/batch, accuracy=0.2979, loss=0.7191]


Epoch 18/30 - Train loss: 0.7191, Train accuracy: 0.2979


Epoch 19/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.60s/batch, accuracy=0.4362, loss=0.7100]


Epoch 19/30 - Train loss: 0.7100, Train accuracy: 0.4362


Epoch 20/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.58s/batch, accuracy=0.3298, loss=0.7231]


Epoch 20/30 - Train loss: 0.7231, Train accuracy: 0.3298


Epoch 21/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.51s/batch, accuracy=0.3191, loss=0.7166]


Epoch 21/30 - Train loss: 0.7166, Train accuracy: 0.3191


Epoch 22/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.51s/batch, accuracy=0.2766, loss=0.7261]


Epoch 22/30 - Train loss: 0.7261, Train accuracy: 0.2766


Epoch 23/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.39s/batch, accuracy=0.3511, loss=0.7232]


Epoch 23/30 - Train loss: 0.7232, Train accuracy: 0.3511


Epoch 24/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.39s/batch, accuracy=0.2553, loss=0.7265]


Epoch 24/30 - Train loss: 0.7265, Train accuracy: 0.2553


Epoch 25/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.38s/batch, accuracy=0.3617, loss=0.7162]


Epoch 25/30 - Train loss: 0.7162, Train accuracy: 0.3617


Epoch 26/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.38s/batch, accuracy=0.3511, loss=0.7292]


Epoch 26/30 - Train loss: 0.7292, Train accuracy: 0.3511


Epoch 27/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.42s/batch, accuracy=0.3936, loss=0.7158]


Epoch 27/30 - Train loss: 0.7158, Train accuracy: 0.3936


Epoch 28/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.41s/batch, accuracy=0.3830, loss=0.7168]


Epoch 28/30 - Train loss: 0.7168, Train accuracy: 0.3830


Epoch 29/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.41s/batch, accuracy=0.3830, loss=0.7108]


Epoch 29/30 - Train loss: 0.7108, Train accuracy: 0.3830


Epoch 30/30 [Train]: 100%|██████████| 3/3 [00:10<00:00,  3.66s/batch, accuracy=0.3298, loss=0.7168]

Epoch 30/30 - Train loss: 0.7168, Train accuracy: 0.3298



