### Wrapper for training.py to test function behaviours

#### Import training function and required packages

In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [None]:
import torch

from BDQ import BDQEncoder
from action_recognition_model import ActionRecognitionModel
from loss import ActionLoss, PrivacyLoss
from preprocess import KTHBDQDataset, ConsecutiveTemporalSubsample, MultiScaleCrop, NormalizePixelValues
from privacy_attribute_prediction_model import PrivacyAttributePredictor
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop
from training import adverserial_training, load_train_checkpoint, get_sorted_checkpoints

#### Perform adverserial training

In [None]:
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"

In [None]:
os.environ['COLAB_PATH'] = "../checkpoints"

In [None]:
# Specify location of KTH dataset and labels file
KTH_DATA_DIR = "../KTH"
KTH_LABELS_DIR = "../kth_clips.json"

# Set parameters according to https://arxiv.org/abs/2208.02459
num_epochs = 50
lr = 0.001
batch_size = 4
consecutive_frames = 8
crop_size = (224, 224)

# Load KTH dataset. Apply transformation sequence according to Section 4.2 in https://arxiv.org/abs/2208.02459
train_transform = Compose([
    ConsecutiveTemporalSubsample(consecutive_frames), # first, sample 32 consecutive frames
    MultiScaleCrop(), # then, apply randomized multi-scale crop
    Resize(crop_size), # then, resize to (224, 224)
    NormalizePixelValues(), # (also normalize pixel values for pytorch)
])
train_data = KTHBDQDataset(
    root_dir=KTH_DATA_DIR,
    json_path=KTH_LABELS_DIR,
    transform=train_transform,
    split="train",
)
train_dataloader = DataLoader(
    train_data,
    batch_size=batch_size,
    num_workers=4,
)
# Load validation dataset according to the same Section 4.2
val_transform = Compose([
    ConsecutiveTemporalSubsample(consecutive_frames), # first sample 32 consecutive frames
    CenterCrop(crop_size),  # then, we apply a center crop of (224, 224) without scaling (resizing)
    NormalizePixelValues(), # (also normalize pixel values for pytorch)
])
val_data = KTHBDQDataset(
    root_dir=KTH_DATA_DIR,
    json_path=KTH_LABELS_DIR,
    transform=val_transform,
    split="val",
)
val_dataloader = DataLoader(
    val_data,
    batch_size=batch_size,
    num_workers=4,
)

# Initialize the BDQEncoder (E), the action attribute predictor (T),
# and the privacy attribute predictor (P)
E = BDQEncoder(hardness=5.0).to(device)
T = ActionRecognitionModel(fine_tune=True, num_classes=6).to(device)
P = PrivacyAttributePredictor(num_privacy_classes=25).to(device)

# Initialize optimizer, scheduler and loss functions
optim_ET = SGD(params=list(E.parameters())+list(T.parameters()), lr=lr)
optim_P = SGD(params=list(P.parameters()), lr=lr)
scheduler_ET = CosineAnnealingLR(optimizer=optim_ET, T_max=num_epochs)
scheduler_P = CosineAnnealingLR(optimizer=optim_P, T_max=num_epochs)
checkpoints = get_sorted_checkpoints()
last_checkpoint_path = None
last_epoch = 0
if len(checkpoints) > 0:
    last_checkpoint_path, last_epoch = checkpoints[-1]
load_train_checkpoint(E, T, P, optim_ET, optim_P, scheduler_ET, scheduler_P, last_checkpoint_path)
criterion_action = ActionLoss(alpha=1)
criterion_privacy = PrivacyLoss()

adverserial_training(train_dataloader=train_dataloader, val_dataloader=val_dataloader, E=E, T=T, P=P, 
                        optimizer_ET=optim_ET, optimizer_P=optim_P, scheduler_ET=scheduler_ET, 
                        scheduler_P=scheduler_P, action_loss=criterion_action, privacy_loss=criterion_privacy,
                        last_epoch=last_epoch, num_epochs=num_epochs)