In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch import amp 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import confusion_matrix
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
sys.path.append(r"C:\Users\jashw\Desktop\Video Surveillance")
from datasets.custom_dataset import get_data_loader  # Your custom dataset loader
from TimeSformer.timesformer.models.vit import TimeSformer  # Import TimeSformer class


# -------------------
# Hyperparameters and Paths
# -------------------
MODEL_PATH = r"C:\Users\jashw\Desktop\Video Surveillance\data\trained_models\TimeSformer_divST_96x4_224_K600.pyth"
TRAIN_DIR = r"C:\Users\jashw\Desktop\Video Surveillance\data\train"  # Training dataset directory
VAL_DIR = r"C:\Users\jashw\Desktop\Video Surveillance\data\val"
BATCH_SIZE = 6
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4
NUM_CLASSES = 8  # Target classes for your use-case


# -------------------
# Function to load the pretrained TimeSformer model
# -------------------
def load_model(pretrained_path, device):
    print(f"Loading pre-trained model from: {pretrained_path}")
    
    # Initialize model with num_classes=600 to match checkpoint
    model = TimeSformer(
        img_size=224,
        num_classes=600,
        num_frames=96,
        attention_type='divided_space_time',
        pretrained_model=pretrained_path  # This prevents external download if file exists
    )
    
    # Load checkpoint (checkpoint has keys: ['epoch', 'model_state', 'optimizer_state', 'cfg'])
    checkpoint = torch.load(pretrained_path, map_location=device)
    # Use 'model_state' since that's the key containing the weights
    model_state = checkpoint.get('model_state', checkpoint)
    # Load all weights except the classifier head (which are for 600 classes)
    model_state = {k: v for k, v in model_state.items() if not k.startswith("head")}
    model.load_state_dict(model_state, strict=False)
    
    model.to(device)
    print("Model loaded successfully!")
    return model


# -------------------
# Function to modify model for partial fine-tuning
# -------------------
def modify_for_partial_training(model, device, num_classes=8, freeze_layers=True):
    # Replace the classifier head with a new one for the target number of classes
    in_features = model.model.head.in_features
    model.model.head = nn.Linear(in_features, num_classes).to(device)

    
    # Freeze all parameters except those in the new classifier head
    if freeze_layers:
        for name, param in model.named_parameters():
            if "model.head" not in name:
                param.requires_grad = False

                
    print("Model modified for partial fine-tuning (only classifier head will be trained).")
    return model


# -------------------
# Training and Validation Functions
# -------------------
from torch.optim.lr_scheduler import ReduceLROnPlateau

def train_partial(
    model, train_loader, val_loader, device,
    epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    target_accuracy=95.0,
    checkpoint_dir="checkpoints",
    log_dir="runs/TimeSformer",
    class_list = class_list
    
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)
    class_list = class_list
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.model.head.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    scaler = amp.GradScaler()
    
    training_loss = []
    training_acc = []
    validation_loss = []
    val_acc = []
    
    best_accuracy = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            print("Input shape to model:", inputs.shape)
            with amp.autocast('cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)

        val_loss, val_accuracy = validate(
            model, val_loader, device,
            writer=writer,
            epoch=epoch,
            class_names=class_list
        )

        training_loss.append(avg_loss)
        validation_loss.append(val_loss)
        val_acc.append(val_accuracy)
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

        # 🟢 TensorBoard Logging
        writer.add_scalar("Loss/train", avg_loss, epoch)
        writer.add_scalar("LearningRate", optimizer.param_groups[0]['lr'], epoch)

        # 🧠 Step the LR scheduler
        scheduler.step(val_loss)

        # 🏁 Save if better
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            best_model_path = os.path.join(checkpoint_dir, f"best_model_epoch{epoch+1}_acc{val_accuracy:.2f}.pth")
            torch.save(model.state_dict(), best_model_path)
            print(f"✅ Saved better model at {best_model_path}")

        if val_accuracy >= target_accuracy:
            print(f"🎯 Target accuracy of {target_accuracy}% reached. Stopping training early.")
            break

    writer.close()
    print("Training complete!")


def plot_confusion_matrix(cm, class_names, normalize=False, title="Confusion Matrix"):
    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1, keepdims=True)
        fmt = ".2f"
    else:
        fmt = "d"

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(
        cm,
        annot=True,
        fmt=fmt,
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
        square=True,
        linewidths=0.5,
        cbar=True
    )

    ax.set_xlabel("Predicted Labels", fontsize=12)
    ax.set_ylabel("True Labels", fontsize=12)
    ax.set_title(title, fontsize=14)
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    return fig
def validate(model, val_loader, device, writer=None, epoch=None, class_names=None):
    model.eval()
    total = 0
    correct = 0
    running_val_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total if total > 0 else 0
    print(f"\n✅ Validation Accuracy: {accuracy:.2f}%")

    # ➕ Confusion Matrix
    if writer is not None and epoch is not None and class_names is not None:
        cm = confusion_matrix(all_labels, all_preds)

        # Raw Confusion Matrix
        raw_fig = plot_confusion_matrix(cm, class_names, normalize=False, title="Confusion Matrix (Raw Counts)")
        writer.add_figure("ConfusionMatrix/Raw", raw_fig, global_step=epoch)

        # Normalized Confusion Matrix
        norm_fig = plot_confusion_matrix(cm, class_names, normalize=True, title="Confusion Matrix (Normalized)")
        writer.add_figure("ConfusionMatrix/Normalized", norm_fig, global_step=epoch)

        # ➕ Classification Report
        report_str = classification_report(all_labels, all_preds, target_names=class_names)
        print(f"\n📊 Classification Report (Epoch {epoch}):\n{report_str}")

        # Log the text report as TensorBoard text
        writer.add_text("ClassificationReport", f"```\n{report_str}\n```", global_step=epoch)

    torch.cuda.empty_cache()
    return accuracy

# -------------------
# Main Function
# -------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load the pretrained model (with 600 classes) from the checkpoint
    model = load_model(MODEL_PATH, device)
    print(model)
    # Modify the model for your 8-class use case and freeze early layers for partial fine-tuning
    model = modify_for_partial_training(model, device, num_classes=NUM_CLASSES, freeze_layers=True)
    print(model)
    
    # Load training and validation data using your custom DataLoader
    train_loader = get_data_loader(TRAIN_DIR, batch_size=BATCH_SIZE, clip_len=96, shuffle=True)
    val_loader = get_data_loader(VAL_DIR, batch_size=BATCH_SIZE, clip_len=96, shuffle=False)
    
    # Optionally, compute class weights if needed for imbalanced data and integrate with CrossEntropyLoss
    
    class_list = ['Abuse', 'Explosion', 'Fighting', 'RoadAccidents', 'Robbery', 'Shooting', 'Vandalism', 'z_Normal_Videos_event']
    
    # Train the model partially (only the classifier head will be updated)
    train_partial(
    model, 
    train_loader, 
    val_loader, 
    device, 
    epochs=NUM_EPOCHS, 
    lr=LEARNING_RATE,
    target_accuracy=96.0,
    checkpoint_dir="data/trained_models/checkpoints",
    log_dir="runs/TimeSformer",
    class_list=class_list
    )

    
    # Save the fine-tuned model
    save_path = r"data\trained_models\timesformer_partial_finetuned.pth"
    torch.save(model.state_dict(), save_path)
    print(f"Partially fine-tuned model saved to {save_path}")


if __name__ == "__main__":
    main()



In [None]:
import os
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch import amp 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_fscore_support
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Append project root to sys.path if needed
sys.path.append(r"C:\Users\jashw\Desktop\Video Surveillance")
from datasets.custom_dataset import get_data_loader  # Your custom dataset loader
from TimeSformer.timesformer.models.vit import TimeSformer  # Import TimeSformer class

# -------------------
# Hyperparameters and Paths
# -------------------
MODEL_PATH = r"C:\Users\jashw\Desktop\Video Surveillance\data\trained_models\TimeSformer_divST_96x4_224_K600.pyth"
TRAIN_DIR = r"C:\Users\jashw\Desktop\Video Surveillance\data\train"  # Training dataset directory
VAL_DIR = r"C:\Users\jashw\Desktop\Video Surveillance\data\val"
BATCH_SIZE = 1
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
NUM_CLASSES = 8  # Target classes for your use-case

# Define actual class names (order must match your dataset labels)
CLASS_NAMES = ['Abuse', 'Explosion', 'Fighting', 'RoadAccidents', 'Robbery', 'Shooting', 'Vandalism', 'z_Normal_Videos_event']

# -------------------
# Function to load the pretrained TimeSformer model
# -------------------
def load_model(pretrained_path, device):
    print(f"Loading pre-trained model from: {pretrained_path}")
    
    # Initialize model with num_classes=600 to match checkpoint
    model = TimeSformer(
        img_size=224,
        num_classes=600,
        num_frames=96,
        attention_type='divided_space_time',
        pretrained_model=pretrained_path  # Prevents external download if file exists
    )
    
    # Load checkpoint (checkpoint has keys: ['epoch', 'model_state', 'optimizer_state', 'cfg'])
    checkpoint = torch.load(pretrained_path, map_location=device)
    # Use 'model_state' since that's the key containing the weights
    model_state = checkpoint.get('model_state', checkpoint)
    # Load all weights except the classifier head (which are for 600 classes)
    model_state = {k: v for k, v in model_state.items() if not k.startswith("head")}
    model.load_state_dict(model_state, strict=False)
    
    model.to(device)
    print("Model loaded successfully!")
    return model

# -------------------
# Function to modify model for partial fine-tuning
# -------------------
def modify_for_partial_training(model, device, num_classes=8, freeze_layers=True):
    # Replace the classifier head with a new one for the target number of classes
    in_features = model.model.head.in_features
    model.model.head = nn.Linear(in_features, num_classes).to(device)
    
    # Freeze all parameters except those in the new classifier head
    if freeze_layers:
        for name, param in model.named_parameters():
            if "model.head" not in name:
                param.requires_grad = False
                
    print("Model modified for partial fine-tuning (only classifier head will be trained).")
    return model

# -------------------
# Helper function to plot confusion matrix
# -------------------
def plot_confusion_matrix(cm, class_names, normalize=False, title="Confusion Matrix"):
    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1, keepdims=True)
        fmt = ".2f"
    else:
        fmt = "d"

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(
        cm,
        annot=True,
        fmt=fmt,
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
        square=True,
        linewidths=0.5,
        cbar=True
    )

    ax.set_xlabel("Predicted Labels", fontsize=12)
    ax.set_ylabel("True Labels", fontsize=12)
    ax.set_title(title, fontsize=14)
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    return fig

# -------------------
# Validation Function with logging and classification report
# -------------------
def validate(model, val_loader, device, writer=None, epoch=None, class_names=None):
    model.eval()
    total = 0
    correct = 0
    all_preds = []
    all_labels = []
    running_val_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    print("Validating The Model....")
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_val_loss = running_val_loss / len(val_loader)
    accuracy = 100 * correct / total if total > 0 else 0
    print(f"\n✅ Validation Loss: {avg_val_loss:.4f} | Accuracy: {accuracy:.2f}%")

    # Log confusion matrix and classification report to TensorBoard
    if writer is not None and epoch is not None and class_names is not None:
        cm = confusion_matrix(all_labels, all_preds)

        # Raw Confusion Matrix
        raw_fig = plot_confusion_matrix(cm, class_names, normalize=False, title="Confusion Matrix (Raw Counts)")
        writer.add_figure("ConfusionMatrix/Raw", raw_fig, global_step=epoch)

        # Normalized Confusion Matrix
        norm_fig = plot_confusion_matrix(cm, class_names, normalize=True, title="Confusion Matrix (Normalized)")
        writer.add_figure("ConfusionMatrix/Normalized", norm_fig, global_step=epoch)

        # Classification Report
        report_str = classification_report(all_labels, all_preds, target_names=class_names)
        print(f"\n📊 Classification Report (Epoch {epoch}):\n{report_str}")
        writer.add_text("ClassificationReport", f"```\n{report_str}\n```", global_step=epoch)

        # Additionally log precision, recall, F1 scores per class
        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds)
        for i, cname in enumerate(class_names):
            writer.add_scalar(f'Precision/{cname}', precision[i], epoch)
            writer.add_scalar(f'Recall/{cname}', recall[i], epoch)
            writer.add_scalar(f'F1-Score/{cname}', f1[i], epoch)

    torch.cuda.empty_cache()
    return avg_val_loss, accuracy

# -------------------
# Training Function with time measurement and estimated time remaining
# -------------------
def train_partial(
    model, train_loader, val_loader, device,
    epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    target_accuracy=95.0,
    checkpoint_dir="checkpoints",
    log_dir="runs/TimeSformer",
    class_list=CLASS_NAMES
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.model.head.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    scaler = amp.GradScaler()
    
    best_accuracy = 0.0
    total_start_time = time.time()
    epoch_times = []  # To store time per epoch for estimating remaining time

    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()
        running_loss = 0.0
        i = BATCH_SIZE
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            print(f"In Epoch {epoch} Processed {i} videos remaining {1600 - i}")
            optimizer.zero_grad()
            with amp.autocast('cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
            i += BATCH_SIZE
        
        avg_train_loss = running_loss / len(train_loader)
        epoch_train_time = time.time() - epoch_start_time
        epoch_times.append(epoch_train_time)

        # Validate and capture validation loss and accuracy
        val_loss, val_accuracy = validate(
            model, val_loader, device,
            writer=writer,
            epoch=epoch,
            class_names=class_list
        )

        epoch_total_time = time.time() - epoch_start_time

        # Estimate time remaining
        avg_epoch_time_so_far = np.mean(epoch_times)
        epochs_remaining = epochs - (epoch + 1)
        estimated_remaining_time = avg_epoch_time_so_far * epochs_remaining

        print(f"\n📦 Epoch [{epoch+1}/{epochs}] Summary:")
        print(f"🔺 Train Loss: {avg_train_loss:.4f} | Train Time: {epoch_train_time:.2f}s")
        print(f"🔹 Val Loss  : {val_loss:.4f} | Accuracy: {val_accuracy:.2f}%")
        print(f"⏱ Epoch Time: {epoch_total_time:.2f}s | Estimated Time Remaining: {estimated_remaining_time:.2f}s")
        
        # Log metrics to TensorBoard
        writer.add_scalar("Loss/train", avg_train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Accuracy/val", val_accuracy, epoch)
        writer.add_scalar("LearningRate", optimizer.param_groups[0]['lr'], epoch)
        writer.add_scalar("Epoch/Time", epoch_total_time, epoch)
        writer.add_scalar("EstimatedTimeRemaining", estimated_remaining_time, epoch)

        # Step the LR scheduler based on validation loss
        scheduler.step(val_loss)

        # Save checkpoint if improved accuracy
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            best_model_path = os.path.join(checkpoint_dir, f"best_model_epoch{epoch+1}_acc{val_accuracy:.2f}.pth")
            torch.save(model.state_dict(), best_model_path)
            print(f"✅ Saved better model at {best_model_path}")

        if val_accuracy >= target_accuracy:
            print(f"🎯 Target accuracy of {target_accuracy}% reached. Stopping training early.")
            break

    total_training_time = time.time() - total_start_time
    writer.add_text("TrainingSummary", f"Total training time: {total_training_time:.2f}s", global_step=epochs)
    writer.close()
    print(f"\nTraining complete! Total time: {total_training_time:.2f} seconds.")

# -------------------
# Main Function
# -------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load the pretrained model (with 600 classes) from the checkpoint
    model = load_model(MODEL_PATH, device)
    
    # Modify the model for your 8-class use case and freeze early layers for partial fine-tuning
    model = modify_for_partial_training(model, device, num_classes=NUM_CLASSES, freeze_layers=True)
    
    # Load training and validation data using your custom DataLoader
    train_loader = get_data_loader(TRAIN_DIR, batch_size=BATCH_SIZE, clip_len=96, shuffle=True)
    val_loader = get_data_loader(VAL_DIR, batch_size=BATCH_SIZE, clip_len=96, shuffle=False)
    
    # Train the model partially (only the classifier head will be updated)
    train_partial(
        model, 
        train_loader, 
        val_loader, 
        device, 
        epochs=NUM_EPOCHS, 
        lr=LEARNING_RATE,
        target_accuracy=96.0,
        checkpoint_dir="data/trained_models/checkpoints",
        log_dir="runs/TimeSformer",
        class_list=CLASS_NAMES
    )
    
    # Save the final fine-tuned model
    save_path = r"data\trained_models\timesformer_partial_finetuned.pth"
    torch.save(model.state_dict(), save_path)
    print(f"Partially fine-tuned model saved to {save_path}")

if __name__ == "__main__":
    main()


Using device: cuda
Loading pre-trained model from: C:\Users\jashw\Desktop\Video Surveillance\data\trained_models\TimeSformer_divST_96x4_224_K600.pyth
Model loaded successfully!
Model modified for partial fine-tuning (only classifier head will be trained).




In Epoch 0 Processed 1 videos remaining 1599
In Epoch 0 Processed 2 videos remaining 1598
In Epoch 0 Processed 3 videos remaining 1597
In Epoch 0 Processed 4 videos remaining 1596
In Epoch 0 Processed 5 videos remaining 1595
In Epoch 0 Processed 6 videos remaining 1594
In Epoch 0 Processed 7 videos remaining 1593
In Epoch 0 Processed 8 videos remaining 1592
In Epoch 0 Processed 9 videos remaining 1591
In Epoch 0 Processed 10 videos remaining 1590
In Epoch 0 Processed 11 videos remaining 1589
In Epoch 0 Processed 12 videos remaining 1588
In Epoch 0 Processed 13 videos remaining 1587
In Epoch 0 Processed 14 videos remaining 1586
In Epoch 0 Processed 15 videos remaining 1585
In Epoch 0 Processed 16 videos remaining 1584
In Epoch 0 Processed 17 videos remaining 1583
In Epoch 0 Processed 18 videos remaining 1582
In Epoch 0 Processed 19 videos remaining 1581
In Epoch 0 Processed 20 videos remaining 1580
In Epoch 0 Processed 21 videos remaining 1579
In Epoch 0 Processed 22 videos remaining 15

KeyboardInterrupt: 

5