In [1]:
#from models import *
import torch
import pytorchvideo
from pytorchvideo.data import labeled_video_dataset
from pytorchvideo.data.clip_sampling import make_clip_sampler
from pytorchvideo.transforms import ApplyTransformToKey, RandomResizedCrop, Normalize, RandomShortSideScale, RemoveKey, ShortSideScale, UniformTemporalSubsample
import torchvision.transforms as transforms 
from torch.utils.data import DataLoader, ConcatDataset
from tqdm import tqdm
import pytorchvideo.models as models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.models.video as v_model
import numpy as np
import matplotlib.pyplot as plt
import os
#from models import *


device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
leg = 'R'

In [2]:
train_transform = transforms.Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform = transforms.Compose(
                  [
                    #transforms.RandomRotation(20),
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                    transforms.Resize((192,108)),
                  ]
                ),
              ),
            ]
        )
add_transform = transforms.Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform = transforms.Compose(
                  [
                    
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                    transforms.Resize((192,108)),
                    transforms.RandomRotation(20),
                  ]
                ),
              ),
            ]
        )

batch_size = 3
num_sec = 5

In [3]:
train_sampler = make_clip_sampler('uniform', num_sec)
#valid_sampler = make_clip_sampler('uniform', num_sec)
dataset = labeled_video_dataset(f'video_dataset_{leg}', train_sampler, transform=train_transform)
add_dataset = labeled_video_dataset(f'video_dataset_{leg}', train_sampler, transform=add_transform)

#train_set = ConcatDataset([dataset, add_dataset])

valid_set = labeled_video_dataset(f'valid_set_{leg}', train_sampler, transform=train_transform)
#train_set, valid_set = random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])
train_loader = DataLoader(dataset, batch_size = 3)
add_loader = DataLoader(add_dataset, batch_size = 3)

valid_loader = DataLoader(valid_set, batch_size = 3)






In [4]:
model = v_model.r3d_18(weights = 'R3D_18_Weights.DEFAULT')
model = model.to(device)

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01,momentum=0.9)
sched = lr_scheduler.MultiStepLR(optimizer, milestones=[20,40], gamma=0.1)

#batch_loss = 0

if not os.path.isdir(f'hint_model_{leg}'):
    os.mkdir(f'hint_model_{leg}')

best_loss = np.Inf
epoch_mem = 0

epochs = 50
train_mem, valid_mem = [], []

for epoch in range(epochs):

    train_loss = 0
    accuracy = 0
    total = 0
    
    # start training
    model.train()
    
    # read training data
    for batch in tqdm(train_loader):
        imgs, labels = batch['video'], batch['label']
        imgs, labels = imgs.to(device), labels.to(device)

        # feed the group of images into model and get probability
        outputs = model(imgs)
        
        # make a prediction
        _, preds = torch.max(outputs, 1)
        
        # calculate loss by CrossEntropy
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        
        # update model weight
        loss.backward()
        optimizer.step()
        sched.step()

        # add training loss
        train_loss += loss.item()
        total += imgs.shape[0]
        
        # add training accuracy
        accuracy += torch.sum(preds == labels.data)
        
        # release memory space to prevent out of memory
        torch.cuda.empty_cache()
     
    for batch in tqdm(add_loader):
        imgs, labels = batch['video'], batch['label']
        imgs, labels = imgs.to(device), labels.to(device)

        # feed the group of images into model and get probability
        outputs = model(imgs)
        
        # make a prediction
        _, preds = torch.max(outputs, 1)
        
        # calculate loss by CrossEntropy
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        
        # update model weight
        loss.backward()
        optimizer.step()
        sched.step()

        # add training loss
        train_loss += loss.item()
        total += imgs.shape[0]
        
        # add training accuracy
        accuracy += torch.sum(preds == labels.data)
        
        # release memory space to prevent out of memory
        torch.cuda.empty_cache()
    print(f'Epoch {epoch} | train_accuracy = {accuracy / total:03f}, train_loss = {train_loss / total:03f}')
    train_mem.append(train_loss / total)
    
    valid_loss = 0
    total = 0
    accuracy = 0
    
    model.eval()
    
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            imgs, labels = batch['video'], batch['label']
            imgs, labels = imgs.to(device), labels.to(device)

            # feed the group of images into model and get probability
            outputs = model(imgs)

            # make a prediction
            _, preds = torch.max(outputs, 1)

            # calculate loss by CrossEntropy
            loss = criterion(outputs, labels)

            optimizer.zero_grad()

            # add training loss
            valid_loss += loss.item()
            total += imgs.shape[0]

            # add training accuracy
            accuracy += torch.sum(preds == labels.data)

            # release memory space to prevent out of memory
            torch.cuda.empty_cache()
    if valid_loss / total < best_loss:
        best_loss = valid_loss / total
        epoch_mem = epoch
        torch.save(model.state_dict(), f'./hint_model_{leg}/right_model_{epoch}.pt' if leg == 'R' else f'./hint_model_{leg}/left_model_{epoch}.pt')
    valid_mem.append(valid_loss / total)
    print(f'Epoch {epoch} | valid_accuracy = {accuracy / total:03f}, valid_loss = {valid_loss / total:03f}')

59it [02:06,  2.15s/it]
59it [02:05,  2.12s/it]


Epoch 0 | train_accuracy = 0.560000, train_loss = 1.080941


23it [00:36,  1.59s/it]


Epoch 0 | valid_accuracy = 0.507246, valid_loss = 0.551438


59it [02:01,  2.06s/it]
59it [02:05,  2.13s/it]


Epoch 1 | train_accuracy = 0.642857, train_loss = 0.363132


23it [00:37,  1.61s/it]


Epoch 1 | valid_accuracy = 0.463768, valid_loss = 0.475687


59it [02:02,  2.07s/it]
59it [02:05,  2.14s/it]


Epoch 2 | train_accuracy = 0.660000, train_loss = 0.339885


23it [00:37,  1.61s/it]


Epoch 2 | valid_accuracy = 0.492754, valid_loss = 0.578480


59it [02:02,  2.08s/it]
59it [02:06,  2.15s/it]


Epoch 3 | train_accuracy = 0.674286, train_loss = 0.333192


23it [00:37,  1.62s/it]


Epoch 3 | valid_accuracy = 0.608696, valid_loss = 0.454444


59it [02:03,  2.09s/it]
59it [02:07,  2.16s/it]


Epoch 4 | train_accuracy = 0.657143, train_loss = 0.330161


23it [00:37,  1.65s/it]


Epoch 4 | valid_accuracy = 0.478261, valid_loss = 0.393666


59it [02:04,  2.10s/it]
59it [02:06,  2.15s/it]


Epoch 5 | train_accuracy = 0.674286, train_loss = 0.309609


23it [00:37,  1.62s/it]


Epoch 5 | valid_accuracy = 0.492754, valid_loss = 0.412941


59it [02:02,  2.08s/it]
59it [02:07,  2.16s/it]


Epoch 6 | train_accuracy = 0.680000, train_loss = 0.298746


23it [00:37,  1.64s/it]


Epoch 6 | valid_accuracy = 0.492754, valid_loss = 0.379587


59it [02:02,  2.08s/it]
59it [02:06,  2.15s/it]


Epoch 7 | train_accuracy = 0.688571, train_loss = 0.287381


23it [00:37,  1.62s/it]


Epoch 7 | valid_accuracy = 0.536232, valid_loss = 0.397432


59it [02:03,  2.09s/it]
59it [02:07,  2.16s/it]


Epoch 8 | train_accuracy = 0.705714, train_loss = 0.272613


23it [00:37,  1.64s/it]


Epoch 8 | valid_accuracy = 0.492754, valid_loss = 0.398751


59it [02:03,  2.10s/it]
59it [02:06,  2.15s/it]


Epoch 9 | train_accuracy = 0.680000, train_loss = 0.284132


23it [00:37,  1.62s/it]


Epoch 9 | valid_accuracy = 0.449275, valid_loss = 0.397485


59it [02:02,  2.07s/it]
59it [02:07,  2.17s/it]


Epoch 10 | train_accuracy = 0.677143, train_loss = 0.277768


23it [00:37,  1.63s/it]


Epoch 10 | valid_accuracy = 0.463768, valid_loss = 0.416819


59it [02:03,  2.10s/it]
59it [02:07,  2.16s/it]


Epoch 11 | train_accuracy = 0.745714, train_loss = 0.236561


23it [00:37,  1.64s/it]


Epoch 11 | valid_accuracy = 0.304348, valid_loss = 0.439546


59it [02:03,  2.10s/it]
59it [02:07,  2.16s/it]


Epoch 12 | train_accuracy = 0.728571, train_loss = 0.247951


23it [00:37,  1.64s/it]


Epoch 12 | valid_accuracy = 0.304348, valid_loss = 0.418310


59it [02:02,  2.08s/it]
59it [02:07,  2.16s/it]


Epoch 13 | train_accuracy = 0.725714, train_loss = 0.234855


23it [00:37,  1.63s/it]


Epoch 13 | valid_accuracy = 0.478261, valid_loss = 0.450457


59it [02:04,  2.11s/it]
59it [02:07,  2.17s/it]


Epoch 14 | train_accuracy = 0.700000, train_loss = 0.253566


23it [00:37,  1.64s/it]


Epoch 14 | valid_accuracy = 0.449275, valid_loss = 0.501964


59it [02:03,  2.10s/it]
59it [02:07,  2.16s/it]


Epoch 15 | train_accuracy = 0.748571, train_loss = 0.224529


23it [00:37,  1.63s/it]


Epoch 15 | valid_accuracy = 0.347826, valid_loss = 0.490296


59it [02:03,  2.09s/it]
59it [02:07,  2.16s/it]


Epoch 16 | train_accuracy = 0.745714, train_loss = 0.219476


23it [00:38,  1.65s/it]


Epoch 16 | valid_accuracy = 0.391304, valid_loss = 0.442867


59it [02:03,  2.09s/it]
59it [02:08,  2.17s/it]


Epoch 17 | train_accuracy = 0.751429, train_loss = 0.214472


23it [00:37,  1.63s/it]


Epoch 17 | valid_accuracy = 0.347826, valid_loss = 0.459937


59it [02:02,  2.08s/it]
59it [02:07,  2.17s/it]


Epoch 18 | train_accuracy = 0.791429, train_loss = 0.207018


23it [00:37,  1.65s/it]


Epoch 18 | valid_accuracy = 0.463768, valid_loss = 0.504902


59it [02:03,  2.09s/it]
59it [02:06,  2.15s/it]


Epoch 19 | train_accuracy = 0.820000, train_loss = 0.182437


23it [00:37,  1.63s/it]


Epoch 19 | valid_accuracy = 0.333333, valid_loss = 0.461105


59it [02:03,  2.09s/it]
59it [02:08,  2.18s/it]


Epoch 20 | train_accuracy = 0.791429, train_loss = 0.195778


23it [00:37,  1.65s/it]


Epoch 20 | valid_accuracy = 0.333333, valid_loss = 0.525021


59it [02:03,  2.09s/it]
59it [02:07,  2.16s/it]


Epoch 21 | train_accuracy = 0.788571, train_loss = 0.192447


23it [00:37,  1.64s/it]


Epoch 21 | valid_accuracy = 0.333333, valid_loss = 0.497185


59it [02:04,  2.11s/it]
59it [02:08,  2.17s/it]


Epoch 22 | train_accuracy = 0.817143, train_loss = 0.185146


23it [00:37,  1.63s/it]


Epoch 22 | valid_accuracy = 0.289855, valid_loss = 0.537392


59it [02:03,  2.10s/it]
59it [02:09,  2.19s/it]


Epoch 23 | train_accuracy = 0.848571, train_loss = 0.161679


23it [00:37,  1.65s/it]


Epoch 23 | valid_accuracy = 0.376812, valid_loss = 0.582186


59it [02:04,  2.11s/it]
59it [02:06,  2.14s/it]


Epoch 24 | train_accuracy = 0.845714, train_loss = 0.153215


23it [00:37,  1.63s/it]


Epoch 24 | valid_accuracy = 0.260870, valid_loss = 0.575824


59it [02:04,  2.10s/it]
59it [02:06,  2.15s/it]


Epoch 25 | train_accuracy = 0.845714, train_loss = 0.164524


23it [00:37,  1.62s/it]


Epoch 25 | valid_accuracy = 0.260870, valid_loss = 0.558703


59it [02:03,  2.09s/it]
59it [02:06,  2.15s/it]


Epoch 26 | train_accuracy = 0.862857, train_loss = 0.161219


23it [00:37,  1.61s/it]


Epoch 26 | valid_accuracy = 0.318841, valid_loss = 0.594712


59it [02:03,  2.09s/it]
59it [02:07,  2.16s/it]


Epoch 27 | train_accuracy = 0.885714, train_loss = 0.147546


23it [00:37,  1.62s/it]


Epoch 27 | valid_accuracy = 0.260870, valid_loss = 0.571709


59it [02:03,  2.10s/it]
59it [02:08,  2.18s/it]


Epoch 28 | train_accuracy = 0.885714, train_loss = 0.134701


23it [00:37,  1.64s/it]


Epoch 28 | valid_accuracy = 0.275362, valid_loss = 0.560101


59it [02:03,  2.10s/it]
59it [02:07,  2.16s/it]


Epoch 29 | train_accuracy = 0.857143, train_loss = 0.145877


23it [00:37,  1.63s/it]


Epoch 29 | valid_accuracy = 0.405797, valid_loss = 0.616401


59it [02:04,  2.11s/it]
59it [02:08,  2.18s/it]


Epoch 30 | train_accuracy = 0.920000, train_loss = 0.110931


23it [00:37,  1.63s/it]


Epoch 30 | valid_accuracy = 0.333333, valid_loss = 0.661831


59it [02:04,  2.10s/it]
59it [02:07,  2.16s/it]


Epoch 31 | train_accuracy = 0.917143, train_loss = 0.112076


23it [00:37,  1.63s/it]


Epoch 31 | valid_accuracy = 0.434783, valid_loss = 0.590067


59it [02:02,  2.08s/it]
59it [02:07,  2.15s/it]


Epoch 32 | train_accuracy = 0.908571, train_loss = 0.119881


23it [00:37,  1.63s/it]


Epoch 32 | valid_accuracy = 0.347826, valid_loss = 0.621815


59it [02:03,  2.09s/it]
59it [02:08,  2.17s/it]


Epoch 33 | train_accuracy = 0.940000, train_loss = 0.108169


23it [00:37,  1.62s/it]


Epoch 33 | valid_accuracy = 0.362319, valid_loss = 0.607896


59it [02:02,  2.07s/it]
59it [02:06,  2.14s/it]


Epoch 34 | train_accuracy = 0.931429, train_loss = 0.101465


23it [00:37,  1.64s/it]


Epoch 34 | valid_accuracy = 0.347826, valid_loss = 0.609209


59it [02:03,  2.09s/it]
59it [02:08,  2.17s/it]


Epoch 35 | train_accuracy = 0.917143, train_loss = 0.097326


23it [00:37,  1.63s/it]


Epoch 35 | valid_accuracy = 0.304348, valid_loss = 0.703505


59it [02:02,  2.08s/it]
59it [02:08,  2.18s/it]


Epoch 36 | train_accuracy = 0.928571, train_loss = 0.097290


23it [00:39,  1.74s/it]


Epoch 36 | valid_accuracy = 0.362319, valid_loss = 0.656219


59it [02:12,  2.25s/it]
59it [02:17,  2.33s/it]


Epoch 37 | train_accuracy = 0.940000, train_loss = 0.090207


23it [00:40,  1.75s/it]


Epoch 37 | valid_accuracy = 0.362319, valid_loss = 0.641356


59it [02:10,  2.21s/it]
59it [02:16,  2.32s/it]


Epoch 38 | train_accuracy = 0.931429, train_loss = 0.089432


23it [00:40,  1.74s/it]


Epoch 38 | valid_accuracy = 0.333333, valid_loss = 0.700489


59it [02:09,  2.19s/it]
59it [02:15,  2.30s/it]


Epoch 39 | train_accuracy = 0.948571, train_loss = 0.089196


23it [00:40,  1.77s/it]


Epoch 39 | valid_accuracy = 0.405797, valid_loss = 0.609915


59it [02:10,  2.20s/it]
59it [02:13,  2.26s/it]


Epoch 40 | train_accuracy = 0.957143, train_loss = 0.079191


23it [00:40,  1.75s/it]


Epoch 40 | valid_accuracy = 0.333333, valid_loss = 0.735696


59it [02:11,  2.22s/it]
59it [02:15,  2.30s/it]


Epoch 41 | train_accuracy = 0.957143, train_loss = 0.070325


23it [00:38,  1.69s/it]


Epoch 41 | valid_accuracy = 0.463768, valid_loss = 0.612863


59it [02:09,  2.20s/it]
59it [02:14,  2.28s/it]


Epoch 42 | train_accuracy = 0.974286, train_loss = 0.064590


23it [00:40,  1.76s/it]


Epoch 42 | valid_accuracy = 0.449275, valid_loss = 0.634079


59it [02:07,  2.15s/it]
59it [02:13,  2.25s/it]


Epoch 43 | train_accuracy = 0.980000, train_loss = 0.062977


23it [00:40,  1.75s/it]


Epoch 43 | valid_accuracy = 0.405797, valid_loss = 0.632305


59it [02:08,  2.17s/it]
59it [02:13,  2.26s/it]


Epoch 44 | train_accuracy = 0.974286, train_loss = 0.062362


23it [00:39,  1.70s/it]


Epoch 44 | valid_accuracy = 0.449275, valid_loss = 0.649060


59it [02:07,  2.16s/it]
59it [02:14,  2.27s/it]


Epoch 45 | train_accuracy = 0.977143, train_loss = 0.057582


23it [00:40,  1.75s/it]


Epoch 45 | valid_accuracy = 0.362319, valid_loss = 0.631561


59it [02:10,  2.21s/it]
59it [02:16,  2.31s/it]


Epoch 46 | train_accuracy = 0.982857, train_loss = 0.058366


23it [00:40,  1.77s/it]


Epoch 46 | valid_accuracy = 0.434783, valid_loss = 0.578713


59it [02:10,  2.21s/it]
59it [02:14,  2.29s/it]


Epoch 47 | train_accuracy = 0.980000, train_loss = 0.056602


23it [00:40,  1.75s/it]


Epoch 47 | valid_accuracy = 0.449275, valid_loss = 0.657042


59it [02:08,  2.17s/it]
59it [02:12,  2.24s/it]


Epoch 48 | train_accuracy = 0.982857, train_loss = 0.049823


23it [00:40,  1.76s/it]


Epoch 48 | valid_accuracy = 0.478261, valid_loss = 0.667662


59it [02:09,  2.20s/it]
59it [02:13,  2.27s/it]


Epoch 49 | train_accuracy = 0.991429, train_loss = 0.043381


23it [00:40,  1.75s/it]

Epoch 49 | valid_accuracy = 0.434783, valid_loss = 0.605077





In [6]:
n_epochs = np.arange(epochs)
plt.figure(0)
plt.plot(n_epochs, train_mem, color = 'b', label = 'training loss')
plt.plot(n_epochs, valid_mem, color = 'r', label = 'Evaluation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
#plt.show()
plt.savefig(f'./hint_model_{leg}/{leg}_loss_surface.png')
plt.clf()

<Figure size 432x288 with 0 Axes>