### Обучаем модель

In [1]:
# Импорт библиотек
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torch.nn.utils import clip_grad_norm_
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import corpus_bleu
from collections import Counter

import torch.optim
import torch.utils.data
from nltk.translate.bleu_score import corpus_bleu

import warnings
warnings.filterwarnings('ignore')

In [2]:
# Проверка доступности GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Используемое устройство: {device}')

Используемое устройство: cuda


In [3]:
# Конфигурация
DATA_DIR = '/kaggle/input/dota-2-comments'
FRAMES_PER_VIDEO = 20

In [4]:
# Словарь для сортировки видео
order_dict = {
    '2396917974_5': 1,
    '2396917974_2': 2,
    '2406178117_1': 3,
    '2406178117_2': 4,
    '2406178117_3': 5,
    '2406178117_4': 6,
    '2406178117_5': 7,
    '2406178117_6': 8,
    '2406178117_7': 9,
    '2406178117_8': 10,
    '2406178117_9': 11,
    '2382667658_1': 12,
    '2382667658_2': 13,
    '2382667658_3': 14,
    '2382667658_4': 15,
    '2382667658_5': 16,
    '2382667658_6': 17,
    '2382667658_7': 18,
    '2382667658_8': 19,
    '2392353636_1': 20,
    '2392353636_2': 21
}

### Датасет

In [5]:
# Датасет
class Dota2Dataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.samples = []
        self.transform = transform
        self.word_freq = Counter()
        
        # Проходим по всем папкам с видео
        for video_id in sorted(os.listdir(data_dir), key=lambda video_id: order_dict[video_id]):
            id_dir = os.path.join(data_dir, video_id)
            for video_dir in sorted(os.listdir(id_dir), key=lambda video_dir: int(video_dir)):
                video_path = os.path.join(id_dir, video_dir)
                if not os.path.isdir(video_path):
                    continue
                    
                # Загружаем кадры
                frames_dir = os.path.join(video_path, "frames")
                frame_paths = [
                    os.path.join(frames_dir, f) 
                    for f in sorted(os.listdir(frames_dir), key=lambda x: int(x.split('.')[0])) if not int(f.split('.')[0]) % 2
                ]
                
                # Загружаем комментарии
                labels_path = os.path.join(video_path, "labels.txt")
                with open(labels_path, 'r', encoding='utf-8') as f:
                    captions = f.read().strip().split('\n')
            
                # Токенизируем комментарии и обновляем частоту слов
                tokenized_captions = []
                for cap in captions:
                    tokens = word_tokenize(cap.lower())
                    self.word_freq.update(tokens)
                    tokenized_captions.append(tokens)
                
                self.samples.append({
                    'video_id': video_dir,
                    'frame_paths': frame_paths,
                    'captions': captions,
                    'tokenized_captions': tokenized_captions
                })
        
        # Создаем словарь
        self.vocab = ['<pad>', '<start>', '<end>', '<unk>'] + \
                     [word for word, freq in self.word_freq.items() if freq >= 5]
        self.word2idx = {word: idx for idx, word in enumerate(self.vocab)}
        self.idx2word = {idx: word for idx, word in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Загружаем и преобразуем кадры
        frames = []
        for frame_path in sample['frame_paths']:
            img = Image.open(frame_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            frames.append(img)
        frames = torch.stack(frames)
        
        # Преобразуем комментарии в индексы
        caption_indices = []
        for tokens in sample['tokenized_captions']:
            indices = [self.word2idx.get(token, self.word2idx['<unk>']) for token in tokens]
            indices = [self.word2idx['<start>']] + indices + [self.word2idx['<end>']]
            caption_indices.append(torch.tensor(indices, dtype=torch.long))
        
        return {
            'video_id': sample['video_id'],
            'frames': frames,                    # [10, 3, 224, 224] <-> [N, C, H, W]
            'captions': sample['captions'],
            'caption_indices': caption_indices
        }

# Трансформации для изображений
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Загружаем датасет
dataset = Dota2Dataset(DATA_DIR, transform=transform)
vocab_size = dataset.vocab_size
print(f"Размер словаря: {vocab_size}")

Размер словаря: 2663


In [6]:
# Пример объекта выборки
example = dataset[1]

print('Video_id:\t', example['video_id'])
print('Frames.size:\t', example['frames'].size())
print('Captions:\t', example['captions'])
print('Indices:\t', example['caption_indices'])

Video_id:	 2
Frames.size:	 torch.Size([10, 3, 224, 224])
Captions:	 ['В памяти было шесть игр, пять выиграл, в составе было супер-потное против спирит, но не сло. Да, тут, конечно, нужно понимать, что демордж...']
Indices:	 [tensor([ 1, 28,  3, 29,  3, 30,  4, 31, 32,  4, 28,  3, 29,  3, 33, 34,  4, 35,
        20,  3,  8,  9,  4, 36,  4, 37,  4, 38, 39,  4, 40,  3, 41,  2])]


In [7]:
# Разделяем на тренировочную и валидационную выборки
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

### Извлечение признаков из кадров

In [8]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        effnet = models.efficientnet_v2_s(pretrained=True)
        modules = list(effnet.children())[:-2]  # Удаляем avgpool и fc слои
        self.effnet = nn.Sequential(*modules)
        
        for param in self.effnet.parameters():
            param.requires_grad = False
    
    def forward(self, images):
        # images: [batch_size, num_frames, num_channels, height, weight]
        batch_size, num_frames = images.size(0), images.size(1)
        images = images.view(-1, *images.size()[2:])
        out = self.effnet(images)
        return out.view(batch_size, num_frames, 1280, 7, 7).permute(0, 1, 3, 4, 2)

encoder = Encoder().to(device)

Downloading: "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_s-dd5fe13b.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 193MB/s]


In [9]:
encoder(dataset[1]['frames'].unsqueeze(0).to(device)).size()

torch.Size([1, 10, 7, 7, 1280])

### Модель генерации комментариев

In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

In [None]:
class DecoderWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=1280, dropout=0.5):
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim) 

        self.embedding = nn.Embedding(vocab_size, embed_dim) 
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) 
        self.init_h = nn.Linear(encoder_dim, decoder_dim) 
        self.init_c = nn.Linear(encoder_dim, decoder_dim) 
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  
        self.init_weights()  

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        encoder_out = encoder_out.reshape(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        decode_lengths = (caption_lengths - 1).tolist()

        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t])) 
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(self.dropout(h))  
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

### Обучение модели

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(scores, targets, k):
    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)

def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

def collate_fn(batch):
    # Обработка кадров
    frames = [item['frames'] for item in batch]
    frames = torch.stack(frames)
    
    # Обработка комментариев
    captions = [cap for item in batch for cap in item['caption_indices']]
    caption_lengths = torch.tensor([len(cap) for cap in captions])
    
    # Дополняем комментарии до одинаковой длины
    captions_padded = pad_sequence(captions, batch_first=True, padding_value=0)
    
    return frames, captions_padded, caption_lengths.unsqueeze(1), [item['captions'] for item in batch]

def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, decoder_optimizer,
                    bleu4, is_best):
    state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'bleu-4': bleu4,
             'encoder': encoder,
             'decoder': decoder,
             'decoder_optimizer': decoder_optimizer}
    filename = 'checkpoint_' + data_name + '.pth.tar'
    torch.save(state, filename)
    # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
    if is_best:
        torch.save(state, 'BEST_' + filename)

In [None]:
def train(train_loader, encoder, decoder, criterion, decoder_optimizer, epoch):
    decoder.train() 
    losses = AverageMeter() 
    top5accs = AverageMeter() 

    for i, (imgs, caps, caplens, _) in enumerate(train_loader):
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        targets = caps_sorted[:, 1:]

        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

        loss = criterion(scores, targets)
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        decoder_optimizer.zero_grad()
        loss.backward()

        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)

        decoder_optimizer.step()

        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          loss=losses,
                                                                          top5=top5accs))

In [None]:
def validate(val_loader, encoder, decoder, criterion):
    decoder.eval()  
    if encoder is not None:
        encoder.eval()

    losses = AverageMeter()
    top5accs = AverageMeter()

    references = list()  
    hypotheses = list()  

    with torch.no_grad():
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            targets = caps_sorted[:, 1:]

            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

            loss = criterion(scores, targets)
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader),
                                                                                loss=losses, top5=top5accs))

    return 0

In [15]:
batch_size = 16

# Создаем DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=1
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=1
)

In [None]:
emb_dim = 512 
attention_dim = 512 
decoder_dim = 512 
dropout = 0.5

epochs = 50  
epochs_since_improvement = 0  
decoder_lr = 4e-4  
grad_clip = 5. 
alpha_c = 1. 
best_bleu4 = 0.  
print_freq = 100  

In [None]:
decoder = DecoderWithAttention(attention_dim=attention_dim,
                               embed_dim=emb_dim,
                               decoder_dim=decoder_dim,
                               vocab_size=dataset.vocab_size,
                               dropout=dropout)
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                     lr=decoder_lr)
encoder = Encoder()

decoder = decoder.to(device)
encoder = encoder.to(device)

criterion = nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)

    train(train_loader=train_loader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch)

    recent_bleu4 = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)

save_checkpoint('dota', epoch, epochs_since_improvement, encoder, decoder,
                decoder_optimizer, recent_bleu4, True)

Epoch: [0][0/288]	Loss 8.7199 (8.7199)	Top-5 Accuracy 0.169 (0.169)
Epoch: [0][100/288]	Loss 5.9053 (6.2492)	Top-5 Accuracy 39.837 (37.217)
Epoch: [0][200/288]	Loss 5.7019 (6.0022)	Top-5 Accuracy 41.109 (39.896)
Validation: [0/32]	Loss 5.5458 (5.5458)	Top-5 Accuracy 43.617 (43.617)	
Epoch: [1][0/288]	Loss 5.4870 (5.4870)	Top-5 Accuracy 45.534 (45.534)
Epoch: [1][100/288]	Loss 5.6206 (5.5161)	Top-5 Accuracy 41.831 (44.663)
Epoch: [1][200/288]	Loss 5.5148 (5.5026)	Top-5 Accuracy 45.375 (44.775)
Validation: [0/32]	Loss 5.4323 (5.4323)	Top-5 Accuracy 45.035 (45.035)	
Epoch: [2][0/288]	Loss 5.2919 (5.2919)	Top-5 Accuracy 48.596 (48.596)
Epoch: [2][100/288]	Loss 5.2819 (5.3510)	Top-5 Accuracy 46.835 (46.482)
Epoch: [2][200/288]	Loss 5.4263 (5.3441)	Top-5 Accuracy 46.283 (46.527)
Validation: [0/32]	Loss 5.3460 (5.3460)	Top-5 Accuracy 45.567 (45.567)	
Epoch: [3][0/288]	Loss 5.3484 (5.3484)	Top-5 Accuracy 45.794 (45.794)
Epoch: [3][100/288]	Loss 5.4339 (5.2444)	Top-5 Accuracy 45.572 (47.370)
Ep

### Генерация комментариев