Skip to content

eyalbetzalel/pytorch-generative

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch-generative

pytorch-generative is a nascent project that aims to provide a simple, easy to use library for generative modeling in PyTorch.

The library makes generative model implementation and experimentation easier by abstracting common building blocks such as MaskedConv2d and MaskedAttention. It also provides clean, high quality reference implementations of recent State of the Art papers that are easy to read, understand, and extend. Finally, it provides utilities for training, debugging, and working with Google Colab.

So far, the library has primarily focues on Autoregressive modeling. The future goal is to also expand into VAEs, GANS, Flows, etc.

Example - ImageGPT

Supported models are implemented as PyTorch Modules and are easy to use:

from pytorch_generative import models

model = models.ImageGPT(in_channels=1, in_size=28)
...
model(data)

Alternatively, lower level building blocks in pytorch_generative.nn can be used to write models from scratch. For example, we implement a convolutional ImageGPT-like model below:

from torch import nn

from pytorch_generative import nn as pg_nn


class TransformerBlock(nn.Module):
  """An ImageGPT Transformer block."""

  def __init__(self, 
               n_channels, 
               n_attention_heads):
    """Initializes a new TransformerBlock instance.
    
    Args:
      n_channels: The number of input and output channels.
      n_attention_heads: The number of attention heads to use.
    """
    super().__init__()
    self._attn = pg_nn.MaskedAttention(
        in_channels=n_channels,
        embed_channels=n_channels,
        out_channels=n_channels,
        n_heads=n_attention_heads,
        is_causal=False)
    self._out = nn.Sequential(
        nn.Conv2d(
            in_channels=n_channels, 
            out_channels=4*n_channels, 
            kernel_size=1),
        nn.GELU(),
        nn.Conv2d(
            in_channels=4*n_channels, 
            out_channels=n_channels, 
            kernel_size=1))

  def forward(self, x):
    x = x + self._attn(x)
    return x + self._out(x)


class ImageGPT(nn.Module):
  """The ImageGPT Model.
  
  Note that we don't use LayerNorm because it would break the model's 
  autoregressive property.
  """
  
  def __init__(self,       
               in_channels,
               in_size,
               n_transformer_blocks=8,
               n_attention_heads=4,
               n_embedding_channels=16):
    """Initializes a new ImageGPT instance.
    
    Args:
      in_channels: The number of input channels.
      in_size: Size of the input images. Used to create positional encodings.
      n_transformer_blocks: Number of TransformerBlocks to use.
      n_attention_heads: Number of attention heads to use.
      n_embedding_channels: Number of attention embedding channels to use.
    """
    super().__init__()
    self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
    self._input = pg_nn.MaskedConv2d(
        is_causal=True,
        in_channels=in_channels,
        out_channels=n_embedding_channels,
        kernel_size=3,
        padding=1)
    self._transformer = nn.Sequential(
        *[TransformerBlock(n_channels=n_embedding_channels,
                         n_attention_heads=n_attention_heads)
          for _ in range(n_transformer_blocks)])
    self._out = nn.Conv2d(in_channels=n_embedding_channels,
                          out_channels=in_channels,
                          kernel_size=1)

  def forward(self, x):
    x = self._input(x + self._pos)
    x = self._transformer(x)
    return self._out(x)

Supported Algorithms

pytorch-generative supports the following algorithms. We train most algorithms on Binarized Mnist and either match or surpass the relevant papers.

Autoregressive Generative Models

Binary MNIST (NLL):

Algorithm Our Results Best Other Results Links
ImageGPT TODO N/A Code, Notebook
PixelSNAIL 78.61 N/A Code, Notebook
Gated PixelCNN 81.50 81.30 [1] Code, Notebook
PixelCNN 81.45 81.30 [1] Code, Notebook
MADE 84.867 88.04 [4] Code, Notebook
NADE 85.65 88.86 [5] Code, Notebook

Note: Our reported binary MNIST results may be optimistic. Instead of using a fixed dataset, we resample a new binary MNIST dataset on every epoch. We can think of this as using data augmentation which helps our models learn better.

References

  1. https://arxiv.org/pdf/1601.06759.pdf
  2. https://arxiv.org/abs/1606.05328
  3. http://www.scottreed.info/files/iclr2017.pdf
  4. https://arxiv.org/abs/1502.03509
  5. http://proceedings.mlr.press/v15/larochelle11a/larochelle11a.pdf

Neural Style Transfer

Blog: https://towardsdatascience.com/how-to-get-beautiful-results-with-neural-style-transfer-75d0c05d6489
Notebook: https://github.com/EugenHotaj/pytorch-generative/blob/master/notebooks/style_transfer.ipynb
Paper: https://arxiv.org/pdf/1508.06576.pdf

Compositional Pattern Producing Networks

Notebook: https://github.com/EugenHotaj/pytorch-generative/blob/master/notebooks/cppn.ipynb
Background: https://en.wikipedia.org/wiki/Compositional_pattern-producing_network

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published