In [1]:
import data
import torch
import torch.nn as nn

In [2]:
%cd

/home/kacper


In [3]:
class Generator(nn.Module):
    def __init__(self, latent_size, image_size=64):
        super(Generator, self).__init__()
        self.latent_size = latent_size
        
        self.mlp = nn.Sequential(
            nn.Linear(latent_size, latent_size, bias=False),
            nn.ReLU(),
            nn.Linear(latent_size, latent_size, bias=False),
            nn.ReLU(),
        )
                
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=latent_size, 
                out_channels=latent_size//2, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                in_channels=latent_size//2, 
                out_channels=latent_size//4, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                in_channels=latent_size//4, 
                out_channels=3, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.mlp(x)
        x = x.reshape((-1, self.latent_size, 1, 1))
        x = self.conv(x)
        return x
    
class Discriminator(nn.Module):
    def __init__(self, latent_size, image_size=64):
        super(Discriminator, self).__init__()
        self.latent_size = latent_size
        
        self.mlp = nn.Sequential(
            nn.Linear(latent_size, latent_size, bias=False),
            nn.ReLU(),
            nn.Linear(latent_size, 1, bias=False),
            nn.Sigmoid(),
        )
                
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=3, 
                out_channels=latent_size//4, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),      
            nn.ReLU(True),
            nn.Conv2d(
                in_channels=latent_size//4, 
                out_channels=latent_size//2, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.ReLU(True),
            nn.Conv2d(
                in_channels=latent_size//2, 
                out_channels=latent_size, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.ReLU(True),
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.reshape(-1, self.latent_size)
        x = self.mlp(x)        
        return x.reshape(-1)

In [4]:
model = torch.load("experiments_gan/simpleconvgan_2_generator.pt")

In [5]:
from DeepLearning.Project3.frechet_metric import generate_images_to_path

In [8]:
generate_images_to_path(
    model,
    path = "images/simpleconvgan/fid/0/",
    batch_size = 20_000,
    latent_size = 128,
    img_size = 64
)