# From README.md

# FaderNetworks

PyTorch implementation of [Fader Networks](https://arxiv.org/pdf/1706.00409.pdf) (NIPS 2017).

<p align="center"><a href=https://github.com/facebookresearch/FaderNetworks/blob/master/images/interpolation.jpg?raw=true><img width="100%" src="./images/interpolation.jpg" /></a></p>

Fader Networks can generate different realistic versions of images by modifying attributes such as gender or age group. They can swap multiple attributes at a time, and continuously interpolate between each attribute value. In this repository we provide the code to reproduce the results presented in the paper, as well as trained models.

### Single-attribute swap

Below are some examples of different attribute swaps:

<p align="center"><a href=https://github.com/facebookresearch/FaderNetworks/blob/master/images/swap.jpg?raw=true><img width="100%" src="./images/swap.jpg" /></a></p>

### Multi-attributes swap

The Fader Networks are also designed to disentangle multiple attributes at a time:

<p align="center"><a href=https://github.com/facebookresearch/FaderNetworks/blob/master/images/multi_attr.jpg?raw=true><img width="100%" src="./images/multi_attr.jpg" /></a></p>

## Model

<p align="center"><a href=https://github.com/facebookresearch/FaderNetworks/blob/master/images/v3.png?raw=true><img width="70%" src="./images/v3.png" /></a></p>

The main branch of the model (Inference Model), is an autoencoder of images. Given an image `x` and an attribute `y` (e.g. male/female), the decoder is trained to reconstruct the image from the latent state `E(x)` and `y`. The other branch (Adversarial Component), is composed of a discriminator trained to predict the attribute from the latent state. The encoder of the Inference Model is trained not only to reconstruct the image, but also to fool the discriminator, by removing from `E(x)` the information related to the attribute. As a result, the decoder needs to consider `y` to properly reconstruct the image. During training, the model is trained using real attribute values, but at test time, `y` can be manipulated to generate variations of the original image.

## Dependencies
* Python 2/3 with [NumPy](http://www.numpy.org/)/[SciPy](https://www.scipy.org/)
* [PyTorch](http://pytorch.org/)
* OpenCV
* CUDA


## Installation

Simply clone the repository:

```bash
git clone https://github.com/facebookresearch/FaderNetworks.git
cd FaderNetworks
```

## Dataset
Download the aligned and cropped CelebA dataset from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. Extract all images and move them to the `data/img_align_celeba/` folder. There should be 202599 images. The dataset also provides a file `list_attr_celeba.txt` containing the list of the 40 attributes associated with each image. Move it to `data/`. Then simply run:

```batch
cd data
./preprocess.py # python ./preprocess.py
```

It will resize images, and create 2 files: `images_256_256.pth` and `attributes.pth`. The first one contains a tensor of size `(202599, 3, 256, 256)` containing the concatenation of all resized images. Note that you can update the image size in `preprocess.py` to work with different resolutions. The second file is a pre-processed version of the attributes.

## Pretrained models
You can download pretrained classifiers and Fader Networks by running:

```batch
cd models
./download.sh
```

## Train your own models

### Train a classifier
To train your own model you first need to train a classifier to let the model evaluate the swap quality during the training. Training a good classifier is relatively simple for most attributes, and a good model can be trained in a few minutes. We provide a trained classifier for all attributes in `models/classifier256.pth`. Note that the classifier does not need to be state-of-the-art, it is not used during the training process, but is just here to monitor the swap quality. If you want to train your own classifier, you can run `classifier.py`, using the following parameters:


```bash
python classifier.py

# Main parameters
--img_sz 256                  # image size
--img_fm 3                    # number of feature maps
--attr "*"                    # attributes list. "*" for all attributes

# Network architecture
--init_fm 32                  # number of feature maps in the first layer
--max_fm 512                  # maximum number of feature maps
--hid_dim 512                 # hidden layer size

# Training parameters
--v_flip False                # randomly flip images vertically (data augmentation)
--h_flip True                 # randomly flip images horizontally (data augmentation)
--batch_size 32               # batch size
--optimizer "adam,lr=0.0002"  # optimizer
--clip_grad_norm 5            # clip gradient L2 norm
--n_epochs 1000               # number of epochs
--epoch_size 50000            # number of images per epoch

# Reload
--reload ""                   # reload a trained classifier
--debug False                 # debug mode (if True, load a small subset of the dataset)
```


### Train a Fader Network

You can train a Fader Network with `train.py`. The autoencoder can receive feedback from:
- The image reconstruction loss
- The latent discriminator loss
- The PatchGAN discriminator loss
- The classifier loss

In the paper, only the first two losses are used, but the two others could improve the results further. You can tune the impact of each of these losses with the lambda_ae, lambda_lat_dis, lambda_ptc_dis, and lambda_clf_dis coefficients. Below is a complete list of all parameters:

```bash
# Main parameters
--img_sz 256                      # image size
--img_fm 3                        # number of feature maps
--attr "Male"                     # attributes list. "*" for all attributes

# Networks architecture
--instance_norm False             # use instance normalization instead of batch normalization
--init_fm 32                      # number of feature maps in the first layer
--max_fm 512                      # maximum number of feature maps
--n_layers 6                      # number of layers in the encoder / decoder
--n_skip 0                        # number of skip connections
--deconv_method "convtranspose"   # deconvolution method
--hid_dim 512                     # hidden layer size
--dec_dropout 0                   # dropout in the decoder
--lat_dis_dropout 0.3             # dropout in the latent discriminator

# Training parameters
--n_lat_dis 1                     # number of latent discriminator training steps
--n_ptc_dis 0                     # number of PatchGAN discriminator training steps
--n_clf_dis 0                     # number of classifier training steps
--smooth_label 0.2                # smooth discriminator labels
--lambda_ae 1                     # autoencoder loss coefficient
--lambda_lat_dis 0.0001           # latent discriminator loss coefficient
--lambda_ptc_dis 0                # PatchGAN discriminator loss coefficient
--lambda_clf_dis 0                # classifier loss coefficient
--lambda_schedule 500000          # lambda scheduling (0 to disable)
--v_flip False                    # randomly flip images vertically (data augmentation)
--h_flip True                     # randomly flip images horizontally (data augmentation)
--batch_size 32                   # batch size
--ae_optimizer "adam,lr=0.0002"   # autoencoder optimizer
--dis_optimizer "adam,lr=0.0002"  # discriminator optimizer
--clip_grad_norm 5                # clip gradient L2 norm
--n_epochs 1000                   # number of epochs
--epoch_size 50000                # number of images per epoch

# Reload
--ae_reload ""                    # reload pretrained autoencoder
--lat_dis_reload ""               # reload pretrained latent discriminator
--ptc_dis_reload ""               # reload pretrained PatchGAN discriminator
--clf_dis_reload ""               # reload pretrained classifier
--eval_clf ""                     # evaluation classifier (trained with classifier.py)
--debug False                     # debug mode (if True, load a small subset of the dataset)
```

## Generate interpolations

Given a trained model, you can use it to swap attributes of images in the dataset. Below are examples using the pretrained models:

```bash
# Narrow Eyes
python interpolate.py --model_path models/narrow_eyes.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path narrow_eyes.png

# Eyeglasses
python interpolate.py --model_path models/eyeglasses.pth --n_images 10 --n_interpolations 10 --alpha_min 2.0 --alpha_max 2.0 --output_path eyeglasses.png

# Age
python interpolate.py --model_path models/young.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path young.png

# Gender
python interpolate.py --model_path models/male.pth --n_images 10 --n_interpolations 10 --alpha_min 2.0 --alpha_max 2.0 --output_path male.png

# Pointy nose
python interpolate.py --model_path models/pointy_nose.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path pointy_nose.png
```

These commands will generate images with 10 rows of 12 columns with the interpolated images. The first column corresponds to the original image, the second is the reconstructed image (without alteration of the attribute), and the remaining ones correspond to the interpolated images. `alpha_min` and `alpha_max` represent the range of the interpolation. Values superior to 1 represent generations over the True / False range of the boolean attribute in the model. Note that the variations of some attributes may only be noticeable for high values of alphas. For instance, for the "eyeglasses" or "gender" attributes, alpha_max=2 is usually enough, while for the "age" or "narrow eyes" attributes, it is better to go up to alpha_max=10.


## References

If you find this code useful, please consider citing:

[*Fader Networks: Manipulating Images by Sliding Attributes*](https://arxiv.org/pdf/1706.00409.pdf) - G. Lample, N. Zeghidour, N. Usunier, A. Bordes, L. Denoyer, M'A. Ranzato

```
@inproceedings{lample2017fader,
  title={Fader Networks: Manipulating Images by Sliding Attributes},
  author={Lample, Guillaume and Zeghidour, Neil and Usunier, Nicolas and Bordes, Antoine and DENOYER, Ludovic and others},
  booktitle={Advances in Neural Information Processing Systems},
  pages={5963--5972},
  year={2017}
}
```

Contact: [gl@fb.com](mailto:gl@fb.com), [neilz@fb.com](mailto:neilz@fb.com)


# From preprocess.py

In [1]:
#!/usr/bin/env python
import os
import matplotlib.image as mpimg
import cv2
import numpy as np
import torch
import glob

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path=r'C:\GitHub\Smart-Education-data\data'

In [7]:
if True:
    #Celeba
    raw_img_path=os.path.join(data_path, 'celeba')
    out_img_path=os.path.join(data_path, 'celeba-out')
    list_attr_path=os.path.join(data_path, 'list_attr_celeba.txt')
if False:
    #Celeba-HQ
    raw_img_path=os.path.join(data_path, 'celeba-hq')
    out_img_path=os.path.join(data_path, 'celeba-hq-out')
    # list_attr_path=?

In [9]:
# N_IMAGES = 202599#5#200 # => let's use glob
N_IMAGES = len(glob.glob(os.path.join(raw_img_path, '*.*')))
print(N_IMAGES)
# IMG_SIZE = 256
IMG_SIZE = 128 #same as attgan
IMG_PATH = os.path.join(out_img_path, 'images_%i_%i.pth' % (IMG_SIZE, IMG_SIZE)) #path for image binaries to be saved
ATTR_PATH = r'C:/GitHub/Smart-Education-data/data/attributes.pth' #attributes will be saved as pth file

print(IMG_PATH)

202599
C:\GitHub\Smart-Education-data\data\celeba-out\images_128_128.pth


In [11]:
def preprocess_images(raw_img_path, N_IMAGES, IMG_SIZE, IMG_PATH):
    if os.path.isfile(IMG_PATH):
        print("%s exists, nothing to do." % IMG_PATH)
        return 

    # print("Reading images from img_align_celeba/ ...")
    # raw_images = []
    # for i in range(1, N_IMAGES + 1):
    #     if i % 10000 == 0:
    #     # if i % 10 == 0:
    #         print(i)
    #     raw_images.append(mpimg.imread(os.path.join(raw_img_path, '%06i.jpg' % i))[20:-20])
    
    # print('debugging1')
    # print(len(raw_images))

    # if len(raw_images) != N_IMAGES:
        # raise Exception("Found %i images. Expected %i" % (len(raw_images), N_IMAGES))
    print("Resizing images ...")
    all_images = []
    # for i, image in enumerate(raw_images):
    for i in range(1, N_IMAGES + 1): #####
        if i % 10000 == 0:
            print(i)
        image=mpimg.imread(os.path.join(raw_img_path, '%06i.jpg' % i))[20:-20] #####
        # assert image.shape == (178, 178, 3)
        if IMG_SIZE < 178:
            image = cv2.resize(image, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
        elif IMG_SIZE > 178:
            image = cv2.resize(image, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LANCZOS4)
        # print(image.shape)
        assert image.shape == (IMG_SIZE, IMG_SIZE, 3)
        all_images.append(image)

    data = np.concatenate([img.transpose((2, 0, 1))[None] for img in all_images], 0)
    data = torch.from_numpy(data)
    assert data.size() == (N_IMAGES, 3, IMG_SIZE, IMG_SIZE)

    print("Saving images to %s ..." % IMG_PATH)
    torch.save(data[:20000].clone(), os.path.join(out_img_path, 'images_%i_%i_20000.pth' % (IMG_SIZE, IMG_SIZE)))
    torch.save(data, IMG_PATH)

In [12]:
def preprocess_attributes(list_attr_path, N_IMAGES, ATTR_PATH):

    if os.path.isfile(ATTR_PATH):
        print("%s exists, nothing to do." % ATTR_PATH)
        return

    attr_lines = [line.rstrip() for line in open(list_attr_path, 'r')]
    assert len(attr_lines) == N_IMAGES + 2

    print('debug:: ', attr_lines)
    
    attr_keys = attr_lines[1].split()
    attributes = {k: np.zeros(N_IMAGES, dtype=np.bool) for k in attr_keys}

    for i, line in enumerate(attr_lines[2:]):
        image_id = i + 1
        split = line.split()
        assert len(split) == 41
        assert split[0] == ('%06i.jpg' % image_id)
        # assert split[0] == ('%03i.jpg' % image_id)
        assert all(x in ['-1', '1'] for x in split[1:])
        for j, value in enumerate(split[1:]):
            attributes[attr_keys[j]][i] = value == '1'

    print("Saving attributes to %s ..." % ATTR_PATH)
    torch.save(attributes, ATTR_PATH)

In [7]:
# preprocess_images(raw_img_path, N_IMAGES, IMG_SIZE, IMG_PATH)
# preprocess_attributes(list_attr_path, N_IMAGES, ATTR_PATH)

# from model.py

In [13]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F


def build_layers(img_sz, img_fm, init_fm, max_fm, n_layers, n_attr, n_skip,
                 deconv_method, instance_norm, enc_dropout, dec_dropout):
    """
    Build auto-encoder layers.
    """
    assert init_fm <= max_fm
    assert n_skip <= n_layers - 1
    assert np.log2(img_sz).is_integer()
    assert n_layers <= int(np.log2(img_sz))
    assert type(instance_norm) is bool
    assert 0 <= enc_dropout < 1
    assert 0 <= dec_dropout < 1
    norm_fn = nn.InstanceNorm2d if instance_norm else nn.BatchNorm2d

    enc_layers = []
    dec_layers = []

    n_in = img_fm
    n_out = init_fm

    for i in range(n_layers):
        enc_layer = []
        dec_layer = []
        skip_connection = n_layers - (n_skip + 1) <= i < n_layers - 1
        n_dec_in = n_out + n_attr + (n_out if skip_connection else 0)
        n_dec_out = n_in

        # encoder layer
        enc_layer.append(nn.Conv2d(n_in, n_out, 4, 2, 1))
        if i > 0:
            enc_layer.append(norm_fn(n_out, affine=True))
        enc_layer.append(nn.LeakyReLU(0.2, inplace=True))
        if enc_dropout > 0:
            enc_layer.append(nn.Dropout(enc_dropout))

        # decoder layer
        if deconv_method == 'upsampling':
            dec_layer.append(nn.UpsamplingNearest2d(scale_factor=2))
            dec_layer.append(nn.Conv2d(n_dec_in, n_dec_out, 3, 1, 1))
        elif deconv_method == 'convtranspose':
            dec_layer.append(nn.ConvTranspose2d(n_dec_in, n_dec_out, 4, 2, 1, bias=False))
        else:
            assert deconv_method == 'pixelshuffle'
            dec_layer.append(nn.Conv2d(n_dec_in, n_dec_out * 4, 3, 1, 1))
            dec_layer.append(nn.PixelShuffle(2))
        if i > 0:
            dec_layer.append(norm_fn(n_dec_out, affine=True))
            if dec_dropout > 0 and i >= n_layers - 3:
                dec_layer.append(nn.Dropout(dec_dropout))
            dec_layer.append(nn.ReLU(inplace=True))
        else:
            dec_layer.append(nn.Tanh())

        # update
        n_in = n_out
        n_out = min(2 * n_out, max_fm)
        enc_layers.append(nn.Sequential(*enc_layer))
        dec_layers.insert(0, nn.Sequential(*dec_layer))

    return enc_layers, dec_layers


class AutoEncoder(nn.Module):

    def __init__(self, params):
        super(AutoEncoder, self).__init__()

        self.img_sz = params.img_sz
        self.img_fm = params.img_fm
        self.instance_norm = params.instance_norm
        self.init_fm = params.init_fm
        self.max_fm = params.max_fm
        self.n_layers = params.n_layers
        self.n_skip = params.n_skip
        self.deconv_method = params.deconv_method
        self.dropout = params.dec_dropout
        self.attr = params.attr
        self.n_attr = params.n_attr

        enc_layers, dec_layers = build_layers(self.img_sz, self.img_fm, self.init_fm,
                                              self.max_fm, self.n_layers, self.n_attr,
                                              self.n_skip, self.deconv_method,
                                              self.instance_norm, 0, self.dropout)
        self.enc_layers = nn.ModuleList(enc_layers)
        self.dec_layers = nn.ModuleList(dec_layers)

    def encode(self, x):
        assert x.size()[1:] == (self.img_fm, self.img_sz, self.img_sz)

        enc_outputs = [x]
        for layer in self.enc_layers:
            enc_outputs.append(layer(enc_outputs[-1]))

        assert len(enc_outputs) == self.n_layers + 1
        return enc_outputs

    def decode(self, enc_outputs, y):
        bs = enc_outputs[0].size(0)
        assert len(enc_outputs) == self.n_layers + 1
        assert y.size() == (bs, self.n_attr)

        dec_outputs = [enc_outputs[-1]]
        y = y.unsqueeze(2).unsqueeze(3)
        for i, layer in enumerate(self.dec_layers):
            size = dec_outputs[-1].size(2)
            # attributes
            input = [dec_outputs[-1], y.expand(bs, self.n_attr, size, size)]
            # skip connection
            if 0 < i <= self.n_skip:
                input.append(enc_outputs[-1 - i])
            input = torch.cat(input, 1)
            dec_outputs.append(layer(input))

        assert len(dec_outputs) == self.n_layers + 1
        assert dec_outputs[-1].size() == (bs, self.img_fm, self.img_sz, self.img_sz)
        return dec_outputs

    def forward(self, x, y):
        enc_outputs = self.encode(x)
        dec_outputs = self.decode(enc_outputs, y)
        return enc_outputs, dec_outputs


class LatentDiscriminator(nn.Module):

    def __init__(self, params):
        super(LatentDiscriminator, self).__init__()

        self.img_sz = params.img_sz
        self.img_fm = params.img_fm
        self.init_fm = params.init_fm
        self.max_fm = params.max_fm
        self.n_layers = params.n_layers
        self.n_skip = params.n_skip
        self.hid_dim = params.hid_dim
        self.dropout = params.lat_dis_dropout
        self.attr = params.attr
        self.n_attr = params.n_attr

        self.n_dis_layers = int(np.log2(self.img_sz))
        self.conv_in_sz = self.img_sz / (2 ** (self.n_layers - self.n_skip))
        self.conv_in_fm = min(self.init_fm * (2 ** (self.n_layers - self.n_skip - 1)), self.max_fm)
        self.conv_out_fm = min(self.init_fm * (2 ** (self.n_dis_layers - 1)), self.max_fm)

        # discriminator layers are identical to encoder, but convolve until size 1
        enc_layers, _ = build_layers(self.img_sz, self.img_fm, self.init_fm, self.max_fm,
                                     self.n_dis_layers, self.n_attr, 0, 'convtranspose',
                                     False, self.dropout, 0)

        self.conv_layers = nn.Sequential(*(enc_layers[self.n_layers - self.n_skip:]))
        self.proj_layers = nn.Sequential(
            nn.Linear(self.conv_out_fm, self.hid_dim),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(self.hid_dim, self.n_attr)
        )

    def forward(self, x):
        assert x.size()[1:] == (self.conv_in_fm, self.conv_in_sz, self.conv_in_sz)
        conv_output = self.conv_layers(x)
        assert conv_output.size() == (x.size(0), self.conv_out_fm, 1, 1)
        return self.proj_layers(conv_output.view(x.size(0), self.conv_out_fm))


class PatchDiscriminator(nn.Module):
    def __init__(self, params):
        super(PatchDiscriminator, self).__init__()

        self.img_sz = params.img_sz
        self.img_fm = params.img_fm
        self.init_fm = params.init_fm
        self.max_fm = params.max_fm
        self.n_patch_dis_layers = 3

        layers = []
        layers.append(nn.Conv2d(self.img_fm, self.init_fm, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.2, True))

        n_in = self.init_fm
        n_out = min(2 * n_in, self.max_fm)

        for n in range(self.n_patch_dis_layers):
            stride = 1 if n == self.n_patch_dis_layers - 1 else 2
            layers.append(nn.Conv2d(n_in, n_out, kernel_size=4, stride=stride, padding=1))
            layers.append(nn.BatchNorm2d(n_out))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            if n < self.n_patch_dis_layers - 1:
                n_in = n_out
                n_out = min(2 * n_out, self.max_fm)

        layers.append(nn.Conv2d(n_out, 1, kernel_size=4, stride=1, padding=1))
        layers.append(nn.Sigmoid())

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        assert x.dim() == 4
        return self.layers(x).view(x.size(0), -1).mean(1).view(x.size(0))


class Classifier(nn.Module):

    def __init__(self, params):
        super(Classifier, self).__init__()

        self.img_sz = params.img_sz
        self.img_fm = params.img_fm
        self.init_fm = params.init_fm
        self.max_fm = params.max_fm
        self.hid_dim = params.hid_dim
        self.attr = params.attr
        self.n_attr = params.n_attr

        self.n_clf_layers = int(np.log2(self.img_sz))
        self.conv_out_fm = min(self.init_fm * (2 ** (self.n_clf_layers - 1)), self.max_fm)

        # classifier layers are identical to encoder, but convolve until size 1
        enc_layers, _ = build_layers(self.img_sz, self.img_fm, self.init_fm, self.max_fm,
                                     self.n_clf_layers, self.n_attr, 0, 'convtranspose',
                                     False, 0, 0)

        self.conv_layers = nn.Sequential(*enc_layers)
        self.proj_layers = nn.Sequential(
            nn.Linear(self.conv_out_fm, self.hid_dim),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(self.hid_dim, self.n_attr)
        )

    def forward(self, x):
        assert x.size()[1:] == (self.img_fm, self.img_sz, self.img_sz)
        conv_output = self.conv_layers(x)
        assert conv_output.size() == (x.size(0), self.conv_out_fm, 1, 1)
        return self.proj_layers(conv_output.view(x.size(0), self.conv_out_fm))


def get_attr_loss(output, attributes, flip, params):
    """
    Compute attributes loss.
    """
    assert type(flip) is bool
    k = 0
    loss = 0
    for (_, n_cat) in params.attr:
        # categorical
        x = output[:, k:k + n_cat].contiguous()
        y = attributes[:, k:k + n_cat].max(1)[1].view(-1)
        if flip:
            # generate different categories
            shift = torch.LongTensor(y.size()).random_(n_cat - 1) + 1
            y = (y + Variable(shift.cuda())) % n_cat
        loss += F.cross_entropy(x, y)
        k += n_cat
    return loss


def update_predictions(all_preds, preds, targets, params):
    """
    Update discriminator / classifier predictions.
    """
    assert len(all_preds) == len(params.attr)
    k = 0
    for j, (_, n_cat) in enumerate(params.attr):
        _preds = preds[:, k:k + n_cat].max(1)[1]
        _targets = targets[:, k:k + n_cat].max(1)[1]
        all_preds[j].extend((_preds == _targets).tolist())
        k += n_cat
    assert k == params.n_attr


def get_mappings(params):
    """
    Create a mapping between attributes and their associated IDs.
    """
    if not hasattr(params, 'mappings'):
        mappings = []
        k = 0
        for (_, n_cat) in params.attr:
            assert n_cat >= 2
            mappings.append((k, k + n_cat))
            k += n_cat
        assert k == params.n_attr
        params.mappings = mappings
    return params.mappings


def flip_attributes(attributes, params, attribute_id, new_value=None):
    """
    Randomly flip a set of attributes.
    """
    assert attributes.size(1) == params.n_attr
    mappings = get_mappings(params)
    attributes = attributes.data.clone().cpu()

    def flip_attribute(attribute_id, new_value=None):
        bs = attributes.size(0)
        i, j = mappings[attribute_id]
        attributes[:, i:j].zero_()
        if new_value is None:
            y = torch.LongTensor(bs).random_(j - i)
        else:
            assert new_value in range(j - i)
            y = torch.LongTensor(bs).fill_(new_value)
        attributes[:, i:j].scatter_(1, y.unsqueeze(1), 1)

    if attribute_id == 'all':
        assert new_value is None
        for attribute_id in range(len(params.attr)):
            flip_attribute(attribute_id)
    else:
        assert type(new_value) is int
        flip_attribute(attribute_id, new_value)

    return Variable(attributes.cuda())


# From logger.py

In [14]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import logging
import time
from datetime import timedelta


class LogFormatter():

    def __init__(self):
        self.start_time = time.time()

    def format(self, record):
        elapsed_seconds = round(record.created - self.start_time)

        prefix = "%s - %s - %s" % (
            record.levelname,
            time.strftime('%x %X'),
            timedelta(seconds=elapsed_seconds)
        )
        message = record.getMessage()
        message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
        return "%s - %s" % (prefix, message)


def create_logger(filepath):
    """
    Create a logger.
    """
    # create log formatter
    log_formatter = LogFormatter()

    # create file handler and set level to debug
    if filepath is not None:
        file_handler = logging.FileHandler(filepath, "a")
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(log_formatter)

    # create console handler and set level to info
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(log_formatter)

    # create logger and set level to debug
    logger = logging.getLogger()
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False
    if filepath is not None:
        logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # reset logger elapsed time
    def reset_time():
        log_formatter.start_time = time.time()
    logger.reset_time = reset_time

    return logger


# From loader.py

In [15]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import numpy as np
import torch
from torch.autograd import Variable
from logging import getLogger


logger = getLogger()


AVAILABLE_ATTR = [
    "5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald",
    "Bangs", "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair",
    "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses", "Goatee", "Gray_Hair",
    "Heavy_Makeup", "High_Cheekbones", "Male", "Mouth_Slightly_Open", "Mustache",
    "Narrow_Eyes", "No_Beard", "Oval_Face", "Pale_Skin", "Pointy_Nose",
    "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair",
    "Wavy_Hair", "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick",
    "Wearing_Necklace", "Wearing_Necktie", "Young"
]

# DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
DATA_PATH=r'C:\GitHub\Smart-Education-data\data'

def log_attributes_stats(train_attributes, valid_attributes, test_attributes, params):
    """
    Log attributes distributions.
    """
    k = 0
    for (attr_name, n_cat) in params.attr:
        logger.debug('Train %s: %s' % (attr_name, ' / '.join(['%.5f' % train_attributes[:, k + i].mean() for i in range(n_cat)])))
        logger.debug('Valid %s: %s' % (attr_name, ' / '.join(['%.5f' % valid_attributes[:, k + i].mean() for i in range(n_cat)])))
        logger.debug('Test  %s: %s' % (attr_name, ' / '.join(['%.5f' % test_attributes[:, k + i].mean() for i in range(n_cat)])))
        assert train_attributes[:, k:k + n_cat].sum() == train_attributes.size(0)
        assert valid_attributes[:, k:k + n_cat].sum() == valid_attributes.size(0)
        assert test_attributes[:, k:k + n_cat].sum() == test_attributes.size(0)
        k += n_cat
    assert k == params.n_attr


def load_images(params):
    """
    Load celebA dataset.
    """
    # load data
    images_filename = 'images_%i_%i_20000.pth' if params.debug else 'images_%i_%i.pth'
    images_filename = images_filename % (params.img_sz, params.img_sz)
    # images = torch.load(os.path.join(DATA_PATH, images_filename))
    images = torch.load(os.path.join(DATA_PATH, 'celeba-out', images_filename))
    attributes = torch.load(os.path.join(DATA_PATH, 'attributes.pth'))

    # parse attributes
    attrs = []
    for name, n_cat in params.attr:
        for i in range(n_cat):
            attrs.append(torch.FloatTensor((attributes[name] == i).astype(np.float32)))
    attributes = torch.cat([x.unsqueeze(1) for x in attrs], 1)
    # split train / valid / test
    if params.debug:
        train_index = 0
        valid_index = 0
        test_index = len(images)
        # train_index = 10000
        # valid_index = 15000
        # test_index = 20000
    else:
        # train_index = 162770
        train_index = 0
        valid_index = 0
        # valid_index = 162770 + 19867
        test_index = len(images)
    train_images = images[:train_index]
    valid_images = images[train_index:valid_index]
    test_images = images[valid_index:test_index]
    train_attributes = attributes[:train_index]
    valid_attributes = attributes[train_index:valid_index]
    test_attributes = attributes[valid_index:test_index]
    # log dataset statistics / return dataset
    logger.info('%i / %i / %i images with attributes for train / valid / test sets'
                % (len(train_images), len(valid_images), len(test_images)))
    log_attributes_stats(train_attributes, valid_attributes, test_attributes, params)
    images = train_images, valid_images, test_images
    attributes = train_attributes, valid_attributes, test_attributes
    return images, attributes


def normalize_images(images):
    """
    Normalize image values.
    """
    return images.float().div_(255.0).mul_(2.0).add_(-1)


class DataSampler(object):

    def __init__(self, images, attributes, params):
        """
        Initialize the data sampler with training data.
        """
        assert images.size(0) == attributes.size(0), (images.size(), attributes.size())
        self.images = images
        self.attributes = attributes
        self.batch_size = params.batch_size
        self.v_flip = params.v_flip
        self.h_flip = params.h_flip

    def __len__(self):
        """
        Number of images in the object dataset.
        """
        return self.images.size(0)

    def train_batch(self, bs):
        """
        Get a batch of random images with their attributes.
        """
        # image IDs
        idx = torch.LongTensor(bs).random_(len(self.images))

        # select images / attributes
        batch_x = normalize_images(self.images.index_select(0, idx).cuda())
        batch_y = self.attributes.index_select(0, idx).cuda()

        # data augmentation
        if self.v_flip and np.random.rand() <= 0.5:
            batch_x = batch_x.index_select(2, torch.arange(batch_x.size(2) - 1, -1, -1).long().cuda())
        if self.h_flip and np.random.rand() <= 0.5:
            batch_x = batch_x.index_select(3, torch.arange(batch_x.size(3) - 1, -1, -1).long().cuda())

        return Variable(batch_x, volatile=False), Variable(batch_y, volatile=False)

    def eval_batch(self, i, j):
        """
        Get a batch of images in a range with their attributes.
        """
        assert i < j
        batch_x = normalize_images(self.images[i:j].cuda())
        batch_y = self.attributes[i:j].cuda()
        return Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)


# From utils.py

In [16]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import re
import pickle
import random
import inspect
import argparse
import subprocess
import torch
from torch import optim
from logging import getLogger

# from .logger import create_logger # above
# from .loader import AVAILABLE_ATTR # above


FALSY_STRINGS = {'off', 'false', '0'}
TRUTHY_STRINGS = {'on', 'true', '1'}

# MODELS_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'models')
MODELS_PATH=r'C:\GitHub\Smart-Education-data\saved_models_fader'

logger = getLogger()


def initialize_exp(params):
    """
    Experiment initialization.
    """
    # dump parameters
    params.dump_path = get_dump_path(params)
    pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb'))

    # create a logger
    logger = create_logger(os.path.join(params.dump_path, 'train.log'))
    logger.info('============ Initialized logger ============')
    logger.info('\n'.join('%s: %s' % (k, str(v)) for k, v
                          in sorted(dict(vars(params)).items())))
    return logger


def bool_flag(s):
    """
    Parse boolean arguments from the command line.
    """
    if s.lower() in FALSY_STRINGS:
        return False
    elif s.lower() in TRUTHY_STRINGS:
        return True
    else:
        raise argparse.ArgumentTypeError("invalid value for a boolean flag. use 0 or 1")


def attr_flag(s):
    """
    Parse attributes parameters.
    """
    if s == "*":
        return s
    attr = s.split(',')
    assert len(attr) == len(set(attr))
    attributes = []
    for x in attr:
        if '.' not in x:
            attributes.append((x, 2))
        else:
            split = x.split('.')
            assert len(split) == 2 and len(split[0]) > 0
            assert split[1].isdigit() and int(split[1]) >= 2
            attributes.append((split[0], int(split[1])))
    return sorted(attributes, key=lambda x: (x[1], x[0]))


def check_attr(params):
    """
    Check attributes validy.
    """
    if params.attr == '*':
        params.attr = attr_flag(','.join(AVAILABLE_ATTR))
    else:
        assert all(name in AVAILABLE_ATTR and n_cat >= 2 for name, n_cat in params.attr)
    params.n_attr = sum([n_cat for _, n_cat in params.attr])


def get_optimizer(model, s):
    """
    Parse optimizer parameters.
    Input should be of the form:
        - "sgd,lr=0.01"
        - "adagrad,lr=0.1,lr_decay=0.05"
    """
    if "," in s:
        method = s[:s.find(',')]
        optim_params = {}
        for x in s[s.find(',') + 1:].split(','):
            split = x.split('=')
            assert len(split) == 2
            assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None
            optim_params[split[0]] = float(split[1])
    else:
        method = s
        optim_params = {}

    if method == 'adadelta':
        optim_fn = optim.Adadelta
    elif method == 'adagrad':
        optim_fn = optim.Adagrad
    elif method == 'adam':
        optim_fn = optim.Adam
        optim_params['betas'] = (optim_params.get('beta1', 0.5), optim_params.get('beta2', 0.999))
        optim_params.pop('beta1', None)
        optim_params.pop('beta2', None)
    elif method == 'adamax':
        optim_fn = optim.Adamax
    elif method == 'asgd':
        optim_fn = optim.ASGD
    elif method == 'rmsprop':
        optim_fn = optim.RMSprop
    elif method == 'rprop':
        optim_fn = optim.Rprop
    elif method == 'sgd':
        optim_fn = optim.SGD
        assert 'lr' in optim_params
    else:
        raise Exception('Unknown optimization method: "%s"' % method)

    # check that we give good parameters to the optimizer
    expected_args = inspect.getargspec(optim_fn.__init__)[0]
    assert expected_args[:2] == ['self', 'params']
    if not all(k in expected_args[2:] for k in optim_params.keys()):
        raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
            str(expected_args[2:]), str(optim_params.keys())))

    return optim_fn(model.parameters(), **optim_params)


def clip_grad_norm(parameters, max_norm, norm_type=2):
    """Clips gradient norm of an iterable of parameters.
    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.
    Arguments:
        parameters (Iterable[Variable]): an iterable of Variables that will have
            gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
    """
    parameters = list(parameters)
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if norm_type == float('inf'):
        total_norm = max(p.grad.data.abs().max() for p in parameters)
    else:
        total_norm = 0
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm ** norm_type
        total_norm = total_norm ** (1. / norm_type)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef >= 1:
        return
    for p in parameters:
        p.grad.data.mul_(clip_coef)


def get_dump_path(params):
    """
    Create a directory to store the experiment.
    """
    assert os.path.isdir(MODELS_PATH)

    # create the sweep path if it does not exist
    sweep_path = os.path.join(MODELS_PATH, params.name)
    if not os.path.exists(sweep_path):
        subprocess.Popen("mkdir %s" % sweep_path, shell=True).wait()

    # create a random name for the experiment
    chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
    while True:
        exp_id = ''.join(random.choice(chars) for _ in range(10))
        dump_path = os.path.join(MODELS_PATH, params.name, exp_id)
        if not os.path.isdir(dump_path):
            break

    # create the dump folder
    if not os.path.isdir(dump_path):
        subprocess.Popen("mkdir %s" % dump_path, shell=True).wait()
    return dump_path


def reload_model(model, to_reload, attributes=None):
    """
    Reload a previously trained model.
    """
    # reload the model
    assert os.path.isfile(to_reload)
    to_reload = torch.load(to_reload)

    # check parameters sizes
    model_params = set(model.state_dict().keys())
    to_reload_params = set(to_reload.state_dict().keys())
    assert model_params == to_reload_params, (model_params - to_reload_params,
                                              to_reload_params - model_params)

    # check attributes
    attributes = [] if attributes is None else attributes
    for k in attributes:
        if getattr(model, k, None) is None:
            raise Exception('Attribute "%s" not found in the current model' % k)
        if getattr(to_reload, k, None) is None:
            raise Exception('Attribute "%s" not found in the model to reload' % k)
        if getattr(model, k) != getattr(to_reload, k):
            raise Exception('Attribute "%s" differs between the current model (%s) '
                            'and the one to reload (%s)'
                            % (k, str(getattr(model, k)), str(getattr(to_reload, k))))

    # copy saved parameters
    for k in model.state_dict().keys():
        if model.state_dict()[k].size() != to_reload.state_dict()[k].size():
            raise Exception("Expected tensor {} of size {}, but got {}".format(
                k, model.state_dict()[k].size(),
                to_reload.state_dict()[k].size()
            ))
        model.state_dict()[k].copy_(to_reload.state_dict()[k])


def print_accuracies(values):
    """
    Pretty plot of accuracies.
    """
    assert all(len(x) == 2 for x in values)
    for name, value in values:
        logger.info('{:<20}: {:>6}'.format(name, '%.3f%%' % (100 * value)))
    logger.info('')


def get_lambda(l, params):
    """
    Compute discriminators' lambdas.
    """
    s = params.lambda_schedule
    if s == 0:
        return l
    else:
        return l * float(min(params.n_total_iter, s)) / s


# From evaluation.py

In [17]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import json
import numpy as np
from logging import getLogger

# from .model import update_predictions, flip_attributes
# from .utils import print_accuracies


logger = getLogger()


class Evaluator(object):

    def __init__(self, ae, lat_dis, ptc_dis, clf_dis, eval_clf, data, params):
        """
        Evaluator initialization.
        """
        # data / parameters
        self.data = data
        self.params = params

        # modules
        self.ae = ae
        self.lat_dis = lat_dis
        self.ptc_dis = ptc_dis
        self.clf_dis = clf_dis
        self.eval_clf = eval_clf
        assert eval_clf.img_sz == params.img_sz
        assert all(attr in eval_clf.attr for attr in params.attr)

    def eval_reconstruction_loss(self):
        """
        Compute the autoencoder reconstruction perplexity.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        bs = params.batch_size

        costs = []
        for i in range(0, len(data), bs):
            batch_x, batch_y = data.eval_batch(i, i + bs)
            _, dec_outputs = self.ae(batch_x, batch_y)
            costs.append(((dec_outputs[-1] - batch_x) ** 2).mean().data[0])

        return np.mean(costs)

    def eval_lat_dis_accuracy(self):
        """
        Compute the latent discriminator prediction accuracy.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        self.lat_dis.eval()
        bs = params.batch_size

        all_preds = [[] for _ in range(len(params.attr))]
        for i in range(0, len(data), bs):
            batch_x, batch_y = data.eval_batch(i, i + bs)
            enc_outputs = self.ae.encode(batch_x)
            preds = self.lat_dis(enc_outputs[-1 - params.n_skip]).data.cpu()
            update_predictions(all_preds, preds, batch_y.data.cpu(), params)

        return [np.mean(x) for x in all_preds]

    def eval_ptc_dis_accuracy(self):
        """
        Compute the patch discriminator prediction accuracy.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        self.ptc_dis.eval()
        bs = params.batch_size

        real_preds = []
        fake_preds = []

        for i in range(0, len(data), bs):
            # batch / encode / decode
            batch_x, batch_y = data.eval_batch(i, i + bs)
            flipped = flip_attributes(batch_y, params, 'all')
            _, dec_outputs = self.ae(batch_x, flipped)
            # predictions
            real_preds.extend(self.ptc_dis(batch_x).data.tolist())
            fake_preds.extend(self.ptc_dis(dec_outputs[-1]).data.tolist())

        return real_preds, fake_preds

    def eval_clf_dis_accuracy(self):
        """
        Compute the classifier discriminator prediction accuracy.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        self.clf_dis.eval()
        bs = params.batch_size

        all_preds = [[] for _ in range(params.n_attr)]
        for i in range(0, len(data), bs):
            # batch / encode / decode
            batch_x, batch_y = data.eval_batch(i, i + bs)
            enc_outputs = self.ae.encode(batch_x)
            # flip all attributes one by one
            k = 0
            for j, (_, n_cat) in enumerate(params.attr):
                for value in range(n_cat):
                    flipped = flip_attributes(batch_y, params, j, new_value=value)
                    dec_outputs = self.ae.decode(enc_outputs, flipped)
                    # classify
                    clf_dis_preds = self.clf_dis(dec_outputs[-1])[:, j:j + n_cat].max(1)[1].view(-1)
                    all_preds[k].extend((clf_dis_preds.data.cpu() == value).tolist())
                    k += 1
            assert k == params.n_attr

        return [np.mean(x) for x in all_preds]

    def eval_clf_accuracy(self):
        """
        Compute the accuracy of flipped attributes according to the trained classifier.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        bs = params.batch_size

        idx = []
        for j in range(len(params.attr)):
            attr_index = self.eval_clf.attr.index(params.attr[j])
            idx.append(sum([x[1] for x in self.eval_clf.attr[:attr_index]]))

        all_preds = [[] for _ in range(params.n_attr)]
        for i in range(0, len(data), bs):
            # batch / encode / decode
            batch_x, batch_y = data.eval_batch(i, i + bs)
            enc_outputs = self.ae.encode(batch_x)
            # flip all attributes one by one
            k = 0
            for j, (_, n_cat) in enumerate(params.attr):
                for value in range(n_cat):
                    flipped = flip_attributes(batch_y, params, j, new_value=value)
                    dec_outputs = self.ae.decode(enc_outputs, flipped)
                    # classify
                    clf_preds = self.eval_clf(dec_outputs[-1])[:, idx[j]:idx[j] + n_cat].max(1)[1].view(-1)
                    all_preds[k].extend((clf_preds.data.cpu() == value).tolist())
                    k += 1
            assert k == params.n_attr

        return [np.mean(x) for x in all_preds]

    def evaluate(self, n_epoch):
        """
        Evaluate all models / log evaluation results.
        """
        params = self.params
        logger.info('')

        # reconstruction loss
        ae_loss = self.eval_reconstruction_loss()

        # latent discriminator accuracy
        log_lat_dis = []
        if params.n_lat_dis:
            lat_dis_accu = self.eval_lat_dis_accuracy()
            log_lat_dis.append(('lat_dis_accu', np.mean(lat_dis_accu)))
            for accu, (name, _) in zip(lat_dis_accu, params.attr):
                log_lat_dis.append(('lat_dis_accu_%s' % name, accu))
            logger.info('Latent discriminator accuracy:')
            print_accuracies(log_lat_dis)

        # patch discriminator accuracy
        log_ptc_dis = []
        if params.n_ptc_dis:
            ptc_dis_real_preds, ptc_dis_fake_preds = self.eval_ptc_dis_accuracy()
            accu_real = (np.array(ptc_dis_real_preds).astype(np.float32) >= 0.5).mean()
            accu_fake = (np.array(ptc_dis_fake_preds).astype(np.float32) <= 0.5).mean()
            log_ptc_dis.append(('ptc_dis_preds_real', np.mean(ptc_dis_real_preds)))
            log_ptc_dis.append(('ptc_dis_preds_fake', np.mean(ptc_dis_fake_preds)))
            log_ptc_dis.append(('ptc_dis_accu_real', accu_real))
            log_ptc_dis.append(('ptc_dis_accu_fake', accu_fake))
            log_ptc_dis.append(('ptc_dis_accu', (accu_real + accu_fake) / 2))
            logger.info('Patch discriminator accuracy:')
            print_accuracies(log_ptc_dis)

        # classifier discriminator accuracy
        log_clf_dis = []
        if params.n_clf_dis:
            clf_dis_accu = self.eval_clf_dis_accuracy()
            k = 0
            log_clf_dis += [('clf_dis_accu', np.mean(clf_dis_accu))]
            for name, n_cat in params.attr:
                log_clf_dis.append(('clf_dis_accu_%s' % name, np.mean(clf_dis_accu[k:k + n_cat])))
                log_clf_dis.extend([('clf_dis_accu_%s_%i' % (name, j), clf_dis_accu[k + j])
                                    for j in range(n_cat)])
                k += n_cat
            logger.info('Classifier discriminator accuracy:')
            print_accuracies(log_clf_dis)

        # classifier accuracy
        log_clf = []
        clf_accu = self.eval_clf_accuracy()
        k = 0
        log_clf += [('clf_accu', np.mean(clf_accu))]
        for name, n_cat in params.attr:
            log_clf.append(('clf_accu_%s' % name, np.mean(clf_accu[k:k + n_cat])))
            log_clf.extend([('clf_accu_%s_%i' % (name, j), clf_accu[k + j])
                            for j in range(n_cat)])
            k += n_cat
        logger.info('Classifier accuracy:')
        print_accuracies(log_clf)

        # log autoencoder loss
        logger.info('Autoencoder loss: %.5f' % ae_loss)

        # JSON log
        to_log = dict([
            ('n_epoch', n_epoch),
            ('ae_loss', ae_loss)
        ] + log_lat_dis + log_ptc_dis + log_clf_dis + log_clf)
        logger.debug("__log__:%s" % json.dumps(to_log))

        return to_log


def compute_accuracy(classifier, data, params):
    """
    Compute the classifier prediction accuracy.
    """
    classifier.eval()
    bs = params.batch_size

    all_preds = [[] for _ in range(len(classifier.attr))]
    for i in range(0, len(data), bs):
        batch_x, batch_y = data.eval_batch(i, i + bs)
        preds = classifier(batch_x).data.cpu()
        update_predictions(all_preds, preds, batch_y.data.cpu(), params)

    return [np.mean(x) for x in all_preds]


# From training.py

In [18]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import numpy as np
import torch
from torch.autograd import Variable
from torch.nn import functional as F
from logging import getLogger

# from .utils import get_optimizer, clip_grad_norm, get_lambda, reload_model
# from .model import get_attr_loss, flip_attributes


logger = getLogger()


class Trainer(object):

    def __init__(self, ae, lat_dis, ptc_dis, clf_dis, data, params):
        """
        Trainer initialization.
        """
        # data / parameters
        self.data = data
        self.params = params

        # modules
        self.ae = ae
        self.lat_dis = lat_dis
        self.ptc_dis = ptc_dis
        self.clf_dis = clf_dis

        # optimizers
        self.ae_optimizer = get_optimizer(ae, params.ae_optimizer)
        logger.info(ae)
        logger.info('%i parameters in the autoencoder. '
                    % sum([p.nelement() for p in ae.parameters()]))
        if params.n_lat_dis:
            logger.info(lat_dis)
            logger.info('%i parameters in the latent discriminator. '
                        % sum([p.nelement() for p in lat_dis.parameters()]))
            self.lat_dis_optimizer = get_optimizer(lat_dis, params.dis_optimizer)
        if params.n_ptc_dis:
            logger.info(ptc_dis)
            logger.info('%i parameters in the patch discriminator. '
                        % sum([p.nelement() for p in ptc_dis.parameters()]))
            self.ptc_dis_optimizer = get_optimizer(ptc_dis, params.dis_optimizer)
        if params.n_clf_dis:
            logger.info(clf_dis)
            logger.info('%i parameters in the classifier discriminator. '
                        % sum([p.nelement() for p in clf_dis.parameters()]))
            self.clf_dis_optimizer = get_optimizer(clf_dis, params.dis_optimizer)

        # reload pretrained models
        if params.ae_reload:
            reload_model(ae, params.ae_reload,
                         ['img_sz', 'img_fm', 'init_fm', 'n_layers', 'n_skip', 'attr', 'n_attr'])
        if params.lat_dis_reload:
            reload_model(lat_dis, params.lat_dis_reload,
                         ['enc_dim', 'attr', 'n_attr'])
        if params.ptc_dis_reload:
            reload_model(ptc_dis, params.ptc_dis_reload,
                         ['img_sz', 'img_fm', 'init_fm', 'max_fm', 'n_patch_dis_layers'])
        if params.clf_dis_reload:
            reload_model(clf_dis, params.clf_dis_reload,
                         ['img_sz', 'img_fm', 'init_fm', 'max_fm', 'hid_dim', 'attr', 'n_attr'])

        # training statistics
        self.stats = {}
        self.stats['rec_costs'] = []
        self.stats['lat_dis_costs'] = []
        self.stats['ptc_dis_costs'] = []
        self.stats['clf_dis_costs'] = []

        # best reconstruction loss / best accuracy
        self.best_loss = 1e12
        self.best_accu = -1e12
        self.params.n_total_iter = 0

    def lat_dis_step(self):
        """
        Train the latent discriminator.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        self.lat_dis.train()
        bs = params.batch_size
        # batch / encode / discriminate
        batch_x, batch_y = data.train_batch(bs)
        enc_outputs = self.ae.encode(Variable(batch_x.data, volatile=True))
        preds = self.lat_dis(Variable(enc_outputs[-1 - params.n_skip].data))
        # loss / optimize
        loss = get_attr_loss(preds, batch_y, False, params)
        self.stats['lat_dis_costs'].append(loss.data[0])
        self.lat_dis_optimizer.zero_grad()
        loss.backward()
        if params.clip_grad_norm:
            clip_grad_norm(self.lat_dis.parameters(), params.clip_grad_norm)
        self.lat_dis_optimizer.step()

    def ptc_dis_step(self):
        """
        Train the patch discriminator.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        self.ptc_dis.train()
        bs = params.batch_size
        # batch / encode / discriminate
        batch_x, batch_y = data.train_batch(bs)
        flipped = flip_attributes(batch_y, params, 'all')
        _, dec_outputs = self.ae(Variable(batch_x.data, volatile=True), flipped)
        real_preds = self.ptc_dis(batch_x)
        fake_preds = self.ptc_dis(Variable(dec_outputs[-1].data))
        y_fake = Variable(torch.FloatTensor(real_preds.size())
                               .fill_(params.smooth_label).cuda())
        # loss / optimize
        loss = F.binary_cross_entropy(real_preds, 1 - y_fake)
        loss += F.binary_cross_entropy(fake_preds, y_fake)
        self.stats['ptc_dis_costs'].append(loss.data[0])
        self.ptc_dis_optimizer.zero_grad()
        loss.backward()
        if params.clip_grad_norm:
            clip_grad_norm(self.ptc_dis.parameters(), params.clip_grad_norm)
        self.ptc_dis_optimizer.step()

    def clf_dis_step(self):
        """
        Train the classifier discriminator.
        """
        data = self.data
        params = self.params
        self.clf_dis.train()
        bs = params.batch_size
        # batch / predict
        batch_x, batch_y = data.train_batch(bs)
        preds = self.clf_dis(batch_x)
        # loss / optimize
        loss = get_attr_loss(preds, batch_y, False, params)
        self.stats['clf_dis_costs'].append(loss.data[0])
        self.clf_dis_optimizer.zero_grad()
        loss.backward()
        if params.clip_grad_norm:
            clip_grad_norm(self.clf_dis.parameters(), params.clip_grad_norm)
        self.clf_dis_optimizer.step()

    def autoencoder_step(self):
        """
        Train the autoencoder with cross-entropy loss.
        Train the encoder with discriminator loss.
        """
        data = self.data
        params = self.params
        self.ae.train()
        if params.n_lat_dis:
            self.lat_dis.eval()
        if params.n_ptc_dis:
            self.ptc_dis.eval()
        if params.n_clf_dis:
            self.clf_dis.eval()
        bs = params.batch_size
        # batch / encode / decode
        batch_x, batch_y = data.train_batch(bs)
        enc_outputs, dec_outputs = self.ae(batch_x, batch_y)
        # autoencoder loss from reconstruction
        loss = params.lambda_ae * ((batch_x - dec_outputs[-1]) ** 2).mean()
        self.stats['rec_costs'].append(loss.data[0])
        # encoder loss from the latent discriminator
        if params.lambda_lat_dis:
            lat_dis_preds = self.lat_dis(enc_outputs[-1 - params.n_skip])
            lat_dis_loss = get_attr_loss(lat_dis_preds, batch_y, True, params)
            loss = loss + get_lambda(params.lambda_lat_dis, params) * lat_dis_loss
        # decoding with random labels
        if params.lambda_ptc_dis + params.lambda_clf_dis > 0:
            flipped = flip_attributes(batch_y, params, 'all')
            dec_outputs_flipped = self.ae.decode(enc_outputs, flipped)
        # autoencoder loss from the patch discriminator
        if params.lambda_ptc_dis:
            ptc_dis_preds = self.ptc_dis(dec_outputs_flipped[-1])
            y_fake = Variable(torch.FloatTensor(ptc_dis_preds.size())
                                   .fill_(params.smooth_label).cuda())
            ptc_dis_loss = F.binary_cross_entropy(ptc_dis_preds, 1 - y_fake)
            loss = loss + get_lambda(params.lambda_ptc_dis, params) * ptc_dis_loss
        # autoencoder loss from the classifier discriminator
        if params.lambda_clf_dis:
            clf_dis_preds = self.clf_dis(dec_outputs_flipped[-1])
            clf_dis_loss = get_attr_loss(clf_dis_preds, flipped, False, params)
            loss = loss + get_lambda(params.lambda_clf_dis, params) * clf_dis_loss
        # check NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            exit()
        # optimize
        self.ae_optimizer.zero_grad()
        loss.backward()
        if params.clip_grad_norm:
            clip_grad_norm(self.ae.parameters(), params.clip_grad_norm)
        self.ae_optimizer.step()

    def step(self, n_iter):
        """
        End training iteration / print training statistics.
        """
        # average loss
        if len(self.stats['rec_costs']) >= 25:
            mean_loss = [
                ('Latent discriminator', 'lat_dis_costs'),
                ('Patch discriminator', 'ptc_dis_costs'),
                ('Classifier discriminator', 'clf_dis_costs'),
                ('Reconstruction loss', 'rec_costs'),
            ]
            logger.info(('%06i - ' % n_iter) +
                        ' / '.join(['%s : %.5f' % (a, np.mean(self.stats[b]))
                                    for a, b in mean_loss if len(self.stats[b]) > 0]))
            del self.stats['rec_costs'][:]
            del self.stats['lat_dis_costs'][:]
            del self.stats['ptc_dis_costs'][:]
            del self.stats['clf_dis_costs'][:]

        self.params.n_total_iter += 1

    def save_model(self, name):
        """
        Save the model.
        """
        def save(model, filename):
            path = os.path.join(self.params.dump_path, '%s_%s.pth' % (name, filename))
            logger.info('Saving %s to %s ...' % (filename, path))
            torch.save(model, path)
        save(self.ae, 'ae')
        if self.params.n_lat_dis:
            save(self.lat_dis, 'lat_dis')
        if self.params.n_ptc_dis:
            save(self.ptc_dis, 'ptc_dis')
        if self.params.n_clf_dis:
            save(self.clf_dis, 'clf_dis')

    def save_best_periodic(self, to_log):
        """
        Save the best models / periodically save the models.
        """
        if to_log['ae_loss'] < self.best_loss:
            self.best_loss = to_log['ae_loss']
            logger.info('Best reconstruction loss: %.5f' % self.best_loss)
            self.save_model('best_rec')
        if self.params.eval_clf and np.mean(to_log['clf_accu']) > self.best_accu:
            self.best_accu = np.mean(to_log['clf_accu'])
            logger.info('Best evaluation accuracy: %.5f' % self.best_accu)
            self.save_model('best_accu')
        if to_log['n_epoch'] % 5 == 0 and to_log['n_epoch'] > 0:
            self.save_model('periodic-%i' % to_log['n_epoch'])


def classifier_step(classifier, optimizer, data, params, costs):
    """
    Train the classifier.
    """
    classifier.train()
    bs = params.batch_size

    # batch / classify
    batch_x, batch_y = data.train_batch(bs)
    preds = classifier(batch_x)
    # loss / optimize
    loss = get_attr_loss(preds, batch_y, False, params)
    costs.append(loss.data[0])
    optimizer.zero_grad()
    loss.backward()
    if params.clip_grad_norm:
        clip_grad_norm(classifier.parameters(), params.clip_grad_norm)
    optimizer.step()


# From classifier.py

In [19]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import json
import argparse
import numpy as np
import torch

# from src.loader import load_images, DataSampler
# from src.utils import initialize_exp, bool_flag, attr_flag, check_attr
# from src.utils import get_optimizer, reload_model, print_accuracies
# from src.model import Classifier
# from src.training import classifier_step
# from src.evaluation import compute_accuracy


# parse parameters
parser = argparse.ArgumentParser(description='Classifier')
parser.add_argument("--name", type=str, default="default",
                    help="Experiment name")
parser.add_argument("--img_sz", type=int, default=128,#256,
                    help="Image sizes (images have to be squared)")
parser.add_argument("--img_fm", type=int, default=3,
                    help="Number of feature maps (1 for grayscale, 3 for RGB)")
parser.add_argument("--attr", type=attr_flag, default="Smiling",
                    help="Attributes to classify")
parser.add_argument("--init_fm", type=int, default=32,
                    help="Number of initial filters in the encoder")
parser.add_argument("--max_fm", type=int, default=512,
                    help="Number maximum of filters in the autoencoder")
parser.add_argument("--hid_dim", type=int, default=512,
                    help="Last hidden layer dimension")
parser.add_argument("--v_flip", type=bool_flag, default=False,
                    help="Random vertical flip for data augmentation")
parser.add_argument("--h_flip", type=bool_flag, default=True,
                    help="Random horizontal flip for data augmentation")
parser.add_argument("--batch_size", type=int, default=32,
                    help="Batch size")
parser.add_argument("--optimizer", type=str, default="adam",
                    help="Classifier optimizer (SGD / RMSprop / Adam, etc.)")
parser.add_argument("--clip_grad_norm", type=float, default=5,
                    help="Clip gradient norms (0 to disable)")
parser.add_argument("--n_epochs", type=int, default=1000,
                    help="Total number of epochs")
parser.add_argument("--epoch_size", type=int, default=50000,
                    help="Number of samples per epoch")
parser.add_argument("--reload", type=str, default="",
                    help="Reload a pretrained classifier")
parser.add_argument("--debug", type=bool_flag, default=False,
                    help="Debug mode (only load a subset of the whole dataset)")
# params = parser.parse_args()
params = parser.parse_args('')

# check parameters
check_attr(params)
assert len(params.name.strip()) > 0
assert not params.reload or os.path.isfile(params.reload)

# initialize experiment / load dataset
logger = initialize_exp(params)
data, attributes = load_images(params)
train_data = DataSampler(data[0], attributes[0], params)
valid_data = DataSampler(data[1], attributes[1], params)
test_data = DataSampler(data[2], attributes[2], params)

# build the model / reload / optimizer
classifier = Classifier(params).cuda()
if params.reload:
    reload_model(classifier, params.reload,
                 ['img_sz', 'img_fm', 'init_fm', 'hid_dim', 'attr', 'n_attr'])
optimizer = get_optimizer(classifier, params.optimizer)


def save_model(name):
    """
    Save the model.
    """
    path = os.path.join(params.dump_path, '%s.pth' % name)
    logger.info('Saving the classifier to %s ...' % path)
    torch.save(classifier, path)


# best accuracy
best_accu = -1e12


for n_epoch in range(params.n_epochs):

    logger.info('Starting epoch %i...' % n_epoch)
    costs = []

    classifier.train()

    for n_iter in range(0, params.epoch_size, params.batch_size):

        # classifier training
        classifier_step(classifier, optimizer, train_data, params, costs)

        # average loss
        if len(costs) >= 25:
            logger.info('%06i - Classifier loss: %.5f' % (n_iter, np.mean(costs)))
            del costs[:]

    # compute accuracy
    valid_accu = compute_accuracy(classifier, valid_data, params)
    test_accu = compute_accuracy(classifier, test_data, params)

    # log classifier accuracy
    log_accu = [('valid_accu', np.mean(valid_accu)), ('test_accu', np.mean(test_accu))]
    for accu, (name, _) in zip(valid_accu, params.attr):
        log_accu.append(('valid_accu_%s' % name, accu))
    for accu, (name, _) in zip(test_accu, params.attr):
        log_accu.append(('test_accu_%s' % name, accu))
    logger.info('Classifier accuracy:')
    print_accuracies(log_accu)

    # JSON log
    logger.debug("__log__:%s" % json.dumps(dict([('n_epoch', n_epoch)] + log_accu)))

    # save best or periodic model
    if np.mean(valid_accu) > best_accu:
        best_accu = np.mean(valid_accu)
        logger.info('Best validation average accuracy: %.5f' % best_accu)
        save_model('best')
    elif n_epoch % 10 == 0 and n_epoch > 0:
        save_model('periodic-%i' % n_epoch)

    logger.info('End of epoch %i.\n' % n_epoch)


AssertionError: 

# From train.py

In [20]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import argparse
import torch

# from src.loader import load_images, DataSampler
# from src.utils import initialize_exp, bool_flag, attr_flag, check_attr
# from src.model import AutoEncoder, LatentDiscriminator, PatchDiscriminator, Classifier
# from src.training import Trainer
# from src.evaluation import Evaluator


# parse parameters
parser = argparse.ArgumentParser(description='Images autoencoder')
parser.add_argument("--name", type=str, default="default",
                    help="Experiment name")
parser.add_argument("--img_sz", type=int, default=128,#256,
                    help="Image sizes (images have to be squared)")
parser.add_argument("--img_fm", type=int, default=3,
                    help="Number of feature maps (1 for grayscale, 3 for RGB)")
parser.add_argument("--attr", type=attr_flag, default="Smiling,Male",
                    help="Attributes to classify")
parser.add_argument("--instance_norm", type=bool_flag, default=False,
                    help="Use instance normalization instead of batch normalization")
parser.add_argument("--init_fm", type=int, default=32,
                    help="Number of initial filters in the encoder")
parser.add_argument("--max_fm", type=int, default=512,
                    help="Number maximum of filters in the autoencoder")
parser.add_argument("--n_layers", type=int, default=6,
                    help="Number of layers in the encoder / decoder")
parser.add_argument("--n_skip", type=int, default=0,
                    help="Number of skip connections")
parser.add_argument("--deconv_method", type=str, default="convtranspose",
                    help="Deconvolution method")
parser.add_argument("--hid_dim", type=int, default=512,
                    help="Last hidden layer dimension for discriminator / classifier")
parser.add_argument("--dec_dropout", type=float, default=0.,
                    help="Dropout in the decoder")
parser.add_argument("--lat_dis_dropout", type=float, default=0.3,
                    help="Dropout in the latent discriminator")
parser.add_argument("--n_lat_dis", type=int, default=1,
                    help="Number of latent discriminator training steps")
parser.add_argument("--n_ptc_dis", type=int, default=0,
                    help="Number of patch discriminator training steps")
parser.add_argument("--n_clf_dis", type=int, default=0,
                    help="Number of classifier discriminator training steps")
parser.add_argument("--smooth_label", type=float, default=0.2,
                    help="Smooth label for patch discriminator")
parser.add_argument("--lambda_ae", type=float, default=1,
                    help="Autoencoder loss coefficient")
parser.add_argument("--lambda_lat_dis", type=float, default=0.0001,
                    help="Latent discriminator loss feedback coefficient")
parser.add_argument("--lambda_ptc_dis", type=float, default=0,
                    help="Patch discriminator loss feedback coefficient")
parser.add_argument("--lambda_clf_dis", type=float, default=0,
                    help="Classifier discriminator loss feedback coefficient")
parser.add_argument("--lambda_schedule", type=float, default=500000,
                    help="Progressively increase discriminators' lambdas (0 to disable)")
parser.add_argument("--v_flip", type=bool_flag, default=False,
                    help="Random vertical flip for data augmentation")
parser.add_argument("--h_flip", type=bool_flag, default=True,
                    help="Random horizontal flip for data augmentation")
parser.add_argument("--batch_size", type=int, default=32,
                    help="Batch size")
parser.add_argument("--ae_optimizer", type=str, default="adam,lr=0.0002",
                    help="Autoencoder optimizer (SGD / RMSprop / Adam, etc.)")
parser.add_argument("--dis_optimizer", type=str, default="adam,lr=0.0002",
                    help="Discriminator optimizer (SGD / RMSprop / Adam, etc.)")
parser.add_argument("--clip_grad_norm", type=float, default=5,
                    help="Clip gradient norms (0 to disable)")
parser.add_argument("--n_epochs", type=int, default=1000,
                    help="Total number of epochs")
parser.add_argument("--epoch_size", type=int, default=50000,
                    help="Number of samples per epoch")
parser.add_argument("--ae_reload", type=str, default="",
                    help="Reload a pretrained encoder")
parser.add_argument("--lat_dis_reload", type=str, default="",
                    help="Reload a pretrained latent discriminator")
parser.add_argument("--ptc_dis_reload", type=str, default="",
                    help="Reload a pretrained patch discriminator")
parser.add_argument("--clf_dis_reload", type=str, default="",
                    help="Reload a pretrained classifier discriminator")
parser.add_argument("--eval_clf", type=str, default="",
                    help="Load an external classifier for evaluation")
parser.add_argument("--debug", type=bool_flag, default=False,
                    help="Debug mode (only load a subset of the whole dataset)")
# params = parser.parse_args()
params = parser.parse_args('')

# check parameters
check_attr(params)
assert len(params.name.strip()) > 0
assert params.n_skip <= params.n_layers - 1
assert params.deconv_method in ['convtranspose', 'upsampling', 'pixelshuffle']
assert 0 <= params.smooth_label < 0.5
assert not params.ae_reload or os.path.isfile(params.ae_reload)
assert not params.lat_dis_reload or os.path.isfile(params.lat_dis_reload)
assert not params.ptc_dis_reload or os.path.isfile(params.ptc_dis_reload)
assert not params.clf_dis_reload or os.path.isfile(params.clf_dis_reload)
assert os.path.isfile(params.eval_clf)
assert params.lambda_lat_dis == 0 or params.n_lat_dis > 0
assert params.lambda_ptc_dis == 0 or params.n_ptc_dis > 0
assert params.lambda_clf_dis == 0 or params.n_clf_dis > 0

# initialize experiment / load dataset
logger = initialize_exp(params)
data, attributes = load_images(params)
train_data = DataSampler(data[0], attributes[0], params)
valid_data = DataSampler(data[1], attributes[1], params)

# build the model
ae = AutoEncoder(params).cuda()
lat_dis = LatentDiscriminator(params).cuda() if params.n_lat_dis else None
ptc_dis = PatchDiscriminator(params).cuda() if params.n_ptc_dis else None
clf_dis = Classifier(params).cuda() if params.n_clf_dis else None
eval_clf = torch.load(params.eval_clf).cuda().eval()

# trainer / evaluator
trainer = Trainer(ae, lat_dis, ptc_dis, clf_dis, train_data, params)
evaluator = Evaluator(ae, lat_dis, ptc_dis, clf_dis, eval_clf, valid_data, params)


for n_epoch in range(params.n_epochs):

    logger.info('Starting epoch %i...' % n_epoch)

    for n_iter in range(0, params.epoch_size, params.batch_size):

        # latent discriminator training
        for _ in range(params.n_lat_dis):
            trainer.lat_dis_step()

        # patch discriminator training
        for _ in range(params.n_ptc_dis):
            trainer.ptc_dis_step()

        # classifier discriminator training
        for _ in range(params.n_clf_dis):
            trainer.clf_dis_step()

        # autoencoder training
        trainer.autoencoder_step()

        # print training statistics
        trainer.step(n_iter)

    # run all evaluations / save best or periodic model
    to_log = evaluator.evaluate(n_epoch)
    trainer.save_best_periodic(to_log)
    logger.info('End of epoch %i.\n' % n_epoch)


AssertionError: 

# From interpolate.py

In [21]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import argparse
import numpy as np
import torch
from torch.autograd import Variable
from torchvision.utils import make_grid
import matplotlib.image

# from src.logger import create_logger
# from src.loader import load_images, DataSampler
# from src.utils import bool_flag


# parse parameters
parser = argparse.ArgumentParser(description='Attributes swapping')
parser.add_argument("--model_path", type=str, default=r"C:\GitHub\FaderNetworks\models\narrow_eyes.pth",
                    help="Trained model path")
parser.add_argument("--n_images", type=int, default=9,#10,
                    help="Number of images to modify")
parser.add_argument("--offset", type=int, default=0,
                    help="First image index")
parser.add_argument("--n_interpolations", type=int, default=10,
                    help="Number of interpolations per image")
parser.add_argument("--alpha_min", type=float, default=10.0,#1,
                    help="Min interpolation value")
parser.add_argument("--alpha_max", type=float, default=10.0,#1,
                    help="Max interpolation value")
parser.add_argument("--plot_size", type=int, default=5,
                    help="Size of images in the grid")
parser.add_argument("--row_wise", type=bool_flag, default=True,
                    help="Represent image interpolations horizontally")
parser.add_argument("--output_path", type=str, default='narrow_eyes.png',#"output.png",
                    help="Output path")
params = parser.parse_args('')

print(params.model_path)
      
# check parameters
assert os.path.isfile(params.model_path)
assert params.n_images >= 1 and params.n_interpolations >= 2

# create logger / load trained model
logger = create_logger(None)
ae = torch.load(params.model_path).eval()

# restore main parameters
params.debug = True
params.batch_size = 32
params.v_flip = False
params.h_flip = False
params.img_sz = ae.img_sz
params.attr = ae.attr
params.n_attr = ae.n_attr
if not (len(params.attr) == 1 and params.n_attr == 2):
    raise Exception("The model must use a single boolean attribute only.")

# load dataset
data, attributes = load_images(params)
test_data = DataSampler(data[2], attributes[2], params)


def get_interpolations(ae, images, attributes, params):
    """
    Reconstruct images / create interpolations
    """
    assert len(images) == len(attributes)
    enc_outputs = ae.encode(images)

    # interpolation values
    alphas = np.linspace(1 - params.alpha_min, params.alpha_max, params.n_interpolations)
    alphas = [torch.FloatTensor([1 - alpha, alpha]) for alpha in alphas]

    # original image / reconstructed image / interpolations
    outputs = []
    outputs.append(images)
    outputs.append(ae.decode(enc_outputs, attributes)[-1])
    for alpha in alphas:
        alpha = Variable(alpha.unsqueeze(0).expand((len(images), 2)).cuda())
        outputs.append(ae.decode(enc_outputs, alpha)[-1])

    # return stacked images
    return torch.cat([x.unsqueeze(1) for x in outputs], 1).data.cpu()


interpolations = []

for k in range(0, params.n_images, 100):
# for k in range(0, params.n_images, 10):
    i = params.offset + k
    j = params.offset + min(params.n_images, k + 100)
    # j = params.offset + min(params.n_images, k + 10)
    images, attributes = test_data.eval_batch(i, j)
    #debug
    # print('images:', images)
    # print('attributes:', attributes)
    # print('params', params )
    # _=get_interpolations(ae, images, attributes, params)
    # interpolations.append(_)
    # print(_)
    interpolations.append(get_interpolations(ae, images, attributes, params))
    # print('k', k)
# print('debug11', interpolations)
interpolations = torch.cat(interpolations, 0)
# print('deb1',tuple(interpolations.size()))
# print('deb2',(params.n_images, 2 + params.n_interpolations,
#                                  3, params.img_sz, params.img_sz))
# print(list(interpolations.size()) == list((params.img_sz, 2 + params.n_interpolations, 3, params.img_sz, params.img_sz)))
# assert list(interpolations.size()) == list((params.img_sz, 2 + params.n_interpolations, 3, params.img_sz, params.img_sz))

def get_grid(images, row_wise, plot_size=5):
    """
    Create a grid with all images.
    """
    n_images, n_columns, img_fm, img_sz, _ = images.size()
    if not row_wise:
        images = images.transpose(0, 1).contiguous()
    images = images.view(n_images * n_columns, img_fm, img_sz, img_sz)
    images.add_(1).div_(2.0)
    return make_grid(images, nrow=(n_columns if row_wise else n_images))


# generate the grid / save it to a PNG file
grid = get_grid(interpolations, params.row_wise, params.plot_size)
matplotlib.image.imsave(params.output_path, grid.numpy().transpose((1, 2, 0)))


C:\GitHub\FaderNetworks\models\narrow_eyes.pth


ModuleNotFoundError: No module named 'src'