In [1]:
import os
import numpy as np
import seaborn as sns
import torch
import torch.optim as opt
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms


from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix


from models.model_saver import SaveBestModel
from models.kth_set_splitter import KTHDataset
from models.cnn_3d import CNNModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
%matplotlib inline

In [2]:
# seed_everything

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(1234)

In [None]:
# Download KTH video dataset to isolate the .txt file containing the split for Training, Validation and Test sets
with open('subjects.txt') as f:
    lines = f.readlines()

datasets = {}

for line in lines:
    dataset, subjects = line.split(':\t\t')
    datasets[dataset] = subjects[:-1]

seq_length = 75
image_shape = 120

# Download KTH video dataset to isolate the .txt file containing sequences configuration
annot_files = 'sequences00.txt'
video_dir = 'actions'

transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(120, scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip()
])

data_train = KTHDataset(annotations_file=annot_files,
                            data_dir=video_dir,
                            subjects=datasets['Training'],
                            seq_length=seq_length,
                            image_shape=image_shape)

data_valid = KTHDataset(annotations_file=annot_files, 
                            data_dir=video_dir,
                            subjects=datasets['Validation'], 
                            seq_length=seq_length, 
                            image_shape=image_shape)

data_test = KTHDataset(annotations_file=annot_files, 
                           data_dir=video_dir,
                           subjects=datasets['Test'], 
                           seq_length=seq_length,
                           image_shape=image_shape)

In [4]:
# Hyperparameters definition

batch_size = 2
channels = 3
num_classes = 6
num_epochs = 50

train_loader = DataLoader(dataset=data_train, batch_size=batch_size, shuffle=True, num_workers=3, pin_memory=True)
valid_loader = DataLoader(dataset=data_valid, batch_size=batch_size, shuffle=False, num_workers=3, pin_memory=True)
test_loader = DataLoader(dataset=data_test, batch_size=batch_size, shuffle=False, num_workers=3, pin_memory=True)

In [5]:
# Training function

def train_epoch(epoch, num_epochs, model, optimizer, dataloader, criterion):

    model.train()
    total_loss = 0
    progress = tqdm(dataloader, total=len(dataloader))

    for sequence, label in progress:
        sequence = sequence.permute(0,2,1,3,4)  # Adjust the shape so that the number of channels gets the right position
        sequence = sequence.to(device)
        label = label.to(device)

        # Zero the gradients for every batch !
        optimizer.zero_grad()

        # Make predictions for this batch
        pred_label = model(sequence)

        # Compute the loss and its gradients
        train_loss = criterion(pred_label, label)
        train_loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        total_loss += train_loss.item()
        progress.set_description(f'[%.2g/%.2g] train loss. %.2f' % (epoch+1, num_epochs, total_loss/len(sequence)))
        
    return total_loss

In [6]:
# Validation function

def validate_epoch(model, dataloader, criterion):
    
    acc = 0.0
    n = 0
    pred_labels, true_labels = [], []
    total_loss = 0.0
    model.eval()
    
    with torch.no_grad():
        progress = tqdm(dataloader, total=len(dataloader))

        for sequence, label in progress:
            sequence = sequence.permute(0,2,1,3,4) # Adjust the shape so that the number of channels gets the right position
            sequence = sequence.to(device)
            label = label.to(device)

            true_labels.extend(list(label.detach().cpu().numpy()))
            pred_label = model(sequence)

            valid_loss = criterion(pred_label, label)
            pred_label = pred_label.argmax(dim=1)

            vec_label = label.flatten()
            acc += (pred_label == vec_label).sum().item()
            n += len(vec_label)
            total_loss += valid_loss.item()

            pred_labels.extend(list(pred_label.detach().cpu().numpy()))
            accuracy = (acc / n)*100

            desc = '[VALID]> loss. %.2f > acc. %.2g%%' % (total_loss/len(label), accuracy)
            progress.set_description(desc)

    return total_loss, true_labels, pred_labels, accuracy

In [None]:
# Training Loop
    
acc = 0
results = []
best_loss_val = float('inf')
saver = SaveBestModel(best_valid_loss=best_loss_val)
accuracies = []
losses = []

model = CNNModel(in_dim=channels, num_classes=num_classes)
model = nn.DataParallel(model).to(device)

cross_entropy = nn.CrossEntropyLoss(label_smoothing=0.1)
adam = opt.Adagrad(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(num_epochs):
    train_loss = train_epoch(epoch, num_epochs=num_epochs, model=model, optimizer=adam, dataloader=train_loader, criterion=cross_entropy)
    valid_loss, true_labels, pred_labels, accuracy = validate_epoch(model=model, dataloader=valid_loader, criterion=cross_entropy)

    if acc < accuracy:
        results.append((true_labels, pred_labels))
        acc = accuracy

    if valid_loss < best_loss_val:
        best_loss_val = valid_loss
        saver.save(current_valid_loss=best_loss_val, epoch=epoch, model=model, optimizer=adam, criterion=cross_entropy)
        ## save the checkpoints
    
    accuracies.append(accuracy)
    losses.append(valid_loss)

In [None]:
conf_mat = confusion_matrix(results[-1][0], results[-1][1], normalize='true')

fig, ax = plt.subplots()
sns.heatmap(conf_mat, ax=ax, square=True, annot=True)
ax.xaxis.set_ticklabels(['boxing', 'clap', 'waving', 'jogging', 'running', 'walking'])   
ax.yaxis.set_ticklabels(['boxing', 'clap', 'waving', 'jogging', 'running', 'walking'])

fig.tight_layout()