In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np

In [2]:
from data.dataset import VideoDataSet
from sklearn.preprocessing import LabelEncoder
train_dataset = VideoDataSet('train')
le = LabelEncoder()
label = train_dataset.get_label()
print(label)
le.fit(label)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\evilr/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

['Punch', 'ShavingBeard', 'CricketShot', 'PlayingCello', 'TennisSwing']


In [3]:

def collate_fn(batch):
    frames = [b[0] for b in batch]
    masks = [torch.ones(b[0].shape[0]) for b in batch]
    frames = torch.nn.utils.rnn.pad_sequence(frames)
    masks = torch.nn.utils.rnn.pad_sequence(masks)

    frames = torch.transpose(frames, 0 , 1)
    masks = torch.transpose(masks, 0 , 1)
    item = le.transform([b[1] for b in batch])
    labels = torch.tensor(item)
    return (frames, masks), labels

data_loader = torch.utils.data.DataLoader(train_dataset,batch_size=32, collate_fn = collate_fn)

In [4]:

import Model
size = len(data_loader)
model = Model.CNNRNN(5)
model.to('cuda')
model.train()


loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for i in range(100):
    correct = 0
    for batch, (X, y) in enumerate(data_loader):
        frames, mask = X
        frames = frames.to('cuda')
        y = y.to('cuda').type(torch.int64)
        pred = model(frames)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    print(f'loss:{loss.item():>7f} acc:{correct / len(data_loader.dataset) :>2f}')



loss:1.083769 acc:0.750842
loss:0.923282 acc:0.915825
loss:0.961924 acc:0.959596
loss:1.016810 acc:0.858586
loss:0.905634 acc:0.885522
loss:0.911693 acc:0.941077
loss:0.905300 acc:0.964646
loss:0.905187 acc:0.986532
loss:0.905117 acc:0.986532
loss:0.905069 acc:0.986532
loss:0.905035 acc:0.986532
loss:0.905009 acc:0.986532
loss:0.904988 acc:0.986532
loss:0.904971 acc:0.986532
loss:0.904957 acc:0.986532
loss:0.904945 acc:0.986532
loss:0.904935 acc:0.986532
loss:0.904926 acc:0.986532
loss:0.904919 acc:0.986532
loss:0.904912 acc:0.986532
loss:0.904907 acc:0.986532
loss:0.904901 acc:0.986532
loss:0.904897 acc:0.986532
loss:0.904893 acc:0.986532
loss:0.904889 acc:0.986532
loss:0.904886 acc:0.986532
loss:1.071532 acc:0.981481
loss:1.346508 acc:0.572391
loss:1.344714 acc:0.592593
loss:1.200516 acc:0.686869
loss:1.019852 acc:0.774411
loss:0.905554 acc:0.900673
loss:0.905329 acc:0.983165
loss:0.905101 acc:0.983165
loss:0.960569 acc:0.981481
loss:0.960579 acc:0.924242
loss:0.960516 acc:0.909091
l