<a href="https://colab.research.google.com/github/lhayana/violence-detection-CV/blob/main/violence_detection_CV.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup
Montando o drive, instalando e importando as bibliotecas necessárias

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DRIVE_DATASET_PATH = '/content/drive/MyDrive/violence_dataset_frames'

In [None]:
!pip install transformers accelerate
!pip install torch torchvision opencv-python

In [None]:
!pip install gradio --quiet

In [None]:
!pip install torchvision --upgrade

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models.video import r3d_18
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Caminho original no Drive
DRIVE_DATASET_PATH = '/content/drive/MyDrive/violence_dataset_frames'

# Caminho destino no disco local
LOCAL_DATASET_PATH = '/content/violence_dataset_frames'

# Copiar só o zip
!cp /content/drive/MyDrive/violence_dataset_frames.zip /content/

# Descompactar
!unzip /content/violence_dataset_frames.zip -d /content/

Dataset: https://github.com/airtlab/A-Dataset-for-Automatic-Violence-Detection-in-Videos/tree/master/violence-detection-dataset

# ResNet3D-18

In [None]:
import os, glob, time, numpy as np, cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models.video import r3d_18
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import seaborn as sns
from torch.cuda.amp import autocast, GradScaler

In [None]:
# Dataset
class ViolenceFramesDataset(Dataset):
    def __init__(self, root_dir, split='train', num_frames=16):
        self.root_dir = os.path.join(root_dir, split)
        self.num_frames = num_frames
        self.samples = []
        self.labels_map = {'non-violence': 0, 'violence': 1}

        for label in ['non-violence', 'violence']:
            label_folder = os.path.join(self.root_dir, label)
            videos = os.listdir(label_folder)
            for video in videos:
                self.samples.append((os.path.join(label_folder, video), self.labels_map[label]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        frames = sorted(glob.glob(os.path.join(video_path, '*.jpg')))
        if len(frames) < self.num_frames:
            frames = frames * (self.num_frames // len(frames) + 1)
        selected_frames = frames[:self.num_frames]

        frames_data = []
        for frame_path in selected_frames:
            frame = cv2.imread(frame_path)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, (112, 112))
            frame = frame.astype(np.float32) / 255.0
            frames_data.append(frame)

        frames_data = np.stack(frames_data, axis=0)
        frames_data = np.transpose(frames_data, (3, 0, 1, 2))  # C, T, H, W
        return torch.tensor(frames_data, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# Dataloaders
def create_dataloaders(dataset_path, batch_size=32, num_workers=4):
    train = ViolenceFramesDataset(dataset_path, 'train')
    val = ViolenceFramesDataset(dataset_path, 'val')
    test = ViolenceFramesDataset(dataset_path, 'test')
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader, test_loader

In [None]:
# Configuração
DATASET_PATH = "/content/violence_dataset_frames"
train_loader, val_loader, test_loader = create_dataloaders(DATASET_PATH)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Usando dispositivo: {device}")

# Modelo
model = r3d_18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)
scaler = GradScaler()

# EarlyStopping Config
patience = 8
best_val_loss = float('inf')
epochs_no_improve = 0

num_epochs = 50
train_losses, val_losses = [], []
start = time.time()

In [None]:
# Treinamento
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for videos, labels in train_loader:
        videos = videos.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        with autocast():
            outputs = model(videos)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validação
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for videos, labels in val_loader:
            videos = videos.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            with autocast():
                outputs = model(videos)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    print(f"Época [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    # EarlyStopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "best_model.pth")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Parando treino por early stopping na época {epoch+1}")
            break

end = time.time()
print(f"Tempo total de treino: {(end - start)/60:.2f} minutos")

In [None]:
# Avaliação no conjunto de teste
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for videos, labels in test_loader:
        videos = videos.to(device)
        labels = labels.to(device)
        outputs = model(videos)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

acc = accuracy_score(all_labels, all_preds)
print(f"\n Acurácia no conjunto de teste: {acc * 100:.2f}%")
print("\n Relatório de Classificação:")
print(classification_report(all_labels, all_preds, target_names=["Non-Violence", "Violence"]))

# Matriz de confusão
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=["Non-Violence", "Violence"], yticklabels=["Non-Violence", "Violence"])
plt.xlabel('Predito')
plt.ylabel('Verdadeiro')
plt.title('Matriz de Confusão')
plt.show()

# Curvas de perda
plt.figure(figsize=(10,5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Época')
plt.ylabel('Loss')
plt.title('Curvas de Perda durante o Treinamento')
plt.grid()
plt.legend()
plt.show()

# VideoMAE (modelo pré-treinado com Kinetics-400)

In [None]:
# Imports
import os, glob, cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
from transformers import VideoMAEModel, VideoMAEImageProcessor, get_cosine_schedule_with_warmup
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix, classification_report
import random
import time
from tqdm.auto import tqdm

# Configurações de performance para A100
torch.backends.cudnn.benchmark = True  # Otimiza operações convolucionais
torch.backends.cuda.matmul.allow_tf32 = True  # Permite TF32 na A100
torch.backends.cudnn.allow_tf32 = True  # Permite TF32 em operações cuDNN

# Caminho para o dataset
DATASET_PATH = "/content/violence_dataset_frames"

# Transformações para aumento de dados
def get_transforms(mode='train'):
    if mode == 'train':
        return transforms.Compose([
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.2),  # Adicionado para aumentar variabilidade
            transforms.RandomRotation(15),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Pequenas translações
        ])
    return None

In [None]:
# Collate + Preprocess com VideoMAEProcessor
processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-base")

def collate_fn(batch):
    videos, labels = zip(*batch)
    processed = processor(list(videos), return_tensors="pt")
    return processed["pixel_values"], torch.tensor(labels)

In [None]:
# Dataset adaptado para VideoMAE com aumento de dados e caching
class ViolenceDataset(Dataset):
    def __init__(self, root_dir, split='train', num_frames=16, size=(224, 224), transform=None, use_cache=True):
        self.root_dir = os.path.join(root_dir, split)
        self.num_frames = num_frames
        self.size = size
        self.transform = transform
        self.use_cache = use_cache
        self.cache = {}  # Cache para frames

        self.samples = []
        self.labels_map = {'non-violence': 0, 'violence': 1}

        print(f"Carregando dataset {split}...")
        for label in self.labels_map:
            label_path = os.path.join(self.root_dir, label)
            if not os.path.exists(label_path): continue
            for folder in os.listdir(label_path):
                self.samples.append((os.path.join(label_path, folder), self.labels_map[label]))

        # Balanceamento de classes (opcional)
        self._balance_classes()
        print(f"Total de amostras no dataset {split}: {len(self.samples)}")

    def _balance_classes(self):
        # Contar amostras por classe
        class_counts = {}
        for _, label in self.samples:
            class_counts[label] = class_counts.get(label, 0) + 1

        print(f"Distribuição de classes: {class_counts}")

        # Se desbalanceado, equilibrar (opcional)
        if len(set(class_counts.values())) > 1:
            max_count = max(class_counts.values())
            balanced_samples = []
            class_samples = {c: [] for c in class_counts.keys()}

            for sample in self.samples:
                class_samples[sample[1]].append(sample)

            for c in class_samples:
                if len(class_samples[c]) < max_count:
                    # Oversampling para classes minoritárias
                    needed = max_count - len(class_samples[c])
                    balanced_samples.extend(class_samples[c])
                    balanced_samples.extend(random.choices(class_samples[c], k=needed))
                else:
                    balanced_samples.extend(class_samples[c])

            random.shuffle(balanced_samples)
            self.samples = balanced_samples

            # Verificar o novo balanceamento
            new_counts = {}
            for _, label in self.samples:
                new_counts[label] = new_counts.get(label, 0) + 1
            print(f"Nova distribuição após balanceamento: {new_counts}")

    def _load_frames(self, video_path):
        # Chave de cache baseada no caminho e configurações
        cache_key = f"{video_path}_{self.num_frames}_{self.size[0]}_{self.size[1]}"

        # Verificar cache
        if self.use_cache and cache_key in self.cache:
            return self.cache[cache_key]

        # Carregar frames
        frame_paths = sorted(glob.glob(os.path.join(video_path, '*.jpg')))

        # Lidar com vídeos com poucos frames
        if len(frame_paths) < self.num_frames:
            frame_paths = frame_paths * (self.num_frames // len(frame_paths) + 1)

        # Amostragem uniforme de frames para capturar todo o vídeo
        if len(frame_paths) > self.num_frames:
            indices = np.linspace(0, len(frame_paths)-1, self.num_frames, dtype=int)
            frame_paths = [frame_paths[i] for i in indices]
        else:
            frame_paths = frame_paths[:self.num_frames]

        frames = []
        for fp in frame_paths:
            frame = cv2.imread(fp)
            if frame is None:
                # Lidar com erro de leitura de frame
                print(f"Erro ao ler frame: {fp}")
                frame = np.zeros((self.size[0], self.size[1], 3), dtype=np.uint8)
            else:
                frame = cv2.resize(frame, self.size)
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)

        # Armazenar no cache
        if self.use_cache:
            self.cache[cache_key] = frames

        return frames

    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        frames = self._load_frames(video_path)

        # Aplicar transformações de aumento de dados
        if self.transform:
            # Aplicar as mesmas transformações em todos os frames
            seed = np.random.randint(2147483647)
            random.seed(seed)
            torch.manual_seed(seed)

            augmented_frames = []
            for frame in frames:
                # Converter para PIL para compatibilidade com transforms
                pil_img = transforms.ToPILImage()(frame)
                augmented_img = self.transform(pil_img)
                augmented_frames.append(np.array(augmented_img))
            frames = augmented_frames

        return frames, label

    def __len__(self):
        return len(self.samples)

# Collate + Preprocess com VideoMAEProcessor
processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-base")

def collate_fn(batch):
    videos, labels = zip(*batch)
    processed = processor(list(videos), return_tensors="pt")
    return processed["pixel_values"], torch.tensor(labels)

In [None]:
# Modelo otimizado para A100
class VideoMAEForViolenceClassification(nn.Module):
    def __init__(self, dropout_rate=0.3, use_learnable_pooling=True, freeze_layers=80):
        super().__init__()
        self.backbone = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")

        # Congelar algumas camadas para estabilidade e performance
        if freeze_layers > 0:
            for name, param in list(self.backbone.named_parameters())[:freeze_layers]:
                param.requires_grad = False

        hidden_size = self.backbone.config.hidden_size

        # Pooling aprendível (atenção)
        self.use_learnable_pooling = use_learnable_pooling
        if use_learnable_pooling:
            self.attention_pool = nn.Sequential(
                nn.Linear(hidden_size, 256),
                nn.GELU(),  # GELU tem melhores resultados em modelos transformer
                nn.Linear(256, 1),
                nn.Softmax(dim=1)
            )

        # MLP classificador com skip connection
        self.pre_classifier = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hidden_size, 512),
            nn.GELU(),
            nn.LayerNorm(512),
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate/2),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.LayerNorm(128),
            nn.Linear(128, 2)
        )

        # Inicialização dos pesos para convergência mais rápida
        self._init_weights()

    def _init_weights(self):
        # Inicialização de pesos para camadas lineares
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, pixel_values):
        outputs = self.backbone(pixel_values=pixel_values)

        if self.use_learnable_pooling:
            # Usar atenção ponderada em vez de apenas o token CLS
            hidden_states = outputs.last_hidden_state
            attention_weights = self.attention_pool(hidden_states)
            feature = torch.sum(attention_weights * hidden_states, dim=1)
        else:
            # Usar apenas o token CLS
            feature = outputs.last_hidden_state[:, 0]

        # Forward com skip connection
        pre_cls = self.pre_classifier(feature)
        return self.classifier(pre_cls)

In [None]:
# Early Stopping aprimorado com métricas múltiplas
class EarlyStopping:
    def __init__(self, patience=5, verbose=True, delta=0.0001, path='checkpoint.pt', metric='loss'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = float('inf') if metric == 'loss' else -float('inf')
        self.early_stop = False
        self.delta = delta
        self.path = path
        self.metric = metric  # 'loss' ou 'accuracy'

    def __call__(self, score, model):
        if self.metric == 'loss':
            improved = score < self.best_score - self.delta
        else:  # accuracy
            improved = score > self.best_score + self.delta

        if improved:
            if self.verbose:
                print(f"{self.metric.capitalize()} melhorou de {self.best_score:.4f} para {score:.4f}. Salvando modelo...")
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping: {self.counter}/{self.patience} sem melhoria.")
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), self.path)

In [None]:
# Função auxiliar para calcular métricas
def compute_metrics(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for pixel_values, labels in dataloader:
            pixel_values = pixel_values.to(device)
            labels = labels.to(device)

            outputs = model(pixel_values)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
    avg_loss = total_loss / len(dataloader)

    return {
        'accuracy': accuracy,
        'loss': avg_loss,
        'predictions': all_preds,
        'labels': all_labels
    }

In [None]:
# Função de treino otimizada para A100
def train_model(model, train_loader, val_loader, num_epochs=30,
                lr=5e-5 * 8,  # Aumento do lr base devido ao batch maior
                warmup_steps=100,
                weight_decay=2e-5,
                checkpoint_path='best_model.pt',
                patience=6):

    device = torch.device("cuda")
    model = model.to(device)

    # Otimizador com peso diferenciado por camada
    optimizer = torch.optim.AdamW([
        {'params': [p for n, p in model.backbone.named_parameters() if p.requires_grad], 'lr': lr / 10},
        {'params': model.attention_pool.parameters() if hasattr(model, 'attention_pool') else [], 'lr': lr / 2},
        {'params': model.pre_classifier.parameters(), 'lr': lr},
        {'params': model.classifier.parameters(), 'lr': lr}
    ], weight_decay=weight_decay)

    # Scheduler com warmup
    total_steps = len(train_loader) * num_epochs
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    # Loss com focal loss para lidar com classes desbalanceadas
    class FocalLoss(nn.Module):
        def __init__(self, alpha=0.25, gamma=2):
            super().__init__()
            self.alpha = alpha
            self.gamma = gamma
            self.ce = nn.CrossEntropyLoss(reduction='none')

        def forward(self, inputs, targets):
            ce_loss = self.ce(inputs, targets)
            pt = torch.exp(-ce_loss)
            focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
            return focal_loss.mean()

    criterion = FocalLoss(alpha=0.25, gamma=2)
    scaler = torch.cuda.amp.GradScaler()

    # Early Stopping
    early_stopping = EarlyStopping(patience=patience, verbose=True, path=checkpoint_path, metric='accuracy')

    # Registros
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    # Tracking do tempo
    start_time = time.time()

    print("Iniciando treinamento...")
    for epoch in range(num_epochs):
        epoch_start = time.time()

        # Treino
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        # Usar tqdm para progressbar
        progress_bar = tqdm(train_loader, desc=f"Época {epoch+1}/{num_epochs} [Treino]")
        for pixel_values, labels in progress_bar:
            pixel_values = pixel_values.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)  # Mais eficiente que zero_grad()

            with torch.cuda.amp.autocast():  # Mixed precision para A100
                outputs = model(pixel_values)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()

            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # Métricas
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Atualizar progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{100 * correct / total:.2f}%",
                'lr': f"{optimizer.param_groups[0]['lr']:.2e}"
            })

        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = 100 * correct / total
        train_losses.append(avg_train_loss)
        train_accs.append(train_accuracy)

        # Validação
        val_metrics = compute_metrics(model, val_loader, device)
        val_losses.append(val_metrics['loss'])
        val_accs.append(100 * val_metrics['accuracy'])

        # Tempo por época
        epoch_time = time.time() - epoch_start

        print(f"Época [{epoch+1}/{num_epochs}] - "
              f"Tempo: {epoch_time:.1f}s - "
              f"Train Loss: {avg_train_loss:.4f}, Acc: {train_accuracy:.2f}% - "
              f"Val Loss: {val_metrics['loss']:.4f}, Acc: {100 * val_metrics['accuracy']:.2f}%")

        # Early stopping
        early_stopping(val_metrics['accuracy'], model)
        if early_stopping.early_stop:
            print("Parando treino por early stopping.")
            break

    # Tempo total de treinamento
    total_time = time.time() - start_time
    print(f"Treinamento completo em {total_time/60:.2f} minutos")

    # Carregar melhor modelo
    model.load_state_dict(torch.load(checkpoint_path))

    return {
        'model': model,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs,
        'training_time': total_time
    }

In [None]:
# Script principal
def run_training():
    # Definir seed para reprodutibilidade
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Verificar disponibilidade da A100
    if torch.cuda.is_available():
        device_name = torch.cuda.get_device_name()
        print(f"GPU detectada: {device_name}")

        # Verificar se temos A100
        if "A100" in device_name:
            print("GPU A100 detectada! Utilizando configurações otimizadas.")
            # Aumentar batch size para A100
            batch_size = 32  # Otimizado para A100
        else:
            print(f"GPU {device_name} detectada. Usando configurações padrão.")
            batch_size = 8
    else:
        print("GPU não detectada. O treinamento será muito lento em CPU.")
        batch_size = 4

    print(f"Batch size: {batch_size}")

    # Dataset com aumento de dados
    train_transform = get_transforms(mode='train')
    dataset = ViolenceDataset(DATASET_PATH, split='train', transform=train_transform, use_cache=True)

    # Splitting para validação
    dataset_size = len(dataset)
    val_size = int(0.2 * dataset_size)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Dataset de teste
    test_dataset = ViolenceDataset(DATASET_PATH, split='test', transform=None, use_cache=True)

    # Modelo otimizado
    model = VideoMAEForViolenceClassification(dropout_rate=0.25, use_learnable_pooling=True, freeze_layers=80)

    # Contagem de parâmetros
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total de parâmetros: {total_params:,}")
    print(f"Parâmetros treináveis: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")

    # Otimizações para DataLoader com A100
    persistent_workers = True
    prefetch_factor = 4

    # Se usar num_workers > 0, PyTorch pode criar processos worker que persistem
    # entre chamadas de iteração. Isso reduz a sobrecarga de criação de processos.

    # Loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True,
        persistent_workers=persistent_workers,
        prefetch_factor=prefetch_factor
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True,
        persistent_workers=persistent_workers
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True
    )

    # Treinamento com configurações otimizadas
    results = train_model(
        model,
        train_loader,
        val_loader,
        num_epochs=50,  # Aumentamos o número de épocas, early stopping irá parar quando necessário
        lr=1e-4 * (batch_size/8),  # Ajuste de lr para batch size maior
        warmup_steps=len(train_loader) // 2,  # 1/2 época de warmup
        weight_decay=1e-5,
        checkpoint_path='best_a100_model.pt',
        patience=8
    )

    model = results['model']

    # Avaliação no conjunto de teste
    print("\nAvaliando no conjunto de teste...")
    test_metrics = compute_metrics(model, test_loader, torch.device("cuda"))

    print(f"Acurácia no conjunto de teste: {100 * test_metrics['accuracy']:.2f}%")

    # Matriz de confusão e relatório de classificação
    cm = confusion_matrix(test_metrics['labels'], test_metrics['predictions'])
    print("\nMatriz de Confusão:")
    print(cm)

    print("\nRelatório de Classificação:")
    print(classification_report(test_metrics['labels'], test_metrics['predictions'],
                              target_names=['Non-Violence', 'Violence']))

    # Gráfico de perdas e acurácias
    plt.figure(figsize=(15, 6))

    plt.subplot(1, 2, 1)
    plt.plot(results['train_losses'], label='Train Loss')
    plt.plot(results['val_losses'], label='Validation Loss')
    plt.xlabel('Épocas')
    plt.ylabel('Loss')
    plt.title('Loss durante o Treinamento')
    plt.legend()
    plt.grid()

    plt.subplot(1, 2, 2)
    plt.plot(results['train_accs'], label='Train Accuracy')
    plt.plot(results['val_accs'], label='Validation Accuracy')
    plt.xlabel('Épocas')
    plt.ylabel('Acurácia (%)')
    plt.title('Acurácia durante o Treinamento')
    plt.legend()
    plt.grid()

    plt.tight_layout()
    plt.savefig('training_metrics.png', dpi=300)
    plt.show()

    # Salvar modelo
    MODEL_DIR = '/content/drive/MyDrive/models'
    os.makedirs(MODEL_DIR, exist_ok=True)

    torch.save({
        'model_state_dict': model.state_dict(),
        'train_losses': results['train_losses'],
        'val_losses': results['val_losses'],
        'train_accs': results['train_accs'],
        'val_accs': results['val_accs'],
        'test_accuracy': test_metrics['accuracy'],
        'training_time': results['training_time'],
        'batch_size': batch_size
    }, os.path.join(MODEL_DIR, 'violence_videomae_a100.pth'))

    print(f"Modelo salvo com sucesso em: {MODEL_DIR}/violence_videomae_a100.pth")

    return model, results, test_metrics

In [None]:
if __name__ == "__main__":
    run_training()