# Initialize training project

## Checking gpu in colab pro

In [None]:
import torch

torch.cuda.get_device_name(0)

'Tesla P100-PCIE-16GB'

## Entering the google drive data

In [None]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

Mounted at /content/drive


## Seeds for reproducibility


In [None]:
%matplotlib inline
import numpy as np 
import torch
import random
import os
from skimage import io, img_as_float32

In [None]:
def set_seed(seed):
  random.seed(seed)        
  torch.manual_seed(seed)  

set_seed(42)

## Partitioning on mini batches
<a href="https://imgbb.com/"><img src="https://i.ibb.co/7C3SV7Q/sampling.jpg" alt="sampling" border="0"></a>


Remember that this script was developed in a google collab context. You can change the path to the files in the ```data/vox-png``` folder.

The ```batches.txt``` file is in charge of saving in numbers the indexes of files that will be used in each training group

In [None]:
main_path = '/content/drive/MyDrive/Tesis/first-order-model-6c/data/vox-png' #Cambiar la ruta de ser necesario
batch_index_path = '/content/drive/MyDrive/first-order-model-6c'
fbatches='batches.txt'

In [None]:
def directory_iter(files, batch_size, shuffle=True):
  n = files.shape[0]
  
  if shuffle:
    indices = np.random.permutation(n)
  else:
    indices = range(n)

  for i in range(0, n, batch_size):
    batch_indices = indices[i:i+batch_size if i+batch_size <=n else n]
    #files_batch = files[batch_indices]
    yield batch_indices

We will save the indexes in the file ```batches.txt```

In [None]:
files=np.array(os.listdir(main_path))
batch_size = 400
total_samples = 0
mini_directory_bt=[]

if os.path.isfile(os.path.join(batch_index_path,fbatches)):
  os.remove(os.path.join(batch_index_path,fbatches))

with open(os.path.join(batch_index_path,fbatches), 'w') as f:
  for i, batch_indices in enumerate(directory_iter(files, batch_size), 1):
    total_samples += batch_indices.shape[0]
    #print(f'Batch {i} has size {batch_indices.shape[0]}') #To show the size of each lot
    for index in batch_indices:
      f.write("%s " % index)
    f.write("\n")

if total_samples == files.shape[0]:
  print(':) The total number of samples per batch is correct.')
else:
  print(':( The total number of samples per batch differs from the total number of samples.')

# TRAINING

Remember that this script was developed in a google collab context. You can change the path ```cd``` but in the same folder of the present model  **first-order-model-6c**

In [None]:
cd drive/MyDrive/Tesis/first-order-model-6c/

Installing First Order Model repository requirements

In [None]:
#!pip install -r requirements.txt

In [None]:
from shutil import copy
import imageio
import numpy as np
import sys
import uuid
import yaml
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import warnings
from time import gmtime, strftime
from skimage import img_as_ubyte
from ctypes import cdll
from train import train
from modules.generator import OcclusionAwareGenerator
from modules.discriminator import MultiScaleDiscriminator
from modules.keypoint_detector import KPDetector
from frames_dataset import FramesDataset,FramesDatasetPartitioning
from modules.util import DownBlock2d
from tqdm import trange
from torch.utils.data import DataLoader
from logger import Logger
from modules.model import GeneratorFullModel, DiscriminatorFullModel
from torch.optim.lr_scheduler import MultiStepLR
from sync_batchnorm import DataParallelWithCallback
from frames_dataset import DatasetRepeater
#from demo import load_checkpoints, make_animation, load_checkpoints_Unet_3
from demo import load_checkpoints, make_animation

Reading the ```batches.txt``` file and the **number of batch**

In [None]:
num_batch=0

batchf = './batches.txt'
config = './config/vox-adv-256.yaml'


batch_list=[]
with open(batchf, 'r') as f:
  for line in f.readlines():
    batch_list.append(list(map(int, line.rstrip().split(" "))))


with open(config) as f:
        config = yaml.load(f)

dataset = FramesDatasetPartitioning(is_train=1, **config['dataset_params'],batches_list=batch_list[num_batch])

Use predefined train-test split.


Debug the index file and the name of each video

In [None]:
print(batch_list[num_batch])
print(dataset.videos)

Training

In [None]:
warnings.filterwarnings("ignore")

config = './config/vox-adv-256.yaml'
device_ids = [0]
#checkpoint = './models/vox-adv-cpk.pth.tar'
checkpoint = None
log_dir = './logs'
batchf= './batches.txt'


batch_list=[]
with open(batchf, 'r') as f:
  for line in f.readlines():
    batch_list.append(list(map(int, line.rstrip().split(" "))))

if __name__ == "__main__":
    
    with open(config) as f:
        config = yaml.load(f)
        
    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params'])

    if torch.cuda.is_available():
        generator.to(device_ids[0])

    discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'], **config['model_params']['common_params'])

    if torch.cuda.is_available():
        discriminator.to(device_ids[0])

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params'])

    if torch.cuda.is_available():
        kp_detector.to(device_ids[0])
            
    #dataset = FramesDatasetPartitioning(is_train=1, **config['dataset_params'],batches_list=batch_list[num_batch])

    print("Training...")

    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))

    if checkpoint is not None:
        start_epoch = load_cpk(checkpoint, generator, discriminator, kp_detector, optimizer_generator, optimizer_discriminator, None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
    else:
        start_epoch = 0

    scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1)

    scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1)

    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])

    dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True)

    generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)

    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
        discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)

    with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for x in dataloader:
                #print(dataloader)
                losses_generator, generated = generator_full(x)

                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_kp_detector.step()
                optimizer_kp_detector.zero_grad()

                if train_params['loss_weights']['generator_gan'] != 0:
                    optimizer_discriminator.zero_grad()
                    losses_discriminator = discriminator_full(x, generated)
                    loss_values = [val.mean() for val in losses_discriminator.values()]
                    loss = sum(loss_values)

                    loss.backward()
                    optimizer_discriminator.step()
                    optimizer_discriminator.zero_grad()
                else:
                    losses_discriminator = {}

                losses_generator.update(losses_discriminator)
                losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
                logger.log_iter(losses=losses)

            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()
            
            logger.log_epoch(epoch, {'generator': generator, 'discriminator': discriminator, 'kp_detector': kp_detector, 'optimizer_generator': optimizer_generator, 'optimizer_discriminator': optimizer_discriminator, 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)

Training...


100%|██████████| 10/10 [3:49:40<00:00, 1378.07s/it]


# CHECKPOINT READING

Showing the current training model

In [None]:
checkpoint_path = './logs/00000009-checkpoint.pth.tar'

config_path = './config/vox-adv-256.yaml'
generator1, kp_detector1 = load_checkpoints(config_path,checkpoint_path)
print(kp_detector1)

DataParallelWithCallback(
  (module): KPDetector(
    (predictor): Hourglass(
      (encoder): Encoder(
        (down_blocks): ModuleList(
          (0): DownBlock2d(
            (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (pool): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
          )
          (1): DownBlock2d(
            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): SynchronizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (pool): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
          )
          (2): DownBlock2d(
            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): SynchronizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   

# ANIMATION

If we want a test of the model with the generated checkpoint, a source image and a driver video are saved

In [None]:
from animate import normalize_kp
from skimage.transform import resize
from moviepy import editor
from tqdm import tqdm
import os

In [None]:
adapt_movement_scale = True 
cpu = False
relative = True

D_VIDEO_PATH = 'logs/animation/video_d.mp4'
S_IMAGE_PATH = 'logs/animation/image_s.png'

G_VIDEO_PATH = 'logs/animation/video_g.mp4'
if os.path.isfile(G_VIDEO_PATH):
  os.remove(G_VIDEO_PATH)

source_image = imageio.imread(S_IMAGE_PATH)
driving_video = imageio.mimread(D_VIDEO_PATH, memtest=False)

reader = imageio.get_reader(D_VIDEO_PATH)
driving_video_FPS = reader.get_meta_data()['fps']

source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
driving_video_audio = editor.AudioFileClip(D_VIDEO_PATH)


with torch.no_grad():
  predictions = []
  source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
  if not cpu:
    source = source.cuda()
    driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
    kp_source = kp_detector1(source)
    kp_driving_initial = kp_detector1(driving[:, :, 0])

    for frame_idx in tqdm(range(driving.shape[2])):
      percentage = (frame_idx + 1)/ driving.shape[2]
      #yield "data:" + str(percentage) + "\n\n"
      driving_frame = driving[:, :, frame_idx]
      if not cpu:
        driving_frame = driving_frame.cuda()
      kp_driving = kp_detector1(driving_frame)
      kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
              kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
              use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
      out = generator1(source, kp_source=kp_source, kp_driving=kp_norm)

      predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

100%|██████████| 150/150 [00:06<00:00, 22.27it/s]


Saving de video generated

In [None]:
imageio.mimsave(G_VIDEO_PATH, [img_as_ubyte(frame) for frame in predictions], fps=driving_video_FPS)