In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np

In [2]:
from data.dataset import VideoDataSet
from sklearn.preprocessing import LabelEncoder
train_dataset = VideoDataSet('train')
le = LabelEncoder()
label = train_dataset.get_label()
print(label)
le.fit(label)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\evilr/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

['Punch', 'ShavingBeard', 'CricketShot', 'PlayingCello', 'TennisSwing']


In [3]:

def collate_fn(batch):
    frames = [b[0] for b in batch]
    masks = [torch.ones(b[0].shape[0]) for b in batch]
    frames = torch.nn.utils.rnn.pad_sequence(frames)
    masks = torch.nn.utils.rnn.pad_sequence(masks)

    frames = torch.transpose(frames, 0 , 1)
    masks = torch.transpose(masks, 0 , 1)
    item = le.transform([b[1] for b in batch])
    labels = torch.tensor(item)
    return (frames, masks), labels

data_loader = torch.utils.data.DataLoader(train_dataset,batch_size=32, collate_fn = collate_fn)

In [59]:

import Model
from tqdm import tqdm

size = len(data_loader)
model = Model.CNNRNN(5)
model.to('cuda')
model.train()


loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


for i in range(10):
    correct = 0
    pbar = tqdm(range(len(data_loader)))
    for batch, (X, y) in enumerate(data_loader):
        pbar.update()
        frames, mask = X
        frames = frames.to('cuda')
        y = y.to('cuda').type(torch.int64)
        pred = model(frames)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        pbar.set_description(f'loss:{loss.item():<.3f} acc:{correct / (data_loader.batch_size*(batch+1)) * 100 :.3f}%')
    del pbar


loss:1.005 acc:73.355%: 100%|██████████| 19/19 [00:00<00:00, 50.56it/s]
loss:0.963 acc:88.980%: 100%|██████████| 19/19 [00:00<00:00, 78.92it/s]
loss:0.906 acc:94.243%: 100%|██████████| 19/19 [00:00<00:00, 75.40it/s]
loss:0.906 acc:94.901%: 100%|██████████| 19/19 [00:00<00:00, 73.64it/s]
loss:0.906 acc:97.204%: 100%|██████████| 19/19 [00:00<00:00, 76.92it/s]
loss:0.905 acc:97.204%: 100%|██████████| 19/19 [00:00<00:00, 78.51it/s]
loss:0.905 acc:97.204%: 100%|██████████| 19/19 [00:00<00:00, 78.51it/s] 
loss:0.905 acc:97.204%: 100%|██████████| 19/19 [00:00<00:00, 80.67it/s] 
loss:0.905 acc:97.697%: 100%|██████████| 19/19 [00:00<00:00, 76.30it/s] 
loss:0.905 acc:97.697%: 100%|██████████| 19/19 [00:00<00:00, 71.70it/s] 
loss:0.905 acc:97.697%: 100%|██████████| 19/19 [00:00<00:00, 77.87it/s] 
loss:0.905 acc:97.697%: 100%|██████████| 19/19 [00:00<00:00, 73.64it/s] 
loss:0.905 acc:97.697%: 100%|██████████| 19/19 [00:00<00:00, 73.64it/s] 
loss:0.905 acc:97.697%: 100%|██████████| 19/19 [00:00<00: