Скачиваем содержимое

In [None]:
import requests

yandex_url_test = "https://disk.yandex.ru/d/lgAOpG2O1VAs5w"
yandex_url_train = "https://disk.yandex.ru/d/bg0Gtr4bFnHJDQ"

for j in [yandex_url_test, yandex_url_train]:
    download_url = f"https://cloud-api.yandex.net/v1/disk/public/resources/download?public_key={j}"

    response = requests.get(download_url)
    if response.status_code == 200:
        direct_url = response.json()["href"]
        !wget -O dataset.zip "{direct_url}"
        !unzip dataset.zip
    else:
        print("Ошибка:", response.status_code)

--2025-06-17 14:21:39--  https://downloader.disk.yandex.ru/disk/dff87f6a62bb920049499ac51c6a5e682e9612a6dc6bd02ffe1da396354f67ad/6851b233/hHsHnWApe_ASqMVFpLHBF6_9rCBsLrgaKF5kp2qzYrG5Krh0oqrxe_ldyi5jK0idd-DFUHk89N_oBbBMveda1Q%3D%3D?uid=0&filename=data_test_short.zip&disposition=attachment&hash=FV0wyZboC5sTTzIXVZUjNUU6JbhMopjOAzsTOHYCp9K70U839gY2PzXs50Xtwkjqq/J6bpmRyOJonT3VoXnDag%3D%3D%3A&limit=0&content_type=application%2Fzip&owner_uid=176861886&fsize=12895585234&hid=bd7c4ac68533d5552604210731df433a&media_type=compressed&tknv=v3
Resolving downloader.disk.yandex.ru (downloader.disk.yandex.ru)... 77.88.21.127, 2a02:6b8::2:127
Connecting to downloader.disk.yandex.ru (downloader.disk.yandex.ru)|77.88.21.127|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://s18klg.storage.yandex.net/rdisk/dff87f6a62bb920049499ac51c6a5e682e9612a6dc6bd02ffe1da396354f67ad/6851b233/hHsHnWApe_ASqMVFpLHBF6_9rCBsLrgaKF5kp2qzYrG5Krh0oqrxe_ldyi5jK0idd-DFUHk89N_oBbBMveda1Q==?uid=0&

In [None]:
!rm dataset.zip

In [None]:
import os
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

import cv2
import librosa
import soundfile as sf
import glob

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


Гиперпараметры

In [None]:
#Train и test намеренно перемешаны из-за объёма данных рассматриваемых видео
TEST_LABELS_JSON = '/content/data_train_short/labels.json'
TRAIN_LABELS_JSON = '/content/data_test_short/labels.json'
TEST_DIR = '/content/data_train_short'
TRAIN_DIR = '/content/data_test_short'

BATCH_SIZE = 4
NUM_EPOCHS = 20

SEQ_LEN = 30
FPS = 1
AUDIO_SAMPLE_RATE = 16000

LEARNING_RATE = 0.0005
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
print(DEVICE)

cuda


Перевод времени в секунды

In [None]:
def to_seconds(timestr):
    parts = list(map(int, timestr.split(':')))
    if len(parts) == 3:
        h, m, s = parts
        return h * 3600 + m * 60 + s
    elif len(parts) == 2:
        m, s = parts
        return m * 60 + s
    else:
        return parts[0]

def seconds_to_time(secs):
    return f"{secs//3600:02}:{(secs%3600)//60:02}:{secs%60:02}"

In [None]:
class IntroDataset(Dataset):
    def __init__(self, root_dir, labels_json, seq_len=SEQ_LEN, fps=FPS, train=False):
        self.MAX_DURATION = 40 #Максимальное время заставок, которые мы рассматриваем. Также ставится как предел для обрезания видео.
        self.root_dir = root_dir
        self.seq_len = seq_len
        self.fps = fps
        self.train = train

        with open(labels_json, 'r', encoding='utf-8') as f:
            all_labels = json.load(f)

        self.labels = {}
        for video_id, label in all_labels.items():
            try:
                end_gt = to_seconds(label['end'])
                start_gt = to_seconds(label['start'])
                if 0 < end_gt <= self.MAX_DURATION and start_gt < end_gt:
                    self.labels[video_id] = {
                        'start': start_gt,
                        'end': end_gt
                    }
            except:
                continue

        self.video_ids = list(self.labels.keys())

        # Аугментации только для обучения
        self.transform = T.Compose([
            T.ToPILImage(),
            T.Resize((224, 224)),
            T.RandomHorizontalFlip() if train else T.Lambda(lambda x: x),
            T.ColorJitter(brightness=0.2, contrast=0.2) if train else T.Lambda(lambda x: x),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.video_ids)

    def __getitem__(self, idx):
        vid = self.video_ids[idx]
        info = self.labels[vid]
        start_gt = info['start']
        end_gt = info['end']

        # Случайный сдвиг для набора train
        if self.train:
            max_start = max(0, min(start_gt, self.MAX_DURATION - self.seq_len - 5))
            start_s = np.random.randint(0, max_start + 1) if max_start > 0 else 0
        else:
            start_s = max(0, start_gt - np.random.randint(0, 5))

        video_folder = os.path.join(self.root_dir, vid)
        mp4_path = glob.glob(os.path.join(video_folder, "*.mp4"))[0]

        # Извлечение кадров
        cap = cv2.VideoCapture(mp4_path)
        fps_video = cap.get(cv2.CAP_PROP_FPS) or 25
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        frames = []
        for i in range(self.seq_len):
            frame_pos = int((start_s + i) * fps_video)
            if frame_pos >= total_frames:
                frames.append(torch.zeros(3, 224, 224))
                continue

            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(self.transform(frame))
            else:
                frames.append(torch.zeros(3, 224, 224))
        cap.release()

        video_feats = torch.stack(frames)

        # Извлечение аудио
        try:
            y, sr = librosa.load(mp4_path, sr=AUDIO_SAMPLE_RATE, duration=self.MAX_DURATION)
            audio_feats = []
            for i in range(self.seq_len):
                start_sample = int((start_s + i) * AUDIO_SAMPLE_RATE)
                end_sample = int((start_s + i + 1) * AUDIO_SAMPLE_RATE)
                segment = y[start_sample:end_sample]
                audio_feats.append(np.mean(np.abs(segment))) if segment.any() else 0.0
        except:
            audio_feats = [0.0] * self.seq_len

        while len(audio_feats) < self.seq_len:
            audio_feats.append(0.0)

        audio_feats = torch.tensor(audio_feats, dtype=torch.float32).unsqueeze(1)

        # Нормализованные метки временных отрезков
        start_rel = max(0, (start_gt - start_s) / self.seq_len)
        end_rel = min(1, (end_gt - start_s) / self.seq_len)

        return video_feats, audio_feats, torch.tensor([start_rel, end_rel]), vid


Модель сети

In [None]:
class IntroDetector(nn.Module):
    def __init__(self, video_dim=512, audio_dim=1, hidden_size=128):
        super().__init__()
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()

        self.lstm = nn.LSTM(
            input_size=video_dim + audio_dim,
            hidden_size=hidden_size,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )

        self.regressor = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 2),
            nn.Sigmoid()  # Нормализация выхода [0, 1]
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, video, audio):
        B, seq, C, H, W = video.size()
        v = video.view(B * seq, C, H, W)
        video_feats = self.cnn(v).view(B, seq, -1)

        # Конкатенация признаков
        x = torch.cat([video_feats, audio], dim=2)

        # Обработка последовательности
        lstm_out, _ = self.lstm(x)
        last_out = lstm_out[:, -1, :]

        # Регрессия временных меток
        time_output = self.regressor(last_out)

        # Классификация наличия заставки
        presence_output = self.classifier(last_out)

        return time_output, presence_output

In [None]:
def train_epoch(model, loader, optimizer, criterion_reg, criterion_cls):
    model.train()
    total_loss = 0
    for v, a, y, _ in tqdm(loader, desc="Training"):
        v, a, y = v.to(DEVICE), a.to(DEVICE), y.float().to(DEVICE)

        optimizer.zero_grad()
        time_pred, presence_pred = model(v, a)

        # Multi-task loss
        loss_reg = criterion_reg(time_pred, y)
        loss_cls = criterion_cls(presence_pred, (y.mean(dim=1) > 0).float().view(-1, 1))
        loss = 0.7 * loss_reg + 0.3 * loss_cls

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
def eval_epoch(model, loader, criterion_reg, criterion_cls):
    model.eval()
    total_loss = 0
    reg_errors = []
    cls_correct = 0
    total_samples = 0

    with torch.no_grad():
        for v, a, y, _ in tqdm(loader, desc="Evaluating"):
            v, a, y = v.to(DEVICE), a.to(DEVICE), y.float().to(DEVICE)

            time_pred, presence_pred = model(v, a)

            loss_reg = criterion_reg(time_pred, y)
            loss_cls = criterion_cls(presence_pred, (y.mean(dim=1) > 0).float().view(-1, 1))
            loss = 0.7 * loss_reg + 0.3 * loss_cls
            total_loss += loss.item()

            abs_errors = torch.abs(time_pred - y).mean(dim=0)
            reg_errors.append(abs_errors.cpu())

            cls_pred = (presence_pred > 0.5).float()
            cls_target = (y.mean(dim=1) > 0).float().view(-1, 1)
            cls_correct += (cls_pred == cls_target).sum().item()
            total_samples += v.size(0)

    avg_reg_error = torch.stack(reg_errors).mean(dim=0)
    cls_accuracy = cls_correct / total_samples

    return total_loss / len(loader), avg_reg_error, cls_accuracy

In [None]:
def custom_collate(batch):
    """
    batch: (video_feats, audio_feats, tensor([start_rel, end_rel]), vid)
    """

    video_feats, audio_feats, rels, vids = zip(*batch)

    # Паддинг видео
    max_video_len = max(v.shape[0] for v in video_feats)
    padded_videos = []
    for v in video_feats:
        pad_len = max_video_len - v.shape[0]
        if pad_len > 0:
            pad = torch.zeros(pad_len, *v.shape[1:], dtype=v.dtype)
            v = torch.cat([v, pad], dim=0)
        padded_videos.append(v)
    videos = torch.stack(padded_videos)  # (B, T_max, 3, 224, 224)

    # Паддинг аудио
    max_audio_len = max(a.shape[0] for a in audio_feats)
    padded_audios = []
    for a in audio_feats:
        pad_len = max_audio_len - a.shape[0]
        if pad_len > 0:
            pad = torch.zeros(pad_len, *a.shape[1:], dtype=a.dtype)
            a = torch.cat([a, pad], dim=0)
        padded_audios.append(a)
    audios = torch.stack(padded_audios)  # (B, T_max, audio_len)

    rels_tensor = torch.stack(rels)  # (B, 2)

    return videos, audios, rels_tensor, list(vids)

In [None]:
train_dataset = IntroDataset(TRAIN_DIR, TRAIN_LABELS_JSON, train=True)
test_dataset = IntroDataset(TEST_DIR, TEST_LABELS_JSON, train=False)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=custom_collate,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=custom_collate,
    shuffle=False,
    num_workers=0
)

In [None]:
model = IntroDetector().to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 195MB/s]


In [None]:
criterion_reg = nn.HuberLoss()
criterion_cls = nn.BCEWithLogitsLoss()

In [None]:
def predict_time(model, video_tensor, audio_tensor):
    model.eval()
    with torch.no_grad():
        time_pred, presence = model(video_tensor.unsqueeze(0), audio_tensor.unsqueeze(0))
        if presence.item() > 0.5:
            return time_pred.squeeze().cpu().numpy()
        return None

In [None]:
best_loss = float('inf')
for epoch in range(1, NUM_EPOCHS + 1):
    # Обучение
    model.train()
    train_loss = train_epoch(model, train_loader, optimizer, criterion_reg, criterion_cls)

    # Валидация
    val_loss, val_reg_err, val_cls_acc = eval_epoch(model, test_loader, criterion_reg, criterion_cls)

    # Логирование
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val Reg Errors - Start: {val_reg_err[0]:.3f}, Duration: {val_reg_err[1]:.3f}")
    print(f"Val Cls Acc: {val_cls_acc:.4f}")

    scheduler.step(val_loss)

    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print("Saved best model!")

    # Ранняя остановка
    if epoch > 10 and val_loss > best_loss * 1.1:
        print("Early stopping triggered")
        break

model.load_state_dict(torch.load('best_model.pth'))

Training: 100%|██████████| 9/9 [02:32<00:00, 16.93s/it]
Evaluating: 100%|██████████| 6/6 [01:01<00:00, 10.17s/it]



Epoch 1/20
Train Loss: 0.1332
Val Loss: 0.1139
Val Reg Errors - Start: 0.137, Duration: 0.205
Val Cls Acc: 1.0000
Saved best model!


Training: 100%|██████████| 9/9 [02:14<00:00, 14.95s/it]
Evaluating: 100%|██████████| 6/6 [01:00<00:00, 10.16s/it]



Epoch 2/20
Train Loss: 0.1016
Val Loss: 0.1029
Val Reg Errors - Start: 0.031, Duration: 0.191
Val Cls Acc: 1.0000
Saved best model!


Training: 100%|██████████| 9/9 [02:14<00:00, 14.89s/it]
Evaluating: 100%|██████████| 6/6 [01:00<00:00, 10.05s/it]



Epoch 3/20
Train Loss: 0.0997
Val Loss: 0.1027
Val Reg Errors - Start: 0.060, Duration: 0.191
Val Cls Acc: 1.0000
Saved best model!


Training: 100%|██████████| 9/9 [02:13<00:00, 14.81s/it]
Evaluating: 100%|██████████| 6/6 [01:02<00:00, 10.36s/it]



Epoch 4/20
Train Loss: 0.0987
Val Loss: 0.1009
Val Reg Errors - Start: 0.050, Duration: 0.167
Val Cls Acc: 1.0000
Saved best model!


Training: 100%|██████████| 9/9 [02:14<00:00, 14.93s/it]
Evaluating: 100%|██████████| 6/6 [01:00<00:00, 10.12s/it]



Epoch 5/20
Train Loss: 0.0984
Val Loss: 0.1061
Val Reg Errors - Start: 0.057, Duration: 0.231
Val Cls Acc: 1.0000


Training: 100%|██████████| 9/9 [02:13<00:00, 14.89s/it]
Evaluating: 100%|██████████| 6/6 [01:01<00:00, 10.20s/it]



Epoch 6/20
Train Loss: 0.0976
Val Loss: 0.1025
Val Reg Errors - Start: 0.061, Duration: 0.192
Val Cls Acc: 1.0000


Training: 100%|██████████| 9/9 [02:10<00:00, 14.46s/it]
Evaluating: 100%|██████████| 6/6 [01:00<00:00, 10.16s/it]



Epoch 7/20
Train Loss: 0.0978
Val Loss: 0.1034
Val Reg Errors - Start: 0.060, Duration: 0.198
Val Cls Acc: 1.0000


Training: 100%|██████████| 9/9 [02:12<00:00, 14.73s/it]
Evaluating: 100%|██████████| 6/6 [01:01<00:00, 10.19s/it]



Epoch 8/20
Train Loss: 0.0974
Val Loss: 0.1043
Val Reg Errors - Start: 0.061, Duration: 0.213
Val Cls Acc: 1.0000


Training: 100%|██████████| 9/9 [02:14<00:00, 14.92s/it]
Evaluating: 100%|██████████| 6/6 [01:01<00:00, 10.17s/it]



Epoch 9/20
Train Loss: 0.0974
Val Loss: 0.1041
Val Reg Errors - Start: 0.063, Duration: 0.206
Val Cls Acc: 1.0000


Training: 100%|██████████| 9/9 [02:12<00:00, 14.74s/it]
Evaluating: 100%|██████████| 6/6 [01:00<00:00, 10.01s/it]



Epoch 10/20
Train Loss: 0.0985
Val Loss: 0.1002
Val Reg Errors - Start: 0.040, Duration: 0.170
Val Cls Acc: 1.0000
Saved best model!


Training:  11%|█         | 1/9 [00:13<01:48, 13.53s/it]


KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load('best_model.pth'))

<All keys matched successfully>

In [None]:
import pandas as pd
from datetime import timedelta

def save_predictions_to_csv(model, dataset, output_file='predictions.csv'):
    model.eval()
    results = []

    with torch.no_grad():
        for idx in tqdm(range(len(dataset)), desc="Generating predictions"):
            video, audio, _, video_id = dataset[idx]

            # Получаем предсказания
            time_pred, presence_prob = model(
                video.unsqueeze(0).to(DEVICE),
                audio.unsqueeze(0).to(DEVICE)
            )

            # Конвертируем относительные значения в абсолютные секунды
            if presence_prob.item() > 0.5:  # Порог наличия заставки
                start_rel, duration_rel = time_pred.squeeze().cpu().numpy()
                start_sec = start_rel * dataset.seq_len
                end_sec = start_sec + (duration_rel * dataset.seq_len)

                # Конвертируем секунды в формат HH:MM:SS
                start_time = str(timedelta(seconds=int(start_sec)))
                end_time = str(timedelta(seconds=int(end_sec)))
            else:
                start_time = "00:00:00"
                end_time = "00:00:00"

            results.append({
                'video_id': video_id,
                'start_time': start_time,
                'end_time': end_time,
                'presence_prob': presence_prob.item()
            })

    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)
    print(f"Predictions saved to {output_file}")
    return df

test_dataset = IntroDataset(TEST_DIR, TEST_LABELS_JSON, train=False)
predictions_df = save_predictions_to_csv(model, test_dataset, 'test_predictions.csv')

def save_with_ground_truth(model, dataset, output_file='predictions_with_gt.csv'):
    model.eval()
    results = []

    with torch.no_grad():
        for idx in tqdm(range(len(dataset)), desc="Generating predictions"):
            video, audio, target, video_id = dataset[idx]

            time_pred, presence_prob = model(
                video.unsqueeze(0).to(DEVICE),
                audio.unsqueeze(0).to(DEVICE)
            )

            gt_start = target[0].item() * dataset.seq_len
            gt_end = gt_start + target[1].item() * dataset.seq_len

            if presence_prob.item() > 0.5:
                start_rel, duration_rel = time_pred.squeeze().cpu().numpy()
                pred_start = start_rel * dataset.seq_len
                pred_end = pred_start + duration_rel * dataset.seq_len
            else:
                pred_start = pred_end = 0

            results.append({
                'video_id': video_id,
                'pred_start': str(timedelta(seconds=int(pred_start))),
                'pred_end': str(timedelta(seconds=int(pred_end))),
                'gt_start': str(timedelta(seconds=int(gt_start))),
                'gt_end': str(timedelta(seconds=int(gt_end))),
                'presence_prob': presence_prob.item(),
                'error_sec': abs(pred_start - gt_start) + abs(pred_end - gt_end)
            })

    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)
    print(f"Predictions with GT saved to {output_file}")
    return df

full_results_df = save_with_ground_truth(model, test_dataset)

Generating predictions: 100%|██████████| 21/21 [00:59<00:00,  2.85s/it]


Predictions saved to test_predictions.csv


Generating predictions: 100%|██████████| 21/21 [00:59<00:00,  2.84s/it]

Predictions with GT saved to predictions_with_gt.csv



