In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from tqdm import tqdm, trange

from faceio import get_date_directories
from utils import attempt_load_day, load_day_to_batch
from network import create_network, init_weights

# init device and ensure deterministic
SEED = 997
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
device = torch.device('cuda:1')

RATIO = 10
video_by_day = get_date_directories('video', RATIO=RATIO)
model = create_network(device).to(device)

The model has 6,335,784 trainable parameters


In [53]:
def train(model, current_epoch, target_epoch):
    model.train()
    while current_epoch < TARGET_EPOCH:
        epoch_loss = 0
        epoch_slices = 0

        proxies = list(range(len(video_by_day)))
        random.shuffle(proxies)
        
        with trange(len(video_by_day)) as t:
            for day_idx in t:

                t.set_description('Training on day %d (of %d)' % (day_idx+1, len(video_by_day)))
                if day_idx == 0:
                    t.set_postfix(epoch=current_epoch)

                total_slices, src_lens, inputs, label = load_day_to_batch(video_by_day, proxies[day_idx], device)
                src_len = torch.Tensor(src_lens)

                output = model(inputs, src_len).view(-1, len(src_lens), 20, 2).cpu()

                loss = criterion(label, output)
                loss.backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
                optimizer.step()

                day_loss = loss.item()
                epoch_loss += loss.item()
                epoch_slices += 1

                day_loss = day_loss / total_slices
                t.set_postfix(day_loss=day_loss, epoch_loss=epoch_loss/epoch_slices, epoch=current_epoch)

        current_epoch += 1
    return current_epoch

In [54]:
model.apply(init_weights)
optimizer = optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.MSELoss()
clip = 1

In [55]:
current_epoch = 0
TARGET_EPOCH = 100
current_epoch = train(model, current_epoch, TARGET_EPOCH)

Training on day 120 (of 120): 100%|██████████| 120/120 [00:55<00:00,  2.15it/s, day_loss=1.95e+6, epoch=0, epoch_loss=5.68e+5]
Training on day 28 (of 120):  22%|██▎       | 27/120 [00:13<00:46,  2.01it/s, day_loss=4.38e+6, epoch=1, epoch_loss=3.24e+6]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'video2lip_ratio10_epoch100.pt')