In [None]:
import os, io, shutil
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

import numpy as np
from tensorboardX import SummaryWriter
from dataloader.handhygiene import HandHygiene
from spatial_transforms import (
    Compose, Normalize, Scale, CenterCrop, CornerCrop, MultiScaleCornerCrop,
    MultiScaleRandomCrop, RandomHorizontalFlip, ToTensor)
from temporal_transforms import (
    MirrorLoopPadding, LoopPadding, TemporalBeginCrop, 
    TemporalRandomCrop, TemporalCenterCrop, TemporalRandomChoice)
from openpose_transforms import CropTorso, MultiScaleTorsoRandomCrop

In [None]:
model_name = 'i3d'
batch_size = 4
clip_len = 64
num_classes = 1

torch.manual_seed(100)
data_name = 'anesthesia'
dataset_path = os.path.join(os.getcwd(), 'data')

In [None]:
scales = 1.0
sample_duration = 64
sample_size = 224
mean=[110.63666788, 103.16065604, 96.29023126]
std=[38.7568578, 37.88248729, 40.02898126]

scales=np.linspace(1, 1.75, num=4)
openpose_transform = MultiScaleTorsoRandomCrop(scales, sample_size)
spatial_transform = {
    'train': Compose([
            Scale(sample_size),
            CenterCrop(sample_size),
            RandomHorizontalFlip(),
            ToTensor(1), 
            Normalize(mean, std)]),
    'val': Compose([Scale(sample_size), CenterCrop(sample_size), 
                    ToTensor(1), Normalize(mean, std)])}
temporal_transform = {
    'train': TemporalRandomChoice([
            TemporalBeginCrop(sample_duration),
            TemporalRandomCrop(sample_duration),
            TemporalCenterCrop(sample_duration),
            LoopPadding(sample_duration),
            MirrorLoopPadding(sample_duration)]),
    'val':TemporalCenterCrop(sample_duration)}

In [None]:
dataset = {
    'train':HandHygiene(dataset_path, split='train', clip_len=clip_len, 
                        spatial_transform=spatial_transform['train'],
                        openpose_transform=openpose_transform,
                        temporal_transform=temporal_transform['train'], num_workers=16),
    'val':HandHygiene(dataset_path, split='val', clip_len=clip_len, 
                        spatial_transform=spatial_transform['val'],
                        temporal_transform=temporal_transform['val'], num_workers=16)}

dataloaders = {
    'train': DataLoader(dataset['train'], batch_size=batch_size, shuffle=True, num_workers=16),
    'val': DataLoader(dataset['val'], batch_size=batch_size, shuffle=False, num_workers=16)}

In [None]:
from train import get_models
from train import train
from torchsummary import summary

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
i3d_rgb, i3d_flow = get_models(num_classes, True, 152) # unfreeze last mix 170
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    i3d_rgb = nn.DataParallel(i3d_rgb)
    i3d_flow = nn.DataParallel(i3d_flow)
i3d_rgb.to(device)
i3d_flow.to(device)

criterion = F.binary_cross_entropy
#criterion = nn.BCELoss
optims={'rgb':None, 'flow':None}
schedulers = {'rgb':None, 'flow':None}
feature_extract=True

In [None]:
def trainable_params(model, mode='rgb'):
    params_to_update = model.parameters()
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                print("\t",name)
    optims[mode] = optim.SGD(model.parameters(), lr=1e-6, momentum=0.9, weight_decay=1e-7)
    schedulers[mode] = lr_scheduler.StepLR(optims[mode], 20, 0.1)
    
trainable_params(i3d_rgb, 'rgb')
trainable_params(i3d_flow, 'flow')

In [None]:
summary(i3d_rgb, (3, 64, 224, 224))

In [None]:
train((i3d_rgb, i3d_flow), dataloaders, optims, criterion, schedulers, device, num_epochs=100)