In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os 
import os
import cv2
from torch.utils.data import Dataset
import torch
from torchvision import transforms
from collections import Counter
from torchvision import transforms
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from torch.optim import AdamW
from tqdm import tqdm
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from transformers import VideoMAEImageProcessor
from transformers import AutoConfig, AutoModelForVideoClassification

In [None]:

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, reduction='mean'):
        """
        Focal Loss для решения проблемы дисбаланса классов.
        :param alpha: Веса классов. Если None, все классы считаются равноважными.
        :param gamma: Фокусирующий параметр. Чем выше gamma, тем больше фокус на сложных примерах.
        :param reduction: Способ агрегации потерь ('none', 'mean', 'sum').
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Применяем логит к целям
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)  # Вероятность предсказания правильного класса
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [None]:
class VideoDataset(Dataset):
    def __init__(self, root_dir, model_name, max_frames=8, transform=None, video_files=None, video_labels=None, class_to_idx=None):
        self.model_name = model_name
        self.max_frames = max_frames
        self.transform = transform

        if video_files is not None and video_labels is not None:
            self.video_files = video_files
            self.video_labels = video_labels

            if class_to_idx is not None:
                self.class_to_idx = class_to_idx
                self.class_names = sorted(class_to_idx, key=class_to_idx.get)
            else:
                raise ValueError("Если предоставлены video_files и video_labels, необходимо также предоставить class_to_idx.")
        else:
            self.root_dir = root_dir
            self.class_names = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
            self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)}

            self.video_files = []
            self.video_labels = []

            for class_name in self.class_names:
                class_dir = os.path.join(root_dir, class_name)
                for video_file in os.listdir(class_dir):
                    if video_file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
                        video_path = os.path.join(class_dir, video_file)
                        self.video_files.append(video_path)
                        self.video_labels.append(self.class_to_idx[class_name])

        self.feature_extractor = VideoMAEImageProcessor.from_pretrained(self.model_name)

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path = self.video_files[idx]
        label = self.video_labels[idx]

        cap = cv2.VideoCapture(video_path)

        frames = []
        frame_count = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break  
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
            frame_count += 1
            if frame_count == self.max_frames:
                break

        cap.release()

        if len(frames) < self.max_frames:
            last_frame = frames[-1]
            while len(frames) < self.max_frames:
                frames.append(last_frame)

        if self.transform:
            frames = [self.transform(frame) for frame in frames]

        frames = [np.array(frame) for frame in frames]

        inputs = self.feature_extractor(frames, return_tensors="pt")

        return inputs, label


In [None]:
model_name = "facebook/timesformer-base-finetuned-k400"  
batch_size = 4                                        
epochs = 5                                           
learning_rate = 5e-5                               
max_frames = 8                                    
num_workers = 0                                    
save_path = "best_model.pt"                           
data_dir = r"D:\sports activities"                    

train_dir = data_dir

if not os.path.isdir(train_dir):
    raise ValueError(f"Тренировочная директория не найдена: {train_dir}")

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  
])

print("Загрузка конфигурации и модели...")

num_classes = 6

config = AutoConfig.from_pretrained(model_name, num_labels=num_classes)

model = AutoModelForVideoClassification.from_pretrained(
    model_name,
    config=config,
    ignore_mismatched_sizes=True  
)

print("Сбор всех видеофайлов и меток из train_dir...")
full_dataset = VideoDataset(root_dir=train_dir, model_name=model_name, max_frames=max_frames, transform=transform)
all_video_files = full_dataset.video_files
all_video_labels = full_dataset.video_labels
class_to_idx = full_dataset.class_to_idx
class_names = full_dataset.class_names

print("Разделение данных на тренировочную и валидационную выборки...")
train_files, val_files, train_labels, val_labels = train_test_split(
    all_video_files,
    all_video_labels,
    test_size=0.2,
    stratify=all_video_labels,
    random_state=42
)

print(f"Количество тренировочных видео: {len(train_files)}")
print(f"Количество валидационных видео: {len(val_files)}")
print(f"Число классов: {len(class_names)}")

print("Вычисление весов классов для FocalLoss...")
label_counts = Counter(train_labels)
total_counts = sum(label_counts.values())
num_classes = len(class_names)

class_weights = []
for i in range(num_classes):
    count = label_counts.get(i, 0)
    if count > 0:
        class_weights.append(total_counts / (num_classes * count))
    else:
        class_weights.append(1.0)  


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print(f"Вес каждого класса: {class_weights}")

# ======= Создание Тренировочного и Валидационного Dataset =======
print("Создание тренировочного и валидационного Dataset...")
train_dataset = VideoDataset(
    root_dir=None,
    model_name=model_name,
    max_frames=max_frames,
    video_files=train_files,
    video_labels=train_labels,
    class_to_idx=class_to_idx,
    transform=transform
)
val_dataset = VideoDataset(
    root_dir=None,
    model_name=model_name,
    max_frames=max_frames,
    video_files=val_files,
    video_labels=val_labels,
    class_to_idx=class_to_idx,
    transform=transform
)

print("Создание DataLoader...")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Используемое устройство: {device}")
model.to(device)

optimizer = AdamW(model.parameters(), lr=learning_rate)

num_training_steps = epochs * len(train_dataloader)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


loss_fn = FocalLoss(alpha=class_weights, gamma=2, reduction='mean')

best_val_f1 = 0.0

for epoch in range(epochs):
    print(f"\nЭпоха {epoch + 1}/{epochs}")
    print("-" * 20)

    # ---- Обучение ----
    model.train()
    total_train_loss = 0.0
    all_train_preds = []
    all_train_labels = []

    train_progress = tqdm(train_dataloader, desc="Обучение", leave=False)
    for batch in train_progress:
        try:
            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            inputs['pixel_values'] = inputs['pixel_values'].squeeze(1)

            outputs = model(**inputs)
            logits = outputs.logits

            loss = loss_fn(logits, labels)
            total_train_loss += loss.item()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            preds = torch.argmax(logits, dim=-1)
            all_train_preds.extend(preds.detach().cpu().numpy())
            all_train_labels.extend(labels.detach().cpu().numpy())

            train_progress.set_postfix({"loss": loss.item()})

        except Exception as e:
            print(f"Ошибка при обработке батча: {e}")
            continue

    avg_train_loss = total_train_loss / len(train_dataloader)
    train_accuracy = accuracy_score(all_train_labels, all_train_preds)
    train_f1 = f1_score(all_train_labels, all_train_preds, average='weighted')

    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.4f} | Train F1: {train_f1:.4f}")

    model.eval()
    total_val_loss = 0.0
    all_val_preds = []
    all_val_labels = []

    with torch.no_grad():
        val_progress = tqdm(val_dataloader, desc="Валидация", leave=False)
        for batch in val_progress:
            try:
                inputs, labels = batch
                inputs = {k: v.to(device) for k, v in inputs.items()}
                labels = labels.to(device)
                inputs['pixel_values'] = inputs['pixel_values'].squeeze(1)

                outputs = model(**inputs)
                logits = outputs.logits

                loss = loss_fn(logits, labels)
                total_val_loss += loss.item()

                preds = torch.argmax(logits, dim=-1)
                all_val_preds.extend(preds.detach().cpu().numpy())
                all_val_labels.extend(labels.detach().cpu().numpy())

            except Exception as e:
                print(f"Ошибка при обработке батча валидации: {e}")
                continue

    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = accuracy_score(all_val_labels, all_val_preds)
    val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted')

    print(f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Val F1: {val_f1:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), save_path)
        print(f"Сохранена лучшая модель с Val F1: {best_val_f1:.4f}")

print("\nОбучение завершено.")
print(f"Лучший Val F1: {best_val_f1:.4f}")

