<a href="https://colab.research.google.com/github/julioteleco/SAM/blob/master/Copia_de_inference_playground_mp4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SAM: Animation Inference Playground

# Nueva sección

In [1]:
import os
os.chdir('/content')
CODE_DIR = 'SAM'

In [2]:
!git clone https://github.com/yuval-alaluf/SAM.git $CODE_DIR

Cloning into 'SAM'...
remote: Enumerating objects: 104, done.[K
remote: Counting objects: 100% (104/104), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 104 (delta 22), reused 92 (delta 14), pack-reused 0[K
Receiving objects: 100% (104/104), 3.45 MiB | 39.22 MiB/s, done.
Resolving deltas: 100% (22/22), done.


In [3]:
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 

--2021-02-16 05:33:09--  https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
Resolving github.com (github.com)... 192.30.255.112
Connecting to github.com (github.com)|192.30.255.112|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-releases.githubusercontent.com/1335132/d2f252e2-9801-11e7-9fbf-bc7b4e4b5c83?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20210216%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20210216T053310Z&X-Amz-Expires=300&X-Amz-Signature=96dec3540f121862c6f515bc02acd6bc81a7b58e062958a90f4f9e88f4fbaec0&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=1335132&response-content-disposition=attachment%3B%20filename%3Dninja-linux.zip&response-content-type=application%2Foctet-stream [following]
--2021-02-16 05:33:10--  https://github-releases.githubusercontent.com/1335132/d2f252e2-9801-11e7-9fbf-bc7b4e4b5c83?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%

In [4]:
os.chdir(f'./{CODE_DIR}')

In [5]:
from argparse import Namespace
import os
import sys
import pprint
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from datasets.augmentations import AgeTransformer
from utils.common import tensor2im
from models.psp import pSp

In [6]:
EXPERIMENT_TYPE = 'ffhq_aging'

## Step 1: Download Pretrained Model
As part of this repository, we provide our pretrained aging model.
We'll download the model for the selected experiments as save it to the folder `../pretrained_models`.

In [7]:
def get_download_model_command(file_id, file_name):
    """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
    current_directory = os.getcwd()
    save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
    return url    

In [8]:
MODEL_PATHS = {
    "ffhq_aging": {"id": "1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC", "name": "sam_ffhq_aging.pt"}
}

path = MODEL_PATHS[EXPERIMENT_TYPE]
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"]) 

In [9]:
!wget {download_command}

--2021-02-16 05:34:10--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2021-02-16 05:34:10--  https://docs.google.com/uc?export=download&confirm=i59R&id=1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC
Resolving docs.google.com (docs.google.com)... 74.125.20.138, 74.125.20.101, 74.125.20.100, ...
Connecting to docs.google.com (docs.google.com)|74.125.20.138|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-08-7g-docs.googleusercontent.com/docs/securesc/3288s4o1us02iiims1id6qrftql3lq11/r12uc4jdkq5fsoice2hhek2qiuko5ap2/1613453625000/05457687429326987275/17328170664508509099Z/1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC?e=download [following]
--2021-02-16 05:34:10--  https://doc-08-7g-docs.googleusercontent.com/docs/securesc/3288s4o1us02iiims1id6qrftql3lq11/r12uc4jdkq5fsoice2hhek2qiuko5ap2/1613453625000/05457687429326987275/17328170664508509099Z/1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC?e=download


## Step 3: Define Inference Parameters

Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the
image to perform inference on.
While we provide default values to run this script, feel free to change as needed.

In [10]:
EXPERIMENT_DATA_ARGS = {
    "ffhq_aging": {
        "model_path": "../pretrained_models/sam_ffhq_aging.pt",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    }
}

In [11]:
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]

## Step 4: Load Pretrained Model
We assume that you have downloaded the pretrained aging model and placed it in the path defined above.

In [12]:
model_path = EXPERIMENT_ARGS['model_path']
ckpt = torch.load(model_path, map_location='cpu')

In [13]:
opts = ckpt['opts']
pprint.pprint(opts)

{'aging_lambda': 5.0,
 'batch_size': 6,
 'board_interval': 50,
 'checkpoint_path': None,
 'cycle_lambda': 1.0,
 'dataset_type': 'ffhq_aging',
 'device': 'cuda',
 'exp_dir': '',
 'id_lambda': 0.1,
 'image_interval': 100,
 'input_nc': 4,
 'l2_lambda': 0.25,
 'l2_lambda_aging': 0.25,
 'l2_lambda_crop': 1.0,
 'label_nc': 0,
 'learning_rate': 0.0001,
 'lpips_lambda': 0.1,
 'lpips_lambda_aging': 0.1,
 'lpips_lambda_crop': 0.6,
 'max_steps': 500000,
 'optim_name': 'ranger',
 'output_size': 1024,
 'pretrained_psp_path': '',
 'save_interval': 2500,
 'start_from_encoded_w_plus': True,
 'start_from_latent_avg': False,
 'stylegan_weights': '',
 'target_age': 'uniform_random',
 'test_batch_size': 6,
 'test_workers': 6,
 'train_decoder': False,
 'use_weighted_id_loss': True,
 'val_interval': 2500,
 'w_norm_lambda': 0.005,
 'workers': 6}


In [14]:
# update the training options
opts['checkpoint_path'] = model_path

In [15]:
opts = Namespace(**opts)
net = pSp(opts)
net.eval()
net.cuda()
print('Model successfully loaded!')

Loading SAM from checkpoint: ../pretrained_models/sam_ffhq_aging.pt
Model successfully loaded!


### Utils for Generating MP4 

In [16]:
import imageio
from tqdm import tqdm
import matplotlib

matplotlib.use('module://ipykernel.pylab.backend_inline')
%matplotlib inline

def generate_mp4(out_name, images, kwargs):
    writer = imageio.get_writer(out_name + '.mp4', **kwargs)
    for image in images:
        writer.append_data(image)
    writer.close()
    

def run_on_batch_to_vecs(inputs, net):
    _, result_batch = net(inputs.to("cuda").float(), return_latents=True, randomize_noise=False, resize=False)
    return result_batch.cpu()


def get_result_from_vecs(vectors_a, vectors_b, alpha):
    results = []
    for i in range(len(vectors_a)):
        cur_vec = vectors_b[i] * alpha + vectors_a[i] * (1 - alpha)
        res = net(cur_vec.cuda(), randomize_noise=False, input_code=True, input_is_full=True, resize=False)
        results.append(res[0])
    return results

In [17]:
SEED = 42
np.random.seed(SEED)

img_transforms = EXPERIMENT_ARGS['transform']
n_transition = 25
kwargs = {'fps': 40}
save_path = "notebooks/animations"
os.makedirs(save_path, exist_ok=True)

#################################################################
# TODO: define your image paths here to be fed into the model
#################################################################
root_dir = 'notebooks/images'
ims = ['866', '1287', '2468']
im_paths = [os.path.join(root_dir, im) + '.jpg' for im in ims]

# NOTE: Please make sure the images are pre-aligned!

target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0]
age_transformers = [AgeTransformer(target_age=age) for age in target_ages]

for image_path in im_paths:
    image_name = os.path.basename(image_path)
    print(f'Working on image: {image_name}')
    original_image = Image.open(image_path).convert("RGB")
    input_image = img_transforms(original_image)
    all_vecs = []
    for idx, age_transformer in enumerate(age_transformers):

        input_age_batch = [age_transformer(input_image.cpu()).to('cuda')]
        input_age_batch = torch.stack(input_age_batch)

        # get latent vector for the current target age amount
        with torch.no_grad():
            result_vec = run_on_batch_to_vecs(input_age_batch, net)
            result_image = get_result_from_vecs([result_vec],result_vec,0)[0]
            all_vecs.append([result_vec])

    images = []
    for i in range(1, len(target_ages)):
        alpha_vals = np.linspace(0, 1, n_transition).tolist()
        for alpha in tqdm(alpha_vals):
            result_image = get_result_from_vecs(all_vecs[i-1], all_vecs[i], alpha)[0]
            output_im = tensor2im(result_image)
            images.append(np.array(output_im))

    animation_path = os.path.join(save_path, f"{image_name}_animation")
    generate_mp4(animation_path, images, kwargs)

Working on image: 866.jpg


100%|██████████| 25/25 [00:02<00:00,  8.84it/s]
100%|██████████| 25/25 [00:02<00:00,  8.69it/s]
100%|██████████| 25/25 [00:02<00:00,  8.54it/s]
100%|██████████| 25/25 [00:02<00:00,  8.51it/s]
100%|██████████| 25/25 [00:03<00:00,  8.08it/s]
100%|██████████| 25/25 [00:02<00:00,  8.57it/s]
100%|██████████| 25/25 [00:02<00:00,  8.60it/s]
100%|██████████| 25/25 [00:02<00:00,  8.67it/s]
100%|██████████| 25/25 [00:02<00:00,  8.51it/s]
100%|██████████| 25/25 [00:02<00:00,  8.45it/s]
100%|██████████| 25/25 [00:02<00:00,  8.61it/s]
100%|██████████| 25/25 [00:02<00:00,  8.50it/s]
100%|██████████| 25/25 [00:02<00:00,  8.54it/s]
100%|██████████| 25/25 [00:02<00:00,  8.50it/s]
100%|██████████| 25/25 [00:02<00:00,  8.55it/s]
100%|██████████| 25/25 [00:02<00:00,  8.58it/s]
100%|██████████| 25/25 [00:02<00:00,  8.55it/s]
100%|██████████| 25/25 [00:02<00:00,  8.53it/s]
100%|██████████| 25/25 [00:02<00:00,  8.58it/s]
100%|██████████| 25/25 [00:02<00:00,  8.74it/s]


Working on image: 1287.jpg


100%|██████████| 25/25 [00:02<00:00,  9.65it/s]
100%|██████████| 25/25 [00:02<00:00,  9.57it/s]
100%|██████████| 25/25 [00:02<00:00,  9.48it/s]
100%|██████████| 25/25 [00:02<00:00,  9.38it/s]
100%|██████████| 25/25 [00:02<00:00,  9.38it/s]
100%|██████████| 25/25 [00:02<00:00,  9.27it/s]
100%|██████████| 25/25 [00:02<00:00,  9.24it/s]
100%|██████████| 25/25 [00:02<00:00,  9.19it/s]
100%|██████████| 25/25 [00:02<00:00,  9.27it/s]
100%|██████████| 25/25 [00:02<00:00,  9.36it/s]
100%|██████████| 25/25 [00:02<00:00,  9.28it/s]
100%|██████████| 25/25 [00:02<00:00,  9.36it/s]
100%|██████████| 25/25 [00:02<00:00,  9.40it/s]
100%|██████████| 25/25 [00:02<00:00,  9.30it/s]
100%|██████████| 25/25 [00:02<00:00,  9.32it/s]
100%|██████████| 25/25 [00:02<00:00,  9.38it/s]
100%|██████████| 25/25 [00:02<00:00,  9.34it/s]
100%|██████████| 25/25 [00:02<00:00,  9.25it/s]
100%|██████████| 25/25 [00:02<00:00,  9.42it/s]
100%|██████████| 25/25 [00:02<00:00,  9.23it/s]


Working on image: 2468.jpg


100%|██████████| 25/25 [00:02<00:00,  8.79it/s]
100%|██████████| 25/25 [00:02<00:00,  8.40it/s]
100%|██████████| 25/25 [00:02<00:00,  8.40it/s]
100%|██████████| 25/25 [00:02<00:00,  8.47it/s]
100%|██████████| 25/25 [00:02<00:00,  8.43it/s]
100%|██████████| 25/25 [00:03<00:00,  8.31it/s]
100%|██████████| 25/25 [00:02<00:00,  8.45it/s]
100%|██████████| 25/25 [00:02<00:00,  8.41it/s]
100%|██████████| 25/25 [00:02<00:00,  8.86it/s]
100%|██████████| 25/25 [00:02<00:00,  9.28it/s]
100%|██████████| 25/25 [00:02<00:00,  9.31it/s]
100%|██████████| 25/25 [00:02<00:00,  8.89it/s]
100%|██████████| 25/25 [00:02<00:00,  8.41it/s]
100%|██████████| 25/25 [00:03<00:00,  8.27it/s]
100%|██████████| 25/25 [00:02<00:00,  8.45it/s]
100%|██████████| 25/25 [00:02<00:00,  8.42it/s]
100%|██████████| 25/25 [00:02<00:00,  8.35it/s]
100%|██████████| 25/25 [00:02<00:00,  8.36it/s]
100%|██████████| 25/25 [00:02<00:00,  8.47it/s]
100%|██████████| 25/25 [00:02<00:00,  8.70it/s]
