In [None]:
# INSTALL MODULES
%pip install pickle5

# For ray.Tune
%pip install ray torch torchvision      # Preferred install command when running PyTorch
#%pip install -U ray                    # Alternate installation for Tune

# For Tensorboard
%pip install tensorboardX

# IMPORTS
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
import os
import pickle5 as pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
from ray import tune
from ray.tune.schedulers import ASHAScheduler           # To use ASHA Scheduler (it's recommended using this over the standard HyperBand scheduler)
from hyperopt import hp
from ray.tune.suggest.hyperopt import HyperOptSearch
import tensorboardX

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Mount Google Drive in Colab
#from google.colab import files
from google.colab import drive
drive.mount("/content/gdrive")

# Import modules in Colab from another notebooks (works!! :)) )
%run '/content/gdrive/MyDrive/Colab Notebooks/training/dataset.ipynb'
%run '/content/gdrive/MyDrive/Colab Notebooks/training/models.ipynb'

if not torch.cuda.is_available():
    raise RuntimeError("You should enable GPU runtime.")



In [None]:
# To speed up training, it's better to copy dataset from Drive to a Colab folder

# choose a local (colab) directory to store the data.
local_dataset_path = os.path.expanduser('/content/data')
try:
  os.makedirs(local_dataset_path)
except: pass

dataset_path = '/content/gdrive/MyDrive/jester_dataset/'

!cp -avr "{dataset_path}" "{local_dataset_path}"

# Make sure it's there
!ls -lha /content/gdrive/MyDrive/jester_dataset/features

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Retrieves features from i3d_resnet50_v1_kinetics400 Gluon pre-trained model for our 9 classes' Jester dataset videos
with open('/content/data/jester_dataset/features/features_RGB.pickle', 'rb') as handle:
    features_dict = pickle.load(handle)        # dict as {n_video:features, ...}  features: (1,2048)
    
csv_dir = '/content/data/jester_dataset/csv/'  # if in Google Cloud --> csv_dir = '/mnt/disks/disk-1/jester_dataset/dataset/csvs/' 
train_csv = csv_dir + 'train.csv' 
val_csv = csv_dir + 'validation.csv'
labels = csv_dir + 'labels.csv'

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, epoch):  
  model.train()
  accs, losses = [], []

  for features, labels in train_loader:
    optimizer.zero_grad()
    features, labels = features.to(device), labels.to(device)
    output = model(features)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    accs.append(accuracy(labels, output))
    losses.append(loss.item())

  return np.mean(losses), np.mean(accs)

In [None]:
def eval_epoch(model, val_loader, criterion, epoch): 
  with torch.no_grad():
    model.eval()
    accs, losses = [], []

    for features, labels in val_loader:
      features, labels = features.to(device), labels.to(device)
      output = model(features)
      loss = criterion(output, labels)
      accs.append(accuracy(labels, output))
      losses.append(loss.item())

    return np.mean(losses), np.mean(accs)

In [None]:
def accuracy(labels, outputs):
    preds = outputs.argmax(-1)
    acc = (preds == labels.view_as(preds)).cpu().float().detach().numpy().mean()
    return acc

In [None]:
def train_model(config):
  # DATASETS
  train_dataset = JesterDatasetOneStream(features_dict, train_csv, labels)
  validation_dataset = JesterDatasetOneStream(features_dict, val_csv, labels)

  train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
  validation_loader = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=False)

  # MODEL
  model = ClassifierOneStream(hidden_sz=config['hidden_size'], dropout=config['dropout']).to(device)
  optimizer = optim.Adam(model.parameters(), lr = config['lr'])

  counts = pd.read_csv(train_csv).Action.value_counts()       # Counts = list of video counts for each class
  weights = torch.tensor([max(counts)/x for x in counts])     # calculates weights for all classes in training dataset (weight regularization)
  criterion = nn.CrossEntropyLoss(weight=weights).to(device)  # assigns weight to each of the classes. This is particularly useful when you have an unbalanced training set
  
  train_losses = []
  val_losses = []
  train_accs = []
  val_accs = []
  
  for epoch in range(config['epochs']+1):
    
    loss, acc = train_epoch(model, train_loader, optimizer, criterion, epoch)
    train_losses.append(loss)
    train_accs.append(acc)
    print(f"Train Epoch {epoch} loss={loss:.2f} acc={acc:.2f}")
        
    loss, acc = eval_epoch(model, validation_loader, criterion, epoch)
    val_losses.append(loss)
    val_accs.append(acc)
    print(f"Eval Epoch {epoch} loss={loss:.2f} acc={acc:.2f}")
    tune.report(tune_loss=loss, accuracy=acc)       # send report to Tune



In [None]:
if __name__ == "__main__":

    config = {
            "lr": hp.loguniform("lr", np.log(1e-4), np.log(1e-2)),
            "batch_size": hp.choice("batch_size", [8, 16, 32, 64, 128, 256]),
            "hidden_size": hp.choice("hidden_size", [128, 256, 512, 1024, 2048]),      # 128, 256, 512, 1024 or 2048
            "dropout": 0.5,
            "epochs": 20,
    }

    asha_scheduler = ASHAScheduler(             # For using ASHA (AsyncHyperBandScheduler) (recommended over Hyperband optimization)
            metric="tune_loss",
            mode="min",
            #time_attr='epoch',
            #max_t=config["epochs"],
            grace_period=1,
            #reduction_factor=2,
    )

    current_best_params = [{
            "lr": 0.00023,
            "batch_size": 64,
            "hidden_size": 1024,
            "dropout": 0.5,
            "epochs": 20,
    }]

    hyperopt_search = HyperOptSearch(config, metric="tune_loss", mode="min", points_to_evaluate=current_best_params)

    analysis = tune.run(                        # run hyperparameter tuning trials with Tune
        train_model,
        num_samples=50,
        #config=config,
        resources_per_trial={"gpu": 1},
        scheduler=asha_scheduler,
        search_alg=hyperopt_search,
        name="tune_RGB"
        )     

    print("\nHyperparameter tuning finished")



In [None]:
# Obtain a trial dataframe from all run trials of this `tune.run` call.
dfs = analysis.trial_dataframes

# Plot by epoch
ax = None  # This plots everything on the same plot
for d in dfs.values():
  ax = d.tune_loss.plot(ax=ax, legend=False)
ax.set_xlabel("Epochs")
ax.set_ylabel("Validation loss")

# Get a dataframe for the last reported results of all of the trials
df_1 = analysis.results_df

# Get a dataframe for the max accuracy seen for each trial
df_2 = analysis.dataframe(metric="tune_loss", mode="min")

# Get a dict mapping {trial logdir -> dataframes} for all trials in the experiment.
all_dataframes = analysis.trial_dataframes

# Get a list of trials
trials = analysis.trials

In [None]:
%tensorboard --logdir ~/ray_results

In [None]:
'''
import seaborn as sns
validation_dataset = JesterDataset(features_dict,val_csv,labels)

 
validation_loader = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=False)
nb_classes = 9
confusion_matrix = np.zeros((nb_classes, nb_classes))
trained_model = trained_model.to(device)
with torch.no_grad():
    for i, (inputs, classes) in enumerate(validation_loader):
        inputs = inputs.to(device)
        classes = classes.to(device)
        outputs = trained_model(inputs)
        _, preds = torch.max(outputs, 1)
        for t, p in zip(classes.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

plt.figure(figsize=(12,7))

l = pd.read_csv(labels)
class_names = list(l['Actions'])
l['Accuracy'] = np.diag(confusion_matrix)/confusion_matrix.sum(1)
df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names).astype(int)
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")

heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right',fontsize=12)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right',fontsize=12)
plt.ylabel('True label')
plt.xlabel('Predicted label')
'''

In [None]:
#l